From 337217a655ba17ea7e11e3116ae2b242f1073f40 Mon Sep 17 00:00:00 2001
From: "github-actions[bot]"
 <41898282+github-actions[bot]@users.noreply.github.com>
Date: Wed, 1 Sep 2021 14:49:17 -0700
Subject: [PATCH] [release/6.0] Pass request cancellation token when using
 IAsyncEnumerable in MVC and WriteAsJsonAsync (#35993)

* Pass request cancellation token when using IAsyncEnumerable

* more

* fb

* update

* fixup assert

Co-authored-by: Brennan <brecon@microsoft.com>
---
 .../src/HttpResponseJsonExtensions.cs         |  40 ++++
 .../test/HttpResponseJsonExtensionsTests.cs   | 187 ++++++++++++++++++
 .../SystemTextJsonOutputFormatter.cs          |   8 +-
 .../Infrastructure/AsyncEnumerableReader.cs   |  12 +-
 .../SystemTextJsonResultExecutor.cs           |  19 +-
 .../SystemTextJsonOutputFormatterTest.cs      |  49 +++++
 .../AsyncEnumerableReaderTest.cs              |  44 ++++-
 .../JsonResultExecutorTestBase.cs             |  68 ++++++-
 .../SystemTextJsonResultExecutorTest.cs       |   4 +-
 ...mlDataContractSerializerOutputFormatter.cs |   6 +-
 .../src/XmlSerializerOutputFormatter.cs       |   6 +-
 ...taContractSerializerOutputFormatterTest.cs |  68 +++++++
 .../test/XmlSerializerOutputFormatterTest.cs  |  70 ++++++-
 .../src/NewtonsoftJsonOutputFormatter.cs      |  15 +-
 .../src/NewtonsoftJsonResultExecutor.cs       |  21 +-
 .../test/NewtonsoftJsonOutputFormatterTest.cs |  56 +++++-
 .../test/NewtonsoftJsonResultExecutorTest.cs  |  34 ++++
 17 files changed, 658 insertions(+), 49 deletions(-)

diff --git a/src/Http/Http.Extensions/src/HttpResponseJsonExtensions.cs b/src/Http/Http.Extensions/src/HttpResponseJsonExtensions.cs
index c36740385e5..1336cc8f1b1 100644
--- a/src/Http/Http.Extensions/src/HttpResponseJsonExtensions.cs
+++ b/src/Http/Http.Extensions/src/HttpResponseJsonExtensions.cs
@@ -82,9 +82,28 @@ namespace Microsoft.AspNetCore.Http
             options ??= ResolveSerializerOptions(response.HttpContext);
 
             response.ContentType = contentType ?? JsonConstants.JsonContentTypeWithCharset;
+            // if no user provided token, pass the RequestAborted token and ignore OperationCanceledException
+            if (!cancellationToken.CanBeCanceled)
+            {
+                return WriteAsJsonAsyncSlow<TValue>(response.Body, value, options, response.HttpContext.RequestAborted);
+            }
+
             return JsonSerializer.SerializeAsync<TValue>(response.Body, value, options, cancellationToken);
         }
 
+        private static async Task WriteAsJsonAsyncSlow<TValue>(
+            Stream body,
+            TValue value,
+            JsonSerializerOptions? options,
+            CancellationToken cancellationToken)
+        {
+            try
+            {
+                await JsonSerializer.SerializeAsync<TValue>(body, value, options, cancellationToken);
+            }
+            catch (OperationCanceledException) { }
+        }
+
         /// <summary>
         /// Write the specified value as JSON to the response body. The response content-type will be set to
         /// <c>application/json; charset=utf-8</c>.
@@ -157,9 +176,30 @@ namespace Microsoft.AspNetCore.Http
             options ??= ResolveSerializerOptions(response.HttpContext);
 
             response.ContentType = contentType ?? JsonConstants.JsonContentTypeWithCharset;
+
+            // if no user provided token, pass the RequestAborted token and ignore OperationCanceledException
+            if (!cancellationToken.CanBeCanceled)
+            {
+                return WriteAsJsonAsyncSlow(response.Body, value, type, options, response.HttpContext.RequestAborted);
+            }
+
             return JsonSerializer.SerializeAsync(response.Body, value, type, options, cancellationToken);
         }
 
+        private static async Task WriteAsJsonAsyncSlow(
+            Stream body,
+            object? value,
+            Type type,
+            JsonSerializerOptions? options,
+            CancellationToken cancellationToken)
+        {
+            try
+            {
+                await JsonSerializer.SerializeAsync(body, value, type, options, cancellationToken);
+            }
+            catch (OperationCanceledException) { }
+        }
+
         private static JsonSerializerOptions ResolveSerializerOptions(HttpContext httpContext)
         {
             // Attempt to resolve options from DI then fallback to default options
diff --git a/src/Http/Http.Extensions/test/HttpResponseJsonExtensionsTests.cs b/src/Http/Http.Extensions/test/HttpResponseJsonExtensionsTests.cs
index 6571716ec2b..5addfa413bf 100644
--- a/src/Http/Http.Extensions/test/HttpResponseJsonExtensionsTests.cs
+++ b/src/Http/Http.Extensions/test/HttpResponseJsonExtensionsTests.cs
@@ -3,6 +3,7 @@
 
 using System;
 using System.IO;
+using System.Runtime.CompilerServices;
 using System.Text;
 using System.Text.Json;
 using System.Text.Json.Serialization;
@@ -255,6 +256,192 @@ namespace Microsoft.AspNetCore.Http.Extensions.Tests
             Assert.Equal(StatusCodes.Status418ImATeapot, context.Response.StatusCode);
         }
 
+        [Fact]
+        public async Task WriteAsJsonAsyncGeneric_AsyncEnumerable()
+        {
+            // Arrange
+            var body = new MemoryStream();
+            var context = new DefaultHttpContext();
+            context.Response.Body = body;
+
+            // Act
+            await context.Response.WriteAsJsonAsync(AsyncEnumerable());
+
+            // Assert
+            Assert.Equal(JsonConstants.JsonContentTypeWithCharset, context.Response.ContentType);
+            Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
+
+            Assert.Equal("[1,2]", Encoding.UTF8.GetString(body.ToArray()));
+
+            async IAsyncEnumerable<int> AsyncEnumerable()
+            {
+                await Task.Yield();
+                yield return 1;
+                yield return 2;
+            }
+        }
+
+        [Fact]
+        public async Task WriteAsJsonAsync_AsyncEnumerable()
+        {
+            // Arrange
+            var body = new MemoryStream();
+            var context = new DefaultHttpContext();
+            context.Response.Body = body;
+
+            // Act
+            await context.Response.WriteAsJsonAsync(AsyncEnumerable(), typeof(IAsyncEnumerable<int>));
+
+            // Assert
+            Assert.Equal(JsonConstants.JsonContentTypeWithCharset, context.Response.ContentType);
+            Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
+
+            Assert.Equal("[1,2]", Encoding.UTF8.GetString(body.ToArray()));
+
+            async IAsyncEnumerable<int> AsyncEnumerable()
+            {
+                await Task.Yield();
+                yield return 1;
+                yield return 2;
+            }
+        }
+
+        [Fact]
+        public async Task WriteAsJsonAsyncGeneric_AsyncEnumerable_ClosedConnecton()
+        {
+            // Arrange
+            var cts = new CancellationTokenSource();
+            var body = new MemoryStream();
+            var context = new DefaultHttpContext();
+            context.Response.Body = body;
+            context.RequestAborted = cts.Token;
+            var iterated = false;
+
+            // Act
+            await context.Response.WriteAsJsonAsync(AsyncEnumerable());
+
+            // Assert
+            Assert.Equal(JsonConstants.JsonContentTypeWithCharset, context.Response.ContentType);
+            Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
+
+            // System.Text.Json might write the '[' before cancellation is observed
+            Assert.InRange(body.ToArray().Length, 0, 1);
+            Assert.False(iterated);
+
+            async IAsyncEnumerable<int> AsyncEnumerable([EnumeratorCancellation] CancellationToken cancellationToken = default)
+            {
+                await Task.Yield();
+                cts.Cancel();
+                for (var i = 0; i < 100 && !cancellationToken.IsCancellationRequested; i++)
+                {
+                    iterated = true;
+                    yield return i;
+                }
+            }
+        }
+
+        [Fact]
+        public async Task WriteAsJsonAsync_AsyncEnumerable_ClosedConnecton()
+        {
+            // Arrange
+            var cts = new CancellationTokenSource();
+            var body = new MemoryStream();
+            var context = new DefaultHttpContext();
+            context.Response.Body = body;
+            context.RequestAborted = cts.Token;
+            var iterated = false;
+
+            // Act
+            await context.Response.WriteAsJsonAsync(AsyncEnumerable(), typeof(IAsyncEnumerable<int>));
+
+            // Assert
+            Assert.Equal(JsonConstants.JsonContentTypeWithCharset, context.Response.ContentType);
+            Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
+
+            // System.Text.Json might write the '[' before cancellation is observed
+            Assert.InRange(body.ToArray().Length, 0, 1);
+            Assert.False(iterated);
+
+            async IAsyncEnumerable<int> AsyncEnumerable([EnumeratorCancellation] CancellationToken cancellationToken = default)
+            {
+                await Task.Yield();
+                cts.Cancel();
+                for (var i = 0; i < 100 && !cancellationToken.IsCancellationRequested; i++)
+                {
+                    iterated = true;
+                    yield return i;
+                }
+            }
+        }
+
+        [Fact]
+        public async Task WriteAsJsonAsync_AsyncEnumerable_UserPassedTokenThrows()
+        {
+            // Arrange
+            var body = new MemoryStream();
+            var context = new DefaultHttpContext();
+            context.Response.Body = body;
+            context.RequestAborted = new CancellationToken(canceled: true);
+            var cts = new CancellationTokenSource();
+            var iterated = false;
+
+            // Act
+            await Assert.ThrowsAnyAsync<OperationCanceledException>(() => context.Response.WriteAsJsonAsync(AsyncEnumerable(), typeof(IAsyncEnumerable<int>), cts.Token));
+
+            // Assert
+            Assert.Equal(JsonConstants.JsonContentTypeWithCharset, context.Response.ContentType);
+            Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
+
+            // System.Text.Json might write the '[' before cancellation is observed
+            Assert.InRange(body.ToArray().Length, 0, 1);
+            Assert.False(iterated);
+
+            async IAsyncEnumerable<int> AsyncEnumerable([EnumeratorCancellation] CancellationToken cancellationToken = default)
+            {
+                await Task.Yield();
+                cts.Cancel();
+                for (var i = 0; i < 100 && !cancellationToken.IsCancellationRequested; i++)
+                {
+                    iterated = true;
+                    yield return i;
+                }
+            }
+        }
+
+        [Fact]
+        public async Task WriteAsJsonAsyncGeneric_AsyncEnumerableG_UserPassedTokenThrows()
+        {
+            // Arrange
+            var body = new MemoryStream();
+            var context = new DefaultHttpContext();
+            context.Response.Body = body;
+            context.RequestAborted = new CancellationToken(canceled: true);
+            var cts = new CancellationTokenSource();
+            var iterated = false;
+
+            // Act
+            await Assert.ThrowsAnyAsync<OperationCanceledException>(() => context.Response.WriteAsJsonAsync(AsyncEnumerable(), cts.Token));
+
+            // Assert
+            Assert.Equal(JsonConstants.JsonContentTypeWithCharset, context.Response.ContentType);
+            Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
+
+            // System.Text.Json might write the '[' before cancellation is observed
+            Assert.InRange(body.ToArray().Length, 0, 1);
+            Assert.False(iterated);
+
+            async IAsyncEnumerable<int> AsyncEnumerable([EnumeratorCancellation] CancellationToken cancellationToken = default)
+            {
+                await Task.Yield();
+                cts.Cancel();
+                for (var i = 0; i < 100 && !cancellationToken.IsCancellationRequested; i++)
+                {
+                    iterated = true;
+                    yield return i;
+                }
+            }
+        }
+
         public class TestObject
         {
             public string? StringProperty { get; set; }
diff --git a/src/Mvc/Mvc.Core/src/Formatters/SystemTextJsonOutputFormatter.cs b/src/Mvc/Mvc.Core/src/Formatters/SystemTextJsonOutputFormatter.cs
index e372782b3b9..830998717d1 100644
--- a/src/Mvc/Mvc.Core/src/Formatters/SystemTextJsonOutputFormatter.cs
+++ b/src/Mvc/Mvc.Core/src/Formatters/SystemTextJsonOutputFormatter.cs
@@ -79,8 +79,12 @@ namespace Microsoft.AspNetCore.Mvc.Formatters
             var responseStream = httpContext.Response.Body;
             if (selectedEncoding.CodePage == Encoding.UTF8.CodePage)
             {
-                await JsonSerializer.SerializeAsync(responseStream, context.Object, objectType, SerializerOptions);
-                await responseStream.FlushAsync();
+                try
+                {
+                    await JsonSerializer.SerializeAsync(responseStream, context.Object, objectType, SerializerOptions, httpContext.RequestAborted);
+                    await responseStream.FlushAsync(httpContext.RequestAborted);
+                }
+                catch (OperationCanceledException) { }
             }
             else
             {
diff --git a/src/Mvc/Mvc.Core/src/Infrastructure/AsyncEnumerableReader.cs b/src/Mvc/Mvc.Core/src/Infrastructure/AsyncEnumerableReader.cs
index 3f8f1431c73..4748c7f31f2 100644
--- a/src/Mvc/Mvc.Core/src/Infrastructure/AsyncEnumerableReader.cs
+++ b/src/Mvc/Mvc.Core/src/Infrastructure/AsyncEnumerableReader.cs
@@ -34,7 +34,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
             nameof(ReadInternal),
             BindingFlags.NonPublic | BindingFlags.Instance)!;
 
-        private readonly ConcurrentDictionary<Type, Func<object, Task<ICollection>>?> _asyncEnumerableConverters = new();
+        private readonly ConcurrentDictionary<Type, Func<object, CancellationToken, Task<ICollection>>?> _asyncEnumerableConverters = new();
         private readonly MvcOptions _mvcOptions;
 
         /// <summary>
@@ -52,7 +52,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
         /// <param name="type">The type to read.</param>
         /// <param name="reader">A delegate that when awaited reads the <see cref="IAsyncEnumerable{T}"/>.</param>
         /// <returns><see langword="true" /> when <paramref name="type"/> is an instance of <see cref="IAsyncEnumerable{T}"/>, othwerise <see langword="false"/>.</returns>
-        public bool TryGetReader(Type type, [NotNullWhen(true)] out Func<object, Task<ICollection>>? reader)
+        public bool TryGetReader(Type type, [NotNullWhen(true)] out Func<object, CancellationToken, Task<ICollection>>? reader)
         {
             if (!_asyncEnumerableConverters.TryGetValue(type, out reader))
             {
@@ -67,9 +67,9 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
                 {
                     var enumeratedObjectType = enumerableType.GetGenericArguments()[0];
 
-                    var converter = (Func<object, Task<ICollection>>)Converter
+                    var converter = (Func<object, CancellationToken, Task<ICollection>>)Converter
                         .MakeGenericMethod(enumeratedObjectType)
-                        .CreateDelegate(typeof(Func<object, Task<ICollection>>), this);
+                        .CreateDelegate(typeof(Func<object, CancellationToken, Task<ICollection>>), this);
 
                     reader = converter;
                     _asyncEnumerableConverters.TryAdd(type, reader);
@@ -79,9 +79,9 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
             return reader != null;
         }
 
-        private async Task<ICollection> ReadInternal<T>(object value)
+        private async Task<ICollection> ReadInternal<T>(object value, CancellationToken cancellationToken)
         {
-            var asyncEnumerable = (IAsyncEnumerable<T>)value;
+            var asyncEnumerable = ((IAsyncEnumerable<T>)value).WithCancellation(cancellationToken);
             var result = new List<T>();
             var count = 0;
 
diff --git a/src/Mvc/Mvc.Core/src/Infrastructure/SystemTextJsonResultExecutor.cs b/src/Mvc/Mvc.Core/src/Infrastructure/SystemTextJsonResultExecutor.cs
index 4d440a2dec5..26c3de180eb 100644
--- a/src/Mvc/Mvc.Core/src/Infrastructure/SystemTextJsonResultExecutor.cs
+++ b/src/Mvc/Mvc.Core/src/Infrastructure/SystemTextJsonResultExecutor.cs
@@ -25,16 +25,13 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
 
         private readonly JsonOptions _options;
         private readonly ILogger<SystemTextJsonResultExecutor> _logger;
-        private readonly AsyncEnumerableReader _asyncEnumerableReaderFactory;
 
         public SystemTextJsonResultExecutor(
             IOptions<JsonOptions> options,
-            ILogger<SystemTextJsonResultExecutor> logger,
-            IOptions<MvcOptions> mvcOptions)
+            ILogger<SystemTextJsonResultExecutor> logger)
         {
             _options = options.Value;
             _logger = logger;
-            _asyncEnumerableReaderFactory = new AsyncEnumerableReader(mvcOptions.Value);
         }
 
         public async Task ExecuteAsync(ActionContext context, JsonResult result)
@@ -77,8 +74,12 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
             var responseStream = response.Body;
             if (resolvedContentTypeEncoding.CodePage == Encoding.UTF8.CodePage)
             {
-                await JsonSerializer.SerializeAsync(responseStream, value, objectType, jsonSerializerOptions);
-                await responseStream.FlushAsync();
+                try
+                {
+                    await JsonSerializer.SerializeAsync(responseStream, value, objectType, jsonSerializerOptions, context.HttpContext.RequestAborted);
+                    await responseStream.FlushAsync(context.HttpContext.RequestAborted);
+                }
+                catch (OperationCanceledException) { }
             }
             else
             {
@@ -89,9 +90,11 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
                 ExceptionDispatchInfo? exceptionDispatchInfo = null;
                 try
                 {
-                    await JsonSerializer.SerializeAsync(transcodingStream, value, objectType, jsonSerializerOptions);
-                    await transcodingStream.FlushAsync();
+                    await JsonSerializer.SerializeAsync(transcodingStream, value, objectType, jsonSerializerOptions, context.HttpContext.RequestAborted);
+                    await transcodingStream.FlushAsync(context.HttpContext.RequestAborted);
                 }
+                catch (OperationCanceledException)
+                { }
                 catch (Exception ex)
                 {
                     // TranscodingStream may write to the inner stream as part of it's disposal.
diff --git a/src/Mvc/Mvc.Core/test/Formatters/SystemTextJsonOutputFormatterTest.cs b/src/Mvc/Mvc.Core/test/Formatters/SystemTextJsonOutputFormatterTest.cs
index b9fa926628c..c178f43f8cb 100644
--- a/src/Mvc/Mvc.Core/test/Formatters/SystemTextJsonOutputFormatterTest.cs
+++ b/src/Mvc/Mvc.Core/test/Formatters/SystemTextJsonOutputFormatterTest.cs
@@ -5,6 +5,7 @@ using System;
 using System.Collections.Generic;
 using System.IO;
 using System.Linq;
+using System.Runtime.CompilerServices;
 using System.Text;
 using System.Text.Json;
 using System.Text.Json.Serialization;
@@ -113,6 +114,54 @@ namespace Microsoft.AspNetCore.Mvc.Formatters
             Assert.Equal(expected.ToArray(), body.ToArray());
         }
 
+        [Fact]
+        public async Task WriteResponseBodyAsync_AsyncEnumerableConnectionCloses()
+        {
+            // Arrange
+            var formatter = GetOutputFormatter();
+            var mediaType = MediaTypeHeaderValue.Parse("application/json; charset=utf-8");
+
+            var body = new MemoryStream();
+            var actionContext = GetActionContext(mediaType, body);
+            var cts = new CancellationTokenSource();
+            actionContext.HttpContext.RequestAborted = cts.Token;
+
+            var asyncEnumerable = AsyncEnumerableClosedConnection();
+            var outputFormatterContext = new OutputFormatterWriteContext(
+                actionContext.HttpContext,
+                new TestHttpResponseStreamWriterFactory().CreateWriter,
+                asyncEnumerable.GetType(),
+                asyncEnumerable)
+            {
+                ContentType = new StringSegment(mediaType.ToString()),
+            };
+            var iterated = false;
+
+            // Act
+            await formatter.WriteResponseBodyAsync(outputFormatterContext, Encoding.GetEncoding("utf-8"));
+
+            // Assert
+            // System.Text.Json might write the '[' before cancellation is observed
+            Assert.InRange(body.ToArray().Length, 0, 1);
+            Assert.False(iterated);
+
+            async IAsyncEnumerable<int> AsyncEnumerableClosedConnection([EnumeratorCancellation] CancellationToken cancellationToken = default)
+            {
+                await Task.Yield();
+                cts.Cancel();
+                // MvcOptions.MaxIAsyncEnumerableBufferLimit is 8192. Pick some value larger than that.
+                foreach (var i in Enumerable.Range(0, 9000))
+                {
+                    if (cancellationToken.IsCancellationRequested)
+                    {
+                        yield break;
+                    }
+                    iterated = true;
+                    yield return i;
+                }
+            }
+        }
+
         private class Person
         {
             public string Name { get; set; }
diff --git a/src/Mvc/Mvc.Core/test/Infrastructure/AsyncEnumerableReaderTest.cs b/src/Mvc/Mvc.Core/test/Infrastructure/AsyncEnumerableReaderTest.cs
index eac659f55ad..3d5afc91f9e 100644
--- a/src/Mvc/Mvc.Core/test/Infrastructure/AsyncEnumerableReaderTest.cs
+++ b/src/Mvc/Mvc.Core/test/Infrastructure/AsyncEnumerableReaderTest.cs
@@ -4,6 +4,7 @@
 using System;
 using System.Collections.Generic;
 using System.Globalization;
+using System.Runtime.CompilerServices;
 using System.Threading;
 using System.Threading.Tasks;
 using Xunit;
@@ -43,7 +44,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
 
             // Assert
             Assert.True(result);
-            var readCollection = await reader(asyncEnumerable);
+            var readCollection = await reader(asyncEnumerable, default);
             var collection = Assert.IsAssignableFrom<ICollection<string>>(readCollection);
             Assert.Equal(new[] { "0", "1", "2", }, collection);
         }
@@ -61,7 +62,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
 
             // Assert
             Assert.True(result);
-            var readCollection = await reader(asyncEnumerable);
+            var readCollection = await reader(asyncEnumerable, default);
             var collection = Assert.IsAssignableFrom<ICollection<int>>(readCollection);
             Assert.Equal(new[] { 0, 1, 2, }, collection);
         }
@@ -114,8 +115,8 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
             Assert.True(readerFactory.TryGetReader(asyncEnumerable1.GetType(), out var reader));
 
             // Assert
-            Assert.Equal(expected, await reader(asyncEnumerable1));
-            Assert.Equal(expected, await reader(asyncEnumerable2));
+            Assert.Equal(expected, await reader(asyncEnumerable1, default));
+            Assert.Equal(expected, await reader(asyncEnumerable2, default));
         }
 
         [Fact]
@@ -131,8 +132,8 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
             Assert.True(readerFactory.TryGetReader(asyncEnumerable1.GetType(), out var reader));
 
             // Assert
-            Assert.Equal(new[] { "0", "1", "2" }, await reader(asyncEnumerable1));
-            Assert.Equal(new[] { "0", "1", "2", "3" }, await reader(asyncEnumerable2));
+            Assert.Equal(new[] { "0", "1", "2" }, await reader(asyncEnumerable1, default));
+            Assert.Equal(new[] { "0", "1", "2", "3" }, await reader(asyncEnumerable2, default));
         }
 
         [Fact]
@@ -166,7 +167,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
 
             // Assert
             Assert.True(result);
-            var readCollection = await reader(asyncEnumerable);
+            var readCollection = await reader(asyncEnumerable, default);
             var collection = Assert.IsAssignableFrom<ICollection<object>>(readCollection);
             Assert.Equal(new[] { "0", "1", "2", }, collection);
         }
@@ -184,7 +185,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
 
             // Act
             Assert.True(readerFactory.TryGetReader(enumerable.GetType(), out var reader));
-            var ex = await Assert.ThrowsAsync<InvalidOperationException>(() => reader(enumerable));
+            var ex = await Assert.ThrowsAsync<InvalidOperationException>(() => reader(enumerable, default));
 
             // Assert
             Assert.Equal(expected, ex.Message);
@@ -200,7 +201,32 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
 
             // Act & Assert
             Assert.True(readerFactory.TryGetReader(enumerable.GetType(), out var reader));
-            await Assert.ThrowsAsync<TimeZoneNotFoundException>(() => reader(enumerable));
+            await Assert.ThrowsAsync<TimeZoneNotFoundException>(() => reader(enumerable, default));
+        }
+
+        [Fact]
+        public async Task Reader_PassesCancellationTokenToIAsyncEnumerable()
+        {
+            // Arrange
+            var enumerable = AsyncEnumerable();
+            var options = new MvcOptions();
+            CancellationToken token = default;
+            var readerFactory = new AsyncEnumerableReader(options);
+            var cts = new CancellationTokenSource();
+
+            // Act & Assert
+            Assert.True(readerFactory.TryGetReader(enumerable.GetType(), out var reader));
+            await reader(enumerable, cts.Token);
+
+            cts.Cancel();
+            Assert.Equal(cts.Token, token);
+
+            async IAsyncEnumerable<string> AsyncEnumerable([EnumeratorCancellation] CancellationToken cancellationToken = default)
+            {
+                token = cancellationToken;
+                await Task.Yield();
+                yield return string.Empty;
+            }
         }
 
         public static async IAsyncEnumerable<string> TestEnumerable(int count = 3)
diff --git a/src/Mvc/Mvc.Core/test/Infrastructure/JsonResultExecutorTestBase.cs b/src/Mvc/Mvc.Core/test/Infrastructure/JsonResultExecutorTestBase.cs
index 9e579244a36..10d6d1f1393 100644
--- a/src/Mvc/Mvc.Core/test/Infrastructure/JsonResultExecutorTestBase.cs
+++ b/src/Mvc/Mvc.Core/test/Infrastructure/JsonResultExecutorTestBase.cs
@@ -5,6 +5,7 @@ using System;
 using System.Collections.Generic;
 using System.IO;
 using System.Linq;
+using System.Runtime.CompilerServices;
 using System.Text;
 using System.Text.Json;
 using System.Threading;
@@ -365,6 +366,71 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
             Assert.Equal(expected, written);
         }
 
+        [Fact]
+        public async Task ExecuteAsync_AsyncEnumerableConnectionCloses()
+        {
+            var context = GetActionContext();
+            var cts = new CancellationTokenSource();
+            context.HttpContext.RequestAborted = cts.Token;
+            var result = new JsonResult(AsyncEnumerableClosedConnection());
+            var executor = CreateExecutor();
+            var iterated = false;
+
+            // Act
+            await executor.ExecuteAsync(context, result);
+
+            // Assert
+            var written = GetWrittenBytes(context.HttpContext);
+            // System.Text.Json might write the '[' before cancellation is observed
+            Assert.InRange(written.Length, 0, 1);
+            Assert.False(iterated);
+
+            async IAsyncEnumerable<int> AsyncEnumerableClosedConnection([EnumeratorCancellation] CancellationToken cancellationToken = default)
+            {
+                await Task.Yield();
+                cts.Cancel();
+                for (var i = 0; i < 100000 && !cancellationToken.IsCancellationRequested; i++)
+                {
+                    iterated = true;
+                    yield return i;
+                }
+            }
+        }
+
+        [Fact]
+        public async Task ExecuteAsyncWithDifferentContentType_AsyncEnumerableConnectionCloses()
+        {
+            var context = GetActionContext();
+            var cts = new CancellationTokenSource();
+            context.HttpContext.RequestAborted = cts.Token;
+            var result = new JsonResult(AsyncEnumerableClosedConnection())
+            {
+                ContentType = "text/json; charset=utf-16",
+            };
+            var executor = CreateExecutor();
+            var iterated = false;
+
+            // Act
+            await executor.ExecuteAsync(context, result);
+
+            // Assert
+            var written = GetWrittenBytes(context.HttpContext);
+            // System.Text.Json might write the '[' before cancellation is observed (utf-16 means 2 bytes per character)
+            Assert.InRange(written.Length, 0, 2);
+            Assert.False(iterated);
+
+            async IAsyncEnumerable<int> AsyncEnumerableClosedConnection([EnumeratorCancellation] CancellationToken cancellationToken = default)
+            {
+                await Task.Yield();
+                cts.Cancel();
+                for (var i = 0; i < 100000 && !cancellationToken.IsCancellationRequested; i++)
+                {
+                    iterated = true;
+                    yield return i;
+                }
+            }
+        }
+
         protected IActionResultExecutor<JsonResult> CreateExecutor() => CreateExecutor(NullLoggerFactory.Instance);
 
         protected abstract IActionResultExecutor<JsonResult> CreateExecutor(ILoggerFactory loggerFactory);
@@ -387,7 +453,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
             return new ActionContext(GetHttpContext(), new RouteData(), new ActionDescriptor());
         }
 
-        private static byte[] GetWrittenBytes(HttpContext context)
+        protected static byte[] GetWrittenBytes(HttpContext context)
         {
             context.Response.Body.Seek(0, SeekOrigin.Begin);
             return Assert.IsType<MemoryStream>(context.Response.Body).ToArray();
diff --git a/src/Mvc/Mvc.Core/test/Infrastructure/SystemTextJsonResultExecutorTest.cs b/src/Mvc/Mvc.Core/test/Infrastructure/SystemTextJsonResultExecutorTest.cs
index cd5489fd46f..403ed61b166 100644
--- a/src/Mvc/Mvc.Core/test/Infrastructure/SystemTextJsonResultExecutorTest.cs
+++ b/src/Mvc/Mvc.Core/test/Infrastructure/SystemTextJsonResultExecutorTest.cs
@@ -2,6 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System;
+using System.Runtime.CompilerServices;
 using System.Text;
 using System.Text.Json;
 using System.Text.Json.Serialization;
@@ -18,8 +19,7 @@ namespace Microsoft.AspNetCore.Mvc.Infrastructure
         {
             return new SystemTextJsonResultExecutor(
                 Options.Create(new JsonOptions()),
-                loggerFactory.CreateLogger<SystemTextJsonResultExecutor>(),
-                Options.Create(new MvcOptions()));
+                loggerFactory.CreateLogger<SystemTextJsonResultExecutor>());
         }
 
         [Fact]
diff --git a/src/Mvc/Mvc.Formatters.Xml/src/XmlDataContractSerializerOutputFormatter.cs b/src/Mvc/Mvc.Formatters.Xml/src/XmlDataContractSerializerOutputFormatter.cs
index 429f76807c6..1ecec56f503 100644
--- a/src/Mvc/Mvc.Formatters.Xml/src/XmlDataContractSerializerOutputFormatter.cs
+++ b/src/Mvc/Mvc.Formatters.Xml/src/XmlDataContractSerializerOutputFormatter.cs
@@ -261,8 +261,12 @@ namespace Microsoft.AspNetCore.Mvc.Formatters
             {
                 Log.BufferingAsyncEnumerable(_logger, value);
 
-                value = await reader(value);
+                value = await reader(value, context.HttpContext.RequestAborted);
                 valueType = value.GetType();
+                if (context.HttpContext.RequestAborted.IsCancellationRequested)
+                {
+                    return;
+                }
             }
 
             Debug.Assert(valueType is not null);
diff --git a/src/Mvc/Mvc.Formatters.Xml/src/XmlSerializerOutputFormatter.cs b/src/Mvc/Mvc.Formatters.Xml/src/XmlSerializerOutputFormatter.cs
index a74b7f05844..eba98cb2a50 100644
--- a/src/Mvc/Mvc.Formatters.Xml/src/XmlSerializerOutputFormatter.cs
+++ b/src/Mvc/Mvc.Formatters.Xml/src/XmlSerializerOutputFormatter.cs
@@ -236,8 +236,12 @@ namespace Microsoft.AspNetCore.Mvc.Formatters
             {
                 Log.BufferingAsyncEnumerable(_logger, value);
 
-                value = await reader(value);
+                value = await reader(value, context.HttpContext.RequestAborted);
                 valueType = value.GetType();
+                if (context.HttpContext.RequestAborted.IsCancellationRequested)
+                {
+                    return;
+                }
             }
 
             // Wrap the object only if there is a wrapping type.
diff --git a/src/Mvc/Mvc.Formatters.Xml/test/XmlDataContractSerializerOutputFormatterTest.cs b/src/Mvc/Mvc.Formatters.Xml/test/XmlDataContractSerializerOutputFormatterTest.cs
index 7208978b08a..44a5a72152d 100644
--- a/src/Mvc/Mvc.Formatters.Xml/test/XmlDataContractSerializerOutputFormatterTest.cs
+++ b/src/Mvc/Mvc.Formatters.Xml/test/XmlDataContractSerializerOutputFormatterTest.cs
@@ -5,6 +5,7 @@ using System;
 using System.Collections.Generic;
 using System.Globalization;
 using System.IO;
+using System.Runtime.CompilerServices;
 using System.Runtime.Serialization;
 using System.Text;
 using System.Threading.Tasks;
@@ -722,6 +723,73 @@ namespace Microsoft.AspNetCore.Mvc.Formatters.Xml
             Assert.Equal(expectedOutput, content);
         }
 
+        [Fact]
+        public async Task WriteResponseBodyAsync_AsyncEnumerableConnectionCloses()
+        {
+            // Arrange
+            var formatter = new XmlDataContractSerializerOutputFormatter();
+            var body = new MemoryStream();
+            var cts = new CancellationTokenSource();
+            var iterated = false;
+
+            var asyncEnumerable = AsyncEnumerableClosedConnection();
+            var outputFormatterContext = GetOutputFormatterContext(
+                asyncEnumerable,
+                asyncEnumerable.GetType());
+            outputFormatterContext.HttpContext.RequestAborted = cts.Token;
+            outputFormatterContext.HttpContext.Response.Body = body;
+
+            // Act
+            await formatter.WriteResponseBodyAsync(outputFormatterContext, Encoding.GetEncoding("utf-8"));
+
+            // Assert
+            Assert.Empty(body.ToArray());
+            Assert.False(iterated);
+
+            async IAsyncEnumerable<int> AsyncEnumerableClosedConnection([EnumeratorCancellation] CancellationToken cancellationToken = default)
+            {
+                await Task.Yield();
+                cts.Cancel();
+                // MvcOptions.MaxIAsyncEnumerableBufferLimit is 8192. Pick some value larger than that.
+                foreach (var i in Enumerable.Range(0, 9000))
+                {
+                    if (cancellationToken.IsCancellationRequested)
+                    {
+                        yield break;
+                    }
+                    iterated = true;
+                    yield return i;
+                }
+            }
+        }
+
+        [Fact]
+        public async Task WriteResponseBodyAsync_AsyncEnumerable()
+        {
+            // Arrange
+            var formatter = new XmlDataContractSerializerOutputFormatter();
+            var body = new MemoryStream();
+
+            var asyncEnumerable = AsyncEnumerable();
+            var outputFormatterContext = GetOutputFormatterContext(
+                asyncEnumerable,
+                asyncEnumerable.GetType());
+            outputFormatterContext.HttpContext.Response.Body = body;
+
+            // Act
+            await formatter.WriteResponseBodyAsync(outputFormatterContext, Encoding.GetEncoding("utf-8"));
+
+            // Assert
+            Assert.Contains("<int>1</int><int>2</int>", Encoding.UTF8.GetString(body.ToArray()));
+
+            async IAsyncEnumerable<int> AsyncEnumerable()
+            {
+                await Task.Yield();
+                yield return 1;
+                yield return 2;
+            }
+        }
+
         private OutputFormatterWriteContext GetOutputFormatterContext(
             object outputValue,
             Type outputType,
diff --git a/src/Mvc/Mvc.Formatters.Xml/test/XmlSerializerOutputFormatterTest.cs b/src/Mvc/Mvc.Formatters.Xml/test/XmlSerializerOutputFormatterTest.cs
index ef5f4523b7b..335c41d84c9 100644
--- a/src/Mvc/Mvc.Formatters.Xml/test/XmlSerializerOutputFormatterTest.cs
+++ b/src/Mvc/Mvc.Formatters.Xml/test/XmlSerializerOutputFormatterTest.cs
@@ -5,6 +5,7 @@ using System;
 using System.Collections.Generic;
 using System.IO;
 using System.Linq;
+using System.Runtime.CompilerServices;
 using System.Text;
 using System.Threading.Tasks;
 using System.Xml;
@@ -505,6 +506,73 @@ namespace Microsoft.AspNetCore.Mvc.Formatters.Xml
             Assert.False(canWriteResult);
         }
 
+        [Fact]
+        public async Task WriteResponseBodyAsync_AsyncEnumerableConnectionCloses()
+        {
+            // Arrange
+            var formatter = new XmlSerializerOutputFormatter();
+            var body = new MemoryStream();
+            var cts = new CancellationTokenSource();
+            var iterated = false;
+
+            var asyncEnumerable = AsyncEnumerableClosedConnection();
+            var outputFormatterContext = GetOutputFormatterContext(
+                asyncEnumerable,
+                asyncEnumerable.GetType());
+            outputFormatterContext.HttpContext.RequestAborted = cts.Token;
+            outputFormatterContext.HttpContext.Response.Body = body;
+
+            // Act
+            await formatter.WriteResponseBodyAsync(outputFormatterContext, Encoding.GetEncoding("utf-8"));
+
+            // Assert
+            Assert.Empty(body.ToArray());
+            Assert.False(iterated);
+
+            async IAsyncEnumerable<int> AsyncEnumerableClosedConnection([EnumeratorCancellation] CancellationToken cancellationToken = default)
+            {
+                await Task.Yield();
+                cts.Cancel();
+                // MvcOptions.MaxIAsyncEnumerableBufferLimit is 8192. Pick some value larger than that.
+                foreach (var i in Enumerable.Range(0, 9000))
+                {
+                    if (cancellationToken.IsCancellationRequested)
+                    {
+                        yield break;
+                    }
+                    iterated = true;
+                    yield return i;
+                }
+            }
+        }
+
+        [Fact]
+        public async Task WriteResponseBodyAsync_AsyncEnumerable()
+        {
+            // Arrange
+            var formatter = new XmlSerializerOutputFormatter();
+            var body = new MemoryStream();
+
+            var asyncEnumerable = AsyncEnumerable();
+            var outputFormatterContext = GetOutputFormatterContext(
+                asyncEnumerable,
+                asyncEnumerable.GetType());
+            outputFormatterContext.HttpContext.Response.Body = body;
+
+            // Act
+            await formatter.WriteResponseBodyAsync(outputFormatterContext, Encoding.GetEncoding("utf-8"));
+
+            // Assert
+            Assert.Contains("<int>1</int><int>2</int>", Encoding.UTF8.GetString(body.ToArray()));
+
+            async IAsyncEnumerable<int> AsyncEnumerable()
+            {
+                await Task.Yield();
+                yield return 1;
+                yield return 2;
+            }
+        }
+
         private OutputFormatterWriteContext GetOutputFormatterContext(
             object outputValue,
             Type outputType,
@@ -577,4 +645,4 @@ namespace Microsoft.AspNetCore.Mvc.Formatters.Xml
             }
         }
     }
-}
\ No newline at end of file
+}
diff --git a/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonOutputFormatter.cs b/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonOutputFormatter.cs
index 56623da8a3c..e471cd36bad 100644
--- a/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonOutputFormatter.cs
+++ b/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonOutputFormatter.cs
@@ -26,7 +26,6 @@ namespace Microsoft.AspNetCore.Mvc.Formatters
         private MvcNewtonsoftJsonOptions? _jsonOptions;
         private readonly AsyncEnumerableReader _asyncEnumerableReaderFactory;
         private JsonSerializerSettings? _serializerSettings;
-        private ILogger? _logger;
 
         /// <summary>
         /// Initializes a new <see cref="NewtonsoftJsonOutputFormatter"/> instance.
@@ -174,9 +173,17 @@ namespace Microsoft.AspNetCore.Mvc.Formatters
             var value = context.Object;
             if (value is not null && _asyncEnumerableReaderFactory.TryGetReader(value.GetType(), out var reader))
             {
-                _logger ??= context.HttpContext.RequestServices.GetRequiredService<ILogger<NewtonsoftJsonOutputFormatter>>();
-                Log.BufferingAsyncEnumerable(_logger, value);
-                value = await reader(value);
+                var logger = context.HttpContext.RequestServices.GetRequiredService<ILogger<NewtonsoftJsonOutputFormatter>>();
+                Log.BufferingAsyncEnumerable(logger, value);
+                try
+                {
+                    value = await reader(value, context.HttpContext.RequestAborted);
+                }
+                catch (OperationCanceledException) { }
+                if (context.HttpContext.RequestAborted.IsCancellationRequested)
+                {
+                    return;
+                }
             }
 
             try
diff --git a/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonResultExecutor.cs b/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonResultExecutor.cs
index 5f7d28d185f..3ee953f8d1b 100644
--- a/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonResultExecutor.cs
+++ b/src/Mvc/Mvc.NewtonsoftJson/src/NewtonsoftJsonResultExecutor.cs
@@ -125,6 +125,21 @@ namespace Microsoft.AspNetCore.Mvc.NewtonsoftJson
 
             try
             {
+                var value = result.Value;
+                if (value != null && _asyncEnumerableReaderFactory.TryGetReader(value.GetType(), out var reader))
+                {
+                    Log.BufferingAsyncEnumerable(_logger, value);
+                    try
+                    {
+                        value = await reader(value, context.HttpContext.RequestAborted);
+                    }
+                    catch (OperationCanceledException) { }
+                    if (context.HttpContext.RequestAborted.IsCancellationRequested)
+                    {
+                        return;
+                    }
+                }
+
                 await using (var writer = _writerFactory.CreateWriter(responseStream, resolvedContentTypeEncoding))
                 {
                     using var jsonWriter = new JsonTextWriter(writer);
@@ -133,12 +148,6 @@ namespace Microsoft.AspNetCore.Mvc.NewtonsoftJson
                     jsonWriter.AutoCompleteOnClose = false;
 
                     var jsonSerializer = JsonSerializer.Create(jsonSerializerSettings);
-                    var value = result.Value;
-                    if (value != null && _asyncEnumerableReaderFactory.TryGetReader(value.GetType(), out var reader))
-                    {
-                        Log.BufferingAsyncEnumerable(_logger, value);
-                        value = await reader(value);
-                    }
 
                     jsonSerializer.Serialize(jsonWriter, value);
                 }
diff --git a/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonOutputFormatterTest.cs b/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonOutputFormatterTest.cs
index d0d1598bf67..116da8922ae 100644
--- a/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonOutputFormatterTest.cs
+++ b/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonOutputFormatterTest.cs
@@ -1,22 +1,17 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
-using System;
 using System.Buffers;
-using System.Collections.Generic;
-using System.IO;
-using System.Linq;
-using System.Reflection;
+using System.Runtime.CompilerServices;
 using System.Text;
-using System.Threading;
-using System.Threading.Tasks;
 using Microsoft.AspNetCore.WebUtilities;
 using Microsoft.Extensions.DependencyInjection;
+using Microsoft.Extensions.Primitives;
+using Microsoft.Net.Http.Headers;
 using Moq;
 using Newtonsoft.Json;
 using Newtonsoft.Json.Linq;
 using Newtonsoft.Json.Serialization;
-using Xunit;
 
 namespace Microsoft.AspNetCore.Mvc.Formatters
 {
@@ -436,6 +431,51 @@ namespace Microsoft.AspNetCore.Mvc.Formatters
             }
         }
 
+        [Fact]
+        public async Task WriteResponseBodyAsync_AsyncEnumerableConnectionCloses()
+        {
+            // Arrange
+            var formatter = GetOutputFormatter();
+            var mediaType = MediaTypeHeaderValue.Parse("application/json; charset=utf-8");
+
+            var body = new MemoryStream();
+            var actionContext = GetActionContext(mediaType, body);
+            var cts = new CancellationTokenSource();
+            actionContext.HttpContext.RequestAborted = cts.Token;
+            actionContext.HttpContext.RequestServices = new ServiceCollection().AddLogging().BuildServiceProvider();
+
+            var asyncEnumerable = AsyncEnumerableClosedConnection();
+            var outputFormatterContext = new OutputFormatterWriteContext(
+                actionContext.HttpContext,
+                new TestHttpResponseStreamWriterFactory().CreateWriter,
+                asyncEnumerable.GetType(),
+                asyncEnumerable)
+            {
+                ContentType = new StringSegment(mediaType.ToString()),
+            };
+            var iterated = false;
+
+            // Act
+            await formatter.WriteResponseBodyAsync(outputFormatterContext, Encoding.GetEncoding("utf-8"));
+
+            // Assert
+            Assert.Empty(body.ToArray());
+            Assert.False(iterated);
+
+            async IAsyncEnumerable<int> AsyncEnumerableClosedConnection([EnumeratorCancellation] CancellationToken cancellationToken = default)
+            {
+                await Task.Yield();
+                cts.Cancel();
+                // MvcOptions.MaxIAsyncEnumerableBufferLimit is 8192. Pick some value larger than that.
+                foreach (var i in Enumerable.Range(0, 9000))
+                {
+                    cancellationToken.ThrowIfCancellationRequested();
+                    iterated = true;
+                    yield return i;
+                }
+            }
+        }
+
         private class TestableJsonOutputFormatter : NewtonsoftJsonOutputFormatter
         {
             public TestableJsonOutputFormatter(JsonSerializerSettings serializerSettings)
diff --git a/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonResultExecutorTest.cs b/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonResultExecutorTest.cs
index ffa7d2f6f17..7d88e8597b2 100644
--- a/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonResultExecutorTest.cs
+++ b/src/Mvc/Mvc.NewtonsoftJson/test/NewtonsoftJsonResultExecutorTest.cs
@@ -2,6 +2,8 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System.Buffers;
+using System.Runtime.CompilerServices;
+using System.Text;
 using Microsoft.AspNetCore.Mvc.Infrastructure;
 using Microsoft.Extensions.Logging;
 using Microsoft.Extensions.Options;
@@ -25,5 +27,37 @@ namespace Microsoft.AspNetCore.Mvc.NewtonsoftJson
         {
             return new JsonSerializerSettings { Formatting = Formatting.Indented };
         }
+
+        [Fact]
+        public async Task ExecuteAsync_AsyncEnumerableReceivesCancellationToken()
+        {
+            // Arrange
+            var expected = System.Text.Json.JsonSerializer.Serialize(new[] { "Hello", "world" });
+
+            var cts = new CancellationTokenSource();
+            var context = GetActionContext();
+            context.HttpContext.RequestAborted = cts.Token;
+            var result = new JsonResult(TestAsyncEnumerable());
+            var executor = CreateExecutor();
+            CancellationToken token = default;
+
+            // Act
+            await executor.ExecuteAsync(context, result);
+
+            // Assert
+            var written = GetWrittenBytes(context.HttpContext);
+            Assert.Equal(expected, Encoding.UTF8.GetString(written));
+
+            cts.Cancel();
+            Assert.Equal(cts.Token, token);
+
+            async IAsyncEnumerable<string> TestAsyncEnumerable([EnumeratorCancellation] CancellationToken cancellationToken = default)
+            {
+                await Task.Yield();
+                token = cancellationToken;
+                yield return "Hello";
+                yield return "world";
+            }
+        }
     }
 }
-- 
GitLab