diff --git a/src/Middleware/WebSockets/samples/EchoApp/EchoApp.csproj b/src/Middleware/WebSockets/samples/EchoApp/EchoApp.csproj index c92354b6f5cd7531e291e547d0f803d1f869ee3d..203f749b8cb803ffb835c77ff468d3ccb14d3c08 100644 --- a/src/Middleware/WebSockets/samples/EchoApp/EchoApp.csproj +++ b/src/Middleware/WebSockets/samples/EchoApp/EchoApp.csproj @@ -6,6 +6,7 @@ </PropertyGroup> <ItemGroup> + <Reference Include="Microsoft.AspNetCore" /> <Reference Include="Microsoft.AspNetCore.Diagnostics" /> <Reference Include="Microsoft.AspNetCore.Server.IISIntegration" /> <Reference Include="Microsoft.AspNetCore.Server.Kestrel" /> @@ -14,4 +15,10 @@ <Reference Include="Microsoft.Extensions.Logging.Console" /> </ItemGroup> + <ItemGroup> + <Content Update="appsettings.json"> + <CopyToOutputDirectory>Always</CopyToOutputDirectory> + </Content> + </ItemGroup> + </Project> diff --git a/src/Middleware/WebSockets/samples/EchoApp/Program.cs b/src/Middleware/WebSockets/samples/EchoApp/Program.cs index 6b65ab2e764ceaf1e8b9ede08e009830689aef9d..e4860fd1416c93cbeddded6777321c4e3f917a9f 100644 --- a/src/Middleware/WebSockets/samples/EchoApp/Program.cs +++ b/src/Middleware/WebSockets/samples/EchoApp/Program.cs @@ -1,23 +1,89 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using System.Net.WebSockets; +using System.Text; +using Microsoft.AspNetCore.Http.Features; + namespace EchoApp; public class Program { public static Task Main(string[] args) { - var host = new HostBuilder() - .ConfigureWebHost(webHostBuilder => + var builder = WebApplication.CreateBuilder(); + var app = builder.Build(); + + app.UseDeveloperExceptionPage(); + app.UseWebSockets(); + + app.Use(async (context, next) => + { + if (context.WebSockets.IsWebSocketRequest) + { + var webSocket = await context.WebSockets.AcceptWebSocketAsync(new WebSocketAcceptContext() { DangerousEnableCompression = true }); + await Echo(context, webSocket, app.Logger); + return; + } + + await next(context); + }); + + app.UseFileServer(); + + return app.RunAsync(); + } + + private static async Task Echo(HttpContext context, WebSocket webSocket, ILogger logger) + { + var buffer = new byte[1024 * 4]; + var result = await webSocket.ReceiveAsync(buffer.AsMemory(), CancellationToken.None); + LogFrame(logger, webSocket, result, buffer); + while (result.MessageType != WebSocketMessageType.Close) + { + // If the client send "ServerClose", then they want a server-originated close to occur + string content = "<<binary>>"; + if (result.MessageType == WebSocketMessageType.Text) + { + content = Encoding.UTF8.GetString(buffer, 0, result.Count); + if (content.Equals("ServerClose")) + { + await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closing from Server", CancellationToken.None); + logger.LogDebug($"Sent Frame Close: {WebSocketCloseStatus.NormalClosure} Closing from Server"); + return; + } + else if (content.Equals("ServerAbort")) + { + context.Abort(); + } + } + + await webSocket.SendAsync(new ArraySegment<byte>(buffer, 0, result.Count), result.MessageType, result.EndOfMessage, CancellationToken.None); + logger.LogDebug($"Sent Frame {result.MessageType}: Len={result.Count}, Fin={result.EndOfMessage}: {content}"); + + result = await webSocket.ReceiveAsync(buffer.AsMemory(), CancellationToken.None); + LogFrame(logger, webSocket, result, buffer); + } + await webSocket.CloseAsync(webSocket.CloseStatus.Value, webSocket.CloseStatusDescription, CancellationToken.None); + } + + private static void LogFrame(ILogger logger, WebSocket webSocket, ValueWebSocketReceiveResult frame, byte[] buffer) + { + var close = frame.MessageType == WebSocketMessageType.Close; + string message; + if (close) + { + message = $"Close: {webSocket.CloseStatus.Value} {webSocket.CloseStatusDescription}"; + } + else + { + string content = "<<binary>>"; + if (frame.MessageType == WebSocketMessageType.Text) { - webHostBuilder - .UseKestrel() - .UseContentRoot(Directory.GetCurrentDirectory()) - .UseIISIntegration() - .UseStartup<Startup>(); - }) - .Build(); - - return host.RunAsync(); + content = Encoding.UTF8.GetString(buffer, 0, frame.Count); + } + message = $"{frame.MessageType}: Len={frame.Count}, Fin={frame.EndOfMessage}: {content}"; + } + logger.LogDebug("Received Frame " + message); } } diff --git a/src/Middleware/WebSockets/samples/EchoApp/Startup.cs b/src/Middleware/WebSockets/samples/EchoApp/Startup.cs deleted file mode 100644 index 2adef419664d935563fa21b9bf018606f54edd3e..0000000000000000000000000000000000000000 --- a/src/Middleware/WebSockets/samples/EchoApp/Startup.cs +++ /dev/null @@ -1,96 +0,0 @@ -// Licensed to the .NET Foundation under one or more agreements. -// The .NET Foundation licenses this file to you under the MIT license. - -using System.Net.WebSockets; -using System.Text; - -namespace EchoApp; - -public class Startup -{ - // This method gets called by the runtime. Use this method to add services to the container. - // For more information on how to configure your application, visit http://go.microsoft.com/fwlink/?LinkID=398940 - public void ConfigureServices(IServiceCollection services) - { - services.AddLogging(builder => builder.AddConsole()); - } - - // This method gets called by the runtime. Use this method to configure the HTTP request pipeline. - public void Configure(IApplicationBuilder app, IWebHostEnvironment env, ILoggerFactory loggerFactory) - { - if (env.IsDevelopment()) - { - app.UseDeveloperExceptionPage(); - } - - app.UseWebSockets(); - - app.Use(async (context, next) => - { - if (context.WebSockets.IsWebSocketRequest) - { - var webSocket = await context.WebSockets.AcceptWebSocketAsync(new WebSocketAcceptContext() { DangerousEnableCompression = true }); - await Echo(context, webSocket, loggerFactory.CreateLogger("Echo")); - } - else - { - await next(context); - } - }); - - app.UseFileServer(); - } - - private async Task Echo(HttpContext context, WebSocket webSocket, ILogger logger) - { - var buffer = new byte[1024 * 4]; - var result = await webSocket.ReceiveAsync(buffer.AsMemory(), CancellationToken.None); - LogFrame(logger, webSocket, result, buffer); - while (result.MessageType != WebSocketMessageType.Close) - { - // If the client send "ServerClose", then they want a server-originated close to occur - string content = "<<binary>>"; - if (result.MessageType == WebSocketMessageType.Text) - { - content = Encoding.UTF8.GetString(buffer, 0, result.Count); - if (content.Equals("ServerClose")) - { - await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closing from Server", CancellationToken.None); - logger.LogDebug($"Sent Frame Close: {WebSocketCloseStatus.NormalClosure} Closing from Server"); - return; - } - else if (content.Equals("ServerAbort")) - { - context.Abort(); - } - } - - await webSocket.SendAsync(new ArraySegment<byte>(buffer, 0, result.Count), result.MessageType, result.EndOfMessage, CancellationToken.None); - logger.LogDebug($"Sent Frame {result.MessageType}: Len={result.Count}, Fin={result.EndOfMessage}: {content}"); - - result = await webSocket.ReceiveAsync(buffer.AsMemory(), CancellationToken.None); - LogFrame(logger, webSocket, result, buffer); - } - await webSocket.CloseAsync(webSocket.CloseStatus.Value, webSocket.CloseStatusDescription, CancellationToken.None); - } - - private static void LogFrame(ILogger logger, WebSocket webSocket, ValueWebSocketReceiveResult frame, byte[] buffer) - { - var close = frame.MessageType == WebSocketMessageType.Close; - string message; - if (close) - { - message = $"Close: {webSocket.CloseStatus.Value} {webSocket.CloseStatusDescription}"; - } - else - { - string content = "<<binary>>"; - if (frame.MessageType == WebSocketMessageType.Text) - { - content = Encoding.UTF8.GetString(buffer, 0, frame.Count); - } - message = $"{frame.MessageType}: Len={frame.Count}, Fin={frame.EndOfMessage}: {content}"; - } - logger.LogDebug("Received Frame " + message); - } -} diff --git a/src/Middleware/WebSockets/samples/EchoApp/appsettings.json b/src/Middleware/WebSockets/samples/EchoApp/appsettings.json new file mode 100644 index 0000000000000000000000000000000000000000..2b15773f26c8b7c596742c5d8d508e2385022983 --- /dev/null +++ b/src/Middleware/WebSockets/samples/EchoApp/appsettings.json @@ -0,0 +1,8 @@ +{ + "Logging": { + "LogLevel": { + "Default": "Debug", + "Microsoft.AspNetCore": "Trace" + } + } +} diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs index 4f3340a425612e2cea61aeb6bc5442613c8f552f..c4e4d186521aeb108bc578dfde8237ccedac2d86 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs @@ -149,7 +149,7 @@ internal partial class HttpProtocol bool IHttpUpgradeFeature.IsUpgradableRequest => IsUpgradableRequest; - bool IHttpExtendedConnectFeature.IsExtendedConnect => IsConnectRequest; + bool IHttpExtendedConnectFeature.IsExtendedConnect => IsExtendedConnectRequest; string? IHttpExtendedConnectFeature.Protocol => ConnectProtocol; @@ -195,7 +195,7 @@ internal partial class HttpProtocol set => AllowSynchronousIO = value; } - bool IHttpMaxRequestBodySizeFeature.IsReadOnly => HasStartedConsumingRequestBody || IsUpgraded; + bool IHttpMaxRequestBodySizeFeature.IsReadOnly => HasStartedConsumingRequestBody || IsUpgraded || IsExtendedConnectRequest; long? IHttpMaxRequestBodySizeFeature.MaxRequestBodySize { @@ -290,12 +290,12 @@ internal partial class HttpProtocol async ValueTask<Stream> IHttpExtendedConnectFeature.AcceptAsync() { - if (!IsConnectRequest) + if (!IsExtendedConnectRequest) { throw new InvalidOperationException(CoreStrings.CannotAcceptNonConnectRequest); } - if (IsUpgraded) + if (IsExtendedConnectAccepted) { throw new InvalidOperationException(CoreStrings.AcceptCannotBeCalledMultipleTimes); } @@ -305,7 +305,7 @@ internal partial class HttpProtocol throw new InvalidOperationException(CoreStrings.ConnectStatusMustBe200); } - IsUpgraded = true; + IsExtendedConnectAccepted = true; await FlushAsync(); diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs index d80e529415b48aa4f5af1233fa3a6bdb1921423c..e48fb710ed226948bb2033a8d1b6b890d840fd7e 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs @@ -127,7 +127,8 @@ internal abstract partial class HttpProtocol : IHttpResponseControl public bool IsUpgradableRequest { get; private set; } public bool IsUpgraded { get; set; } - public bool IsConnectRequest { get; set; } + public bool IsExtendedConnectRequest { get; set; } + public bool IsExtendedConnectAccepted { get; set; } public IPAddress? RemoteIpAddress { get; set; } public int RemotePort { get; set; } public IPAddress? LocalIpAddress { get; set; } @@ -359,6 +360,8 @@ internal abstract partial class HttpProtocol : IHttpResponseControl _statusCode = StatusCodes.Status200OK; _reasonPhrase = null; IsUpgraded = false; + IsExtendedConnectRequest = false; + IsExtendedConnectAccepted = false; var remoteEndPoint = RemoteEndPoint; RemoteIpAddress = remoteEndPoint?.Address; diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/MessageBody.cs b/src/Servers/Kestrel/Core/src/Internal/Http/MessageBody.cs index 5208a35bcde1909daa865503ab2ca689146bf367..4898d874d42d09d12dfbb40eb150f3620650866d 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/MessageBody.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/MessageBody.cs @@ -37,6 +37,8 @@ internal abstract class MessageBody public bool RequestUpgrade { get; protected set; } + public bool ExtendedConnect { get; protected set; } + public virtual bool IsEmpty => false; protected KestrelTrace Log => _context.ServiceContext.Log; @@ -122,7 +124,7 @@ internal abstract class MessageBody OnReadStarting(); _context.HasStartedConsumingRequestBody = true; - if (!RequestUpgrade) + if (!RequestUpgrade && !ExtendedConnect) { // Accessing TraceIdentifier will lazy-allocate a string ID. // Don't access TraceIdentifer unless logging is enabled. @@ -150,7 +152,7 @@ internal abstract class MessageBody _stopped = true; - if (!RequestUpgrade) + if (!RequestUpgrade && !ExtendedConnect) { // Accessing TraceIdentifier will lazy-allocate a string ID // Don't access TraceIdentifer unless logging is enabled. diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2MessageBody.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2MessageBody.cs index 126bdb5651c6ff0d6578b8ae0384ee26e4aab0a5..398baf674fb42f444ff89a9657f3ef2baacf810f 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2MessageBody.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2MessageBody.cs @@ -19,6 +19,7 @@ internal sealed class Http2MessageBody : MessageBody : base(context) { _context = context; + ExtendedConnect = _context.IsExtendedConnectRequest; } protected override void OnReadStarting() @@ -51,6 +52,7 @@ internal sealed class Http2MessageBody : MessageBody { base.Reset(); _readResult = default; + ExtendedConnect = _context.IsExtendedConnectRequest; } public override void AdvanceTo(SequencePosition consumed, SequencePosition examined) @@ -62,7 +64,12 @@ internal sealed class Http2MessageBody : MessageBody // The HTTP/2 flow control window cannot be larger than 2^31-1 which limits bytesRead. _context.OnDataRead((int)newlyExaminedBytes); - AddAndCheckObservedBytes(newlyExaminedBytes); + + // Don't limit extended CONNECT requests to the MaxRequestBodySize. + if (!ExtendedConnect) + { + AddAndCheckObservedBytes(newlyExaminedBytes); + } } public override bool TryRead(out ReadResult readResult) diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs index bcf7f223bc2a8e5d91d9b668bbc505bd5a14392f..6df8039ae77ee27475aff87fe7b7697f314da138 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2Stream.cs @@ -239,7 +239,7 @@ internal abstract partial class Http2Stream : HttpProtocol, IThreadPoolWorkItem, return false; } ConnectProtocol = HttpRequestHeaders.HeaderProtocol; - IsConnectRequest = true; + IsExtendedConnectRequest = true; } // CONNECT - :scheme and :path must be excluded else if (!StringValues.IsNullOrEmpty(HttpRequestHeaders.HeaderScheme) || !StringValues.IsNullOrEmpty(HttpRequestHeaders.HeaderPath)) diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2WebSocketTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2WebSocketTests.cs index b11838501d60b597f22fec5f9527b11c498ac873..95c92cb39409f48099a5de60d55db01c4e68fc50 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2WebSocketTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2WebSocketTests.cs @@ -11,9 +11,11 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Features; using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2; +using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure; using Microsoft.AspNetCore.Testing; using Microsoft.Extensions.Logging; using Microsoft.Net.Http.Headers; +using Moq; namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests; @@ -310,4 +312,130 @@ public class Http2WebSocketTests : Http2TestBase await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); } + + [Fact] + public async Task ExtendedCONNECT_AcceptAsyncStream_IsNotLimitedByMinRequestBodyDataRate() + { + var limits = _serviceContext.ServerOptions.Limits; + + // Use non-default value to ensure the min request and response rates aren't mixed up. + limits.MinRequestBodyDataRate = new MinDataRate(480, TimeSpan.FromSeconds(2.5)); + + await InitializeConnectionAsync(async context => + { + var connectFeature = context.Features.Get<IHttpExtendedConnectFeature>(); + var stream = await connectFeature.AcceptAsync(); + Assert.Equal(0, await stream.ReadAsync(new byte[1])); + await stream.WriteAsync(new byte[] { 0x01 }); + }); + + var headers = new[] + { + new KeyValuePair<string, string>(HeaderNames.Method, "CONNECT"), + new KeyValuePair<string, string>(HeaderNames.Protocol, "websocket"), + new KeyValuePair<string, string>(HeaderNames.Scheme, "http"), + new KeyValuePair<string, string>(HeaderNames.Path, "/chat"), + new KeyValuePair<string, string>(HeaderNames.Authority, "server.example.com"), + new KeyValuePair<string, string>(HeaderNames.WebSocketSubProtocols, "chat, superchat"), + new KeyValuePair<string, string>(HeaderNames.SecWebSocketExtensions, "permessage-deflate"), + new KeyValuePair<string, string>(HeaderNames.SecWebSocketVersion, "13"), + new KeyValuePair<string, string>(HeaderNames.Origin, "http://www.example.com"), + }; + await SendHeadersAsync(1, Http2HeadersFrameFlags.END_HEADERS, headers); + + // Don't send any more data and advance just to and then past the grace period. + AdvanceClock(limits.MinRequestBodyDataRate.GracePeriod + TimeSpan.FromTicks(1)); + + _mockTimeoutHandler.Verify(h => h.OnTimeout(It.IsAny<TimeoutReason>()), Times.Never); + + await SendDataAsync(1, Array.Empty<byte>(), endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 32, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + + _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this); + + Assert.Equal(2, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[HeaderNames.Status]); + + var dataFrame = await ExpectAsync(Http2FrameType.DATA, + withLength: 1, + withFlags: (byte)Http2DataFrameFlags.NONE, + withStreamId: 1); + Assert.Equal(0x01, dataFrame.Payload.Span[0]); + + dataFrame = await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + } + + [Fact] + public async Task ExtendedCONNECT_AcceptAsyncStream_IsNotLimitedByMaxRequestBodySize() + { + var limits = _serviceContext.ServerOptions.Limits; + + // We're going to send more than the MaxRequestBodySize bytes from the client to the server over the connection + // Since this is not a request body, this should be allowed like it would be for an upgraded connection. + limits.MaxRequestBodySize = 5; + + await InitializeConnectionAsync(async context => + { + var connectFeature = context.Features.Get<IHttpExtendedConnectFeature>(); + var maxRequestBodySizeFeature = context.Features.Get<IHttpMaxRequestBodySizeFeature>(); + // Extended connects don't have a meaningful request body to limit. + Assert.True(maxRequestBodySizeFeature.IsReadOnly); + var stream = await connectFeature.AcceptAsync(); + Assert.True(maxRequestBodySizeFeature.IsReadOnly); + using var memoryStream = new MemoryStream(); + await stream.CopyToAsync(memoryStream); + Assert.Equal(_serviceContext.ServerOptions.Limits.MaxRequestBodySize + 1, memoryStream.Length); + await stream.WriteAsync(new byte[] { 0x01 }); + }); + + var headers = new[] + { + new KeyValuePair<string, string>(HeaderNames.Method, "CONNECT"), + new KeyValuePair<string, string>(HeaderNames.Protocol, "websocket"), + new KeyValuePair<string, string>(HeaderNames.Scheme, "http"), + new KeyValuePair<string, string>(HeaderNames.Path, "/chat"), + new KeyValuePair<string, string>(HeaderNames.Authority, "server.example.com"), + new KeyValuePair<string, string>(HeaderNames.WebSocketSubProtocols, "chat, superchat"), + new KeyValuePair<string, string>(HeaderNames.SecWebSocketExtensions, "permessage-deflate"), + new KeyValuePair<string, string>(HeaderNames.SecWebSocketVersion, "13"), + new KeyValuePair<string, string>(HeaderNames.Origin, "http://www.example.com"), + }; + await SendHeadersAsync(1, Http2HeadersFrameFlags.END_HEADERS, headers); + + await SendDataAsync(1, new byte[(int)limits.MaxRequestBodySize + 1], endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 32, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + + _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this); + + Assert.Equal(2, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[HeaderNames.Status]); + + var dataFrame = await ExpectAsync(Http2FrameType.DATA, + withLength: 1, + withFlags: (byte)Http2DataFrameFlags.NONE, + withStreamId: 1); + Assert.Equal(0x01, dataFrame.Payload.Span[0]); + + dataFrame = await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + } }