Skip to content
代码片段 群组 项目
未验证 提交 d05c9f46 编辑于 作者: Alessio Franceschelli's avatar Alessio Franceschelli 提交者: GitHub
浏览文件

HeaderPropagation: reset AsyncLocal per request (#18300)

As Kestrel can bleed the AsyncLocal across requests,
see https://github.com/aspnet/AspNetCore/issues/13991.
上级 7fa6d196
No related branches found
No related tags found
无相关合并请求
...@@ -33,7 +33,8 @@ namespace Microsoft.AspNetCore.HeaderPropagation ...@@ -33,7 +33,8 @@ namespace Microsoft.AspNetCore.HeaderPropagation
_values = values ?? throw new ArgumentNullException(nameof(values)); _values = values ?? throw new ArgumentNullException(nameof(values));
} }
public Task Invoke(HttpContext context) // This needs to be async as otherwise the AsyncLocal could bleed across requests, see https://github.com/aspnet/AspNetCore/issues/13991.
public async Task Invoke(HttpContext context)
{ {
// We need to intialize the headers because the message handler will use this to detect misconfiguration. // We need to intialize the headers because the message handler will use this to detect misconfiguration.
var headers = _values.Headers ??= new Dictionary<string, StringValues>(StringComparer.OrdinalIgnoreCase); var headers = _values.Headers ??= new Dictionary<string, StringValues>(StringComparer.OrdinalIgnoreCase);
...@@ -56,7 +57,7 @@ namespace Microsoft.AspNetCore.HeaderPropagation ...@@ -56,7 +57,7 @@ namespace Microsoft.AspNetCore.HeaderPropagation
} }
} }
return _next.Invoke(context); await _next.Invoke(context);
} }
private static StringValues GetValue(HttpContext context, HeaderPropagationEntry entry) private static StringValues GetValue(HttpContext context, HeaderPropagationEntry entry)
......
// Copyright (c) .NET Foundation. All rights reserved. // Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Threading.Tasks; using System.Threading.Tasks;
using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Http;
using Microsoft.Extensions.Options; using Microsoft.Extensions.Options;
...@@ -14,7 +16,11 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests ...@@ -14,7 +16,11 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests
public HeaderPropagationMiddlewareTest() public HeaderPropagationMiddlewareTest()
{ {
Context = new DefaultHttpContext(); Context = new DefaultHttpContext();
Next = ctx => Task.CompletedTask; Next = ctx =>
{
CapturedHeaders = State.Headers;
return Task.CompletedTask;
};
Configuration = new HeaderPropagationOptions(); Configuration = new HeaderPropagationOptions();
State = new HeaderPropagationValues(); State = new HeaderPropagationValues();
Middleware = new HeaderPropagationMiddleware(Next, Middleware = new HeaderPropagationMiddleware(Next,
...@@ -24,8 +30,10 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests ...@@ -24,8 +30,10 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests
public DefaultHttpContext Context { get; set; } public DefaultHttpContext Context { get; set; }
public RequestDelegate Next { get; set; } public RequestDelegate Next { get; set; }
public Action Assertion { get; set; }
public HeaderPropagationOptions Configuration { get; set; } public HeaderPropagationOptions Configuration { get; set; }
public HeaderPropagationValues State { get; set; } public HeaderPropagationValues State { get; set; }
public IDictionary<string, StringValues> CapturedHeaders { get; set; }
public HeaderPropagationMiddleware Middleware { get; set; } public HeaderPropagationMiddleware Middleware { get; set; }
[Fact] [Fact]
...@@ -39,8 +47,8 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests ...@@ -39,8 +47,8 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests
await Middleware.Invoke(Context); await Middleware.Invoke(Context);
// Assert // Assert
Assert.Contains("in", State.Headers.Keys); Assert.Contains("in", CapturedHeaders.Keys);
Assert.Equal(new[] { "test" }, State.Headers["in"]); Assert.Equal(new[] { "test" }, CapturedHeaders["in"]);
} }
[Fact] [Fact]
...@@ -53,7 +61,7 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests ...@@ -53,7 +61,7 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests
await Middleware.Invoke(Context); await Middleware.Invoke(Context);
// Assert // Assert
Assert.Empty(State.Headers); Assert.Empty(CapturedHeaders);
} }
[Fact] [Fact]
...@@ -66,7 +74,7 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests ...@@ -66,7 +74,7 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests
await Middleware.Invoke(Context); await Middleware.Invoke(Context);
// Assert // Assert
Assert.Empty(State.Headers); Assert.Empty(CapturedHeaders);
} }
[Fact] [Fact]
...@@ -82,10 +90,10 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests ...@@ -82,10 +90,10 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests
await Middleware.Invoke(Context); await Middleware.Invoke(Context);
// Assert // Assert
Assert.Contains("in", State.Headers.Keys); Assert.Contains("in", CapturedHeaders.Keys);
Assert.Equal(new[] { "test" }, State.Headers["in"]); Assert.Equal(new[] { "test" }, CapturedHeaders["in"]);
Assert.Contains("another", State.Headers.Keys); Assert.Contains("another", CapturedHeaders.Keys);
Assert.Equal(new[] { "test2" }, State.Headers["another"]); Assert.Equal(new[] { "test2" }, CapturedHeaders["another"]);
} }
[Theory] [Theory]
...@@ -101,7 +109,7 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests ...@@ -101,7 +109,7 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests
await Middleware.Invoke(Context); await Middleware.Invoke(Context);
// Assert // Assert
Assert.DoesNotContain("in", State.Headers.Keys); Assert.DoesNotContain("in", CapturedHeaders.Keys);
} }
[Theory] [Theory]
...@@ -127,8 +135,8 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests ...@@ -127,8 +135,8 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests
await Middleware.Invoke(Context); await Middleware.Invoke(Context);
// Assert // Assert
Assert.Contains("in", State.Headers.Keys); Assert.Contains("in", CapturedHeaders.Keys);
Assert.Equal(expectedValues, State.Headers["in"]); Assert.Equal(expectedValues, CapturedHeaders["in"]);
Assert.Equal("in", receivedName); Assert.Equal("in", receivedName);
Assert.Equal(new StringValues("value"), receivedValue); Assert.Equal(new StringValues("value"), receivedValue);
Assert.Same(Context, receivedContext); Assert.Same(Context, receivedContext);
...@@ -145,8 +153,8 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests ...@@ -145,8 +153,8 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests
await Middleware.Invoke(Context); await Middleware.Invoke(Context);
// Assert // Assert
Assert.Contains("in", State.Headers.Keys); Assert.Contains("in", CapturedHeaders.Keys);
Assert.Equal("test", State.Headers["in"]); Assert.Equal("test", CapturedHeaders["in"]);
} }
[Fact] [Fact]
...@@ -159,7 +167,7 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests ...@@ -159,7 +167,7 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests
await Middleware.Invoke(Context); await Middleware.Invoke(Context);
// Assert // Assert
Assert.DoesNotContain("in", State.Headers.Keys); Assert.DoesNotContain("in", CapturedHeaders.Keys);
} }
[Fact] [Fact]
...@@ -174,8 +182,46 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests ...@@ -174,8 +182,46 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests
await Middleware.Invoke(Context); await Middleware.Invoke(Context);
// Assert // Assert
Assert.Contains("in", State.Headers.Keys); Assert.Contains("in", CapturedHeaders.Keys);
Assert.Equal("Test", State.Headers["in"]); Assert.Equal("Test", CapturedHeaders["in"]);
}
[Fact]
public async Task HeaderInRequest_WithBleedAsyncLocal_HasCorrectValue()
{
// Arrange
Configuration.Headers.Add("in");
// Process first request
Context.Request.Headers.Add("in", "dirty");
await Middleware.Invoke(Context);
// Process second request
Context = new DefaultHttpContext();
Context.Request.Headers.Add("in", "test");
await Middleware.Invoke(Context);
// Assert
Assert.Contains("in", CapturedHeaders.Keys);
Assert.Equal(new[] { "test" }, CapturedHeaders["in"]);
}
[Fact]
public async Task NoHeaderInRequest_WithBleedAsyncLocal_DoesNotHaveIt()
{
// Arrange
Configuration.Headers.Add("in");
// Process first request
Context.Request.Headers.Add("in", "dirty");
await Middleware.Invoke(Context);
// Process second request
Context = new DefaultHttpContext();
await Middleware.Invoke(Context);
// Assert
Assert.Empty(CapturedHeaders);
} }
} }
} }
0% 加载中 .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册