From d38e2f67d782868df34c69115af9caada787f3ea Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 10 Sep 2021 17:46:39 -0700 Subject: [PATCH] [release/6.0] Implement 100 Continue for HTTP/3 (#36395) * Implement 100 Continue for HTTP/3. * Review feedback. Co-authored-by: Aditya Mandaleeka <adityam@microsoft.com> --- .../src/Internal/Http3/Http3FrameWriter.cs | 25 ++++++++ .../src/Internal/Http3/Http3OutputProducer.cs | 12 +++- .../Http2/Http2TestBase.cs | 2 +- .../Http3/Http3ConnectionTests.cs | 53 ++++++++++++++++ .../Http3/Http3Helpers.cs | 8 ++- .../Http3/Http3RequestTests.cs | 63 ++++++++++++++++--- 6 files changed, 151 insertions(+), 12 deletions(-) diff --git a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3FrameWriter.cs b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3FrameWriter.cs index 495d5ad338f..f0278038d33 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3FrameWriter.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3FrameWriter.cs @@ -22,6 +22,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3 { internal class Http3FrameWriter { + // These bytes represent a ":status: 100" continue response header frame encoded with + // QPACK. To arrive at this, we first take the index in the QPACK static table for status + // 100 (https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#appendix-A), which + // is 63, and encode it to get ff 00 (see QPackEncoder.EncodeStaticIndexedHeaderField). + // The two zero bytes are for the section prefix + // (https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#header-prefix) + private static ReadOnlySpan<byte> ContinueBytes => new byte[] { 0x00, 0x00, 0xff, 0x00 }; + // Size based on HTTP/2 default frame size private const int MaxDataFrameSize = 16 * 1024; private const int HeaderBufferSize = 16 * 1024; @@ -256,6 +264,23 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3 _unflushedBytes += headerLength + _outgoingFrame.Length; } + public ValueTask<FlushResult> Write100ContinueAsync() + { + lock (_writeLock) + { + if (_completed) + { + return default; + } + + _outgoingFrame.PrepareHeaders(); + _outgoingFrame.Length = ContinueBytes.Length; + WriteHeaderUnsynchronized(); + _outputWriter.Write(ContinueBytes); + return TimeFlushUnsynchronizedAsync(); + } + } + internal static int WriteHeader(Http3FrameType frameType, long frameLength, PipeWriter output) { // max size of the header is 16, most likely it will be smaller. diff --git a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3OutputProducer.cs b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3OutputProducer.cs index 52b05501688..27c92f8eae3 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3OutputProducer.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3OutputProducer.cs @@ -295,7 +295,17 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3 public ValueTask<FlushResult> Write100ContinueAsync() { - throw new NotImplementedException(); + lock (_dataWriterLock) + { + ThrowIfSuffixSent(); + + if (_streamCompleted) + { + return default; + } + + return _frameWriter.Write100ContinueAsync(); + } } public ValueTask<FlushResult> WriteChunkAsync(ReadOnlySpan<byte> data, CancellationToken cancellationToken) diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2TestBase.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2TestBase.cs index e8dafc49645..e0d92c63c8b 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2TestBase.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2TestBase.cs @@ -69,7 +69,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests new KeyValuePair<string, string>(HeaderNames.Path, "/"), new KeyValuePair<string, string>(HeaderNames.Authority, "127.0.0.1"), new KeyValuePair<string, string>(HeaderNames.Scheme, "http"), - new KeyValuePair<string, string>("expect", "100-continue"), + new KeyValuePair<string, string>(HeaderNames.Expect, "100-continue"), }; protected static readonly IEnumerable<KeyValuePair<string, string>> _requestTrailers = new[] diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3ConnectionTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3ConnectionTests.cs index 0c2a984db3d..ead3a1656bd 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3ConnectionTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3ConnectionTests.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Buffers; using System.Collections.Generic; using System.Globalization; using System.Net.Http; @@ -73,6 +74,58 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests Assert.True(requestStream.Disposed); } + [Fact] + public async Task HEADERS_Received_ContainsExpect100Continue_100ContinueSent() + { + await Http3Api.InitializeConnectionAsync(async context => + { + var buffer = new byte[16 * 1024]; + var received = 0; + + while ((received = await context.Request.Body.ReadAsync(buffer, 0, buffer.Length)) > 0) + { + await context.Response.Body.WriteAsync(buffer, 0, received); + } + }); + + await Http3Api.CreateControlStream(); + await Http3Api.GetInboundControlStream(); + + var requestStream = await Http3Api.CreateRequestStream(); + + var expectContinueRequestHeaders = new[] + { + new KeyValuePair<string, string>(HeaderNames.Method, "POST"), + new KeyValuePair<string, string>(HeaderNames.Path, "/"), + new KeyValuePair<string, string>(HeaderNames.Authority, "127.0.0.1"), + new KeyValuePair<string, string>(HeaderNames.Scheme, "http"), + new KeyValuePair<string, string>(HeaderNames.Expect, "100-continue"), + }; + + await requestStream.SendHeadersAsync(expectContinueRequestHeaders); + + var frame = await requestStream.ReceiveFrameAsync(); + Assert.Equal(Http3FrameType.Headers, frame.Type); + + var continueBytesQpackEncoded = new byte[] { 0x00, 0x00, 0xff, 0x00 }; + Assert.Equal(continueBytesQpackEncoded, frame.PayloadSequence.ToArray()); + + await requestStream.SendDataAsync(Encoding.ASCII.GetBytes("Hello world"), endStream: false); + var headers = await requestStream.ExpectHeadersAsync(); + Assert.Equal("200", headers[HeaderNames.Status]); + + var responseData = await requestStream.ExpectDataAsync(); + Assert.Equal("Hello world", Encoding.ASCII.GetString(responseData.ToArray())); + + Assert.False(requestStream.Disposed, "Request is in progress and shouldn't be disposed."); + + await requestStream.SendDataAsync(Encoding.ASCII.GetBytes($"End"), endStream: true); + responseData = await requestStream.ExpectDataAsync(); + Assert.Equal($"End", Encoding.ASCII.GetString(responseData.ToArray())); + + await requestStream.ExpectReceiveEndOfStream(); + } + [Theory] [InlineData(0, 0)] [InlineData(1, 4)] diff --git a/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3Helpers.cs b/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3Helpers.cs index 644444b86d3..ef101b0bc01 100644 --- a/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3Helpers.cs +++ b/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3Helpers.cs @@ -19,7 +19,7 @@ namespace Interop.FunctionalTests.Http3 { public static class Http3Helpers { - public static HttpMessageInvoker CreateClient(TimeSpan? idleTimeout = null, bool includeClientCert = false) + public static HttpMessageInvoker CreateClient(TimeSpan? idleTimeout = null, TimeSpan? expect100ContinueTimeout = null, bool includeClientCert = false) { var handler = new SocketsHttpHandler(); handler.SslOptions = new System.Net.Security.SslClientAuthenticationOptions @@ -28,6 +28,12 @@ namespace Interop.FunctionalTests.Http3 TargetHost = "targethost", ClientCertificates = !includeClientCert ? null : new X509CertificateCollection() { TestResources.GetTestCertificate() }, }; + + if (expect100ContinueTimeout != null) + { + handler.Expect100ContinueTimeout = expect100ContinueTimeout.Value; + } + if (idleTimeout != null) { handler.PooledConnectionIdleTimeout = idleTimeout.Value; diff --git a/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3RequestTests.cs b/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3RequestTests.cs index 322b59f09d4..68d8a53e0bb 100644 --- a/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3RequestTests.cs +++ b/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3RequestTests.cs @@ -25,7 +25,7 @@ namespace Interop.FunctionalTests.Http3 [QuarantinedTest("https://github.com/dotnet/aspnetcore/issues/35070")] public class Http3RequestTests : LoggedTest { - private class StreamingHttpContext : HttpContent + private class StreamingHttpContent : HttpContent { private readonly TaskCompletionSource _completeTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously); private readonly TaskCompletionSource<Stream> _getStreamTcs = new TaskCompletionSource<Stream>(TaskCreationOptions.RunContinuationsAsynchronously); @@ -234,7 +234,7 @@ namespace Interop.FunctionalTests.Http3 { await host.StartAsync(); - var requestContent = new StreamingHttpContext(); + var requestContent = new StreamingHttpContent(); var request = new HttpRequestMessage(HttpMethod.Post, $"https://127.0.0.1:{host.GetPort()}/"); request.Content = requestContent; @@ -278,7 +278,7 @@ namespace Interop.FunctionalTests.Http3 { await host.StartAsync().DefaultTimeout(); - var requestContent = new StreamingHttpContext(); + var requestContent = new StreamingHttpContent(); var request = new HttpRequestMessage(HttpMethod.Post, $"https://127.0.0.1:{host.GetPort()}/"); request.Content = requestContent; @@ -347,7 +347,7 @@ namespace Interop.FunctionalTests.Http3 await host.StartAsync().DefaultTimeout(); var cts = new CancellationTokenSource(); - var requestContent = new StreamingHttpContext(); + var requestContent = new StreamingHttpContent(); var request = new HttpRequestMessage(HttpMethod.Post, $"https://127.0.0.1:{host.GetPort()}/"); request.Content = requestContent; @@ -418,7 +418,7 @@ namespace Interop.FunctionalTests.Http3 { await host.StartAsync().DefaultTimeout(); - var requestContent = new StreamingHttpContext(); + var requestContent = new StreamingHttpContent(); var request = new HttpRequestMessage(HttpMethod.Post, $"https://127.0.0.1:{host.GetPort()}/"); request.Content = requestContent; @@ -511,6 +511,51 @@ namespace Interop.FunctionalTests.Http3 } } + [ConditionalFact] + [MsQuicSupported] + public async Task POST_Expect100Continue_Get100Continue() + { + // Arrange + var builder = CreateHostBuilder(async context => + { + var body = context.Request.Body; + + var data = await body.ReadAtLeastLengthAsync(TestData.Length).DefaultTimeout(); + + await context.Response.Body.WriteAsync(data); + }); + + using (var host = builder.Build()) + using (var client = Http3Helpers.CreateClient(expect100ContinueTimeout: TimeSpan.FromMinutes(20))) + { + await host.StartAsync().DefaultTimeout(); + + var requestContent = new StringContent("Hello world"); + + var request = new HttpRequestMessage(HttpMethod.Post, $"https://127.0.0.1:{host.GetPort()}/"); + request.Content = requestContent; + request.Version = HttpVersion.Version30; + request.VersionPolicy = HttpVersionPolicy.RequestVersionExact; + request.Headers.ExpectContinue = true; + + // Act + using var cts = new CancellationTokenSource(); + cts.CancelAfter(TimeSpan.FromSeconds(1)); + var responseTask = client.SendAsync(request, cts.Token); + + var response = await responseTask.DefaultTimeout(); + + // Assert + response.EnsureSuccessStatusCode(); + Assert.Equal(HttpVersion.Version30, response.Version); + var responseText = await response.Content.ReadAsStringAsync().DefaultTimeout(); + Assert.Equal("Hello world", responseText); + + await host.StopAsync().DefaultTimeout(); + } + } + + private static Version GetProtocol(HttpProtocols protocol) { switch (protocol) @@ -632,7 +677,7 @@ namespace Interop.FunctionalTests.Http3 await host.StartAsync().DefaultTimeout(); var cts = new CancellationTokenSource(); - var requestContent = new StreamingHttpContext(); + var requestContent = new StreamingHttpContent(); var request = new HttpRequestMessage(HttpMethod.Post, $"https://127.0.0.1:{host.GetPort()}/"); request.Content = requestContent; @@ -729,7 +774,7 @@ namespace Interop.FunctionalTests.Http3 await host.StartAsync().DefaultTimeout(); var cts = new CancellationTokenSource(); - var requestContent = new StreamingHttpContext(); + var requestContent = new StreamingHttpContent(); var request = new HttpRequestMessage(HttpMethod.Post, $"https://127.0.0.1:{host.GetPort()}/"); request.Content = requestContent; @@ -1456,7 +1501,7 @@ namespace Interop.FunctionalTests.Http3 { await host.StartAsync().DefaultTimeout(); - var requestContent = new StreamingHttpContext(); + var requestContent = new StreamingHttpContent(); var request = new HttpRequestMessage(HttpMethod.Post, $"https://127.0.0.1:{host.GetPort()}/"); request.Content = requestContent; @@ -1551,7 +1596,7 @@ namespace Interop.FunctionalTests.Http3 { await host.StartAsync().DefaultTimeout(); - var requestContent = new StreamingHttpContext(); + var requestContent = new StreamingHttpContent(); var request = new HttpRequestMessage(HttpMethod.Post, $"https://127.0.0.1:{host.GetPort()}/"); request.Content = requestContent; -- GitLab