From 99600c7bc3df310a4aa22a6aca789fd5babb387b Mon Sep 17 00:00:00 2001 From: Brennan <brecon@microsoft.com> Date: Thu, 8 Sep 2022 14:20:03 -0700 Subject: [PATCH] [SignalR] Avoid unobserved tasks in the .NET client (#43545) * [SignalR] Avoid unobserved tasks in the .NET client * fb --- eng/targets/CSharp.Common.targets | 2 +- src/Shared/Nullable/NullableAttributes.cs | 6 ++ .../Client.Core/src/HubConnection.Log.cs | 6 ++ .../csharp/Client.Core/src/HubConnection.cs | 64 ++++++++++++---- .../test/UnitTests/HubConnectionTests.cs | 75 +++++++++++++++++-- 5 files changed, 133 insertions(+), 20 deletions(-) diff --git a/eng/targets/CSharp.Common.targets b/eng/targets/CSharp.Common.targets index bb7a316e0db..9f553dd4682 100644 --- a/eng/targets/CSharp.Common.targets +++ b/eng/targets/CSharp.Common.targets @@ -119,7 +119,7 @@ --> <When Condition=" ('$(Nullable)' == 'annotations' OR '$(Nullable)' == 'enable') AND '$(SuppressNullableAttributesImport)' != 'true' AND - (('$(TargetFrameworkIdentifier)' == '.NETStandard' AND $([MSBuild]::VersionLessThanOrEquals('$(TargetFrameworkVersion)', '2.0'))) OR '$(TargetFrameworkIdentifier)' == '.NETFramework')"> + (('$(TargetFrameworkIdentifier)' == '.NETStandard' AND $([MSBuild]::VersionLessThanOrEquals('$(TargetFrameworkVersion)', '2.1'))) OR '$(TargetFrameworkIdentifier)' == '.NETFramework')"> <PropertyGroup> <DefineConstants>$(DefineConstants),INTERNAL_NULLABLE_ATTRIBUTES</DefineConstants> <NoWarn>$(NoWarn);nullable</NoWarn> diff --git a/src/Shared/Nullable/NullableAttributes.cs b/src/Shared/Nullable/NullableAttributes.cs index ae9a0b28782..860d82a37d0 100644 --- a/src/Shared/Nullable/NullableAttributes.cs +++ b/src/Shared/Nullable/NullableAttributes.cs @@ -5,6 +5,8 @@ namespace System.Diagnostics.CodeAnalysis; +// Attributes added in netstandard2.1 +#if !NETSTANDARD2_1_OR_GREATER /// <summary>Specifies that null is allowed as an input even if the corresponding type disallows it.</summary> [AttributeUsage(AttributeTargets.Field | AttributeTargets.Parameter | AttributeTargets.Property, Inherited = false)] #if INTERNAL_NULLABLE_ATTRIBUTES @@ -131,7 +133,10 @@ internal /// <summary>Gets the condition parameter value.</summary> public bool ParameterValue { get; } } +#endif +// Attributes added in 5.0 +#if NETSTANDARD || NETFRAMEWORK /// <summary>Specifies that the method or property will ensure that the listed field and property members have not-null values.</summary> [AttributeUsage(AttributeTargets.Method | AttributeTargets.Property, Inherited = false, AllowMultiple = true)] #if INTERNAL_NULLABLE_ATTRIBUTES @@ -198,3 +203,4 @@ internal /// <summary>Gets field or property member names.</summary> public string[] Members { get; } } +#endif diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs index 3f66fd3af76..22e1cbfd59d 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs @@ -316,5 +316,11 @@ public partial class HubConnection [LoggerMessage(86, LogLevel.Warning, "Result given for '{Target}' method but server is not expecting a result.", EventName = "ResultNotExpected")] public static partial void ResultNotExpected(ILogger logger, string target); + + [LoggerMessage(87, LogLevel.Trace, "Completion message for stream '{StreamId}' was not sent because the connection is closed.", EventName = "CompletingStreamNotSent")] + public static partial void CompletingStreamNotSent(ILogger logger, string streamId); + + [LoggerMessage(88, LogLevel.Warning, "Error returning result for invocation '{InvocationId}' for method '{Target}' because the underlying connection is closed.", EventName = "ErrorSendingInvocationResult")] + public static partial void ErrorSendingInvocationResult(ILogger logger, string invocationId, string target, Exception exception); } } diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs index 5479ceebe71..8e3c57b6a90 100644 --- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs +++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs @@ -815,11 +815,26 @@ public partial class HubConnection : IAsyncDisposable responseError = $"Stream errored by client: '{ex}'"; } - Log.CompletingStream(_logger, streamId); - // Don't use cancellation token here // this is triggered by a cancellation token to tell the server that the client is done streaming - await SendWithLock(connectionState, CompletionMessage.WithError(streamId, responseError), cancellationToken: default).ConfigureAwait(false); + await _state.WaitConnectionLockAsync(token: default).ConfigureAwait(false); + try + { + // Avoid sending when the connection isn't active, likely happens if there is an active stream when the connection closes + if (_state.IsConnectionActive()) + { + Log.CompletingStream(_logger, streamId); + await SendHubMessage(connectionState, CompletionMessage.WithError(streamId, responseError), cancellationToken: default).ConfigureAwait(false); + } + else + { + Log.CompletingStreamNotSent(_logger, streamId); + } + } + finally + { + _state.ReleaseConnectionLock(); + } } private async Task<object?> InvokeCoreAsyncCore(string methodName, Type returnType, object?[] args, CancellationToken cancellationToken) @@ -1025,7 +1040,14 @@ public partial class HubConnection : IAsyncDisposable if (expectsResult) { Log.MissingResultHandler(_logger, invocation.Target); - await SendWithLock(connectionState, CompletionMessage.WithError(invocation.InvocationId!, "Client didn't provide a result."), cancellationToken: default).ConfigureAwait(false); + try + { + await SendWithLock(connectionState, CompletionMessage.WithError(invocation.InvocationId!, "Client didn't provide a result."), cancellationToken: default).ConfigureAwait(false); + } + catch (Exception ex) + { + Log.ErrorSendingInvocationResult(_logger, invocation.InvocationId!, invocation.Target, ex); + } } else { @@ -1066,18 +1088,25 @@ public partial class HubConnection : IAsyncDisposable if (expectsResult) { - if (resultException is not null) - { - await SendWithLock(connectionState, CompletionMessage.WithError(invocation.InvocationId!, resultException.Message), cancellationToken: default).ConfigureAwait(false); - } - else if (hasResult) + try { - await SendWithLock(connectionState, CompletionMessage.WithResult(invocation.InvocationId!, result), cancellationToken: default).ConfigureAwait(false); + if (resultException is not null) + { + await SendWithLock(connectionState, CompletionMessage.WithError(invocation.InvocationId!, resultException.Message), cancellationToken: default).ConfigureAwait(false); + } + else if (hasResult) + { + await SendWithLock(connectionState, CompletionMessage.WithResult(invocation.InvocationId!, result), cancellationToken: default).ConfigureAwait(false); + } + else + { + Log.MissingResultHandler(_logger, invocation.Target); + await SendWithLock(connectionState, CompletionMessage.WithError(invocation.InvocationId!, "Client didn't provide a result."), cancellationToken: default).ConfigureAwait(false); + } } - else + catch (Exception ex) { - Log.MissingResultHandler(_logger, invocation.Target); - await SendWithLock(connectionState, CompletionMessage.WithError(invocation.InvocationId!, "Client didn't provide a result."), cancellationToken: default).ConfigureAwait(false); + Log.ErrorSendingInvocationResult(_logger, invocation.InvocationId!, invocation.Target, ex); } } else if (hasResult) @@ -2076,7 +2105,7 @@ public partial class HubConnection : IAsyncDisposable { await WaitConnectionLockAsync(token, methodName).ConfigureAwait(false); - if (CurrentConnectionStateUnsynchronized == null || CurrentConnectionStateUnsynchronized.Stopping) + if (!IsConnectionActive()) { ReleaseConnectionLock(methodName); throw new InvalidOperationException($"The '{methodName}' method cannot be called if the connection is not active"); @@ -2085,6 +2114,13 @@ public partial class HubConnection : IAsyncDisposable return CurrentConnectionStateUnsynchronized; } + [MemberNotNullWhen(true, nameof(CurrentConnectionStateUnsynchronized))] + public bool IsConnectionActive() + { + AssertInConnectionLock(); + return CurrentConnectionStateUnsynchronized is not null && !CurrentConnectionStateUnsynchronized.Stopping; + } + public void ReleaseConnectionLock([CallerMemberName] string? memberName = null, [CallerFilePath] string? filePath = null, [CallerLineNumber] int lineNumber = 0) { diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.cs index 0151f8462a7..2ba8c4ba3a1 100644 --- a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.cs +++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.cs @@ -1,15 +1,11 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; using System.Buffers; -using System.IO; using System.IO.Pipelines; using System.Net; using System.Net.WebSockets; -using System.Threading; using System.Threading.Channels; -using System.Threading.Tasks; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Connections.Features; using Microsoft.AspNetCore.Http.Connections.Client; @@ -20,7 +16,6 @@ using Microsoft.Extensions.DependencyInjection; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Testing; using Moq; -using Xunit; namespace Microsoft.AspNetCore.SignalR.Client.Tests; @@ -568,6 +563,40 @@ public partial class HubConnectionTests : VerifiableLoggedTest } } + [Fact] + [LogLevel(LogLevel.Trace)] + public async Task ActiveUploadStreamWhenConnectionClosesObservesException() + { + using (StartVerifiableLog()) + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection, loggerFactory: LoggerFactory); + await hubConnection.StartAsync().DefaultTimeout(); + + var channel = Channel.CreateUnbounded<int>(); + var invokeTask = hubConnection.InvokeAsync<object>("UploadMethod", channel.Reader); + + var invokeMessage = await connection.ReadSentJsonAsync().DefaultTimeout(); + Assert.Equal(HubProtocolConstants.InvocationMessageType, invokeMessage["type"]); + + // Not sure how to test for unobserved task exceptions, best I could come up with is to check that we log where there once was an unobserved task exception + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + TestSink.MessageLogged += wc => + { + if (wc.EventId.Name == "CompletingStreamNotSent") + { + tcs.SetResult(); + } + }; + + await hubConnection.StopAsync(); + + await Assert.ThrowsAsync<TaskCanceledException>(() => invokeTask).DefaultTimeout(); + + await tcs.Task.DefaultTimeout(); + } + } + [Fact] [LogLevel(LogLevel.Trace)] public async Task InvocationCanCompleteBeforeStreamCompletes() @@ -738,6 +767,42 @@ public partial class HubConnectionTests : VerifiableLoggedTest } } + [Fact] + [LogLevel(LogLevel.Trace)] + public async Task ClientResultResponseAfterConnectionCloseObservesException() + { + using (StartVerifiableLog()) + { + var connection = new TestConnection(); + var hubConnection = CreateHubConnection(connection, loggerFactory: LoggerFactory); + await hubConnection.StartAsync().DefaultTimeout(); + + var resultTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + hubConnection.On("Result", async () => + { + await resultTcs.Task; + return 1; + }); + + await connection.ReceiveTextAsync("{\"type\":1,\"invocationId\":\"1\",\"target\":\"Result\",\"arguments\":[]}\u001e").DefaultTimeout(); + + // Not sure how to test for unobserved task exceptions, best I could come up with is to check that we log where there once was an unobserved task exception + var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + TestSink.MessageLogged += wc => + { + if (wc.EventId.Name == "ErrorSendingInvocationResult") + { + tcs.SetResult(); + } + }; + + await hubConnection.StopAsync(); + resultTcs.SetResult(); + + await tcs.Task.DefaultTimeout(); + } + } + [Fact] public async Task HubConnectionIsMockable() { -- GitLab