Skip to content
代码片段 群组 项目
未验证 提交 a382a0fe 编辑于 作者: Chris Ross's avatar Chris Ross 提交者: GitHub
浏览文件

Make H2 WebSockets ignore MaxRequestBodySize and MinRequestBodyDataRate #42101 (#42269)

上级 12822154
No related branches found
No related tags found
无相关合并请求
显示
242 个添加117 个删除
......@@ -6,6 +6,7 @@
</PropertyGroup>
<ItemGroup>
<Reference Include="Microsoft.AspNetCore" />
<Reference Include="Microsoft.AspNetCore.Diagnostics" />
<Reference Include="Microsoft.AspNetCore.Server.IISIntegration" />
<Reference Include="Microsoft.AspNetCore.Server.Kestrel" />
......@@ -14,4 +15,10 @@
<Reference Include="Microsoft.Extensions.Logging.Console" />
</ItemGroup>
<ItemGroup>
<Content Update="appsettings.json">
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
</Content>
</ItemGroup>
</Project>
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Net.WebSockets;
using System.Text;
using Microsoft.AspNetCore.Http.Features;
namespace EchoApp;
public class Program
{
public static Task Main(string[] args)
{
var host = new HostBuilder()
.ConfigureWebHost(webHostBuilder =>
var builder = WebApplication.CreateBuilder();
var app = builder.Build();
app.UseDeveloperExceptionPage();
app.UseWebSockets();
app.Use(async (context, next) =>
{
if (context.WebSockets.IsWebSocketRequest)
{
var webSocket = await context.WebSockets.AcceptWebSocketAsync(new WebSocketAcceptContext() { DangerousEnableCompression = true });
await Echo(context, webSocket, app.Logger);
return;
}
await next(context);
});
app.UseFileServer();
return app.RunAsync();
}
private static async Task Echo(HttpContext context, WebSocket webSocket, ILogger logger)
{
var buffer = new byte[1024 * 4];
var result = await webSocket.ReceiveAsync(buffer.AsMemory(), CancellationToken.None);
LogFrame(logger, webSocket, result, buffer);
while (result.MessageType != WebSocketMessageType.Close)
{
// If the client send "ServerClose", then they want a server-originated close to occur
string content = "<<binary>>";
if (result.MessageType == WebSocketMessageType.Text)
{
content = Encoding.UTF8.GetString(buffer, 0, result.Count);
if (content.Equals("ServerClose"))
{
await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closing from Server", CancellationToken.None);
logger.LogDebug($"Sent Frame Close: {WebSocketCloseStatus.NormalClosure} Closing from Server");
return;
}
else if (content.Equals("ServerAbort"))
{
context.Abort();
}
}
await webSocket.SendAsync(new ArraySegment<byte>(buffer, 0, result.Count), result.MessageType, result.EndOfMessage, CancellationToken.None);
logger.LogDebug($"Sent Frame {result.MessageType}: Len={result.Count}, Fin={result.EndOfMessage}: {content}");
result = await webSocket.ReceiveAsync(buffer.AsMemory(), CancellationToken.None);
LogFrame(logger, webSocket, result, buffer);
}
await webSocket.CloseAsync(webSocket.CloseStatus.Value, webSocket.CloseStatusDescription, CancellationToken.None);
}
private static void LogFrame(ILogger logger, WebSocket webSocket, ValueWebSocketReceiveResult frame, byte[] buffer)
{
var close = frame.MessageType == WebSocketMessageType.Close;
string message;
if (close)
{
message = $"Close: {webSocket.CloseStatus.Value} {webSocket.CloseStatusDescription}";
}
else
{
string content = "<<binary>>";
if (frame.MessageType == WebSocketMessageType.Text)
{
webHostBuilder
.UseKestrel()
.UseContentRoot(Directory.GetCurrentDirectory())
.UseIISIntegration()
.UseStartup<Startup>();
})
.Build();
return host.RunAsync();
content = Encoding.UTF8.GetString(buffer, 0, frame.Count);
}
message = $"{frame.MessageType}: Len={frame.Count}, Fin={frame.EndOfMessage}: {content}";
}
logger.LogDebug("Received Frame " + message);
}
}
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
using System.Net.WebSockets;
using System.Text;
namespace EchoApp;
public class Startup
{
// This method gets called by the runtime. Use this method to add services to the container.
// For more information on how to configure your application, visit http://go.microsoft.com/fwlink/?LinkID=398940
public void ConfigureServices(IServiceCollection services)
{
services.AddLogging(builder => builder.AddConsole());
}
// This method gets called by the runtime. Use this method to configure the HTTP request pipeline.
public void Configure(IApplicationBuilder app, IWebHostEnvironment env, ILoggerFactory loggerFactory)
{
if (env.IsDevelopment())
{
app.UseDeveloperExceptionPage();
}
app.UseWebSockets();
app.Use(async (context, next) =>
{
if (context.WebSockets.IsWebSocketRequest)
{
var webSocket = await context.WebSockets.AcceptWebSocketAsync(new WebSocketAcceptContext() { DangerousEnableCompression = true });
await Echo(context, webSocket, loggerFactory.CreateLogger("Echo"));
}
else
{
await next(context);
}
});
app.UseFileServer();
}
private async Task Echo(HttpContext context, WebSocket webSocket, ILogger logger)
{
var buffer = new byte[1024 * 4];
var result = await webSocket.ReceiveAsync(buffer.AsMemory(), CancellationToken.None);
LogFrame(logger, webSocket, result, buffer);
while (result.MessageType != WebSocketMessageType.Close)
{
// If the client send "ServerClose", then they want a server-originated close to occur
string content = "<<binary>>";
if (result.MessageType == WebSocketMessageType.Text)
{
content = Encoding.UTF8.GetString(buffer, 0, result.Count);
if (content.Equals("ServerClose"))
{
await webSocket.CloseAsync(WebSocketCloseStatus.NormalClosure, "Closing from Server", CancellationToken.None);
logger.LogDebug($"Sent Frame Close: {WebSocketCloseStatus.NormalClosure} Closing from Server");
return;
}
else if (content.Equals("ServerAbort"))
{
context.Abort();
}
}
await webSocket.SendAsync(new ArraySegment<byte>(buffer, 0, result.Count), result.MessageType, result.EndOfMessage, CancellationToken.None);
logger.LogDebug($"Sent Frame {result.MessageType}: Len={result.Count}, Fin={result.EndOfMessage}: {content}");
result = await webSocket.ReceiveAsync(buffer.AsMemory(), CancellationToken.None);
LogFrame(logger, webSocket, result, buffer);
}
await webSocket.CloseAsync(webSocket.CloseStatus.Value, webSocket.CloseStatusDescription, CancellationToken.None);
}
private static void LogFrame(ILogger logger, WebSocket webSocket, ValueWebSocketReceiveResult frame, byte[] buffer)
{
var close = frame.MessageType == WebSocketMessageType.Close;
string message;
if (close)
{
message = $"Close: {webSocket.CloseStatus.Value} {webSocket.CloseStatusDescription}";
}
else
{
string content = "<<binary>>";
if (frame.MessageType == WebSocketMessageType.Text)
{
content = Encoding.UTF8.GetString(buffer, 0, frame.Count);
}
message = $"{frame.MessageType}: Len={frame.Count}, Fin={frame.EndOfMessage}: {content}";
}
logger.LogDebug("Received Frame " + message);
}
}
{
"Logging": {
"LogLevel": {
"Default": "Debug",
"Microsoft.AspNetCore": "Trace"
}
}
}
......@@ -149,7 +149,7 @@ internal partial class HttpProtocol
bool IHttpUpgradeFeature.IsUpgradableRequest => IsUpgradableRequest;
bool IHttpExtendedConnectFeature.IsExtendedConnect => IsConnectRequest;
bool IHttpExtendedConnectFeature.IsExtendedConnect => IsExtendedConnectRequest;
string? IHttpExtendedConnectFeature.Protocol => ConnectProtocol;
......@@ -195,7 +195,7 @@ internal partial class HttpProtocol
set => AllowSynchronousIO = value;
}
bool IHttpMaxRequestBodySizeFeature.IsReadOnly => HasStartedConsumingRequestBody || IsUpgraded;
bool IHttpMaxRequestBodySizeFeature.IsReadOnly => HasStartedConsumingRequestBody || IsUpgraded || IsExtendedConnectRequest;
long? IHttpMaxRequestBodySizeFeature.MaxRequestBodySize
{
......@@ -290,12 +290,12 @@ internal partial class HttpProtocol
async ValueTask<Stream> IHttpExtendedConnectFeature.AcceptAsync()
{
if (!IsConnectRequest)
if (!IsExtendedConnectRequest)
{
throw new InvalidOperationException(CoreStrings.CannotAcceptNonConnectRequest);
}
if (IsUpgraded)
if (IsExtendedConnectAccepted)
{
throw new InvalidOperationException(CoreStrings.AcceptCannotBeCalledMultipleTimes);
}
......@@ -305,7 +305,7 @@ internal partial class HttpProtocol
throw new InvalidOperationException(CoreStrings.ConnectStatusMustBe200);
}
IsUpgraded = true;
IsExtendedConnectAccepted = true;
await FlushAsync();
......
......@@ -127,7 +127,8 @@ internal abstract partial class HttpProtocol : IHttpResponseControl
public bool IsUpgradableRequest { get; private set; }
public bool IsUpgraded { get; set; }
public bool IsConnectRequest { get; set; }
public bool IsExtendedConnectRequest { get; set; }
public bool IsExtendedConnectAccepted { get; set; }
public IPAddress? RemoteIpAddress { get; set; }
public int RemotePort { get; set; }
public IPAddress? LocalIpAddress { get; set; }
......@@ -359,6 +360,8 @@ internal abstract partial class HttpProtocol : IHttpResponseControl
_statusCode = StatusCodes.Status200OK;
_reasonPhrase = null;
IsUpgraded = false;
IsExtendedConnectRequest = false;
IsExtendedConnectAccepted = false;
var remoteEndPoint = RemoteEndPoint;
RemoteIpAddress = remoteEndPoint?.Address;
......
......@@ -37,6 +37,8 @@ internal abstract class MessageBody
public bool RequestUpgrade { get; protected set; }
public bool ExtendedConnect { get; protected set; }
public virtual bool IsEmpty => false;
protected KestrelTrace Log => _context.ServiceContext.Log;
......@@ -122,7 +124,7 @@ internal abstract class MessageBody
OnReadStarting();
_context.HasStartedConsumingRequestBody = true;
if (!RequestUpgrade)
if (!RequestUpgrade && !ExtendedConnect)
{
// Accessing TraceIdentifier will lazy-allocate a string ID.
// Don't access TraceIdentifer unless logging is enabled.
......@@ -150,7 +152,7 @@ internal abstract class MessageBody
_stopped = true;
if (!RequestUpgrade)
if (!RequestUpgrade && !ExtendedConnect)
{
// Accessing TraceIdentifier will lazy-allocate a string ID
// Don't access TraceIdentifer unless logging is enabled.
......
......@@ -19,6 +19,7 @@ internal sealed class Http2MessageBody : MessageBody
: base(context)
{
_context = context;
ExtendedConnect = _context.IsExtendedConnectRequest;
}
protected override void OnReadStarting()
......@@ -51,6 +52,7 @@ internal sealed class Http2MessageBody : MessageBody
{
base.Reset();
_readResult = default;
ExtendedConnect = _context.IsExtendedConnectRequest;
}
public override void AdvanceTo(SequencePosition consumed, SequencePosition examined)
......@@ -62,7 +64,12 @@ internal sealed class Http2MessageBody : MessageBody
// The HTTP/2 flow control window cannot be larger than 2^31-1 which limits bytesRead.
_context.OnDataRead((int)newlyExaminedBytes);
AddAndCheckObservedBytes(newlyExaminedBytes);
// Don't limit extended CONNECT requests to the MaxRequestBodySize.
if (!ExtendedConnect)
{
AddAndCheckObservedBytes(newlyExaminedBytes);
}
}
public override bool TryRead(out ReadResult readResult)
......
......@@ -239,7 +239,7 @@ internal abstract partial class Http2Stream : HttpProtocol, IThreadPoolWorkItem,
return false;
}
ConnectProtocol = HttpRequestHeaders.HeaderProtocol;
IsConnectRequest = true;
IsExtendedConnectRequest = true;
}
// CONNECT - :scheme and :path must be excluded
else if (!StringValues.IsNullOrEmpty(HttpRequestHeaders.HeaderScheme) || !StringValues.IsNullOrEmpty(HttpRequestHeaders.HeaderPath))
......
......@@ -11,9 +11,11 @@ using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Server.Kestrel.Core.Features;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http2;
using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
using Microsoft.AspNetCore.Testing;
using Microsoft.Extensions.Logging;
using Microsoft.Net.Http.Headers;
using Moq;
namespace Microsoft.AspNetCore.Server.Kestrel.Core.Tests;
......@@ -310,4 +312,130 @@ public class Http2WebSocketTests : Http2TestBase
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
}
[Fact]
public async Task ExtendedCONNECT_AcceptAsyncStream_IsNotLimitedByMinRequestBodyDataRate()
{
var limits = _serviceContext.ServerOptions.Limits;
// Use non-default value to ensure the min request and response rates aren't mixed up.
limits.MinRequestBodyDataRate = new MinDataRate(480, TimeSpan.FromSeconds(2.5));
await InitializeConnectionAsync(async context =>
{
var connectFeature = context.Features.Get<IHttpExtendedConnectFeature>();
var stream = await connectFeature.AcceptAsync();
Assert.Equal(0, await stream.ReadAsync(new byte[1]));
await stream.WriteAsync(new byte[] { 0x01 });
});
var headers = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "CONNECT"),
new KeyValuePair<string, string>(HeaderNames.Protocol, "websocket"),
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
new KeyValuePair<string, string>(HeaderNames.Path, "/chat"),
new KeyValuePair<string, string>(HeaderNames.Authority, "server.example.com"),
new KeyValuePair<string, string>(HeaderNames.WebSocketSubProtocols, "chat, superchat"),
new KeyValuePair<string, string>(HeaderNames.SecWebSocketExtensions, "permessage-deflate"),
new KeyValuePair<string, string>(HeaderNames.SecWebSocketVersion, "13"),
new KeyValuePair<string, string>(HeaderNames.Origin, "http://www.example.com"),
};
await SendHeadersAsync(1, Http2HeadersFrameFlags.END_HEADERS, headers);
// Don't send any more data and advance just to and then past the grace period.
AdvanceClock(limits.MinRequestBodyDataRate.GracePeriod + TimeSpan.FromTicks(1));
_mockTimeoutHandler.Verify(h => h.OnTimeout(It.IsAny<TimeoutReason>()), Times.Never);
await SendDataAsync(1, Array.Empty<byte>(), endStream: true);
var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
withLength: 32,
withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS,
withStreamId: 1);
_hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
Assert.Equal(2, _decodedHeaders.Count);
Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
var dataFrame = await ExpectAsync(Http2FrameType.DATA,
withLength: 1,
withFlags: (byte)Http2DataFrameFlags.NONE,
withStreamId: 1);
Assert.Equal(0x01, dataFrame.Payload.Span[0]);
dataFrame = await ExpectAsync(Http2FrameType.DATA,
withLength: 0,
withFlags: (byte)Http2DataFrameFlags.END_STREAM,
withStreamId: 1);
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
}
[Fact]
public async Task ExtendedCONNECT_AcceptAsyncStream_IsNotLimitedByMaxRequestBodySize()
{
var limits = _serviceContext.ServerOptions.Limits;
// We're going to send more than the MaxRequestBodySize bytes from the client to the server over the connection
// Since this is not a request body, this should be allowed like it would be for an upgraded connection.
limits.MaxRequestBodySize = 5;
await InitializeConnectionAsync(async context =>
{
var connectFeature = context.Features.Get<IHttpExtendedConnectFeature>();
var maxRequestBodySizeFeature = context.Features.Get<IHttpMaxRequestBodySizeFeature>();
// Extended connects don't have a meaningful request body to limit.
Assert.True(maxRequestBodySizeFeature.IsReadOnly);
var stream = await connectFeature.AcceptAsync();
Assert.True(maxRequestBodySizeFeature.IsReadOnly);
using var memoryStream = new MemoryStream();
await stream.CopyToAsync(memoryStream);
Assert.Equal(_serviceContext.ServerOptions.Limits.MaxRequestBodySize + 1, memoryStream.Length);
await stream.WriteAsync(new byte[] { 0x01 });
});
var headers = new[]
{
new KeyValuePair<string, string>(HeaderNames.Method, "CONNECT"),
new KeyValuePair<string, string>(HeaderNames.Protocol, "websocket"),
new KeyValuePair<string, string>(HeaderNames.Scheme, "http"),
new KeyValuePair<string, string>(HeaderNames.Path, "/chat"),
new KeyValuePair<string, string>(HeaderNames.Authority, "server.example.com"),
new KeyValuePair<string, string>(HeaderNames.WebSocketSubProtocols, "chat, superchat"),
new KeyValuePair<string, string>(HeaderNames.SecWebSocketExtensions, "permessage-deflate"),
new KeyValuePair<string, string>(HeaderNames.SecWebSocketVersion, "13"),
new KeyValuePair<string, string>(HeaderNames.Origin, "http://www.example.com"),
};
await SendHeadersAsync(1, Http2HeadersFrameFlags.END_HEADERS, headers);
await SendDataAsync(1, new byte[(int)limits.MaxRequestBodySize + 1], endStream: true);
var headersFrame = await ExpectAsync(Http2FrameType.HEADERS,
withLength: 32,
withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS,
withStreamId: 1);
_hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this);
Assert.Equal(2, _decodedHeaders.Count);
Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase);
Assert.Equal("200", _decodedHeaders[HeaderNames.Status]);
var dataFrame = await ExpectAsync(Http2FrameType.DATA,
withLength: 1,
withFlags: (byte)Http2DataFrameFlags.NONE,
withStreamId: 1);
Assert.Equal(0x01, dataFrame.Payload.Span[0]);
dataFrame = await ExpectAsync(Http2FrameType.DATA,
withLength: 0,
withFlags: (byte)Http2DataFrameFlags.END_STREAM,
withStreamId: 1);
await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false);
}
}
0% 加载中 .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册