diff --git a/src/Http/Http.Extensions/src/HttpResponseJsonExtensions.cs b/src/Http/Http.Extensions/src/HttpResponseJsonExtensions.cs index c36740385e55ad961266374860e095e689f2284e..1336cc8f1b1435977a3b7d43fc180537689b1ea8 100644 --- a/src/Http/Http.Extensions/src/HttpResponseJsonExtensions.cs +++ b/src/Http/Http.Extensions/src/HttpResponseJsonExtensions.cs @@ -82,9 +82,28 @@ namespace Microsoft.AspNetCore.Http options ??= ResolveSerializerOptions(response.HttpContext); response.ContentType = contentType ?? JsonConstants.JsonContentTypeWithCharset; + // if no user provided token, pass the RequestAborted token and ignore OperationCanceledException + if (!cancellationToken.CanBeCanceled) + { + return WriteAsJsonAsyncSlow<TValue>(response.Body, value, options, response.HttpContext.RequestAborted); + } + return JsonSerializer.SerializeAsync<TValue>(response.Body, value, options, cancellationToken); } + private static async Task WriteAsJsonAsyncSlow<TValue>( + Stream body, + TValue value, + JsonSerializerOptions? options, + CancellationToken cancellationToken) + { + try + { + await JsonSerializer.SerializeAsync<TValue>(body, value, options, cancellationToken); + } + catch (OperationCanceledException) { } + } + /// <summary> /// Write the specified value as JSON to the response body. The response content-type will be set to /// <c>application/json; charset=utf-8</c>. @@ -157,9 +176,30 @@ namespace Microsoft.AspNetCore.Http options ??= ResolveSerializerOptions(response.HttpContext); response.ContentType = contentType ?? JsonConstants.JsonContentTypeWithCharset; + + // if no user provided token, pass the RequestAborted token and ignore OperationCanceledException + if (!cancellationToken.CanBeCanceled) + { + return WriteAsJsonAsyncSlow(response.Body, value, type, options, response.HttpContext.RequestAborted); + } + return JsonSerializer.SerializeAsync(response.Body, value, type, options, cancellationToken); } + private static async Task WriteAsJsonAsyncSlow( + Stream body, + object? value, + Type type, + JsonSerializerOptions? options, + CancellationToken cancellationToken) + { + try + { + await JsonSerializer.SerializeAsync(body, value, type, options, cancellationToken); + } + catch (OperationCanceledException) { } + } + private static JsonSerializerOptions ResolveSerializerOptions(HttpContext httpContext) { // Attempt to resolve options from DI then fallback to default options diff --git a/src/Http/Http.Extensions/test/HttpResponseJsonExtensionsTests.cs b/src/Http/Http.Extensions/test/HttpResponseJsonExtensionsTests.cs index 6571716ec2bb0531a9b6bb86c371283cacec3bc9..5addfa413bfc8a06e414d9537fd5769ef0244203 100644 --- a/src/Http/Http.Extensions/test/HttpResponseJsonExtensionsTests.cs +++ b/src/Http/Http.Extensions/test/HttpResponseJsonExtensionsTests.cs @@ -3,6 +3,7 @@ using System; using System.IO; +using System.Runtime.CompilerServices; using System.Text; using System.Text.Json; using System.Text.Json.Serialization; @@ -255,6 +256,192 @@ namespace Microsoft.AspNetCore.Http.Extensions.Tests Assert.Equal(StatusCodes.Status418ImATeapot, context.Response.StatusCode); } + [Fact] + public async Task WriteAsJsonAsyncGeneric_AsyncEnumerable() + { + // Arrange + var body = new MemoryStream(); + var context = new DefaultHttpContext(); + context.Response.Body = body; + + // Act + await context.Response.WriteAsJsonAsync(AsyncEnumerable()); + + // Assert + Assert.Equal(JsonConstants.JsonContentTypeWithCharset, context.Response.ContentType); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + Assert.Equal("[1,2]", Encoding.UTF8.GetString(body.ToArray())); + + async IAsyncEnumerable<int> AsyncEnumerable() + { + await Task.Yield(); + yield return 1; + yield return 2; + } + } + + [Fact] + public async Task WriteAsJsonAsync_AsyncEnumerable() + { + // Arrange + var body = new MemoryStream(); + var context = new DefaultHttpContext(); + context.Response.Body = body; + + // Act + await context.Response.WriteAsJsonAsync(AsyncEnumerable(), typeof(IAsyncEnumerable<int>)); + + // Assert + Assert.Equal(JsonConstants.JsonContentTypeWithCharset, context.Response.ContentType); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + Assert.Equal("[1,2]", Encoding.UTF8.GetString(body.ToArray())); + + async IAsyncEnumerable<int> AsyncEnumerable() + { + await Task.Yield(); + yield return 1; + yield return 2; + } + } + + [Fact] + public async Task WriteAsJsonAsyncGeneric_AsyncEnumerable_ClosedConnecton() + { + // Arrange + var cts = new CancellationTokenSource(); + var body = new MemoryStream(); + var context = new DefaultHttpContext(); + context.Response.Body = body; + context.RequestAborted = cts.Token; + var iterated = false; + + // Act + await context.Response.WriteAsJsonAsync(AsyncEnumerable()); + + // Assert + Assert.Equal(JsonConstants.JsonContentTypeWithCharset, context.Response.ContentType); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + // System.Text.Json might write the '[' before cancellation is observed + Assert.InRange(body.ToArray().Length, 0, 1); + Assert.False(iterated); + + async IAsyncEnumerable<int> AsyncEnumerable([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await Task.Yield(); + cts.Cancel(); + for (var i = 0; i < 100 && !cancellationToken.IsCancellationRequested; i++) + { + iterated = true; + yield return i; + } + } + } + + [Fact] + public async Task WriteAsJsonAsync_AsyncEnumerable_ClosedConnecton() + { + // Arrange + var cts = new CancellationTokenSource(); + var body = new MemoryStream(); + var context = new DefaultHttpContext(); + context.Response.Body = body; + context.RequestAborted = cts.Token; + var iterated = false; + + // Act + await context.Response.WriteAsJsonAsync(AsyncEnumerable(), typeof(IAsyncEnumerable<int>)); + + // Assert + Assert.Equal(JsonConstants.JsonContentTypeWithCharset, context.Response.ContentType); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + // System.Text.Json might write the '[' before cancellation is observed + Assert.InRange(body.ToArray().Length, 0, 1); + Assert.False(iterated); + + async IAsyncEnumerable<int> AsyncEnumerable([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await Task.Yield(); + cts.Cancel(); + for (var i = 0; i < 100 && !cancellationToken.IsCancellationRequested; i++) + { + iterated = true; + yield return i; + } + } + } + + [Fact] + public async Task WriteAsJsonAsync_AsyncEnumerable_UserPassedTokenThrows() + { + // Arrange + var body = new MemoryStream(); + var context = new DefaultHttpContext(); + context.Response.Body = body; + context.RequestAborted = new CancellationToken(canceled: true); + var cts = new CancellationTokenSource(); + var iterated = false; + + // Act + await Assert.ThrowsAnyAsync<OperationCanceledException>(() => context.Response.WriteAsJsonAsync(AsyncEnumerable(), typeof(IAsyncEnumerable<int>), cts.Token)); + + // Assert + Assert.Equal(JsonConstants.JsonContentTypeWithCharset, context.Response.ContentType); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + // System.Text.Json might write the '[' before cancellation is observed + Assert.InRange(body.ToArray().Length, 0, 1); + Assert.False(iterated); + + async IAsyncEnumerable<int> AsyncEnumerable([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await Task.Yield(); + cts.Cancel(); + for (var i = 0; i < 100 && !cancellationToken.IsCancellationRequested; i++) + { + iterated = true; + yield return i; + } + } + } + + [Fact] + public async Task WriteAsJsonAsyncGeneric_AsyncEnumerableG_UserPassedTokenThrows() + { + // Arrange + var body = new MemoryStream(); + var context = new DefaultHttpContext(); + context.Response.Body = body; + context.RequestAborted = new CancellationToken(canceled: true); + var cts = new CancellationTokenSource(); + var iterated = false; + + // Act + await Assert.ThrowsAnyAsync<OperationCanceledException>(() => context.Response.WriteAsJsonAsync(AsyncEnumerable(), cts.Token)); + + // Assert + Assert.Equal(JsonConstants.JsonContentTypeWithCharset, context.Response.ContentType); + Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode); + + // System.Text.Json might write the '[' before cancellation is observed + Assert.InRange(body.ToArray().Length, 0, 1); + Assert.False(iterated); + + async IAsyncEnumerable<int> AsyncEnumerable([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await Task.Yield(); + cts.Cancel(); + for (var i = 0; i < 100 && !cancellationToken.IsCancellationRequested; i++) + { + iterated = true; + yield return i; + } + } + } + public class TestObject { public string? StringProperty { get; set; } diff --git a/src/Mvc/Mvc.Core/src/Formatters/SystemTextJsonOutputFormatter.cs b/src/Mvc/Mvc.Core/src/Formatters/SystemTextJsonOutputFormatter.cs index e372782b3b91ce72ec62f8e73d5d0f84b844be6c..830998717d10e1082b2cf8501b947d735174e892 100644 --- a/src/Mvc/Mvc.Core/src/Formatters/SystemTextJsonOutputFormatter.cs +++ b/src/Mvc/Mvc.Core/src/Formatters/SystemTextJsonOutputFormatter.cs @@ -79,8 +79,12 @@ namespace Microsoft.AspNetCore.Mvc.Formatters var responseStream = httpContext.Response.Body; if (selectedEncoding.CodePage == Encoding.UTF8.CodePage) { - await JsonSerializer.SerializeAsync(responseStream, context.Object, objectType, SerializerOptions); - await responseStream.FlushAsync(); + try + { + await JsonSerializer.SerializeAsync(responseStream, context.Object, objectType, SerializerOptions, httpContext.RequestAborted); + await responseStream.FlushAsync(httpContext.RequestAborted); + } + catch (OperationCanceledException) { } } else { diff --git a/src/Mvc/Mvc.Core/src/Infrastructure/AsyncEnumerableReader.cs b/src/Mvc/Mvc.Core/src/Infrastructure/AsyncEnumerableReader.cs index 3f8f1431c736db03f1e5548f5a79a903b9a5db10..4748c7f31f27fa0b6da70f85f9c9bb760628d627 100644 --- a/src/Mvc/Mvc.Core/src/Infrastructure/AsyncEnumerableReader.cs +++ b/src/Mvc/Mvc.Core/src/Infrastructure/AsyncEnumerableReader.cs @@ -34,7 +34,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure nameof(ReadInternal), BindingFlags.NonPublic | BindingFlags.Instance)!; - private readonly ConcurrentDictionary<Type, Func<object, Task<ICollection>>?> _asyncEnumerableConverters = new(); + private readonly ConcurrentDictionary<Type, Func<object, CancellationToken, Task<ICollection>>?> _asyncEnumerableConverters = new(); private readonly MvcOptions _mvcOptions; /// <summary> @@ -52,7 +52,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure /// <param name="type">The type to read.</param> /// <param name="reader">A delegate that when awaited reads the <see cref="IAsyncEnumerable{T}"/>.</param> /// <returns><see langword="true" /> when <paramref name="type"/> is an instance of <see cref="IAsyncEnumerable{T}"/>, othwerise <see langword="false"/>.</returns> - public bool TryGetReader(Type type, [NotNullWhen(true)] out Func<object, Task<ICollection>>? reader) + public bool TryGetReader(Type type, [NotNullWhen(true)] out Func<object, CancellationToken, Task<ICollection>>? reader) { if (!_asyncEnumerableConverters.TryGetValue(type, out reader)) { @@ -67,9 +67,9 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure { var enumeratedObjectType = enumerableType.GetGenericArguments()[0]; - var converter = (Func<object, Task<ICollection>>)Converter + var converter = (Func<object, CancellationToken, Task<ICollection>>)Converter .MakeGenericMethod(enumeratedObjectType) - .CreateDelegate(typeof(Func<object, Task<ICollection>>), this); + .CreateDelegate(typeof(Func<object, CancellationToken, Task<ICollection>>), this); reader = converter; _asyncEnumerableConverters.TryAdd(type, reader); @@ -79,9 +79,9 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure return reader != null; } - private async Task<ICollection> ReadInternal<T>(object value) + private async Task<ICollection> ReadInternal<T>(object value, CancellationToken cancellationToken) { - var asyncEnumerable = (IAsyncEnumerable<T>)value; + var asyncEnumerable = ((IAsyncEnumerable<T>)value).WithCancellation(cancellationToken); var result = new List<T>(); var count = 0; diff --git a/src/Mvc/Mvc.Core/src/Infrastructure/SystemTextJsonResultExecutor.cs b/src/Mvc/Mvc.Core/src/Infrastructure/SystemTextJsonResultExecutor.cs index 4d440a2dec56498600f4ad3b9b1ee6ec37c3ba20..26c3de180eb50c5d44891f67a7b0eded5f02287b 100644 --- a/src/Mvc/Mvc.Core/src/Infrastructure/SystemTextJsonResultExecutor.cs +++ b/src/Mvc/Mvc.Core/src/Infrastructure/SystemTextJsonResultExecutor.cs @@ -25,16 +25,13 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure private readonly JsonOptions _options; private readonly ILogger<SystemTextJsonResultExecutor> _logger; - private readonly AsyncEnumerableReader _asyncEnumerableReaderFactory; public SystemTextJsonResultExecutor( IOptions<JsonOptions> options, - ILogger<SystemTextJsonResultExecutor> logger, - IOptions<MvcOptions> mvcOptions) + ILogger<SystemTextJsonResultExecutor> logger) { _options = options.Value; _logger = logger; - _asyncEnumerableReaderFactory = new AsyncEnumerableReader(mvcOptions.Value); } public async Task ExecuteAsync(ActionContext context, JsonResult result) @@ -77,8 +74,12 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure var responseStream = response.Body; if (resolvedContentTypeEncoding.CodePage == Encoding.UTF8.CodePage) { - await JsonSerializer.SerializeAsync(responseStream, value, objectType, jsonSerializerOptions); - await responseStream.FlushAsync(); + try + { + await JsonSerializer.SerializeAsync(responseStream, value, objectType, jsonSerializerOptions, context.HttpContext.RequestAborted); + await responseStream.FlushAsync(context.HttpContext.RequestAborted); + } + catch (OperationCanceledException) { } } else { @@ -89,9 +90,11 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure ExceptionDispatchInfo? exceptionDispatchInfo = null; try { - await JsonSerializer.SerializeAsync(transcodingStream, value, objectType, jsonSerializerOptions); - await transcodingStream.FlushAsync(); + await JsonSerializer.SerializeAsync(transcodingStream, value, objectType, jsonSerializerOptions, context.HttpContext.RequestAborted); + await transcodingStream.FlushAsync(context.HttpContext.RequestAborted); } + catch (OperationCanceledException) + { } catch (Exception ex) { // TranscodingStream may write to the inner stream as part of it's disposal. diff --git a/src/Mvc/Mvc.Core/test/Formatters/SystemTextJsonOutputFormatterTest.cs b/src/Mvc/Mvc.Core/test/Formatters/SystemTextJsonOutputFormatterTest.cs index b9fa926628c7f651507a07492b31b82542f73212..c178f43f8cb6c016d9455d2dc90f16049e853f2b 100644 --- a/src/Mvc/Mvc.Core/test/Formatters/SystemTextJsonOutputFormatterTest.cs +++ b/src/Mvc/Mvc.Core/test/Formatters/SystemTextJsonOutputFormatterTest.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.IO; using System.Linq; +using System.Runtime.CompilerServices; using System.Text; using System.Text.Json; using System.Text.Json.Serialization; @@ -113,6 +114,54 @@ namespace Microsoft.AspNetCore.Mvc.Formatters Assert.Equal(expected.ToArray(), body.ToArray()); } + [Fact] + public async Task WriteResponseBodyAsync_AsyncEnumerableConnectionCloses() + { + // Arrange + var formatter = GetOutputFormatter(); + var mediaType = MediaTypeHeaderValue.Parse("application/json; charset=utf-8"); + + var body = new MemoryStream(); + var actionContext = GetActionContext(mediaType, body); + var cts = new CancellationTokenSource(); + actionContext.HttpContext.RequestAborted = cts.Token; + + var asyncEnumerable = AsyncEnumerableClosedConnection(); + var outputFormatterContext = new OutputFormatterWriteContext( + actionContext.HttpContext, + new TestHttpResponseStreamWriterFactory().CreateWriter, + asyncEnumerable.GetType(), + asyncEnumerable) + { + ContentType = new StringSegment(mediaType.ToString()), + }; + var iterated = false; + + // Act + await formatter.WriteResponseBodyAsync(outputFormatterContext, Encoding.GetEncoding("utf-8")); + + // Assert + // System.Text.Json might write the '[' before cancellation is observed + Assert.InRange(body.ToArray().Length, 0, 1); + Assert.False(iterated); + + async IAsyncEnumerable<int> AsyncEnumerableClosedConnection([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await Task.Yield(); + cts.Cancel(); + // MvcOptions.MaxIAsyncEnumerableBufferLimit is 8192. Pick some value larger than that. + foreach (var i in Enumerable.Range(0, 9000)) + { + if (cancellationToken.IsCancellationRequested) + { + yield break; + } + iterated = true; + yield return i; + } + } + } + private class Person { public string Name { get; set; } diff --git a/src/Mvc/Mvc.Core/test/Infrastructure/AsyncEnumerableReaderTest.cs b/src/Mvc/Mvc.Core/test/Infrastructure/AsyncEnumerableReaderTest.cs index eac659f55adff5912525b77df41c9cb0f3aa2664..3d5afc91f9eb61f828cbd13f6e8fd95c00a84fa9 100644 --- a/src/Mvc/Mvc.Core/test/Infrastructure/AsyncEnumerableReaderTest.cs +++ b/src/Mvc/Mvc.Core/test/Infrastructure/AsyncEnumerableReaderTest.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Generic; using System.Globalization; +using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; using Xunit; @@ -43,7 +44,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure // Assert Assert.True(result); - var readCollection = await reader(asyncEnumerable); + var readCollection = await reader(asyncEnumerable, default); var collection = Assert.IsAssignableFrom<ICollection<string>>(readCollection); Assert.Equal(new[] { "0", "1", "2", }, collection); } @@ -61,7 +62,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure // Assert Assert.True(result); - var readCollection = await reader(asyncEnumerable); + var readCollection = await reader(asyncEnumerable, default); var collection = Assert.IsAssignableFrom<ICollection<int>>(readCollection); Assert.Equal(new[] { 0, 1, 2, }, collection); } @@ -114,8 +115,8 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure Assert.True(readerFactory.TryGetReader(asyncEnumerable1.GetType(), out var reader)); // Assert - Assert.Equal(expected, await reader(asyncEnumerable1)); - Assert.Equal(expected, await reader(asyncEnumerable2)); + Assert.Equal(expected, await reader(asyncEnumerable1, default)); + Assert.Equal(expected, await reader(asyncEnumerable2, default)); } [Fact] @@ -131,8 +132,8 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure Assert.True(readerFactory.TryGetReader(asyncEnumerable1.GetType(), out var reader)); // Assert - Assert.Equal(new[] { "0", "1", "2" }, await reader(asyncEnumerable1)); - Assert.Equal(new[] { "0", "1", "2", "3" }, await reader(asyncEnumerable2)); + Assert.Equal(new[] { "0", "1", "2" }, await reader(asyncEnumerable1, default)); + Assert.Equal(new[] { "0", "1", "2", "3" }, await reader(asyncEnumerable2, default)); } [Fact] @@ -166,7 +167,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure // Assert Assert.True(result); - var readCollection = await reader(asyncEnumerable); + var readCollection = await reader(asyncEnumerable, default); var collection = Assert.IsAssignableFrom<ICollection<object>>(readCollection); Assert.Equal(new[] { "0", "1", "2", }, collection); } @@ -184,7 +185,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure // Act Assert.True(readerFactory.TryGetReader(enumerable.GetType(), out var reader)); - var ex = await Assert.ThrowsAsync<InvalidOperationException>(() => reader(enumerable)); + var ex = await Assert.ThrowsAsync<InvalidOperationException>(() => reader(enumerable, default)); // Assert Assert.Equal(expected, ex.Message); @@ -200,7 +201,32 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure // Act & Assert Assert.True(readerFactory.TryGetReader(enumerable.GetType(), out var reader)); - await Assert.ThrowsAsync<TimeZoneNotFoundException>(() => reader(enumerable)); + await Assert.ThrowsAsync<TimeZoneNotFoundException>(() => reader(enumerable, default)); + } + + [Fact] + public async Task Reader_PassesCancellationTokenToIAsyncEnumerable() + { + // Arrange + var enumerable = AsyncEnumerable(); + var options = new MvcOptions(); + CancellationToken token = default; + var readerFactory = new AsyncEnumerableReader(options); + var cts = new CancellationTokenSource(); + + // Act & Assert + Assert.True(readerFactory.TryGetReader(enumerable.GetType(), out var reader)); + await reader(enumerable, cts.Token); + + cts.Cancel(); + Assert.Equal(cts.Token, token); + + async IAsyncEnumerable<string> AsyncEnumerable([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + token = cancellationToken; + await Task.Yield(); + yield return string.Empty; + } } public static async IAsyncEnumerable<string> TestEnumerable(int count = 3) diff --git a/src/Mvc/Mvc.Core/test/Infrastructure/JsonResultExecutorTestBase.cs b/src/Mvc/Mvc.Core/test/Infrastructure/JsonResultExecutorTestBase.cs index 9e579244a36ed6d346afd7bcb094b0b6be2fa952..10d6d1f139383ccc9357026ca44c298e7d906876 100644 --- a/src/Mvc/Mvc.Core/test/Infrastructure/JsonResultExecutorTestBase.cs +++ b/src/Mvc/Mvc.Core/test/Infrastructure/JsonResultExecutorTestBase.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.IO; using System.Linq; +using System.Runtime.CompilerServices; using System.Text; using System.Text.Json; using System.Threading; @@ -365,6 +366,71 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure Assert.Equal(expected, written); } + [Fact] + public async Task ExecuteAsync_AsyncEnumerableConnectionCloses() + { + var context = GetActionContext(); + var cts = new CancellationTokenSource(); + context.HttpContext.RequestAborted = cts.Token; + var result = new JsonResult(AsyncEnumerableClosedConnection()); + var executor = CreateExecutor(); + var iterated = false; + + // Act + await executor.ExecuteAsync(context, result); + + // Assert + var written = GetWrittenBytes(context.HttpContext); + // System.Text.Json might write the '[' before cancellation is observed + Assert.InRange(written.Length, 0, 1); + Assert.False(iterated); + + async IAsyncEnumerable<int> AsyncEnumerableClosedConnection([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await Task.Yield(); + cts.Cancel(); + for (var i = 0; i < 100000 && !cancellationToken.IsCancellationRequested; i++) + { + iterated = true; + yield return i; + } + } + } + + [Fact] + public async Task ExecuteAsyncWithDifferentContentType_AsyncEnumerableConnectionCloses() + { + var context = GetActionContext(); + var cts = new CancellationTokenSource(); + context.HttpContext.RequestAborted = cts.Token; + var result = new JsonResult(AsyncEnumerableClosedConnection()) + { + ContentType = "text/json; charset=utf-16", + }; + var executor = CreateExecutor(); + var iterated = false; + + // Act + await executor.ExecuteAsync(context, result); + + // Assert + var written = GetWrittenBytes(context.HttpContext); + // System.Text.Json might write the '[' before cancellation is observed (utf-16 means 2 bytes per character) + Assert.InRange(written.Length, 0, 2); + Assert.False(iterated); + + async IAsyncEnumerable<int> AsyncEnumerableClosedConnection([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await Task.Yield(); + cts.Cancel(); + for (var i = 0; i < 100000 && !cancellationToken.IsCancellationRequested; i++) + { + iterated = true; + yield return i; + } + } + } + protected IActionResultExecutor<JsonResult> CreateExecutor() => CreateExecutor(NullLoggerFactory.Instance); protected abstract IActionResultExecutor<JsonResult> CreateExecutor(ILoggerFactory loggerFactory); @@ -387,7 +453,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure return new ActionContext(GetHttpContext(), new RouteData(), new ActionDescriptor()); } - private static byte[] GetWrittenBytes(HttpContext context) + protected static byte[] GetWrittenBytes(HttpContext context) { context.Response.Body.Seek(0, SeekOrigin.Begin); return Assert.IsType<MemoryStream>(context.Response.Body).ToArray(); diff --git a/src/Mvc/Mvc.Core/test/Infrastructure/SystemTextJsonResultExecutorTest.cs b/src/Mvc/Mvc.Core/test/Infrastructure/SystemTextJsonResultExecutorTest.cs index cd5489fd46f01ca4f642bc644e428fd7754239b3..403ed61b1662ad90570bf70b2bfdbcdb4ee062ab 100644 --- a/src/Mvc/Mvc.Core/test/Infrastructure/SystemTextJsonResultExecutorTest.cs +++ b/src/Mvc/Mvc.Core/test/Infrastructure/SystemTextJsonResultExecutorTest.cs @@ -2,6 +2,7 @@ // The .NET Foundation licenses this file to you under the MIT license. using System; +using System.Runtime.CompilerServices; using System.Text; using System.Text.Json; using System.Text.Json.Serialization; @@ -18,8 +19,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure { return new SystemTextJsonResultExecutor( Options.Create(new JsonOptions()), - loggerFactory.CreateLogger<SystemTextJsonResultExecutor>(), - Options.Create(new MvcOptions())); + loggerFactory.CreateLogger<SystemTextJsonResultExecutor>()); } [Fact] diff --git a/src/Mvc/Mvc.Formatters.Xml/src/XmlDataContractSerializerOutputFormatter.cs b/src/Mvc/Mvc.Formatters.Xml/src/XmlDataContractSerializerOutputFormatter.cs index 429f76807c681e549639f4e50f71de31336ca5ca..1ecec56f503c1f13939421c2b193e8cc47359055 100644 --- a/src/Mvc/Mvc.Formatters.Xml/src/XmlDataContractSerializerOutputFormatter.cs +++ b/src/Mvc/Mvc.Formatters.Xml/src/XmlDataContractSerializerOutputFormatter.cs @@ -261,8 +261,12 @@ namespace Microsoft.AspNetCore.Mvc.Formatters { Log.BufferingAsyncEnumerable(_logger, value); - value = await reader(value); + value = await reader(value, context.HttpContext.RequestAborted); valueType = value.GetType(); + if (context.HttpContext.RequestAborted.IsCancellationRequested) + { + return; + } } Debug.Assert(valueType is not null); diff --git a/src/Mvc/Mvc.Formatters.Xml/src/XmlSerializerOutputFormatter.cs b/src/Mvc/Mvc.Formatters.Xml/src/XmlSerializerOutputFormatter.cs index a74b7f0584492f7b0dbe297d86ca810e297f02a8..eba98cb2a50b3126e62d5057bae833b2cd0e1432 100644 --- a/src/Mvc/Mvc.Formatters.Xml/src/XmlSerializerOutputFormatter.cs +++ b/src/Mvc/Mvc.Formatters.Xml/src/XmlSerializerOutputFormatter.cs @@ -236,8 +236,12 @@ namespace Microsoft.AspNetCore.Mvc.Formatters { Log.BufferingAsyncEnumerable(_logger, value); - value = await reader(value); + value = await reader(value, context.HttpContext.RequestAborted); valueType = value.GetType(); + if (context.HttpContext.RequestAborted.IsCancellationRequested) + { + return; + } } // Wrap the object only if there is a wrapping type. diff --git a/src/Mvc/Mvc.Formatters.Xml/test/XmlDataContractSerializerOutputFormatterTest.cs b/src/Mvc/Mvc.Formatters.Xml/test/XmlDataContractSerializerOutputFormatterTest.cs index 7208978b08a3a0b05ab24dbdc1aceb88cc2d97f0..44a5a72152d7bfe4e218408f5084e4f21938300a 100644 --- a/src/Mvc/Mvc.Formatters.Xml/test/XmlDataContractSerializerOutputFormatterTest.cs +++ b/src/Mvc/Mvc.Formatters.Xml/test/XmlDataContractSerializerOutputFormatterTest.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.Globalization; using System.IO; +using System.Runtime.CompilerServices; using System.Runtime.Serialization; using System.Text; using System.Threading.Tasks; @@ -722,6 +723,73 @@ namespace Microsoft.AspNetCore.Mvc.Formatters.Xml Assert.Equal(expectedOutput, content); } + [Fact] + public async Task WriteResponseBodyAsync_AsyncEnumerableConnectionCloses() + { + // Arrange + var formatter = new XmlDataContractSerializerOutputFormatter(); + var body = new MemoryStream(); + var cts = new CancellationTokenSource(); + var iterated = false; + + var asyncEnumerable = AsyncEnumerableClosedConnection(); + var outputFormatterContext = GetOutputFormatterContext( + asyncEnumerable, + asyncEnumerable.GetType()); + outputFormatterContext.HttpContext.RequestAborted = cts.Token; + outputFormatterContext.HttpContext.Response.Body = body; + + // Act + await formatter.WriteResponseBodyAsync(outputFormatterContext, Encoding.GetEncoding("utf-8")); + + // Assert + Assert.Empty(body.ToArray()); + Assert.False(iterated); + + async IAsyncEnumerable<int> AsyncEnumerableClosedConnection([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await Task.Yield(); + cts.Cancel(); + // MvcOptions.MaxIAsyncEnumerableBufferLimit is 8192. Pick some value larger than that. + foreach (var i in Enumerable.Range(0, 9000)) + { + if (cancellationToken.IsCancellationRequested) + { + yield break; + } + iterated = true; + yield return i; + } + } + } + + [Fact] + public async Task WriteResponseBodyAsync_AsyncEnumerable() + { + // Arrange + var formatter = new XmlDataContractSerializerOutputFormatter(); + var body = new MemoryStream(); + + var asyncEnumerable = AsyncEnumerable(); + var outputFormatterContext = GetOutputFormatterContext( + asyncEnumerable, + asyncEnumerable.GetType()); + outputFormatterContext.HttpContext.Response.Body = body; + + // Act + await formatter.WriteResponseBodyAsync(outputFormatterContext, Encoding.GetEncoding("utf-8")); + + // Assert + Assert.Contains("<int>1</int><int>2</int>", Encoding.UTF8.GetString(body.ToArray())); + + async IAsyncEnumerable<int> AsyncEnumerable() + { + await Task.Yield(); + yield return 1; + yield return 2; + } + } + private OutputFormatterWriteContext GetOutputFormatterContext( object outputValue, Type outputType, diff --git a/src/Mvc/Mvc.Formatters.Xml/test/XmlSerializerOutputFormatterTest.cs b/src/Mvc/Mvc.Formatters.Xml/test/XmlSerializerOutputFormatterTest.cs index ef5f4523b7b3063146287574c84840003a6e007a..335c41d84c9050a5ebae2662fe1e6830781c61e9 100644 --- a/src/Mvc/Mvc.Formatters.Xml/test/XmlSerializerOutputFormatterTest.cs +++ b/src/Mvc/Mvc.Formatters.Xml/test/XmlSerializerOutputFormatterTest.cs @@ -5,6 +5,7 @@ using System; using System.Collections.Generic; using System.IO; using System.Linq; +using System.Runtime.CompilerServices; using System.Text; using System.Threading.Tasks; using System.Xml; @@ -505,6 +506,73 @@ namespace Microsoft.AspNetCore.Mvc.Formatters.Xml Assert.False(canWriteResult); } + [Fact] + public async Task WriteResponseBodyAsync_AsyncEnumerableConnectionCloses() + { + // Arrange + var formatter = new XmlSerializerOutputFormatter(); + var body = new MemoryStream(); + var cts = new CancellationTokenSource(); + var iterated = false; + + var asyncEnumerable = AsyncEnumerableClosedConnection(); + var outputFormatterContext = GetOutputFormatterContext( + asyncEnumerable, + asyncEnumerable.GetType()); + outputFormatterContext.HttpContext.RequestAborted = cts.Token; + outputFormatterContext.HttpContext.Response.Body = body; + + // Act + await formatter.WriteResponseBodyAsync(outputFormatterContext, Encoding.GetEncoding("utf-8")); + + // Assert + Assert.Empty(body.ToArray()); + Assert.False(iterated); + + async IAsyncEnumerable<int> AsyncEnumerableClosedConnection([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await Task.Yield(); + cts.Cancel(); + // MvcOptions.MaxIAsyncEnumerableBufferLimit is 8192. Pick some value larger than that. + foreach (var i in Enumerable.Range(0, 9000)) + { + if (cancellationToken.IsCancellationRequested) + { + yield break; + } + iterated = true; + yield return i; + } + } + } + + [Fact] + public async Task WriteResponseBodyAsync_AsyncEnumerable() + { + // Arrange + var formatter = new XmlSerializerOutputFormatter(); + var body = new MemoryStream(); + + var asyncEnumerable = AsyncEnumerable(); + var outputFormatterContext = GetOutputFormatterContext( + asyncEnumerable, + asyncEnumerable.GetType()); + outputFormatterContext.HttpContext.Response.Body = body; + + // Act + await formatter.WriteResponseBodyAsync(outputFormatterContext, Encoding.GetEncoding("utf-8")); + + // Assert + Assert.Contains("<int>1</int><int>2</int>", Encoding.UTF8.GetString(body.ToArray())); + + async IAsyncEnumerable<int> AsyncEnumerable() + { + await Task.Yield(); + yield return 1; + yield return 2; + } + } + private OutputFormatterWriteContext GetOutputFormatterContext( object outputValue, Type outputType, @@ -577,4 +645,4 @@ namespace Microsoft.AspNetCore.Mvc.Formatters.Xml } } } -} \ No newline at end of file +} diff --git a/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonOutputFormatter.cs b/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonOutputFormatter.cs index 56623da8a3cb248f70079fc4f2ecbd60193af0f8..e471cd36bad00bfa5b918e6d768fae167ccd843f 100644 --- a/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonOutputFormatter.cs +++ b/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonOutputFormatter.cs @@ -26,7 +26,6 @@ namespace Microsoft.AspNetCore.Mvc.Formatters private MvcNewtonsoftJsonOptions? _jsonOptions; private readonly AsyncEnumerableReader _asyncEnumerableReaderFactory; private JsonSerializerSettings? _serializerSettings; - private ILogger? _logger; /// <summary> /// Initializes a new <see cref="NewtonsoftJsonOutputFormatter"/> instance. @@ -174,9 +173,17 @@ namespace Microsoft.AspNetCore.Mvc.Formatters var value = context.Object; if (value is not null && _asyncEnumerableReaderFactory.TryGetReader(value.GetType(), out var reader)) { - _logger ??= context.HttpContext.RequestServices.GetRequiredService<ILogger<NewtonsoftJsonOutputFormatter>>(); - Log.BufferingAsyncEnumerable(_logger, value); - value = await reader(value); + var logger = context.HttpContext.RequestServices.GetRequiredService<ILogger<NewtonsoftJsonOutputFormatter>>(); + Log.BufferingAsyncEnumerable(logger, value); + try + { + value = await reader(value, context.HttpContext.RequestAborted); + } + catch (OperationCanceledException) { } + if (context.HttpContext.RequestAborted.IsCancellationRequested) + { + return; + } } try diff --git a/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonResultExecutor.cs b/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonResultExecutor.cs index 5f7d28d185ff18074ff4e40903d8345404573edb..3ee953f8d1b8bd4e8293768a8ff0b6ac13cb2a02 100644 --- a/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonResultExecutor.cs +++ b/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonResultExecutor.cs @@ -125,6 +125,21 @@ namespace Microsoft.AspNetCore.Mvc.NewtonsoftJson try { + var value = result.Value; + if (value != null && _asyncEnumerableReaderFactory.TryGetReader(value.GetType(), out var reader)) + { + Log.BufferingAsyncEnumerable(_logger, value); + try + { + value = await reader(value, context.HttpContext.RequestAborted); + } + catch (OperationCanceledException) { } + if (context.HttpContext.RequestAborted.IsCancellationRequested) + { + return; + } + } + await using (var writer = _writerFactory.CreateWriter(responseStream, resolvedContentTypeEncoding)) { using var jsonWriter = new JsonTextWriter(writer); @@ -133,12 +148,6 @@ namespace Microsoft.AspNetCore.Mvc.NewtonsoftJson jsonWriter.AutoCompleteOnClose = false; var jsonSerializer = JsonSerializer.Create(jsonSerializerSettings); - var value = result.Value; - if (value != null && _asyncEnumerableReaderFactory.TryGetReader(value.GetType(), out var reader)) - { - Log.BufferingAsyncEnumerable(_logger, value); - value = await reader(value); - } jsonSerializer.Serialize(jsonWriter, value); } diff --git a/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonOutputFormatterTest.cs b/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonOutputFormatterTest.cs index d0d1598bf677edc515f41a464488c55d477c3b70..116da8922aea3704ab6d6ffb271b575d403ab52a 100644 --- a/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonOutputFormatterTest.cs +++ b/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonOutputFormatterTest.cs @@ -1,22 +1,17 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System; using System.Buffers; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Reflection; +using System.Runtime.CompilerServices; using System.Text; -using System.Threading; -using System.Threading.Tasks; using Microsoft.AspNetCore.WebUtilities; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Primitives; +using Microsoft.Net.Http.Headers; using Moq; using Newtonsoft.Json; using Newtonsoft.Json.Linq; using Newtonsoft.Json.Serialization; -using Xunit; namespace Microsoft.AspNetCore.Mvc.Formatters { @@ -436,6 +431,51 @@ namespace Microsoft.AspNetCore.Mvc.Formatters } } + [Fact] + public async Task WriteResponseBodyAsync_AsyncEnumerableConnectionCloses() + { + // Arrange + var formatter = GetOutputFormatter(); + var mediaType = MediaTypeHeaderValue.Parse("application/json; charset=utf-8"); + + var body = new MemoryStream(); + var actionContext = GetActionContext(mediaType, body); + var cts = new CancellationTokenSource(); + actionContext.HttpContext.RequestAborted = cts.Token; + actionContext.HttpContext.RequestServices = new ServiceCollection().AddLogging().BuildServiceProvider(); + + var asyncEnumerable = AsyncEnumerableClosedConnection(); + var outputFormatterContext = new OutputFormatterWriteContext( + actionContext.HttpContext, + new TestHttpResponseStreamWriterFactory().CreateWriter, + asyncEnumerable.GetType(), + asyncEnumerable) + { + ContentType = new StringSegment(mediaType.ToString()), + }; + var iterated = false; + + // Act + await formatter.WriteResponseBodyAsync(outputFormatterContext, Encoding.GetEncoding("utf-8")); + + // Assert + Assert.Empty(body.ToArray()); + Assert.False(iterated); + + async IAsyncEnumerable<int> AsyncEnumerableClosedConnection([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await Task.Yield(); + cts.Cancel(); + // MvcOptions.MaxIAsyncEnumerableBufferLimit is 8192. Pick some value larger than that. + foreach (var i in Enumerable.Range(0, 9000)) + { + cancellationToken.ThrowIfCancellationRequested(); + iterated = true; + yield return i; + } + } + } + private class TestableJsonOutputFormatter : NewtonsoftJsonOutputFormatter { public TestableJsonOutputFormatter(JsonSerializerSettings serializerSettings) diff --git a/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonResultExecutorTest.cs b/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonResultExecutorTest.cs index ffa7d2f6f17c17961b32ebb028122bf572e5889a..7d88e8597b27d8df03021e7fcfd68b99153e1af5 100644 --- a/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonResultExecutorTest.cs +++ b/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonResultExecutorTest.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Buffers; +using System.Runtime.CompilerServices; +using System.Text; using Microsoft.AspNetCore.Mvc.Infrastructure; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; @@ -25,5 +27,37 @@ namespace Microsoft.AspNetCore.Mvc.NewtonsoftJson { return new JsonSerializerSettings { Formatting = Formatting.Indented }; } + + [Fact] + public async Task ExecuteAsync_AsyncEnumerableReceivesCancellationToken() + { + // Arrange + var expected = System.Text.Json.JsonSerializer.Serialize(new[] { "Hello", "world" }); + + var cts = new CancellationTokenSource(); + var context = GetActionContext(); + context.HttpContext.RequestAborted = cts.Token; + var result = new JsonResult(TestAsyncEnumerable()); + var executor = CreateExecutor(); + CancellationToken token = default; + + // Act + await executor.ExecuteAsync(context, result); + + // Assert + var written = GetWrittenBytes(context.HttpContext); + Assert.Equal(expected, Encoding.UTF8.GetString(written)); + + cts.Cancel(); + Assert.Equal(cts.Token, token); + + async IAsyncEnumerable<string> TestAsyncEnumerable([EnumeratorCancellation] CancellationToken cancellationToken = default) + { + await Task.Yield(); + token = cancellationToken; + yield return "Hello"; + yield return "world"; + } + } } }