From 08aa143ff82e77e5e95dedbb3cd5d18624377ce5 Mon Sep 17 00:00:00 2001 From: James Newton-King <james@newtonking.com> Date: Thu, 8 Sep 2022 00:10:04 +0800 Subject: [PATCH] [release/7.0] Fix auth exception stopping QUIC connection listener (#43730) * Fix auth exception stopping connection listener * Update * Update * Update * Fix build * PR feedback --- .../src/Internal/QuicConnectionListener.cs | 65 +++++---- .../Transport.Quic/src/Internal/QuicLog.cs | 3 + .../test/QuicConnectionListenerTests.cs | 130 +++++++++++++++++- .../Transport.Quic/test/QuicTestHelpers.cs | 12 +- 4 files changed, 176 insertions(+), 34 deletions(-) diff --git a/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicConnectionListener.cs b/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicConnectionListener.cs index 68eb810f13c..11866cb4040 100644 --- a/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicConnectionListener.cs +++ b/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicConnectionListener.cs @@ -75,10 +75,7 @@ internal sealed class QuicConnectionListener : IMultiplexedConnectionListener, I var serverAuthenticationOptions = await _tlsConnectionCallbackOptions.OnConnection(context, cancellationToken); // If the callback didn't set protocols then use the listener's list of protocols. - if (serverAuthenticationOptions.ApplicationProtocols == null) - { - serverAuthenticationOptions.ApplicationProtocols = _tlsConnectionCallbackOptions.ApplicationProtocols; - } + serverAuthenticationOptions.ApplicationProtocols ??= _tlsConnectionCallbackOptions.ApplicationProtocols; // If the SslServerAuthenticationOptions doesn't have a cert or protocols then the // QUIC connection will fail and the client receives an unhelpful message. @@ -145,38 +142,54 @@ internal sealed class QuicConnectionListener : IMultiplexedConnectionListener, I throw new InvalidOperationException($"The listener needs to be initialized by calling {nameof(CreateListenerAsync)}."); } - try + while (!cancellationToken.IsCancellationRequested) { - var quicConnection = await _listener.AcceptConnectionAsync(cancellationToken); + try + { + var quicConnection = await _listener.AcceptConnectionAsync(cancellationToken); + + if (!_pendingConnections.TryGetValue(quicConnection, out var connectionContext)) + { + throw new InvalidOperationException("Couldn't find ConnectionContext for QuicConnection."); + } + else + { + _pendingConnections.Remove(quicConnection); + } + + // Verify the connection context was created and set correctly. + Debug.Assert(connectionContext != null); + Debug.Assert(connectionContext.GetInnerConnection() == quicConnection); + + QuicLog.AcceptedConnection(_log, connectionContext); - if (!_pendingConnections.TryGetValue(quicConnection, out var connectionContext)) + return connectionContext; + } + catch (QuicException ex) when (ex.QuicError == QuicError.OperationAborted) { - throw new InvalidOperationException("Couldn't find ConnectionContext for QuicConnection."); + // OperationAborted is reported when an accept is in-progress and the listener is unbind/disposed. + QuicLog.ConnectionListenerAborted(_log, ex); + return null; } - else + catch (ObjectDisposedException ex) { - _pendingConnections.Remove(quicConnection); + // ObjectDisposedException is reported when an accept is started after the listener is unbind/disposed. + QuicLog.ConnectionListenerAborted(_log, ex); + return null; + } + catch (Exception ex) + { + // If the client rejects the connection because of an invalid cert then AcceptConnectionAsync throws. + // An error thrown inside ConnectionOptionsCallback can also throw from AcceptConnectionAsync. + // These are recoverable errors and we don't want to stop accepting connections. + QuicLog.ConnectionListenerAcceptConnectionFailed(_log, ex); } - - // Verify the connection context was created and set correctly. - Debug.Assert(connectionContext != null); - Debug.Assert(connectionContext.GetInnerConnection() == quicConnection); - - QuicLog.AcceptedConnection(_log, connectionContext); - - return connectionContext; - } - catch (QuicException ex) when (ex.QuicError == QuicError.OperationAborted) - { - QuicLog.ConnectionListenerAborted(_log, ex); } + return null; } - public async ValueTask UnbindAsync(CancellationToken cancellationToken = default) - { - await DisposeAsync(); - } + public ValueTask UnbindAsync(CancellationToken cancellationToken = default) => DisposeAsync(); public async ValueTask DisposeAsync() { diff --git a/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicLog.cs b/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicLog.cs index d58061344f2..7077f79ca6f 100644 --- a/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicLog.cs +++ b/src/Servers/Kestrel/Transport.Quic/src/Internal/QuicLog.cs @@ -232,6 +232,9 @@ internal static partial class QuicLog } } + [LoggerMessage(24, LogLevel.Debug, "QUIC listener connection failed.", EventName = "ConnectionListenerAcceptConnectionFailed")] + public static partial void ConnectionListenerAcceptConnectionFailed(ILogger logger, Exception exception); + private static StreamType GetStreamType(QuicStreamContext streamContext) => streamContext.CanRead && streamContext.CanWrite ? StreamType.Bidirectional diff --git a/src/Servers/Kestrel/Transport.Quic/test/QuicConnectionListenerTests.cs b/src/Servers/Kestrel/Transport.Quic/test/QuicConnectionListenerTests.cs index 4416e80c470..b89615a9133 100644 --- a/src/Servers/Kestrel/Transport.Quic/test/QuicConnectionListenerTests.cs +++ b/src/Servers/Kestrel/Transport.Quic/test/QuicConnectionListenerTests.cs @@ -13,6 +13,7 @@ using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Internal; using Microsoft.AspNetCore.Testing; +using Microsoft.Extensions.Logging; using Xunit; namespace Microsoft.AspNetCore.Server.Kestrel.Transport.Quic.Tests; @@ -24,7 +25,7 @@ public class QuicConnectionListenerTests : TestApplicationErrorLoggerLoggedTest [ConditionalFact] [MsQuicSupported] - public async Task AcceptAsync_AfterUnbind_Error() + public async Task AcceptAsync_AfterUnbind_ReturnNull() { // Arrange await using var connectionListener = await QuicTestHelpers.CreateConnectionListenerFactory(LoggerFactory); @@ -33,7 +34,7 @@ public class QuicConnectionListenerTests : TestApplicationErrorLoggerLoggedTest await connectionListener.UnbindAsync().DefaultTimeout(); // Assert - await Assert.ThrowsAsync<ObjectDisposedException>(() => connectionListener.AcceptAndAddFeatureAsync().AsTask()).DefaultTimeout(); + Assert.Null(await connectionListener.AcceptAndAddFeatureAsync().DefaultTimeout()); } [ConditionalFact] @@ -51,7 +52,7 @@ public class QuicConnectionListenerTests : TestApplicationErrorLoggerLoggedTest await using var clientConnection = await QuicConnection.ConnectAsync(options); // Assert - await using var serverConnection = await acceptTask.DefaultTimeout(); + var serverConnection = await acceptTask.DefaultTimeout(); Assert.False(serverConnection.ConnectionClosed.IsCancellationRequested); await serverConnection.DisposeAsync().AsTask().DefaultTimeout(); @@ -60,6 +61,48 @@ public class QuicConnectionListenerTests : TestApplicationErrorLoggerLoggedTest Assert.False(serverConnection.ConnectionClosed.IsCancellationRequested); } + [ConditionalFact] + [MsQuicSupported] + public async Task AcceptAsync_ClientCreatesInvalidConnection_ServerContinuesToAccept() + { + await using var connectionListener = await QuicTestHelpers.CreateConnectionListenerFactory(LoggerFactory); + + // Act & Assert 1 + Logger.LogInformation("Client creating successful connection 1"); + var acceptTask1 = connectionListener.AcceptAndAddFeatureAsync().DefaultTimeout(); + await using var clientConnection1 = await QuicConnection.ConnectAsync( + QuicTestHelpers.CreateClientConnectionOptions(connectionListener.EndPoint)); + + var serverConnection1 = await acceptTask1.DefaultTimeout(); + Assert.False(serverConnection1.ConnectionClosed.IsCancellationRequested); + await serverConnection1.DisposeAsync().AsTask().DefaultTimeout(); + + // Act & Assert 2 + var serverFailureLogTask = WaitForLogMessage(m => m.EventId.Name == "ConnectionListenerAcceptConnectionFailed"); + + Logger.LogInformation("Client creating unsuccessful connection 2"); + var acceptTask2 = connectionListener.AcceptAndAddFeatureAsync().DefaultTimeout(); + var ex = await Assert.ThrowsAsync<AuthenticationException>(async () => + { + await QuicConnection.ConnectAsync( + QuicTestHelpers.CreateClientConnectionOptions(connectionListener.EndPoint, ignoreInvalidCertificate: false)); + }); + Assert.Contains("RemoteCertificateChainErrors", ex.Message); + + Assert.False(acceptTask2.IsCompleted, "Accept doesn't return for failed client connection."); + var serverFailureLog = await serverFailureLogTask.DefaultTimeout(); + Assert.NotNull(serverFailureLog.Exception); + + // Act & Assert 3 + Logger.LogInformation("Client creating successful connection 3"); + await using var clientConnection2 = await QuicConnection.ConnectAsync( + QuicTestHelpers.CreateClientConnectionOptions(connectionListener.EndPoint)); + + var serverConnection2 = await acceptTask2.DefaultTimeout(); + Assert.False(serverConnection2.ConnectionClosed.IsCancellationRequested); + await serverConnection2.DisposeAsync().AsTask().DefaultTimeout(); + } + [ConditionalFact] [MsQuicSupported] [OSSkipCondition(OperatingSystems.Linux | OperatingSystems.MacOSX)] @@ -145,6 +188,85 @@ public class QuicConnectionListenerTests : TestApplicationErrorLoggerLoggedTest Assert.Contains(LogMessages, m => m.EventId.Name == "ConnectionListenerApplicationProtocolsNotSpecified"); } + [ConditionalFact] + [MsQuicSupported] + public async Task AcceptAsync_UnbindAfterCall_CleanExitAndLog() + { + // Arrange + await using var connectionListener = await QuicTestHelpers.CreateConnectionListenerFactory(LoggerFactory); + + // Act + var acceptTask = connectionListener.AcceptAndAddFeatureAsync(); + + await connectionListener.UnbindAsync().DefaultTimeout(); + + // Assert + Assert.Null(await acceptTask.AsTask().DefaultTimeout()); + + Assert.Contains(LogMessages, m => m.EventId.Name == "ConnectionListenerAborted"); + } + + [ConditionalFact] + [MsQuicSupported] + public async Task AcceptAsync_DisposeAfterCall_CleanExitAndLog() + { + // Arrange + await using var connectionListener = await QuicTestHelpers.CreateConnectionListenerFactory(LoggerFactory); + + // Act + var acceptTask = connectionListener.AcceptAndAddFeatureAsync(); + + await connectionListener.DisposeAsync().DefaultTimeout(); + + // Assert + Assert.Null(await acceptTask.AsTask().DefaultTimeout()); + + Assert.Contains(LogMessages, m => m.EventId.Name == "ConnectionListenerAborted"); + } + + [ConditionalFact] + [MsQuicSupported] + public async Task AcceptAsync_ErrorFromServerCallback_CleanExitAndLog() + { + // Arrange + var throwErrorInCallback = true; + await using var connectionListener = await QuicTestHelpers.CreateConnectionListenerFactory( + new TlsConnectionCallbackOptions + { + ApplicationProtocols = new List<SslApplicationProtocol> { SslApplicationProtocol.Http3 }, + OnConnection = (context, cancellationToken) => + { + if (throwErrorInCallback) + { + throwErrorInCallback = false; + throw new Exception("An error!"); + } + + var options = new SslServerAuthenticationOptions(); + options.ServerCertificate = TestResources.GetTestCertificate(); + return ValueTask.FromResult(options); + } + }, + LoggerFactory); + + // Act + var acceptTask = connectionListener.AcceptAndAddFeatureAsync(); + + var options = QuicTestHelpers.CreateClientConnectionOptions(connectionListener.EndPoint); + + var ex = await Assert.ThrowsAsync<AuthenticationException>(() => QuicConnection.ConnectAsync(options).AsTask()).DefaultTimeout(); + Assert.Equal("Authentication failed because the remote party sent a TLS alert: 'UserCanceled'.", ex.Message); + + // Assert + Assert.False(acceptTask.IsCompleted, "Still waiting for non-errored connection."); + + await using var clientConnection = await QuicConnection.ConnectAsync(options).DefaultTimeout(); + await using var serverConnection = await acceptTask.DefaultTimeout(); + + Assert.NotNull(serverConnection); + Assert.NotNull(clientConnection); + } + [ConditionalFact] [MsQuicSupported] public async Task BindAsync_ListenersSharePort_ThrowAddressInUse() @@ -266,8 +388,8 @@ public class QuicConnectionListenerTests : TestApplicationErrorLoggerLoggedTest syncPoint.Continue(); - await Assert.ThrowsAsync<ArgumentException>(() => acceptTask.AsTask()).DefaultTimeout(); await Assert.ThrowsAsync<AuthenticationException>(() => clientConnectionTask.AsTask()).DefaultTimeout(); + Assert.False(acceptTask.IsCompleted); // Assert for (var i = 0; i < 20; i++) diff --git a/src/Servers/Kestrel/Transport.Quic/test/QuicTestHelpers.cs b/src/Servers/Kestrel/Transport.Quic/test/QuicTestHelpers.cs index 3b26da87cb0..4fa223873d5 100644 --- a/src/Servers/Kestrel/Transport.Quic/test/QuicTestHelpers.cs +++ b/src/Servers/Kestrel/Transport.Quic/test/QuicTestHelpers.cs @@ -116,9 +116,9 @@ internal static class QuicTestHelpers return true; } - public static QuicClientConnectionOptions CreateClientConnectionOptions(EndPoint remoteEndPoint) + public static QuicClientConnectionOptions CreateClientConnectionOptions(EndPoint remoteEndPoint, bool? ignoreInvalidCertificate = null) { - return new QuicClientConnectionOptions + var options = new QuicClientConnectionOptions { MaxInboundBidirectionalStreams = 200, MaxInboundUnidirectionalStreams = 200, @@ -128,12 +128,16 @@ internal static class QuicTestHelpers ApplicationProtocols = new List<SslApplicationProtocol> { SslApplicationProtocol.Http3 - }, - RemoteCertificateValidationCallback = RemoteCertificateValidationCallback + } }, DefaultStreamErrorCode = 0, DefaultCloseErrorCode = 0, }; + if (ignoreInvalidCertificate ?? true) + { + options.ClientAuthenticationOptions.RemoteCertificateValidationCallback = RemoteCertificateValidationCallback; + } + return options; } public static async Task<QuicStreamContext> CreateAndCompleteBidirectionalStreamGracefully(QuicConnection clientConnection, MultiplexedConnectionContext serverConnection, ILogger logger) -- GitLab