From 08aa143ff82e77e5e95dedbb3cd5d18624377ce5 Mon Sep 17 00:00:00 2001
From: James Newton-King <james@newtonking.com>
Date: Thu, 8 Sep 2022 00:10:04 +0800
Subject: [PATCH] [release/7.0] Fix auth exception stopping QUIC connection
 listener (#43730)

* Fix auth exception stopping connection listener

* Update

* Update

* Update

* Fix build

* PR feedback
---
 .../src/Internal/QuicConnectionListener.cs    |  65 +++++----
 .../Transport.Quic/src/Internal/QuicLog.cs    |   3 +
 .../test/QuicConnectionListenerTests.cs       | 130 +++++++++++++++++-
 .../Transport.Quic/test/QuicTestHelpers.cs    |  12 +-
 4 files changed, 176 insertions(+), 34 deletions(-)

diff --git a/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicConnectionListener.cs b/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicConnectionListener.cs
index 68eb810f13c..11866cb4040 100644
--- a/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicConnectionListener.cs
+++ b/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicConnectionListener.cs
@@ -75,10 +75,7 @@ internal sealed class QuicConnectionListener : IMultiplexedConnectionListener, I
                 var serverAuthenticationOptions = await _tlsConnectionCallbackOptions.OnConnection(context, cancellationToken);
 
                 // If the callback didn't set protocols then use the listener's list of protocols.
-                if (serverAuthenticationOptions.ApplicationProtocols == null)
-                {
-                    serverAuthenticationOptions.ApplicationProtocols = _tlsConnectionCallbackOptions.ApplicationProtocols;
-                }
+                serverAuthenticationOptions.ApplicationProtocols ??= _tlsConnectionCallbackOptions.ApplicationProtocols;
 
                 // If the SslServerAuthenticationOptions doesn't have a cert or protocols then the
                 // QUIC connection will fail and the client receives an unhelpful message.
@@ -145,38 +142,54 @@ internal sealed class QuicConnectionListener : IMultiplexedConnectionListener, I
             throw new InvalidOperationException($"The listener needs to be initialized by calling {nameof(CreateListenerAsync)}.");
         }
 
-        try
+        while (!cancellationToken.IsCancellationRequested)
         {
-            var quicConnection = await _listener.AcceptConnectionAsync(cancellationToken);
+            try
+            {
+                var quicConnection = await _listener.AcceptConnectionAsync(cancellationToken);
+
+                if (!_pendingConnections.TryGetValue(quicConnection, out var connectionContext))
+                {
+                    throw new InvalidOperationException("Couldn't find ConnectionContext for QuicConnection.");
+                }
+                else
+                {
+                    _pendingConnections.Remove(quicConnection);
+                }
+
+                // Verify the connection context was created and set correctly.
+                Debug.Assert(connectionContext != null);
+                Debug.Assert(connectionContext.GetInnerConnection() == quicConnection);
+
+                QuicLog.AcceptedConnection(_log, connectionContext);
 
-            if (!_pendingConnections.TryGetValue(quicConnection, out var connectionContext))
+                return connectionContext;
+            }
+            catch (QuicException ex) when (ex.QuicError == QuicError.OperationAborted)
             {
-                throw new InvalidOperationException("Couldn't find ConnectionContext for QuicConnection.");
+                // OperationAborted is reported when an accept is in-progress and the listener is unbind/disposed.
+                QuicLog.ConnectionListenerAborted(_log, ex);
+                return null;
             }
-            else
+            catch (ObjectDisposedException ex)
             {
-                _pendingConnections.Remove(quicConnection);
+                // ObjectDisposedException is reported when an accept is started after the listener is unbind/disposed.
+                QuicLog.ConnectionListenerAborted(_log, ex);
+                return null;
+            }
+            catch (Exception ex)
+            {
+                // If the client rejects the connection because of an invalid cert then AcceptConnectionAsync throws.
+                // An error thrown inside ConnectionOptionsCallback can also throw from AcceptConnectionAsync.
+                // These are recoverable errors and we don't want to stop accepting connections.
+                QuicLog.ConnectionListenerAcceptConnectionFailed(_log, ex);
             }
-
-            // Verify the connection context was created and set correctly.
-            Debug.Assert(connectionContext != null);
-            Debug.Assert(connectionContext.GetInnerConnection() == quicConnection);
-
-            QuicLog.AcceptedConnection(_log, connectionContext);
-
-            return connectionContext;
-        }
-        catch (QuicException ex) when (ex.QuicError == QuicError.OperationAborted)
-        {
-            QuicLog.ConnectionListenerAborted(_log, ex);
         }
+
         return null;
     }
 
-    public async ValueTask UnbindAsync(CancellationToken cancellationToken = default)
-    {
-        await DisposeAsync();
-    }
+    public ValueTask UnbindAsync(CancellationToken cancellationToken = default) => DisposeAsync();
 
     public async ValueTask DisposeAsync()
     {
diff --git a/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicLog.cs b/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicLog.cs
index d58061344f2..7077f79ca6f 100644
--- a/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicLog.cs
+++ b/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicLog.cs
@@ -232,6 +232,9 @@ internal static partial class QuicLog
         }
     }
 
+    [LoggerMessage(24, LogLevel.Debug, "QUIC listener connection failed.", EventName = "ConnectionListenerAcceptConnectionFailed")]
+    public static partial void ConnectionListenerAcceptConnectionFailed(ILogger logger, Exception exception);
+
     private static StreamType GetStreamType(QuicStreamContext streamContext) =>
         streamContext.CanRead && streamContext.CanWrite
             ? StreamType.Bidirectional
diff --git a/src/Servers/Kestrel/Transport.Quic/test/QuicConnectionListenerTests.cs b/src/Servers/Kestrel/Transport.Quic/test/QuicConnectionListenerTests.cs
index 4416e80c470..b89615a9133 100644
--- a/src/Servers/Kestrel/Transport.Quic/test/QuicConnectionListenerTests.cs
+++ b/src/Servers/Kestrel/Transport.Quic/test/QuicConnectionListenerTests.cs
@@ -13,6 +13,7 @@ using Microsoft.AspNetCore.Connections;
 using Microsoft.AspNetCore.Http.Features;
 using Microsoft.AspNetCore.Internal;
 using Microsoft.AspNetCore.Testing;
+using Microsoft.Extensions.Logging;
 using Xunit;
 
 namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Quic.Tests;
@@ -24,7 +25,7 @@ public class QuicConnectionListenerTests : TestApplicationErrorLoggerLoggedTest
 
     [ConditionalFact]
     [MsQuicSupported]
-    public async Task AcceptAsync_AfterUnbind_Error()
+    public async Task AcceptAsync_AfterUnbind_ReturnNull()
     {
         // Arrange
         await using var connectionListener = await QuicTestHelpers.CreateConnectionListenerFactory(LoggerFactory);
@@ -33,7 +34,7 @@ public class QuicConnectionListenerTests : TestApplicationErrorLoggerLoggedTest
         await connectionListener.UnbindAsync().DefaultTimeout();
 
         // Assert
-        await Assert.ThrowsAsync<ObjectDisposedException>(() => connectionListener.AcceptAndAddFeatureAsync().AsTask()).DefaultTimeout();
+        Assert.Null(await connectionListener.AcceptAndAddFeatureAsync().DefaultTimeout());
     }
 
     [ConditionalFact]
@@ -51,7 +52,7 @@ public class QuicConnectionListenerTests : TestApplicationErrorLoggerLoggedTest
         await using var clientConnection = await QuicConnection.ConnectAsync(options);
 
         // Assert
-        await using var serverConnection = await acceptTask.DefaultTimeout();
+        var serverConnection = await acceptTask.DefaultTimeout();
         Assert.False(serverConnection.ConnectionClosed.IsCancellationRequested);
 
         await serverConnection.DisposeAsync().AsTask().DefaultTimeout();
@@ -60,6 +61,48 @@ public class QuicConnectionListenerTests : TestApplicationErrorLoggerLoggedTest
         Assert.False(serverConnection.ConnectionClosed.IsCancellationRequested);
     }
 
+    [ConditionalFact]
+    [MsQuicSupported]
+    public async Task AcceptAsync_ClientCreatesInvalidConnection_ServerContinuesToAccept()
+    {
+        await using var connectionListener = await QuicTestHelpers.CreateConnectionListenerFactory(LoggerFactory);
+
+        // Act & Assert 1
+        Logger.LogInformation("Client creating successful connection 1");
+        var acceptTask1 = connectionListener.AcceptAndAddFeatureAsync().DefaultTimeout();
+        await using var clientConnection1 = await QuicConnection.ConnectAsync(
+            QuicTestHelpers.CreateClientConnectionOptions(connectionListener.EndPoint));
+
+        var serverConnection1 = await acceptTask1.DefaultTimeout();
+        Assert.False(serverConnection1.ConnectionClosed.IsCancellationRequested);
+        await serverConnection1.DisposeAsync().AsTask().DefaultTimeout();
+
+        // Act & Assert 2
+        var serverFailureLogTask = WaitForLogMessage(m => m.EventId.Name == "ConnectionListenerAcceptConnectionFailed");
+
+        Logger.LogInformation("Client creating unsuccessful connection 2");
+        var acceptTask2 = connectionListener.AcceptAndAddFeatureAsync().DefaultTimeout();
+        var ex = await Assert.ThrowsAsync<AuthenticationException>(async () =>
+        {
+            await QuicConnection.ConnectAsync(
+                QuicTestHelpers.CreateClientConnectionOptions(connectionListener.EndPoint, ignoreInvalidCertificate: false));
+        });
+        Assert.Contains("RemoteCertificateChainErrors", ex.Message);
+
+        Assert.False(acceptTask2.IsCompleted, "Accept doesn't return for failed client connection.");
+        var serverFailureLog = await serverFailureLogTask.DefaultTimeout();
+        Assert.NotNull(serverFailureLog.Exception);
+
+        // Act & Assert 3
+        Logger.LogInformation("Client creating successful connection 3");
+        await using var clientConnection2 = await QuicConnection.ConnectAsync(
+            QuicTestHelpers.CreateClientConnectionOptions(connectionListener.EndPoint));
+
+        var serverConnection2 = await acceptTask2.DefaultTimeout();
+        Assert.False(serverConnection2.ConnectionClosed.IsCancellationRequested);
+        await serverConnection2.DisposeAsync().AsTask().DefaultTimeout();
+    }
+
     [ConditionalFact]
     [MsQuicSupported]
     [OSSkipCondition(OperatingSystems.Linux | OperatingSystems.MacOSX)]
@@ -145,6 +188,85 @@ public class QuicConnectionListenerTests : TestApplicationErrorLoggerLoggedTest
         Assert.Contains(LogMessages, m => m.EventId.Name == "ConnectionListenerApplicationProtocolsNotSpecified");
     }
 
+    [ConditionalFact]
+    [MsQuicSupported]
+    public async Task AcceptAsync_UnbindAfterCall_CleanExitAndLog()
+    {
+        // Arrange
+        await using var connectionListener = await QuicTestHelpers.CreateConnectionListenerFactory(LoggerFactory);
+
+        // Act
+        var acceptTask = connectionListener.AcceptAndAddFeatureAsync();
+
+        await connectionListener.UnbindAsync().DefaultTimeout();
+
+        // Assert
+        Assert.Null(await acceptTask.AsTask().DefaultTimeout());
+
+        Assert.Contains(LogMessages, m => m.EventId.Name == "ConnectionListenerAborted");
+    }
+
+    [ConditionalFact]
+    [MsQuicSupported]
+    public async Task AcceptAsync_DisposeAfterCall_CleanExitAndLog()
+    {
+        // Arrange
+        await using var connectionListener = await QuicTestHelpers.CreateConnectionListenerFactory(LoggerFactory);
+
+        // Act
+        var acceptTask = connectionListener.AcceptAndAddFeatureAsync();
+
+        await connectionListener.DisposeAsync().DefaultTimeout();
+
+        // Assert
+        Assert.Null(await acceptTask.AsTask().DefaultTimeout());
+
+        Assert.Contains(LogMessages, m => m.EventId.Name == "ConnectionListenerAborted");
+    }
+
+    [ConditionalFact]
+    [MsQuicSupported]
+    public async Task AcceptAsync_ErrorFromServerCallback_CleanExitAndLog()
+    {
+        // Arrange
+        var throwErrorInCallback = true;
+        await using var connectionListener = await QuicTestHelpers.CreateConnectionListenerFactory(
+            new TlsConnectionCallbackOptions
+            {
+                ApplicationProtocols = new List<SslApplicationProtocol> { SslApplicationProtocol.Http3 },
+                OnConnection = (context, cancellationToken) =>
+                {
+                    if (throwErrorInCallback)
+                    {
+                        throwErrorInCallback = false;
+                        throw new Exception("An error!");
+                    }
+
+                    var options = new SslServerAuthenticationOptions();
+                    options.ServerCertificate = TestResources.GetTestCertificate();
+                    return ValueTask.FromResult(options);
+                }
+            },
+            LoggerFactory);
+
+        // Act
+        var acceptTask = connectionListener.AcceptAndAddFeatureAsync();
+
+        var options = QuicTestHelpers.CreateClientConnectionOptions(connectionListener.EndPoint);
+
+        var ex = await Assert.ThrowsAsync<AuthenticationException>(() => QuicConnection.ConnectAsync(options).AsTask()).DefaultTimeout();
+        Assert.Equal("Authentication failed because the remote party sent a TLS alert: 'UserCanceled'.", ex.Message);
+
+        // Assert
+        Assert.False(acceptTask.IsCompleted, "Still waiting for non-errored connection.");
+
+        await using var clientConnection = await QuicConnection.ConnectAsync(options).DefaultTimeout();
+        await using var serverConnection = await acceptTask.DefaultTimeout();
+
+        Assert.NotNull(serverConnection);
+        Assert.NotNull(clientConnection);
+    }
+
     [ConditionalFact]
     [MsQuicSupported]
     public async Task BindAsync_ListenersSharePort_ThrowAddressInUse()
@@ -266,8 +388,8 @@ public class QuicConnectionListenerTests : TestApplicationErrorLoggerLoggedTest
 
         syncPoint.Continue();
 
-        await Assert.ThrowsAsync<ArgumentException>(() => acceptTask.AsTask()).DefaultTimeout();
         await Assert.ThrowsAsync<AuthenticationException>(() => clientConnectionTask.AsTask()).DefaultTimeout();
+        Assert.False(acceptTask.IsCompleted);
 
         // Assert
         for (var i = 0; i < 20; i++)
diff --git a/src/Servers/Kestrel/Transport.Quic/test/QuicTestHelpers.cs b/src/Servers/Kestrel/Transport.Quic/test/QuicTestHelpers.cs
index 3b26da87cb0..4fa223873d5 100644
--- a/src/Servers/Kestrel/Transport.Quic/test/QuicTestHelpers.cs
+++ b/src/Servers/Kestrel/Transport.Quic/test/QuicTestHelpers.cs
@@ -116,9 +116,9 @@ internal static class QuicTestHelpers
         return true;
     }
 
-    public static QuicClientConnectionOptions CreateClientConnectionOptions(EndPoint remoteEndPoint)
+    public static QuicClientConnectionOptions CreateClientConnectionOptions(EndPoint remoteEndPoint, bool? ignoreInvalidCertificate = null)
     {
-        return new QuicClientConnectionOptions
+        var options = new QuicClientConnectionOptions
         {
             MaxInboundBidirectionalStreams = 200,
             MaxInboundUnidirectionalStreams = 200,
@@ -128,12 +128,16 @@ internal static class QuicTestHelpers
                 ApplicationProtocols = new List<SslApplicationProtocol>
                 {
                     SslApplicationProtocol.Http3
-                },
-                RemoteCertificateValidationCallback = RemoteCertificateValidationCallback
+                }
             },
             DefaultStreamErrorCode = 0,
             DefaultCloseErrorCode = 0,
         };
+        if (ignoreInvalidCertificate ?? true)
+        {
+            options.ClientAuthenticationOptions.RemoteCertificateValidationCallback = RemoteCertificateValidationCallback;
+        }
+        return options;
     }
 
     public static async Task<QuicStreamContext> CreateAndCompleteBidirectionalStreamGracefully(QuicConnection clientConnection, MultiplexedConnectionContext serverConnection, ILogger logger)
-- 
GitLab