diff --git a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Connection.cs b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Connection.cs index db9be53ed3a6eb4c6154dd38ecf0b7a5cda24328..fb22e0558c5d1dac05fc9f9dd6c9ade66d6e444a 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Connection.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Connection.cs @@ -23,7 +23,9 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3 { internal class Http3Connection : ITimeoutHandler, IHttp3StreamLifetimeHandler { + // Internal for unit testing internal readonly Dictionary<long, IHttp3Stream> _streams = new Dictionary<long, IHttp3Stream>(); + internal IHttp3StreamLifetimeHandler _streamLifetimeHandler; private long _highestOpenedStreamId; private readonly object _sync = new object(); @@ -49,6 +51,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3 _systemClock = context.ServiceContext.SystemClock; _timeoutControl = new TimeoutControl(this); _context.TimeoutControl ??= _timeoutControl; + _streamLifetimeHandler = this; _errorCodeFeature = context.ConnectionFeatures.Get<IProtocolErrorCodeFeature>()!; @@ -303,7 +306,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3 streamContext.LocalEndPoint as IPEndPoint, streamContext.RemoteEndPoint as IPEndPoint, streamContext.Transport, - this, + _streamLifetimeHandler, streamContext, _serverSettings); httpConnectionContext.TimeoutControl = _context.TimeoutControl; @@ -314,10 +317,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3 { // Unidirectional stream var stream = new Http3ControlStream<TContext>(application, httpConnectionContext); - lock (_streams) - { - _streams[streamId] = stream; - } + _streamLifetimeHandler.OnStreamCreated(stream); ThreadPool.UnsafeQueueUserWorkItem(stream, preferLocal: false); } @@ -327,11 +327,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3 UpdateHighestStreamId(streamId); var stream = new Http3Stream<TContext>(application, httpConnectionContext); - lock (_streams) - { - _activeRequestCount++; - _streams[streamId] = stream; - } + _streamLifetimeHandler.OnStreamCreated(stream); KestrelEventSource.Log.RequestQueuedStart(stream, AspNetCore.Http.HttpProtocol.Http3); ThreadPool.UnsafeQueueUserWorkItem(stream, preferLocal: false); @@ -470,7 +466,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3 streamContext.LocalEndPoint as IPEndPoint, streamContext.RemoteEndPoint as IPEndPoint, streamContext.Transport, - this, + _streamLifetimeHandler, streamContext, _serverSettings); httpConnectionContext.TimeoutControl = _context.TimeoutControl; @@ -490,7 +486,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3 return default; } - public bool OnInboundControlStream(Http3ControlStream stream) + bool IHttp3StreamLifetimeHandler.OnInboundControlStream(Http3ControlStream stream) { lock (_sync) { @@ -503,7 +499,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3 } } - public bool OnInboundEncoderStream(Http3ControlStream stream) + bool IHttp3StreamLifetimeHandler.OnInboundEncoderStream(Http3ControlStream stream) { lock (_sync) { @@ -516,7 +512,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3 } } - public bool OnInboundDecoderStream(Http3ControlStream stream) + bool IHttp3StreamLifetimeHandler.OnInboundDecoderStream(Http3ControlStream stream) { lock (_sync) { @@ -529,24 +525,42 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3 } } - public void OnStreamCompleted(IHttp3Stream stream) + void IHttp3StreamLifetimeHandler.OnStreamCreated(IHttp3Stream stream) { lock (_streams) { - _activeRequestCount--; + if (stream.IsRequestStream) + { + _activeRequestCount++; + } + _streams[stream.StreamId] = stream; + } + } + + void IHttp3StreamLifetimeHandler.OnStreamCompleted(IHttp3Stream stream) + { + lock (_streams) + { + if (stream.IsRequestStream) + { + _activeRequestCount--; + } _streams.Remove(stream.StreamId); } - _streamCompletionAwaitable.Complete(); + if (stream.IsRequestStream) + { + _streamCompletionAwaitable.Complete(); + } } - public void OnStreamConnectionError(Http3ConnectionErrorException ex) + void IHttp3StreamLifetimeHandler.OnStreamConnectionError(Http3ConnectionErrorException ex) { Log.Http3ConnectionError(ConnectionId, ex); Abort(new ConnectionAbortedException(ex.Message, ex), ex.ErrorCode); } - public void OnInboundControlStreamSetting(Http3SettingType type, long value) + void IHttp3StreamLifetimeHandler.OnInboundControlStreamSetting(Http3SettingType type, long value) { switch (type) { @@ -561,6 +575,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3 } } + void IHttp3StreamLifetimeHandler.OnStreamHeaderReceived(IHttp3Stream stream) + { + Debug.Assert(stream.ReceivedHeader); + } + private static class GracefulCloseInitiator { public const int None = 0; diff --git a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3ControlStream.cs b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3ControlStream.cs index 7ed60c621a5b2e10cf75197046ae43285f1b96fc..bf07468fa70b3340a01ff41a7967d092037e3ea4 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3ControlStream.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3ControlStream.cs @@ -154,6 +154,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3 try { _headerType = await TryReadStreamHeaderAsync(); + _context.StreamLifetimeHandler.OnStreamHeaderReceived(this); switch (_headerType) { @@ -195,6 +196,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3 _errorCodeFeature.Error = (long)ex.ErrorCode; _context.StreamLifetimeHandler.OnStreamConnectionError(ex); } + finally + { + _context.StreamLifetimeHandler.OnStreamCompleted(this); + } } private async Task HandleControlStream() diff --git a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs index 1549bd3d1d0687d22ea5c0eca3d4db7fc3c8b1e7..de05f7fba75b41fd7db7dc65b4062d48c1a686d9 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3Stream.cs @@ -549,6 +549,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3 } _appCompleted = new TaskCompletionSource(); + _context.StreamLifetimeHandler.OnStreamHeaderReceived(this); ThreadPool.UnsafeQueueUserWorkItem(this, preferLocal: false); } diff --git a/src/Servers/Kestrel/Core/src/Internal/Http3/IHttp3StreamLifetimeHandler.cs b/src/Servers/Kestrel/Core/src/Internal/Http3/IHttp3StreamLifetimeHandler.cs index 57bbbf99c5973e7a69c9eeef5f1957e39f6d6653..9748d6bb89e5be2e3174d4e27354fbbf1360e78f 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http3/IHttp3StreamLifetimeHandler.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http3/IHttp3StreamLifetimeHandler.cs @@ -5,6 +5,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3 { internal interface IHttp3StreamLifetimeHandler { + void OnStreamCreated(IHttp3Stream stream); + void OnStreamHeaderReceived(IHttp3Stream stream); void OnStreamCompleted(IHttp3Stream stream); void OnStreamConnectionError(Http3ConnectionErrorException ex); diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3TestBase.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3TestBase.cs index b3136f43c012fd8c672b805f180245d2dc538137..e9986ea4e3a094b1f67af15819fdb2adc21b3927 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3TestBase.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3TestBase.cs @@ -3,6 +3,7 @@ using System; using System.Buffers; +using System.Collections.Concurrent; using System.Collections.Generic; using System.Diagnostics; using System.IO; @@ -47,6 +48,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests protected Task _connectionTask; protected readonly TaskCompletionSource _closedStateReached = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + internal readonly ConcurrentDictionary<long, Http3StreamBase> _runningStreams = new ConcurrentDictionary<long, Http3StreamBase>(); protected readonly RequestDelegate _noopApplication; protected readonly RequestDelegate _echoApplication; @@ -242,10 +244,74 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests httpConnectionContext.TimeoutControl = _mockTimeoutControl.Object; Connection = new Http3Connection(httpConnectionContext); + Connection._streamLifetimeHandler = new LifetimeHandlerInterceptor(Connection, this); + _mockTimeoutHandler.Setup(h => h.OnTimeout(It.IsAny<TimeoutReason>())) .Callback<TimeoutReason>(r => Connection.OnTimeout(r)); } + private class LifetimeHandlerInterceptor : IHttp3StreamLifetimeHandler + { + private readonly IHttp3StreamLifetimeHandler _inner; + private readonly Http3TestBase _http3TestBase; + + public LifetimeHandlerInterceptor(IHttp3StreamLifetimeHandler inner, Http3TestBase http3TestBase) + { + _inner = inner; + _http3TestBase = http3TestBase; + } + + public bool OnInboundControlStream(Internal.Http3.Http3ControlStream stream) + { + return _inner.OnInboundControlStream(stream); + } + + public void OnInboundControlStreamSetting(Internal.Http3.Http3SettingType type, long value) + { + _inner.OnInboundControlStreamSetting(type, value); + } + + public bool OnInboundDecoderStream(Internal.Http3.Http3ControlStream stream) + { + return _inner.OnInboundDecoderStream(stream); + } + + public bool OnInboundEncoderStream(Internal.Http3.Http3ControlStream stream) + { + return _inner.OnInboundEncoderStream(stream); + } + + public void OnStreamCompleted(IHttp3Stream stream) + { + _inner.OnStreamCompleted(stream); + } + + public void OnStreamConnectionError(Http3ConnectionErrorException ex) + { + _inner.OnStreamConnectionError(ex); + } + + public void OnStreamCreated(IHttp3Stream stream) + { + _inner.OnStreamCreated(stream); + + if (_http3TestBase._runningStreams.TryGetValue(stream.StreamId, out var testStream)) + { + testStream._onStreamCreatedTcs.TrySetResult(); + } + } + + public void OnStreamHeaderReceived(IHttp3Stream stream) + { + _inner.OnStreamHeaderReceived(stream); + + if (_http3TestBase._runningStreams.TryGetValue(stream.StreamId, out var testStream)) + { + testStream._onHeaderReceivedTcs.TrySetResult(); + } + } + } + protected void ConnectionClosed() { @@ -294,6 +360,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests public async ValueTask<Http3ControlStream> CreateControlStream(int? id) { var stream = new Http3ControlStream(this, StreamInitiator.Client); + _runningStreams[stream.StreamId] = stream; + MultiplexedConnectionContext.ToServerAcceptQueue.Writer.TryWrite(stream.StreamContext); if (id != null) { @@ -305,6 +373,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests internal ValueTask<Http3RequestStream> CreateRequestStream() { var stream = new Http3RequestStream(this, Connection); + _runningStreams[stream.StreamId] = stream; + MultiplexedConnectionContext.ToServerAcceptQueue.Writer.TryWrite(stream.StreamContext); return new ValueTask<Http3RequestStream>(stream); } @@ -318,12 +388,18 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests public class Http3StreamBase : IProtocolErrorCodeFeature { + internal TaskCompletionSource _onStreamCreatedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + internal TaskCompletionSource _onHeaderReceivedTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); + internal DuplexPipe.DuplexPipePair _pair; internal Http3TestBase _testBase; internal Http3Connection _connection; private long _bytesReceived; public long Error { get; set; } + public Task OnStreamCreatedTask => _onStreamCreatedTcs.Task; + public Task OnHeaderReceivedTask => _onHeaderReceivedTcs.Task; + protected Task SendAsync(ReadOnlySpan<byte> span) { var writableBuffer = _pair.Application.Output; diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3TimeoutTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3TimeoutTests.cs index eae0d751dc5a7711b4fc37b03f14162b9068f6a3..f52320ca77328d2bb46777987fff62eaf5829476 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3TimeoutTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3TimeoutTests.cs @@ -6,8 +6,10 @@ using System.Collections.Generic; using System.Linq; using System.Net.Http; using System.Threading.Tasks; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3; using Microsoft.AspNetCore.Testing; using Microsoft.Net.Http.Headers; +using Moq; using Xunit; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests @@ -26,9 +28,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var controlStream = await GetInboundControlStream().DefaultTimeout(); await controlStream.ExpectSettingsAsync().DefaultTimeout(); - await AssertIsTrueRetryAsync( - () => Connection._streams.Count == 2, - "Wait until streams have been created.").DefaultTimeout(); + await requestStream.OnStreamCreatedTask.DefaultTimeout(); var serverRequestStream = Connection._streams[requestStream.StreamId]; @@ -63,9 +63,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var controlStream = await GetInboundControlStream().DefaultTimeout(); await controlStream.ExpectSettingsAsync().DefaultTimeout(); - await AssertIsTrueRetryAsync( - () => Connection._streams.Count == 2, - "Wait until streams have been created.").DefaultTimeout(); + await requestStream.OnStreamCreatedTask.DefaultTimeout(); var serverRequestStream = Connection._streams[requestStream.StreamId]; @@ -76,14 +74,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests await requestStream.SendHeadersAsync(headers).DefaultTimeout(); - await AssertIsTrueRetryAsync( - () => serverRequestStream.ReceivedHeader, - "Request stream has read headers.").DefaultTimeout(); + await requestStream.OnHeaderReceivedTask.DefaultTimeout(); TriggerTick(now + limits.RequestHeadersTimeout + TimeSpan.FromTicks(1)); await requestStream.SendDataAsync(Memory<byte>.Empty, endStream: true); + await requestStream.ExpectHeadersAsync(); + await requestStream.ExpectReceiveEndOfStream(); } @@ -107,9 +105,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var outboundControlStream = await CreateControlStream(id: null); - await AssertIsTrueRetryAsync( - () => Connection._streams.Count == 1, - "Wait until streams have been created.").DefaultTimeout(); + await outboundControlStream.OnStreamCreatedTask.DefaultTimeout(); var serverInboundControlStream = Connection._streams[outboundControlStream.StreamId]; @@ -143,20 +139,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var outboundControlStream = await CreateControlStream(id: null); - await AssertIsTrueRetryAsync( - () => Connection._streams.Count == 1, - "Wait until streams have been created.").DefaultTimeout(); - - var serverInboundControlStream = Connection._streams[outboundControlStream.StreamId]; + await outboundControlStream.OnStreamCreatedTask.DefaultTimeout(); TriggerTick(now); TriggerTick(now + limits.RequestHeadersTimeout); await outboundControlStream.WriteStreamIdAsync(id: 0); - await AssertIsTrueRetryAsync( - () => serverInboundControlStream.ReceivedHeader, - "Control stream has read header.").DefaultTimeout(); + await outboundControlStream.OnHeaderReceivedTask.DefaultTimeout(); TriggerTick(now + limits.RequestHeadersTimeout + TimeSpan.FromTicks(1)); } @@ -183,9 +173,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests var outboundControlStream = await CreateControlStream(id: null); - await AssertIsTrueRetryAsync( - () => Connection._streams.Count == 1, - "Wait until streams have been created.").DefaultTimeout(); + await outboundControlStream.OnStreamCreatedTask.DefaultTimeout(); var serverInboundControlStream = Connection._streams[outboundControlStream.StreamId]; @@ -193,25 +181,5 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests Assert.Equal(TimeSpan.MaxValue.Ticks, serverInboundControlStream.HeaderTimeoutTicks); } - - private static async Task AssertIsTrueRetryAsync(Func<bool> assert, string message) - { - const int Retries = 10; - - for (var i = 0; i < Retries; i++) - { - if (i > 0) - { - await Task.Delay((i + 1) * 10); - } - - if (assert()) - { - return; - } - } - - throw new Exception($"Assert failed after {Retries} retries: {message}"); - } } }