From 337217a655ba17ea7e11e3116ae2b242f1073f40 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 1 Sep 2021 14:49:17 -0700 Subject: [PATCH] [release/6.0] Pass request cancellation token when using IAsyncEnumerable in MVC and WriteAsJsonAsync (#35993) * Pass request cancellation token when using IAsyncEnumerable * more * fb * update * fixup assert Co-authored-by: Brennan <brecon@microsoft.com> --- .../src/HttpResponseJsonExtensions.cs | 40 ++++ .../test/HttpResponseJsonExtensionsTests.cs | 187 ++++++++++++++++++ .../SystemTextJsonOutputFormatter.cs | 8 +- .../Infrastructure/AsyncEnumerableReader.cs | 12 +- .../SystemTextJsonResultExecutor.cs | 19 +- .../SystemTextJsonOutputFormatterTest.cs | 49 +++++ .../AsyncEnumerableReaderTest.cs | 44 ++++- .../JsonResultExecutorTestBase.cs | 68 ++++++- .../SystemTextJsonResultExecutorTest.cs | 4 +- ...mlDataContractSerializerOutputFormatter.cs | 6 +- .../src/XmlSerializerOutputFormatter.cs | 6 +- ...taContractSerializerOutputFormatterTest.cs | 68 +++++++ .../test/XmlSerializerOutputFormatterTest.cs | 70 ++++++- .../src/NewtonsoftJsonOutputFormatter.cs | 15 +- .../src/NewtonsoftJsonResultExecutor.cs | 21 +- .../test/NewtonsoftJsonOutputFormatterTest.cs | 56 +++++- .../test/NewtonsoftJsonResultExecutorTest.cs | 34 ++++ 17 files changed, 658 insertions(+), 49 deletions(-) diff --git a/src/Http/Http.Extensions/src/HttpResponseJsonExtensions.cs b/src/Http/Http.Extensions/src/HttpResponseJsonExtensions.cs index c36740385e5..1336cc8f1b1 100644 --- a/src/Http/Http.Extensions/src/HttpResponseJsonExtensions.cs +++ b/src/Http/Http.Extensions/src/HttpResponseJsonExtensions.cs @@ -82,9 +82,28 @@ namespace Microsoft.AspNetCore.Http options ??= ResolveSerializerOptions(response.HttpContext); response.ContentType = contentType ?? JsonConstants.JsonContentTypeWithCharset; + // if no user provided token, pass the RequestAborted token and ignore OperationCanceledException + if (!cancellationToken.CanBeCanceled) + { + return WriteAsJsonAsyncSlow<TValue>(response.Body, value, options, response.HttpContext.RequestAborted); + } + return JsonSerializer.SerializeAsync<TValue>(response.Body, value, options, cancellationToken); } + private static async Task WriteAsJsonAsyncSlow<TValue>( + Stream body, + TValue value, + JsonSerializerOptions? options, + CancellationToken cancellationToken) + { + try + { + await JsonSerializer.SerializeAsync<TValue>(body, value, options, cancellationToken); + } + catch (OperationCanceledException) { } + } + /// <summary> /// Write the specified value as JSON to the response body. The response content-type will be set to /// <c>application/json; charset=utf-8</c>. @@ -157,9 +176,30 @@ namespace Microsoft.AspNetCore.Http options ??= ResolveSerializerOptions(response.HttpContext); response.ContentType = contentType ?? JsonConstants.JsonContentTypeWithCharset; + + // if no user provided token, pass the RequestAborted token and ignore OperationCanceledException + if (!cancellationToken.CanBeCanceled) + { + return WriteAsJsonAsyncSlow(response.Body, value, type, options, response.HttpContext.RequestAborted); + } + return JsonSerializer.SerializeAsync(response.Body, value, type, options, cancellationToken); } + private static async Task WriteAsJsonAsyncSlow( + Stream body, + object? value, + Type type, + JsonSerializerOptions? options, + CancellationToken cancellationToken) + { + try + { + await JsonSerializer.SerializeAsync(body, value, type, options, cancellationToken); + } + catch (OperationCanceledException) { } + } + private static JsonSerializerOptions ResolveSerializerOptions(HttpContext httpContext) { // Attempt to resolve options from DI then fallback to default options diff --git a/src/Http/Http.Extensions/test/HttpResponseJsonExtensionsTests.cs b/src/Http/Http.Extensions/test/HttpResponseJsonExtensionsTests.cs index 6571716ec2b..5addfa413bf 100644 --- a/src/Http/Http.Extensions/test/HttpResponseJsonExtensionsTests.cs +++ b/src/Http/Http.Extensions/test/HttpResponseJsonExtensionsTests.cs @@ -3,6 +3,7 @@ using System; using System.IO; +using System.Runtime.CompilerServices; using System.Text; using System.Text.Json; using System.Text.Json.Serialization; @@ -255,6 +256,192 @@ namespace Microsoft.AspNetCore.Http.Extensions.Tests Assert.Equal(StatusCodes.Status418ImATeapot, context.Response.StatusCode); } + [Fact] + public async Task WriteAsJsonAsyncGeneric_AsyncEnumerable() + { + // Arrange + var body = new MemoryStream(); + var context = new DefaultHttpContext(); + context.Response.Body = body; + + // Act + await context.Response.WriteAsJsonAsync(AsyncEnumerable()); + + // Assert + Assert.Equal(JsonConstants.JsonContentTypeWithCharset, context.Response.ContentType); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + Assert.Equal("[1,2]", Encoding.UTF8.GetString(body.ToArray())); + + async IAsyncEnumerable<int> AsyncEnumerable() + { + await Task.Yield(); + yield return 1; + yield return 2; + } + } + + [Fact] + public async Task WriteAsJsonAsync_AsyncEnumerable() + { + // Arrange + var body = new MemoryStream(); + var context = new DefaultHttpContext(); + context.Response.Body = body; + + // Act + await context.Response.WriteAsJsonAsync(AsyncEnumerable(), typeof(IAsyncEnumerable<int>)); + + // Assert + Assert.Equal(JsonConstants.JsonContentTypeWithCharset, context.Response.ContentType); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + Assert.Equal("[1,2]", Encoding.UTF8.GetString(body.ToArray())); + + async IAsyncEnumerable<int> AsyncEnumerable() + { + await Task.Yield(); + yield return 1; + yield return 2; + } + } + + [Fact] + public async Task WriteAsJsonAsyncGeneric_AsyncEnumerable_ClosedConnecton() + { + // Arrange + var cts = new CancellationTokenSource(); + var body = new MemoryStream(); + var context = new DefaultHttpContext(); + context.Response.Body = body; + context.RequestAborted = cts.Token; + var iterated = false; + + // Act + await context.Response.WriteAsJsonAsync(AsyncEnumerable()); + + // Assert + Assert.Equal(JsonConstants.JsonContentTypeWithCharset, context.Response.ContentType); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + // System.Text.Json might write the '[' before cancellation is observed + Assert.InRange(body.ToArray().Length, 0, 1); + Assert.False(iterated); + + async IAsyncEnumerable<int> AsyncEnumerable([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await Task.Yield(); + cts.Cancel(); + for (var i = 0; i < 100 && !cancellationToken.IsCancellationRequested; i++) + { + iterated = true; + yield return i; + } + } + } + + [Fact] + public async Task WriteAsJsonAsync_AsyncEnumerable_ClosedConnecton() + { + // Arrange + var cts = new CancellationTokenSource(); + var body = new MemoryStream(); + var context = new DefaultHttpContext(); + context.Response.Body = body; + context.RequestAborted = cts.Token; + var iterated = false; + + // Act + await context.Response.WriteAsJsonAsync(AsyncEnumerable(), typeof(IAsyncEnumerable<int>)); + + // Assert + Assert.Equal(JsonConstants.JsonContentTypeWithCharset, context.Response.ContentType); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + // System.Text.Json might write the '[' before cancellation is observed + Assert.InRange(body.ToArray().Length, 0, 1); + Assert.False(iterated); + + async IAsyncEnumerable<int> AsyncEnumerable([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await Task.Yield(); + cts.Cancel(); + for (var i = 0; i < 100 && !cancellationToken.IsCancellationRequested; i++) + { + iterated = true; + yield return i; + } + } + } + + [Fact] + public async Task WriteAsJsonAsync_AsyncEnumerable_UserPassedTokenThrows() + { + // Arrange + var body = new MemoryStream(); + var context = new DefaultHttpContext(); + context.Response.Body = body; + context.RequestAborted = new CancellationToken(canceled: true); + var cts = new CancellationTokenSource(); + var iterated = false; + + // Act + await Assert.ThrowsAnyAsync<OperationCanceledException>(() => context.Response.WriteAsJsonAsync(AsyncEnumerable(), typeof(IAsyncEnumerable<int>), cts.Token)); + + // Assert + Assert.Equal(JsonConstants.JsonContentTypeWithCharset, context.Response.ContentType); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + // System.Text.Json might write the '[' before cancellation is observed + Assert.InRange(body.ToArray().Length, 0, 1); + Assert.False(iterated); + + async IAsyncEnumerable<int> AsyncEnumerable([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await Task.Yield(); + cts.Cancel(); + for (var i = 0; i < 100 && !cancellationToken.IsCancellationRequested; i++) + { + iterated = true; + yield return i; + } + } + } + + [Fact] + public async Task WriteAsJsonAsyncGeneric_AsyncEnumerableG_UserPassedTokenThrows() + { + // Arrange + var body = new MemoryStream(); + var context = new DefaultHttpContext(); + context.Response.Body = body; + context.RequestAborted = new CancellationToken(canceled: true); + var cts = new CancellationTokenSource(); + var iterated = false; + + // Act + await Assert.ThrowsAnyAsync<OperationCanceledException>(() => context.Response.WriteAsJsonAsync(AsyncEnumerable(), cts.Token)); + + // Assert + Assert.Equal(JsonConstants.JsonContentTypeWithCharset, context.Response.ContentType); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + // System.Text.Json might write the '[' before cancellation is observed + Assert.InRange(body.ToArray().Length, 0, 1); + Assert.False(iterated); + + async IAsyncEnumerable<int> AsyncEnumerable([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await Task.Yield(); + cts.Cancel(); + for (var i = 0; i < 100 && !cancellationToken.IsCancellationRequested; i++) + { + iterated = true; + yield return i; + } + } + } + public class TestObject { public string? StringProperty { get; set; } diff --git a/src/Mvc/Mvc.Core/src/Formatters/SystemTextJsonOutputFormatter.cs b/src/Mvc/Mvc.Core/src/Formatters/SystemTextJsonOutputFormatter.cs index e372782b3b9..830998717d1 100644 --- a/src/Mvc/Mvc.Core/src/Formatters/SystemTextJsonOutputFormatter.cs +++ b/src/Mvc/Mvc.Core/src/Formatters/SystemTextJsonOutputFormatter.cs @@ -79,8 +79,12 @@ namespace Microsoft.AspNetCore.Mvc.Formatters var responseStream = httpContext.Response.Body; if (selectedEncoding.CodePage == Encoding.UTF8.CodePage) { - await JsonSerializer.SerializeAsync(responseStream, context.Object, objectType, SerializerOptions); - await responseStream.FlushAsync(); + try + { + await JsonSerializer.SerializeAsync(responseStream, context.Object, objectType, SerializerOptions, httpContext.RequestAborted); + await responseStream.FlushAsync(httpContext.RequestAborted); + } + catch (OperationCanceledException) { } } else { diff --git a/src/Mvc/Mvc.Core/src/Infrastructure/AsyncEnumerableReader.cs b/src/Mvc/Mvc.Core/src/Infrastructure/AsyncEnumerableReader.cs index 3f8f1431c73..4748c7f31f2 100644 --- a/src/Mvc/Mvc.Core/src/Infrastructure/AsyncEnumerableReader.cs +++ b/src/Mvc/Mvc.Core/src/Infrastructure/AsyncEnumerableReader.cs @@ -34,7 +34,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure nameof(ReadInternal), BindingFlags.NonPublic | BindingFlags.Instance)!; - private readonly ConcurrentDictionary<Type, Func<object, Task<ICollection>>?> _asyncEnumerableConverters = new(); + private readonly ConcurrentDictionary<Type, Func<object, CancellationToken, Task<ICollection>>?> _asyncEnumerableConverters = new(); private readonly MvcOptions _mvcOptions; /// <summary> @@ -52,7 +52,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure /// <param name="type">The type to read.</param> /// <param name="reader">A delegate that when awaited reads the <see cref="IAsyncEnumerable{T}"/>.</param> /// <returns><see langword="true" /> when <paramref name="type"/> is an instance of <see cref="IAsyncEnumerable{T}"/>, othwerise <see langword="false"/>.</returns> - public bool TryGetReader(Type type, [NotNullWhen(true)] out Func<object, Task<ICollection>>? reader) + public bool TryGetReader(Type type, [NotNullWhen(true)] out Func<object, CancellationToken, Task<ICollection>>? reader) { if (!_asyncEnumerableConverters.TryGetValue(type, out reader)) { @@ -67,9 +67,9 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure { var enumeratedObjectType = enumerableType.GetGenericArguments()[0]; - var converter = (Func<object, Task<ICollection>>)Converter + var converter = (Func<object, CancellationToken, Task<ICollection>>)Converter .MakeGenericMethod(enumeratedObjectType) - .CreateDelegate(typeof(Func<object, Task<ICollection>>), this); + .CreateDelegate(typeof(Func<object, CancellationToken, Task<ICollection>>), this); reader = converter; _asyncEnumerableConverters.TryAdd(type, reader); @@ -79,9 +79,9 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure return reader != null; } - private async Task<ICollection> ReadInternal<T>(object value) + private async Task<ICollection> ReadInternal<T>(object value, CancellationToken cancellationToken) { - var asyncEnumerable = (IAsyncEnumerable<T>)value; + var asyncEnumerable = ((IAsyncEnumerable<T>)value).WithCancellation(cancellationToken); var result = new List<T>(); var count = 0; diff --git a/src/Mvc/Mvc.Core/src/Infrastructure/SystemTextJsonResultExecutor.cs b/src/Mvc/Mvc.Core/src/Infrastructure/SystemTextJsonResultExecutor.cs index 4d440a2dec5..26c3de180eb 100644 --- a/src/Mvc/Mvc.Core/src/Infrastructure/SystemTextJsonResultExecutor.cs +++ b/src/Mvc/Mvc.Core/src/Infrastructure/SystemTextJsonResultExecutor.cs @@ -25,16 +25,13 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure private readonly JsonOptions _options; private readonly ILogger<SystemTextJsonResultExecutor> _logger; - private readonly AsyncEnumerableReader _asyncEnumerableReaderFactory; public SystemTextJsonResultExecutor( IOptions<JsonOptions> options, - ILogger<SystemTextJsonResultExecutor> logger, - IOptions<MvcOptions> mvcOptions) + ILogger<SystemTextJsonResultExecutor> logger) { _options = options.Value; _logger = logger; - _asyncEnumerableReaderFactory = new AsyncEnumerableReader(mvcOptions.Value); } public async Task ExecuteAsync(ActionContext context, JsonResult result) @@ -77,8 +74,12 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure var responseStream = response.Body; if (resolvedContentTypeEncoding.CodePage == Encoding.UTF8.CodePage) { - await JsonSerializer.SerializeAsync(responseStream, value, objectType, jsonSerializerOptions); - await responseStream.FlushAsync(); + try + { + await JsonSerializer.SerializeAsync(responseStream, value, objectType, jsonSerializerOptions, context.HttpContext.RequestAborted); + await responseStream.FlushAsync(context.HttpContext.RequestAborted); + } + catch (OperationCanceledException) { } } else { @@ -89,9 +90,11 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure ExceptionDispatchInfo? exceptionDispatchInfo = null; try { - await JsonSerializer.SerializeAsync(transcodingStream, value, objectType, jsonSerializerOptions); - await transcodingStream.FlushAsync(); + await JsonSerializer.SerializeAsync(transcodingStream, value, objectType, jsonSerializerOptions, context.HttpContext.RequestAborted); + await transcodingStream.FlushAsync(context.HttpContext.RequestAborted); } + catch (OperationCanceledException) + { } catch (Exception ex) { // TranscodingStream may write to the inner stream as part of it's disposal. diff --git a/src/Mvc/Mvc.Core/test/Formatters/SystemTextJsonOutputFormatterTest.cs b/src/Mvc/Mvc.Core/test/Formatters/SystemTextJsonOutputFormatterTest.cs index b9fa926628c..c178f43f8cb 100644 --- a/src/Mvc/Mvc.Core/test/Formatters/SystemTextJsonOutputFormatterTest.cs +++ b/src/Mvc/Mvc.Core/test/Formatters/SystemTextJsonOutputFormatterTest.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.IO; using System.Linq; +using System.Runtime.CompilerServices; using System.Text; using System.Text.Json; using System.Text.Json.Serialization; @@ -113,6 +114,54 @@ namespace Microsoft.AspNetCore.Mvc.Formatters Assert.Equal(expected.ToArray(), body.ToArray()); } + [Fact] + public async Task WriteResponseBodyAsync_AsyncEnumerableConnectionCloses() + { + // Arrange + var formatter = GetOutputFormatter(); + var mediaType = MediaTypeHeaderValue.Parse("application/json; charset=utf-8"); + + var body = new MemoryStream(); + var actionContext = GetActionContext(mediaType, body); + var cts = new CancellationTokenSource(); + actionContext.HttpContext.RequestAborted = cts.Token; + + var asyncEnumerable = AsyncEnumerableClosedConnection(); + var outputFormatterContext = new OutputFormatterWriteContext( + actionContext.HttpContext, + new TestHttpResponseStreamWriterFactory().CreateWriter, + asyncEnumerable.GetType(), + asyncEnumerable) + { + ContentType = new StringSegment(mediaType.ToString()), + }; + var iterated = false; + + // Act + await formatter.WriteResponseBodyAsync(outputFormatterContext, Encoding.GetEncoding("utf-8")); + + // Assert + // System.Text.Json might write the '[' before cancellation is observed + Assert.InRange(body.ToArray().Length, 0, 1); + Assert.False(iterated); + + async IAsyncEnumerable<int> AsyncEnumerableClosedConnection([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await Task.Yield(); + cts.Cancel(); + // MvcOptions.MaxIAsyncEnumerableBufferLimit is 8192. Pick some value larger than that. + foreach (var i in Enumerable.Range(0, 9000)) + { + if (cancellationToken.IsCancellationRequested) + { + yield break; + } + iterated = true; + yield return i; + } + } + } + private class Person { public string Name { get; set; } diff --git a/src/Mvc/Mvc.Core/test/Infrastructure/AsyncEnumerableReaderTest.cs b/src/Mvc/Mvc.Core/test/Infrastructure/AsyncEnumerableReaderTest.cs index eac659f55ad..3d5afc91f9e 100644 --- a/src/Mvc/Mvc.Core/test/Infrastructure/AsyncEnumerableReaderTest.cs +++ b/src/Mvc/Mvc.Core/test/Infrastructure/AsyncEnumerableReaderTest.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Globalization; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using Xunit; @@ -43,7 +44,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure // Assert Assert.True(result); - var readCollection = await reader(asyncEnumerable); + var readCollection = await reader(asyncEnumerable, default); var collection = Assert.IsAssignableFrom<ICollection<string>>(readCollection); Assert.Equal(new[] { "0", "1", "2", }, collection); } @@ -61,7 +62,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure // Assert Assert.True(result); - var readCollection = await reader(asyncEnumerable); + var readCollection = await reader(asyncEnumerable, default); var collection = Assert.IsAssignableFrom<ICollection<int>>(readCollection); Assert.Equal(new[] { 0, 1, 2, }, collection); } @@ -114,8 +115,8 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure Assert.True(readerFactory.TryGetReader(asyncEnumerable1.GetType(), out var reader)); // Assert - Assert.Equal(expected, await reader(asyncEnumerable1)); - Assert.Equal(expected, await reader(asyncEnumerable2)); + Assert.Equal(expected, await reader(asyncEnumerable1, default)); + Assert.Equal(expected, await reader(asyncEnumerable2, default)); } [Fact] @@ -131,8 +132,8 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure Assert.True(readerFactory.TryGetReader(asyncEnumerable1.GetType(), out var reader)); // Assert - Assert.Equal(new[] { "0", "1", "2" }, await reader(asyncEnumerable1)); - Assert.Equal(new[] { "0", "1", "2", "3" }, await reader(asyncEnumerable2)); + Assert.Equal(new[] { "0", "1", "2" }, await reader(asyncEnumerable1, default)); + Assert.Equal(new[] { "0", "1", "2", "3" }, await reader(asyncEnumerable2, default)); } [Fact] @@ -166,7 +167,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure // Assert Assert.True(result); - var readCollection = await reader(asyncEnumerable); + var readCollection = await reader(asyncEnumerable, default); var collection = Assert.IsAssignableFrom<ICollection<object>>(readCollection); Assert.Equal(new[] { "0", "1", "2", }, collection); } @@ -184,7 +185,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure // Act Assert.True(readerFactory.TryGetReader(enumerable.GetType(), out var reader)); - var ex = await Assert.ThrowsAsync<InvalidOperationException>(() => reader(enumerable)); + var ex = await Assert.ThrowsAsync<InvalidOperationException>(() => reader(enumerable, default)); // Assert Assert.Equal(expected, ex.Message); @@ -200,7 +201,32 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure // Act & Assert Assert.True(readerFactory.TryGetReader(enumerable.GetType(), out var reader)); - await Assert.ThrowsAsync<TimeZoneNotFoundException>(() => reader(enumerable)); + await Assert.ThrowsAsync<TimeZoneNotFoundException>(() => reader(enumerable, default)); + } + + [Fact] + public async Task Reader_PassesCancellationTokenToIAsyncEnumerable() + { + // Arrange + var enumerable = AsyncEnumerable(); + var options = new MvcOptions(); + CancellationToken token = default; + var readerFactory = new AsyncEnumerableReader(options); + var cts = new CancellationTokenSource(); + + // Act & Assert + Assert.True(readerFactory.TryGetReader(enumerable.GetType(), out var reader)); + await reader(enumerable, cts.Token); + + cts.Cancel(); + Assert.Equal(cts.Token, token); + + async IAsyncEnumerable<string> AsyncEnumerable([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + token = cancellationToken; + await Task.Yield(); + yield return string.Empty; + } } public static async IAsyncEnumerable<string> TestEnumerable(int count = 3) diff --git a/src/Mvc/Mvc.Core/test/Infrastructure/JsonResultExecutorTestBase.cs b/src/Mvc/Mvc.Core/test/Infrastructure/JsonResultExecutorTestBase.cs index 9e579244a36..10d6d1f1393 100644 --- a/src/Mvc/Mvc.Core/test/Infrastructure/JsonResultExecutorTestBase.cs +++ b/src/Mvc/Mvc.Core/test/Infrastructure/JsonResultExecutorTestBase.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.IO; using System.Linq; +using System.Runtime.CompilerServices; using System.Text; using System.Text.Json; using System.Threading; @@ -365,6 +366,71 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure Assert.Equal(expected, written); } + [Fact] + public async Task ExecuteAsync_AsyncEnumerableConnectionCloses() + { + var context = GetActionContext(); + var cts = new CancellationTokenSource(); + context.HttpContext.RequestAborted = cts.Token; + var result = new JsonResult(AsyncEnumerableClosedConnection()); + var executor = CreateExecutor(); + var iterated = false; + + // Act + await executor.ExecuteAsync(context, result); + + // Assert + var written = GetWrittenBytes(context.HttpContext); + // System.Text.Json might write the '[' before cancellation is observed + Assert.InRange(written.Length, 0, 1); + Assert.False(iterated); + + async IAsyncEnumerable<int> AsyncEnumerableClosedConnection([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await Task.Yield(); + cts.Cancel(); + for (var i = 0; i < 100000 && !cancellationToken.IsCancellationRequested; i++) + { + iterated = true; + yield return i; + } + } + } + + [Fact] + public async Task ExecuteAsyncWithDifferentContentType_AsyncEnumerableConnectionCloses() + { + var context = GetActionContext(); + var cts = new CancellationTokenSource(); + context.HttpContext.RequestAborted = cts.Token; + var result = new JsonResult(AsyncEnumerableClosedConnection()) + { + ContentType = "text/json; charset=utf-16", + }; + var executor = CreateExecutor(); + var iterated = false; + + // Act + await executor.ExecuteAsync(context, result); + + // Assert + var written = GetWrittenBytes(context.HttpContext); + // System.Text.Json might write the '[' before cancellation is observed (utf-16 means 2 bytes per character) + Assert.InRange(written.Length, 0, 2); + Assert.False(iterated); + + async IAsyncEnumerable<int> AsyncEnumerableClosedConnection([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await Task.Yield(); + cts.Cancel(); + for (var i = 0; i < 100000 && !cancellationToken.IsCancellationRequested; i++) + { + iterated = true; + yield return i; + } + } + } + protected IActionResultExecutor<JsonResult> CreateExecutor() => CreateExecutor(NullLoggerFactory.Instance); protected abstract IActionResultExecutor<JsonResult> CreateExecutor(ILoggerFactory loggerFactory); @@ -387,7 +453,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure return new ActionContext(GetHttpContext(), new RouteData(), new ActionDescriptor()); } - private static byte[] GetWrittenBytes(HttpContext context) + protected static byte[] GetWrittenBytes(HttpContext context) { context.Response.Body.Seek(0, SeekOrigin.Begin); return Assert.IsType<MemoryStream>(context.Response.Body).ToArray(); diff --git a/src/Mvc/Mvc.Core/test/Infrastructure/SystemTextJsonResultExecutorTest.cs b/src/Mvc/Mvc.Core/test/Infrastructure/SystemTextJsonResultExecutorTest.cs index cd5489fd46f..403ed61b166 100644 --- a/src/Mvc/Mvc.Core/test/Infrastructure/SystemTextJsonResultExecutorTest.cs +++ b/src/Mvc/Mvc.Core/test/Infrastructure/SystemTextJsonResultExecutorTest.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Runtime.CompilerServices; using System.Text; using System.Text.Json; using System.Text.Json.Serialization; @@ -18,8 +19,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure { return new SystemTextJsonResultExecutor( Options.Create(new JsonOptions()), - loggerFactory.CreateLogger<SystemTextJsonResultExecutor>(), - Options.Create(new MvcOptions())); + loggerFactory.CreateLogger<SystemTextJsonResultExecutor>()); } [Fact] diff --git a/src/Mvc/Mvc.Formatters.Xml/src/XmlDataContractSerializerOutputFormatter.cs b/src/Mvc/Mvc.Formatters.Xml/src/XmlDataContractSerializerOutputFormatter.cs index 429f76807c6..1ecec56f503 100644 --- a/src/Mvc/Mvc.Formatters.Xml/src/XmlDataContractSerializerOutputFormatter.cs +++ b/src/Mvc/Mvc.Formatters.Xml/src/XmlDataContractSerializerOutputFormatter.cs @@ -261,8 +261,12 @@ namespace Microsoft.AspNetCore.Mvc.Formatters { Log.BufferingAsyncEnumerable(_logger, value); - value = await reader(value); + value = await reader(value, context.HttpContext.RequestAborted); valueType = value.GetType(); + if (context.HttpContext.RequestAborted.IsCancellationRequested) + { + return; + } } Debug.Assert(valueType is not null); diff --git a/src/Mvc/Mvc.Formatters.Xml/src/XmlSerializerOutputFormatter.cs b/src/Mvc/Mvc.Formatters.Xml/src/XmlSerializerOutputFormatter.cs index a74b7f05844..eba98cb2a50 100644 --- a/src/Mvc/Mvc.Formatters.Xml/src/XmlSerializerOutputFormatter.cs +++ b/src/Mvc/Mvc.Formatters.Xml/src/XmlSerializerOutputFormatter.cs @@ -236,8 +236,12 @@ namespace Microsoft.AspNetCore.Mvc.Formatters { Log.BufferingAsyncEnumerable(_logger, value); - value = await reader(value); + value = await reader(value, context.HttpContext.RequestAborted); valueType = value.GetType(); + if (context.HttpContext.RequestAborted.IsCancellationRequested) + { + return; + } } // Wrap the object only if there is a wrapping type. diff --git a/src/Mvc/Mvc.Formatters.Xml/test/XmlDataContractSerializerOutputFormatterTest.cs b/src/Mvc/Mvc.Formatters.Xml/test/XmlDataContractSerializerOutputFormatterTest.cs index 7208978b08a..44a5a72152d 100644 --- a/src/Mvc/Mvc.Formatters.Xml/test/XmlDataContractSerializerOutputFormatterTest.cs +++ b/src/Mvc/Mvc.Formatters.Xml/test/XmlDataContractSerializerOutputFormatterTest.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.Globalization; using System.IO; +using System.Runtime.CompilerServices; using System.Runtime.Serialization; using System.Text; using System.Threading.Tasks; @@ -722,6 +723,73 @@ namespace Microsoft.AspNetCore.Mvc.Formatters.Xml Assert.Equal(expectedOutput, content); } + [Fact] + public async Task WriteResponseBodyAsync_AsyncEnumerableConnectionCloses() + { + // Arrange + var formatter = new XmlDataContractSerializerOutputFormatter(); + var body = new MemoryStream(); + var cts = new CancellationTokenSource(); + var iterated = false; + + var asyncEnumerable = AsyncEnumerableClosedConnection(); + var outputFormatterContext = GetOutputFormatterContext( + asyncEnumerable, + asyncEnumerable.GetType()); + outputFormatterContext.HttpContext.RequestAborted = cts.Token; + outputFormatterContext.HttpContext.Response.Body = body; + + // Act + await formatter.WriteResponseBodyAsync(outputFormatterContext, Encoding.GetEncoding("utf-8")); + + // Assert + Assert.Empty(body.ToArray()); + Assert.False(iterated); + + async IAsyncEnumerable<int> AsyncEnumerableClosedConnection([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await Task.Yield(); + cts.Cancel(); + // MvcOptions.MaxIAsyncEnumerableBufferLimit is 8192. Pick some value larger than that. + foreach (var i in Enumerable.Range(0, 9000)) + { + if (cancellationToken.IsCancellationRequested) + { + yield break; + } + iterated = true; + yield return i; + } + } + } + + [Fact] + public async Task WriteResponseBodyAsync_AsyncEnumerable() + { + // Arrange + var formatter = new XmlDataContractSerializerOutputFormatter(); + var body = new MemoryStream(); + + var asyncEnumerable = AsyncEnumerable(); + var outputFormatterContext = GetOutputFormatterContext( + asyncEnumerable, + asyncEnumerable.GetType()); + outputFormatterContext.HttpContext.Response.Body = body; + + // Act + await formatter.WriteResponseBodyAsync(outputFormatterContext, Encoding.GetEncoding("utf-8")); + + // Assert + Assert.Contains("<int>1</int><int>2</int>", Encoding.UTF8.GetString(body.ToArray())); + + async IAsyncEnumerable<int> AsyncEnumerable() + { + await Task.Yield(); + yield return 1; + yield return 2; + } + } + private OutputFormatterWriteContext GetOutputFormatterContext( object outputValue, Type outputType, diff --git a/src/Mvc/Mvc.Formatters.Xml/test/XmlSerializerOutputFormatterTest.cs b/src/Mvc/Mvc.Formatters.Xml/test/XmlSerializerOutputFormatterTest.cs index ef5f4523b7b..335c41d84c9 100644 --- a/src/Mvc/Mvc.Formatters.Xml/test/XmlSerializerOutputFormatterTest.cs +++ b/src/Mvc/Mvc.Formatters.Xml/test/XmlSerializerOutputFormatterTest.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.IO; using System.Linq; +using System.Runtime.CompilerServices; using System.Text; using System.Threading.Tasks; using System.Xml; @@ -505,6 +506,73 @@ namespace Microsoft.AspNetCore.Mvc.Formatters.Xml Assert.False(canWriteResult); } + [Fact] + public async Task WriteResponseBodyAsync_AsyncEnumerableConnectionCloses() + { + // Arrange + var formatter = new XmlSerializerOutputFormatter(); + var body = new MemoryStream(); + var cts = new CancellationTokenSource(); + var iterated = false; + + var asyncEnumerable = AsyncEnumerableClosedConnection(); + var outputFormatterContext = GetOutputFormatterContext( + asyncEnumerable, + asyncEnumerable.GetType()); + outputFormatterContext.HttpContext.RequestAborted = cts.Token; + outputFormatterContext.HttpContext.Response.Body = body; + + // Act + await formatter.WriteResponseBodyAsync(outputFormatterContext, Encoding.GetEncoding("utf-8")); + + // Assert + Assert.Empty(body.ToArray()); + Assert.False(iterated); + + async IAsyncEnumerable<int> AsyncEnumerableClosedConnection([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await Task.Yield(); + cts.Cancel(); + // MvcOptions.MaxIAsyncEnumerableBufferLimit is 8192. Pick some value larger than that. + foreach (var i in Enumerable.Range(0, 9000)) + { + if (cancellationToken.IsCancellationRequested) + { + yield break; + } + iterated = true; + yield return i; + } + } + } + + [Fact] + public async Task WriteResponseBodyAsync_AsyncEnumerable() + { + // Arrange + var formatter = new XmlSerializerOutputFormatter(); + var body = new MemoryStream(); + + var asyncEnumerable = AsyncEnumerable(); + var outputFormatterContext = GetOutputFormatterContext( + asyncEnumerable, + asyncEnumerable.GetType()); + outputFormatterContext.HttpContext.Response.Body = body; + + // Act + await formatter.WriteResponseBodyAsync(outputFormatterContext, Encoding.GetEncoding("utf-8")); + + // Assert + Assert.Contains("<int>1</int><int>2</int>", Encoding.UTF8.GetString(body.ToArray())); + + async IAsyncEnumerable<int> AsyncEnumerable() + { + await Task.Yield(); + yield return 1; + yield return 2; + } + } + private OutputFormatterWriteContext GetOutputFormatterContext( object outputValue, Type outputType, @@ -577,4 +645,4 @@ namespace Microsoft.AspNetCore.Mvc.Formatters.Xml } } } -} \ No newline at end of file +} diff --git a/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonOutputFormatter.cs b/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonOutputFormatter.cs index 56623da8a3c..e471cd36bad 100644 --- a/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonOutputFormatter.cs +++ b/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonOutputFormatter.cs @@ -26,7 +26,6 @@ namespace Microsoft.AspNetCore.Mvc.Formatters private MvcNewtonsoftJsonOptions? _jsonOptions; private readonly AsyncEnumerableReader _asyncEnumerableReaderFactory; private JsonSerializerSettings? _serializerSettings; - private ILogger? _logger; /// <summary> /// Initializes a new <see cref="NewtonsoftJsonOutputFormatter"/> instance. @@ -174,9 +173,17 @@ namespace Microsoft.AspNetCore.Mvc.Formatters var value = context.Object; if (value is not null && _asyncEnumerableReaderFactory.TryGetReader(value.GetType(), out var reader)) { - _logger ??= context.HttpContext.RequestServices.GetRequiredService<ILogger<NewtonsoftJsonOutputFormatter>>(); - Log.BufferingAsyncEnumerable(_logger, value); - value = await reader(value); + var logger = context.HttpContext.RequestServices.GetRequiredService<ILogger<NewtonsoftJsonOutputFormatter>>(); + Log.BufferingAsyncEnumerable(logger, value); + try + { + value = await reader(value, context.HttpContext.RequestAborted); + } + catch (OperationCanceledException) { } + if (context.HttpContext.RequestAborted.IsCancellationRequested) + { + return; + } } try diff --git a/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonResultExecutor.cs b/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonResultExecutor.cs index 5f7d28d185f..3ee953f8d1b 100644 --- a/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonResultExecutor.cs +++ b/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonResultExecutor.cs @@ -125,6 +125,21 @@ namespace Microsoft.AspNetCore.Mvc.NewtonsoftJson try { + var value = result.Value; + if (value != null && _asyncEnumerableReaderFactory.TryGetReader(value.GetType(), out var reader)) + { + Log.BufferingAsyncEnumerable(_logger, value); + try + { + value = await reader(value, context.HttpContext.RequestAborted); + } + catch (OperationCanceledException) { } + if (context.HttpContext.RequestAborted.IsCancellationRequested) + { + return; + } + } + await using (var writer = _writerFactory.CreateWriter(responseStream, resolvedContentTypeEncoding)) { using var jsonWriter = new JsonTextWriter(writer); @@ -133,12 +148,6 @@ namespace Microsoft.AspNetCore.Mvc.NewtonsoftJson jsonWriter.AutoCompleteOnClose = false; var jsonSerializer = JsonSerializer.Create(jsonSerializerSettings); - var value = result.Value; - if (value != null && _asyncEnumerableReaderFactory.TryGetReader(value.GetType(), out var reader)) - { - Log.BufferingAsyncEnumerable(_logger, value); - value = await reader(value); - } jsonSerializer.Serialize(jsonWriter, value); } diff --git a/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonOutputFormatterTest.cs b/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonOutputFormatterTest.cs index d0d1598bf67..116da8922ae 100644 --- a/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonOutputFormatterTest.cs +++ b/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonOutputFormatterTest.cs @@ -1,22 +1,17 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; using System.Buffers; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Reflection; +using System.Runtime.CompilerServices; using System.Text; -using System.Threading; -using System.Threading.Tasks; using Microsoft.AspNetCore.WebUtilities; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Primitives; +using Microsoft.Net.Http.Headers; using Moq; using Newtonsoft.Json; using Newtonsoft.Json.Linq; using Newtonsoft.Json.Serialization; -using Xunit; namespace Microsoft.AspNetCore.Mvc.Formatters { @@ -436,6 +431,51 @@ namespace Microsoft.AspNetCore.Mvc.Formatters } } + [Fact] + public async Task WriteResponseBodyAsync_AsyncEnumerableConnectionCloses() + { + // Arrange + var formatter = GetOutputFormatter(); + var mediaType = MediaTypeHeaderValue.Parse("application/json; charset=utf-8"); + + var body = new MemoryStream(); + var actionContext = GetActionContext(mediaType, body); + var cts = new CancellationTokenSource(); + actionContext.HttpContext.RequestAborted = cts.Token; + actionContext.HttpContext.RequestServices = new ServiceCollection().AddLogging().BuildServiceProvider(); + + var asyncEnumerable = AsyncEnumerableClosedConnection(); + var outputFormatterContext = new OutputFormatterWriteContext( + actionContext.HttpContext, + new TestHttpResponseStreamWriterFactory().CreateWriter, + asyncEnumerable.GetType(), + asyncEnumerable) + { + ContentType = new StringSegment(mediaType.ToString()), + }; + var iterated = false; + + // Act + await formatter.WriteResponseBodyAsync(outputFormatterContext, Encoding.GetEncoding("utf-8")); + + // Assert + Assert.Empty(body.ToArray()); + Assert.False(iterated); + + async IAsyncEnumerable<int> AsyncEnumerableClosedConnection([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await Task.Yield(); + cts.Cancel(); + // MvcOptions.MaxIAsyncEnumerableBufferLimit is 8192. Pick some value larger than that. + foreach (var i in Enumerable.Range(0, 9000)) + { + cancellationToken.ThrowIfCancellationRequested(); + iterated = true; + yield return i; + } + } + } + private class TestableJsonOutputFormatter : NewtonsoftJsonOutputFormatter { public TestableJsonOutputFormatter(JsonSerializerSettings serializerSettings) diff --git a/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonResultExecutorTest.cs b/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonResultExecutorTest.cs index ffa7d2f6f17..7d88e8597b2 100644 --- a/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonResultExecutorTest.cs +++ b/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonResultExecutorTest.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Buffers; +using System.Runtime.CompilerServices; +using System.Text; using Microsoft.AspNetCore.Mvc.Infrastructure; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; @@ -25,5 +27,37 @@ namespace Microsoft.AspNetCore.Mvc.NewtonsoftJson { return new JsonSerializerSettings { Formatting = Formatting.Indented }; } + + [Fact] + public async Task ExecuteAsync_AsyncEnumerableReceivesCancellationToken() + { + // Arrange + var expected = System.Text.Json.JsonSerializer.Serialize(new[] { "Hello", "world" }); + + var cts = new CancellationTokenSource(); + var context = GetActionContext(); + context.HttpContext.RequestAborted = cts.Token; + var result = new JsonResult(TestAsyncEnumerable()); + var executor = CreateExecutor(); + CancellationToken token = default; + + // Act + await executor.ExecuteAsync(context, result); + + // Assert + var written = GetWrittenBytes(context.HttpContext); + Assert.Equal(expected, Encoding.UTF8.GetString(written)); + + cts.Cancel(); + Assert.Equal(cts.Token, token); + + async IAsyncEnumerable<string> TestAsyncEnumerable([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await Task.Yield(); + token = cancellationToken; + yield return "Hello"; + yield return "world"; + } + } } } -- GitLab