From 99600c7bc3df310a4aa22a6aca789fd5babb387b Mon Sep 17 00:00:00 2001
From: Brennan <brecon@microsoft.com>
Date: Thu, 8 Sep 2022 14:20:03 -0700
Subject: [PATCH] [SignalR] Avoid unobserved tasks in the .NET client (#43545)

* [SignalR] Avoid unobserved tasks in the .NET client
* fb
---
 eng/targets/CSharp.Common.targets             |  2 +-
 src/Shared/Nullable/NullableAttributes.cs     |  6 ++
 .../Client.Core/src/HubConnection.Log.cs      |  6 ++
 .../csharp/Client.Core/src/HubConnection.cs   | 64 ++++++++++++----
 .../test/UnitTests/HubConnectionTests.cs      | 75 +++++++++++++++++--
 5 files changed, 133 insertions(+), 20 deletions(-)

diff --git a/eng/targets/CSharp.Common.targets b/eng/targets/CSharp.Common.targets
index bb7a316e0db..9f553dd4682 100644
--- a/eng/targets/CSharp.Common.targets
+++ b/eng/targets/CSharp.Common.targets
@@ -119,7 +119,7 @@
      -->
     <When Condition=" ('$(Nullable)' == 'annotations' OR '$(Nullable)' == 'enable') AND
         '$(SuppressNullableAttributesImport)' != 'true' AND
-        (('$(TargetFrameworkIdentifier)' == '.NETStandard' AND $([MSBuild]::VersionLessThanOrEquals('$(TargetFrameworkVersion)', '2.0'))) OR '$(TargetFrameworkIdentifier)' == '.NETFramework')">
+        (('$(TargetFrameworkIdentifier)' == '.NETStandard' AND $([MSBuild]::VersionLessThanOrEquals('$(TargetFrameworkVersion)', '2.1'))) OR '$(TargetFrameworkIdentifier)' == '.NETFramework')">
       <PropertyGroup>
         <DefineConstants>$(DefineConstants),INTERNAL_NULLABLE_ATTRIBUTES</DefineConstants>
         <NoWarn>$(NoWarn);nullable</NoWarn>
diff --git a/src/Shared/Nullable/NullableAttributes.cs b/src/Shared/Nullable/NullableAttributes.cs
index ae9a0b28782..860d82a37d0 100644
--- a/src/Shared/Nullable/NullableAttributes.cs
+++ b/src/Shared/Nullable/NullableAttributes.cs
@@ -5,6 +5,8 @@
 
 namespace System.Diagnostics.CodeAnalysis;
 
+// Attributes added in netstandard2.1
+#if !NETSTANDARD2_1_OR_GREATER
 /// <summary>Specifies that null is allowed as an input even if the corresponding type disallows it.</summary>
 [AttributeUsage(AttributeTargets.Field | AttributeTargets.Parameter | AttributeTargets.Property, Inherited = false)]
 #if INTERNAL_NULLABLE_ATTRIBUTES
@@ -131,7 +133,10 @@ internal
     /// <summary>Gets the condition parameter value.</summary>
     public bool ParameterValue { get; }
 }
+#endif
 
+// Attributes added in 5.0
+#if NETSTANDARD || NETFRAMEWORK
 /// <summary>Specifies that the method or property will ensure that the listed field and property members have not-null values.</summary>
 [AttributeUsage(AttributeTargets.Method | AttributeTargets.Property, Inherited = false, AllowMultiple = true)]
 #if INTERNAL_NULLABLE_ATTRIBUTES
@@ -198,3 +203,4 @@ internal
     /// <summary>Gets field or property member names.</summary>
     public string[] Members { get; }
 }
+#endif
diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs
index 3f66fd3af76..22e1cbfd59d 100644
--- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs
+++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.Log.cs
@@ -316,5 +316,11 @@ public partial class HubConnection
 
         [LoggerMessage(86, LogLevel.Warning, "Result given for '{Target}' method but server is not expecting a result.", EventName = "ResultNotExpected")]
         public static partial void ResultNotExpected(ILogger logger, string target);
+
+        [LoggerMessage(87, LogLevel.Trace, "Completion message for stream '{StreamId}' was not sent because the connection is closed.", EventName = "CompletingStreamNotSent")]
+        public static partial void CompletingStreamNotSent(ILogger logger, string streamId);
+
+        [LoggerMessage(88, LogLevel.Warning, "Error returning result for invocation '{InvocationId}' for method '{Target}' because the underlying connection is closed.", EventName = "ErrorSendingInvocationResult")]
+        public static partial void ErrorSendingInvocationResult(ILogger logger, string invocationId, string target, Exception exception);
     }
 }
diff --git a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs
index 5479ceebe71..8e3c57b6a90 100644
--- a/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs
+++ b/src/SignalR/clients/csharp/Client.Core/src/HubConnection.cs
@@ -815,11 +815,26 @@ public partial class HubConnection : IAsyncDisposable
             responseError = $"Stream errored by client: '{ex}'";
         }
 
-        Log.CompletingStream(_logger, streamId);
-
         // Don't use cancellation token here
         // this is triggered by a cancellation token to tell the server that the client is done streaming
-        await SendWithLock(connectionState, CompletionMessage.WithError(streamId, responseError), cancellationToken: default).ConfigureAwait(false);
+        await _state.WaitConnectionLockAsync(token: default).ConfigureAwait(false);
+        try
+        {
+            // Avoid sending when the connection isn't active, likely happens if there is an active stream when the connection closes
+            if (_state.IsConnectionActive())
+            {
+                Log.CompletingStream(_logger, streamId);
+                await SendHubMessage(connectionState, CompletionMessage.WithError(streamId, responseError), cancellationToken: default).ConfigureAwait(false);
+            }
+            else
+            {
+                Log.CompletingStreamNotSent(_logger, streamId);
+            }
+        }
+        finally
+        {
+            _state.ReleaseConnectionLock();
+        }
     }
 
     private async Task<object?> InvokeCoreAsyncCore(string methodName, Type returnType, object?[] args, CancellationToken cancellationToken)
@@ -1025,7 +1040,14 @@ public partial class HubConnection : IAsyncDisposable
             if (expectsResult)
             {
                 Log.MissingResultHandler(_logger, invocation.Target);
-                await SendWithLock(connectionState, CompletionMessage.WithError(invocation.InvocationId!, "Client didn't provide a result."), cancellationToken: default).ConfigureAwait(false);
+                try
+                {
+                    await SendWithLock(connectionState, CompletionMessage.WithError(invocation.InvocationId!, "Client didn't provide a result."), cancellationToken: default).ConfigureAwait(false);
+                }
+                catch (Exception ex)
+                {
+                    Log.ErrorSendingInvocationResult(_logger, invocation.InvocationId!, invocation.Target, ex);
+                }
             }
             else
             {
@@ -1066,18 +1088,25 @@ public partial class HubConnection : IAsyncDisposable
 
         if (expectsResult)
         {
-            if (resultException is not null)
-            {
-                await SendWithLock(connectionState, CompletionMessage.WithError(invocation.InvocationId!, resultException.Message), cancellationToken: default).ConfigureAwait(false);
-            }
-            else if (hasResult)
+            try
             {
-                await SendWithLock(connectionState, CompletionMessage.WithResult(invocation.InvocationId!, result), cancellationToken: default).ConfigureAwait(false);
+                if (resultException is not null)
+                {
+                    await SendWithLock(connectionState, CompletionMessage.WithError(invocation.InvocationId!, resultException.Message), cancellationToken: default).ConfigureAwait(false);
+                }
+                else if (hasResult)
+                {
+                    await SendWithLock(connectionState, CompletionMessage.WithResult(invocation.InvocationId!, result), cancellationToken: default).ConfigureAwait(false);
+                }
+                else
+                {
+                    Log.MissingResultHandler(_logger, invocation.Target);
+                    await SendWithLock(connectionState, CompletionMessage.WithError(invocation.InvocationId!, "Client didn't provide a result."), cancellationToken: default).ConfigureAwait(false);
+                }
             }
-            else
+            catch (Exception ex)
             {
-                Log.MissingResultHandler(_logger, invocation.Target);
-                await SendWithLock(connectionState, CompletionMessage.WithError(invocation.InvocationId!, "Client didn't provide a result."), cancellationToken: default).ConfigureAwait(false);
+                Log.ErrorSendingInvocationResult(_logger, invocation.InvocationId!, invocation.Target, ex);
             }
         }
         else if (hasResult)
@@ -2076,7 +2105,7 @@ public partial class HubConnection : IAsyncDisposable
         {
             await WaitConnectionLockAsync(token, methodName).ConfigureAwait(false);
 
-            if (CurrentConnectionStateUnsynchronized == null || CurrentConnectionStateUnsynchronized.Stopping)
+            if (!IsConnectionActive())
             {
                 ReleaseConnectionLock(methodName);
                 throw new InvalidOperationException($"The '{methodName}' method cannot be called if the connection is not active");
@@ -2085,6 +2114,13 @@ public partial class HubConnection : IAsyncDisposable
             return CurrentConnectionStateUnsynchronized;
         }
 
+        [MemberNotNullWhen(true, nameof(CurrentConnectionStateUnsynchronized))]
+        public bool IsConnectionActive()
+        {
+            AssertInConnectionLock();
+            return CurrentConnectionStateUnsynchronized is not null && !CurrentConnectionStateUnsynchronized.Stopping;
+        }
+
         public void ReleaseConnectionLock([CallerMemberName] string? memberName = null,
             [CallerFilePath] string? filePath = null, [CallerLineNumber] int lineNumber = 0)
         {
diff --git a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.cs b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.cs
index 0151f8462a7..2ba8c4ba3a1 100644
--- a/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.cs
+++ b/src/SignalR/clients/csharp/Client/test/UnitTests/HubConnectionTests.cs
@@ -1,15 +1,11 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
-using System;
 using System.Buffers;
-using System.IO;
 using System.IO.Pipelines;
 using System.Net;
 using System.Net.WebSockets;
-using System.Threading;
 using System.Threading.Channels;
-using System.Threading.Tasks;
 using Microsoft.AspNetCore.Connections;
 using Microsoft.AspNetCore.Connections.Features;
 using Microsoft.AspNetCore.Http.Connections.Client;
@@ -20,7 +16,6 @@ using Microsoft.Extensions.DependencyInjection;
 using Microsoft.Extensions.Logging;
 using Microsoft.Extensions.Logging.Testing;
 using Moq;
-using Xunit;
 
 namespace Microsoft.AspNetCore.SignalR.Client.Tests;
 
@@ -568,6 +563,40 @@ public partial class HubConnectionTests : VerifiableLoggedTest
         }
     }
 
+    [Fact]
+    [LogLevel(LogLevel.Trace)]
+    public async Task ActiveUploadStreamWhenConnectionClosesObservesException()
+    {
+        using (StartVerifiableLog())
+        {
+            var connection = new TestConnection();
+            var hubConnection = CreateHubConnection(connection, loggerFactory: LoggerFactory);
+            await hubConnection.StartAsync().DefaultTimeout();
+
+            var channel = Channel.CreateUnbounded<int>();
+            var invokeTask = hubConnection.InvokeAsync<object>("UploadMethod", channel.Reader);
+
+            var invokeMessage = await connection.ReadSentJsonAsync().DefaultTimeout();
+            Assert.Equal(HubProtocolConstants.InvocationMessageType, invokeMessage["type"]);
+
+            // Not sure how to test for unobserved task exceptions, best I could come up with is to check that we log where there once was an unobserved task exception
+            var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+            TestSink.MessageLogged += wc =>
+            {
+                if (wc.EventId.Name == "CompletingStreamNotSent")
+                {
+                    tcs.SetResult();
+                }
+            };
+
+            await hubConnection.StopAsync();
+
+            await Assert.ThrowsAsync<TaskCanceledException>(() => invokeTask).DefaultTimeout();
+
+            await tcs.Task.DefaultTimeout();
+        }
+    }
+
     [Fact]
     [LogLevel(LogLevel.Trace)]
     public async Task InvocationCanCompleteBeforeStreamCompletes()
@@ -738,6 +767,42 @@ public partial class HubConnectionTests : VerifiableLoggedTest
         }
     }
 
+    [Fact]
+    [LogLevel(LogLevel.Trace)]
+    public async Task ClientResultResponseAfterConnectionCloseObservesException()
+    {
+        using (StartVerifiableLog())
+        {
+            var connection = new TestConnection();
+            var hubConnection = CreateHubConnection(connection, loggerFactory: LoggerFactory);
+            await hubConnection.StartAsync().DefaultTimeout();
+
+            var resultTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+            hubConnection.On("Result", async () =>
+            {
+                await resultTcs.Task;
+                return 1;
+            });
+
+            await connection.ReceiveTextAsync("{\"type\":1,\"invocationId\":\"1\",\"target\":\"Result\",\"arguments\":[]}\u001e").DefaultTimeout();
+
+            // Not sure how to test for unobserved task exceptions, best I could come up with is to check that we log where there once was an unobserved task exception
+            var tcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
+            TestSink.MessageLogged += wc =>
+            {
+                if (wc.EventId.Name == "ErrorSendingInvocationResult")
+                {
+                    tcs.SetResult();
+                }
+            };
+
+            await hubConnection.StopAsync();
+            resultTcs.SetResult();
+
+            await tcs.Task.DefaultTimeout();
+        }
+    }
+
     [Fact]
     public async Task HubConnectionIsMockable()
     {
-- 
GitLab