diff --git a/src/Middleware/HeaderPropagation/src/HeaderPropagationMiddleware.cs b/src/Middleware/HeaderPropagation/src/HeaderPropagationMiddleware.cs index f62c9e4a72bd65769ffe7c53dbcc8d89ba715b49..bd24fb63e2e5c3f8912d487121968159bcf36507 100644 --- a/src/Middleware/HeaderPropagation/src/HeaderPropagationMiddleware.cs +++ b/src/Middleware/HeaderPropagation/src/HeaderPropagationMiddleware.cs @@ -33,7 +33,8 @@ namespace Microsoft.AspNetCore.HeaderPropagation _values = values ?? throw new ArgumentNullException(nameof(values)); } - public Task Invoke(HttpContext context) + // This needs to be async as otherwise the AsyncLocal could bleed across requests, see https://github.com/aspnet/AspNetCore/issues/13991. + public async Task Invoke(HttpContext context) { // We need to intialize the headers because the message handler will use this to detect misconfiguration. var headers = _values.Headers ??= new Dictionary<string, StringValues>(StringComparer.OrdinalIgnoreCase); @@ -56,7 +57,7 @@ namespace Microsoft.AspNetCore.HeaderPropagation } } - return _next.Invoke(context); + await _next.Invoke(context); } private static StringValues GetValue(HttpContext context, HeaderPropagationEntry entry) diff --git a/src/Middleware/HeaderPropagation/test/HeaderPropagationMiddlewareTest.cs b/src/Middleware/HeaderPropagation/test/HeaderPropagationMiddlewareTest.cs index f6576d2d688dfcc3dc04023bb08cb99c47a91204..99bb4997a5ef3f3a680457b7b19fb059ea54f45c 100644 --- a/src/Middleware/HeaderPropagation/test/HeaderPropagationMiddlewareTest.cs +++ b/src/Middleware/HeaderPropagation/test/HeaderPropagationMiddlewareTest.cs @@ -1,6 +1,8 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System; +using System.Collections.Generic; using System.Threading.Tasks; using Microsoft.AspNetCore.Http; using Microsoft.Extensions.Options; @@ -14,7 +16,11 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests public HeaderPropagationMiddlewareTest() { Context = new DefaultHttpContext(); - Next = ctx => Task.CompletedTask; + Next = ctx => + { + CapturedHeaders = State.Headers; + return Task.CompletedTask; + }; Configuration = new HeaderPropagationOptions(); State = new HeaderPropagationValues(); Middleware = new HeaderPropagationMiddleware(Next, @@ -24,8 +30,10 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests public DefaultHttpContext Context { get; set; } public RequestDelegate Next { get; set; } + public Action Assertion { get; set; } public HeaderPropagationOptions Configuration { get; set; } public HeaderPropagationValues State { get; set; } + public IDictionary<string, StringValues> CapturedHeaders { get; set; } public HeaderPropagationMiddleware Middleware { get; set; } [Fact] @@ -39,8 +47,8 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests await Middleware.Invoke(Context); // Assert - Assert.Contains("in", State.Headers.Keys); - Assert.Equal(new[] { "test" }, State.Headers["in"]); + Assert.Contains("in", CapturedHeaders.Keys); + Assert.Equal(new[] { "test" }, CapturedHeaders["in"]); } [Fact] @@ -53,7 +61,7 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests await Middleware.Invoke(Context); // Assert - Assert.Empty(State.Headers); + Assert.Empty(CapturedHeaders); } [Fact] @@ -66,7 +74,7 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests await Middleware.Invoke(Context); // Assert - Assert.Empty(State.Headers); + Assert.Empty(CapturedHeaders); } [Fact] @@ -82,10 +90,10 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests await Middleware.Invoke(Context); // Assert - Assert.Contains("in", State.Headers.Keys); - Assert.Equal(new[] { "test" }, State.Headers["in"]); - Assert.Contains("another", State.Headers.Keys); - Assert.Equal(new[] { "test2" }, State.Headers["another"]); + Assert.Contains("in", CapturedHeaders.Keys); + Assert.Equal(new[] { "test" }, CapturedHeaders["in"]); + Assert.Contains("another", CapturedHeaders.Keys); + Assert.Equal(new[] { "test2" }, CapturedHeaders["another"]); } [Theory] @@ -101,7 +109,7 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests await Middleware.Invoke(Context); // Assert - Assert.DoesNotContain("in", State.Headers.Keys); + Assert.DoesNotContain("in", CapturedHeaders.Keys); } [Theory] @@ -127,8 +135,8 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests await Middleware.Invoke(Context); // Assert - Assert.Contains("in", State.Headers.Keys); - Assert.Equal(expectedValues, State.Headers["in"]); + Assert.Contains("in", CapturedHeaders.Keys); + Assert.Equal(expectedValues, CapturedHeaders["in"]); Assert.Equal("in", receivedName); Assert.Equal(new StringValues("value"), receivedValue); Assert.Same(Context, receivedContext); @@ -145,8 +153,8 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests await Middleware.Invoke(Context); // Assert - Assert.Contains("in", State.Headers.Keys); - Assert.Equal("test", State.Headers["in"]); + Assert.Contains("in", CapturedHeaders.Keys); + Assert.Equal("test", CapturedHeaders["in"]); } [Fact] @@ -159,7 +167,7 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests await Middleware.Invoke(Context); // Assert - Assert.DoesNotContain("in", State.Headers.Keys); + Assert.DoesNotContain("in", CapturedHeaders.Keys); } [Fact] @@ -174,8 +182,46 @@ namespace Microsoft.AspNetCore.HeaderPropagation.Tests await Middleware.Invoke(Context); // Assert - Assert.Contains("in", State.Headers.Keys); - Assert.Equal("Test", State.Headers["in"]); + Assert.Contains("in", CapturedHeaders.Keys); + Assert.Equal("Test", CapturedHeaders["in"]); + } + + [Fact] + public async Task HeaderInRequest_WithBleedAsyncLocal_HasCorrectValue() + { + // Arrange + Configuration.Headers.Add("in"); + + // Process first request + Context.Request.Headers.Add("in", "dirty"); + await Middleware.Invoke(Context); + + // Process second request + Context = new DefaultHttpContext(); + Context.Request.Headers.Add("in", "test"); + await Middleware.Invoke(Context); + + // Assert + Assert.Contains("in", CapturedHeaders.Keys); + Assert.Equal(new[] { "test" }, CapturedHeaders["in"]); + } + + [Fact] + public async Task NoHeaderInRequest_WithBleedAsyncLocal_DoesNotHaveIt() + { + // Arrange + Configuration.Headers.Add("in"); + + // Process first request + Context.Request.Headers.Add("in", "dirty"); + await Middleware.Invoke(Context); + + // Process second request + Context = new DefaultHttpContext(); + await Middleware.Invoke(Context); + + // Assert + Assert.Empty(CapturedHeaders); } } }