From d38e2f67d782868df34c69115af9caada787f3ea Mon Sep 17 00:00:00 2001
From: "github-actions[bot]"
 <41898282+github-actions[bot]@users.noreply.github.com>
Date: Fri, 10 Sep 2021 17:46:39 -0700
Subject: [PATCH] [release/6.0] Implement 100 Continue for HTTP/3 (#36395)

* Implement 100 Continue for HTTP/3.

* Review feedback.

Co-authored-by: Aditya Mandaleeka <adityam@microsoft.com>
---
 .../src/Internal/Http3/Http3FrameWriter.cs    | 25 ++++++++
 .../src/Internal/Http3/Http3OutputProducer.cs | 12 +++-
 .../Http2/Http2TestBase.cs                    |  2 +-
 .../Http3/Http3ConnectionTests.cs             | 53 ++++++++++++++++
 .../Http3/Http3Helpers.cs                     |  8 ++-
 .../Http3/Http3RequestTests.cs                | 63 ++++++++++++++++---
 6 files changed, 151 insertions(+), 12 deletions(-)

diff --git a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3FrameWriter.cs b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3FrameWriter.cs
index 495d5ad338f..f0278038d33 100644
--- a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3FrameWriter.cs
+++ b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3FrameWriter.cs
@@ -22,6 +22,14 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3
 {
     internal class Http3FrameWriter
     {
+        // These bytes represent a ":status: 100" continue response header frame encoded with
+        // QPACK. To arrive at this, we first take the index in the QPACK static table for status
+        // 100 (https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#appendix-A), which
+        // is 63, and encode it to get ff 00 (see QPackEncoder.EncodeStaticIndexedHeaderField).
+        // The two zero bytes are for the section prefix
+        // (https://quicwg.org/base-drafts/draft-ietf-quic-qpack.html#header-prefix)
+        private static ReadOnlySpan<byte> ContinueBytes => new byte[] { 0x00, 0x00, 0xff, 0x00 };
+
         // Size based on HTTP/2 default frame size
         private const int MaxDataFrameSize = 16 * 1024;
         private const int HeaderBufferSize = 16 * 1024;
@@ -256,6 +264,23 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3
             _unflushedBytes += headerLength + _outgoingFrame.Length;
         }
 
+        public ValueTask<FlushResult> Write100ContinueAsync()
+        {
+            lock (_writeLock)
+            {
+                if (_completed)
+                {
+                    return default;
+                }
+
+                _outgoingFrame.PrepareHeaders();
+                _outgoingFrame.Length = ContinueBytes.Length;
+                WriteHeaderUnsynchronized();
+                _outputWriter.Write(ContinueBytes);
+                return TimeFlushUnsynchronizedAsync();
+            }
+        }
+
         internal static int WriteHeader(Http3FrameType frameType, long frameLength, PipeWriter output)
         {
             // max size of the header is 16, most likely it will be smaller.
diff --git a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3OutputProducer.cs b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3OutputProducer.cs
index 52b05501688..27c92f8eae3 100644
--- a/src/Servers/Kestrel/Core/src/Internal/Http3/Http3OutputProducer.cs
+++ b/src/Servers/Kestrel/Core/src/Internal/Http3/Http3OutputProducer.cs
@@ -295,7 +295,17 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http3
 
         public ValueTask<FlushResult> Write100ContinueAsync()
         {
-            throw new NotImplementedException();
+            lock (_dataWriterLock)
+            {
+                ThrowIfSuffixSent();
+
+                if (_streamCompleted)
+                {
+                    return default;
+                }
+
+                return _frameWriter.Write100ContinueAsync();
+            }
         }
 
         public ValueTask<FlushResult> WriteChunkAsync(ReadOnlySpan<byte> data, CancellationToken cancellationToken)
diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2TestBase.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2TestBase.cs
index e8dafc49645..e0d92c63c8b 100644
--- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2TestBase.cs
+++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2TestBase.cs
@@ -69,7 +69,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
             new KeyValuePair<string, string>(HeaderNames.Path, "/"),
             new KeyValuePair<string, string>(HeaderNames.Authority, "127.0.0.1"),
             new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
-            new KeyValuePair<string, string>("expect", "100-continue"),
+            new KeyValuePair<string, string>(HeaderNames.Expect, "100-continue"),
         };
 
         protected static readonly IEnumerable<KeyValuePair<string, string>> _requestTrailers = new[]
diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3ConnectionTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3ConnectionTests.cs
index 0c2a984db3d..ead3a1656bd 100644
--- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3ConnectionTests.cs
+++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http3/Http3ConnectionTests.cs
@@ -2,6 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System;
+using System.Buffers;
 using System.Collections.Generic;
 using System.Globalization;
 using System.Net.Http;
@@ -73,6 +74,58 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests
             Assert.True(requestStream.Disposed);
         }
 
+        [Fact]
+        public async Task HEADERS_Received_ContainsExpect100Continue_100ContinueSent()
+        {
+            await Http3Api.InitializeConnectionAsync(async context =>
+            {
+                var buffer = new byte[16 * 1024];
+                var received = 0;
+
+                while ((received = await context.Request.Body.ReadAsync(buffer, 0, buffer.Length)) > 0)
+                {
+                    await context.Response.Body.WriteAsync(buffer, 0, received);
+                }
+            });
+
+            await Http3Api.CreateControlStream();
+            await Http3Api.GetInboundControlStream();
+
+            var requestStream = await Http3Api.CreateRequestStream();
+
+            var expectContinueRequestHeaders = new[]
+            {
+                new KeyValuePair<string, string>(HeaderNames.Method, "POST"),
+                new KeyValuePair<string, string>(HeaderNames.Path, "/"),
+                new KeyValuePair<string, string>(HeaderNames.Authority, "127.0.0.1"),
+                new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
+                new KeyValuePair<string, string>(HeaderNames.Expect, "100-continue"),
+            };
+
+            await requestStream.SendHeadersAsync(expectContinueRequestHeaders);
+
+            var frame = await requestStream.ReceiveFrameAsync();
+            Assert.Equal(Http3FrameType.Headers, frame.Type);
+
+            var continueBytesQpackEncoded = new byte[] { 0x00, 0x00, 0xff, 0x00 };
+            Assert.Equal(continueBytesQpackEncoded, frame.PayloadSequence.ToArray());
+
+            await requestStream.SendDataAsync(Encoding.ASCII.GetBytes("Hello world"), endStream: false);
+            var headers = await requestStream.ExpectHeadersAsync();
+            Assert.Equal("200", headers[HeaderNames.Status]);
+
+            var responseData = await requestStream.ExpectDataAsync();
+            Assert.Equal("Hello world", Encoding.ASCII.GetString(responseData.ToArray()));
+
+            Assert.False(requestStream.Disposed, "Request is in progress and shouldn't be disposed.");
+
+            await requestStream.SendDataAsync(Encoding.ASCII.GetBytes($"End"), endStream: true);
+            responseData = await requestStream.ExpectDataAsync();
+            Assert.Equal($"End", Encoding.ASCII.GetString(responseData.ToArray()));
+
+            await requestStream.ExpectReceiveEndOfStream();
+        }
+
         [Theory]
         [InlineData(0, 0)]
         [InlineData(1, 4)]
diff --git a/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3Helpers.cs b/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3Helpers.cs
index 644444b86d3..ef101b0bc01 100644
--- a/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3Helpers.cs
+++ b/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3Helpers.cs
@@ -19,7 +19,7 @@ namespace Interop.FunctionalTests.Http3
 {
     public static class Http3Helpers
     {
-        public static HttpMessageInvoker CreateClient(TimeSpan? idleTimeout = null, bool includeClientCert = false)
+        public static HttpMessageInvoker CreateClient(TimeSpan? idleTimeout = null, TimeSpan? expect100ContinueTimeout = null, bool includeClientCert = false)
         {
             var handler = new SocketsHttpHandler();
             handler.SslOptions = new System.Net.Security.SslClientAuthenticationOptions
@@ -28,6 +28,12 @@ namespace Interop.FunctionalTests.Http3
                 TargetHost = "targethost",
                 ClientCertificates = !includeClientCert ? null : new X509CertificateCollection() { TestResources.GetTestCertificate() },
             };
+
+            if (expect100ContinueTimeout != null)
+            {
+                handler.Expect100ContinueTimeout = expect100ContinueTimeout.Value;
+            }
+
             if (idleTimeout != null)
             {
                 handler.PooledConnectionIdleTimeout = idleTimeout.Value;
diff --git a/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3RequestTests.cs b/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3RequestTests.cs
index 322b59f09d4..68d8a53e0bb 100644
--- a/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3RequestTests.cs
+++ b/src/Servers/Kestrel/test/Interop.FunctionalTests/Http3/Http3RequestTests.cs
@@ -25,7 +25,7 @@ namespace Interop.FunctionalTests.Http3
     [QuarantinedTest("https://github.com/dotnet/aspnetcore/issues/35070")]
     public class Http3RequestTests : LoggedTest
     {
-        private class StreamingHttpContext : HttpContent
+        private class StreamingHttpContent : HttpContent
         {
             private readonly TaskCompletionSource _completeTcs = new TaskCompletionSource(TaskCreationOptions.RunContinuationsAsynchronously);
             private readonly TaskCompletionSource<Stream> _getStreamTcs = new TaskCompletionSource<Stream>(TaskCreationOptions.RunContinuationsAsynchronously);
@@ -234,7 +234,7 @@ namespace Interop.FunctionalTests.Http3
             {
                 await host.StartAsync();
 
-                var requestContent = new StreamingHttpContext();
+                var requestContent = new StreamingHttpContent();
 
                 var request = new HttpRequestMessage(HttpMethod.Post, $"https://127.0.0.1:{host.GetPort()}/");
                 request.Content = requestContent;
@@ -278,7 +278,7 @@ namespace Interop.FunctionalTests.Http3
             {
                 await host.StartAsync().DefaultTimeout();
 
-                var requestContent = new StreamingHttpContext();
+                var requestContent = new StreamingHttpContent();
 
                 var request = new HttpRequestMessage(HttpMethod.Post, $"https://127.0.0.1:{host.GetPort()}/");
                 request.Content = requestContent;
@@ -347,7 +347,7 @@ namespace Interop.FunctionalTests.Http3
                 await host.StartAsync().DefaultTimeout();
 
                 var cts = new CancellationTokenSource();
-                var requestContent = new StreamingHttpContext();
+                var requestContent = new StreamingHttpContent();
 
                 var request = new HttpRequestMessage(HttpMethod.Post, $"https://127.0.0.1:{host.GetPort()}/");
                 request.Content = requestContent;
@@ -418,7 +418,7 @@ namespace Interop.FunctionalTests.Http3
             {
                 await host.StartAsync().DefaultTimeout();
 
-                var requestContent = new StreamingHttpContext();
+                var requestContent = new StreamingHttpContent();
 
                 var request = new HttpRequestMessage(HttpMethod.Post, $"https://127.0.0.1:{host.GetPort()}/");
                 request.Content = requestContent;
@@ -511,6 +511,51 @@ namespace Interop.FunctionalTests.Http3
             }
         }
 
+        [ConditionalFact]
+        [MsQuicSupported]
+        public async Task POST_Expect100Continue_Get100Continue()
+        {
+            // Arrange
+            var builder = CreateHostBuilder(async context =>
+            {
+                var body = context.Request.Body;
+
+                var data = await body.ReadAtLeastLengthAsync(TestData.Length).DefaultTimeout();
+
+                await context.Response.Body.WriteAsync(data);
+            });
+
+            using (var host = builder.Build())
+            using (var client = Http3Helpers.CreateClient(expect100ContinueTimeout: TimeSpan.FromMinutes(20)))
+            {
+                await host.StartAsync().DefaultTimeout();
+
+                var requestContent = new StringContent("Hello world");
+
+                var request = new HttpRequestMessage(HttpMethod.Post, $"https://127.0.0.1:{host.GetPort()}/");
+                request.Content = requestContent;
+                request.Version = HttpVersion.Version30;
+                request.VersionPolicy = HttpVersionPolicy.RequestVersionExact;
+                request.Headers.ExpectContinue = true;
+
+                // Act
+                using var cts = new CancellationTokenSource();
+                cts.CancelAfter(TimeSpan.FromSeconds(1));
+                var responseTask = client.SendAsync(request, cts.Token);
+
+                var response = await responseTask.DefaultTimeout();
+
+                // Assert
+                response.EnsureSuccessStatusCode();
+                Assert.Equal(HttpVersion.Version30, response.Version);
+                var responseText = await response.Content.ReadAsStringAsync().DefaultTimeout();
+                Assert.Equal("Hello world", responseText);
+
+                await host.StopAsync().DefaultTimeout();
+            }
+        }
+
+
         private static Version GetProtocol(HttpProtocols protocol)
         {
             switch (protocol)
@@ -632,7 +677,7 @@ namespace Interop.FunctionalTests.Http3
                 await host.StartAsync().DefaultTimeout();
 
                 var cts = new CancellationTokenSource();
-                var requestContent = new StreamingHttpContext();
+                var requestContent = new StreamingHttpContent();
 
                 var request = new HttpRequestMessage(HttpMethod.Post, $"https://127.0.0.1:{host.GetPort()}/");
                 request.Content = requestContent;
@@ -729,7 +774,7 @@ namespace Interop.FunctionalTests.Http3
                 await host.StartAsync().DefaultTimeout();
 
                 var cts = new CancellationTokenSource();
-                var requestContent = new StreamingHttpContext();
+                var requestContent = new StreamingHttpContent();
 
                 var request = new HttpRequestMessage(HttpMethod.Post, $"https://127.0.0.1:{host.GetPort()}/");
                 request.Content = requestContent;
@@ -1456,7 +1501,7 @@ namespace Interop.FunctionalTests.Http3
             {
                 await host.StartAsync().DefaultTimeout();
 
-                var requestContent = new StreamingHttpContext();
+                var requestContent = new StreamingHttpContent();
 
                 var request = new HttpRequestMessage(HttpMethod.Post, $"https://127.0.0.1:{host.GetPort()}/");
                 request.Content = requestContent;
@@ -1551,7 +1596,7 @@ namespace Interop.FunctionalTests.Http3
             {
                 await host.StartAsync().DefaultTimeout();
 
-                var requestContent = new StreamingHttpContext();
+                var requestContent = new StreamingHttpContent();
 
                 var request = new HttpRequestMessage(HttpMethod.Post, $"https://127.0.0.1:{host.GetPort()}/");
                 request.Content = requestContent;
-- 
GitLab