diff --git a/src/Middleware/ResponseCompression/src/ResponseCompressionBody.cs b/src/Middleware/ResponseCompression/src/ResponseCompressionBody.cs index 28bafedd8c8d085e4b0da763a3e072b70cfaf52a..c0140967b114605bba6019d741e5e88142372ba5 100644 --- a/src/Middleware/ResponseCompression/src/ResponseCompressionBody.cs +++ b/src/Middleware/ResponseCompression/src/ResponseCompressionBody.cs @@ -201,7 +201,11 @@ namespace Microsoft.AspNetCore.ResponseCompression } } - private void InitializeCompressionHeaders() + /// <summary> + /// Checks if the response should be compressed and sets the response headers. + /// </summary> + /// <returns>The compression provider to use if compression is enabled, otherwise null.</returns> + private ICompressionProvider? InitializeCompressionHeaders() { if (_provider.ShouldCompressResponse(_context)) { @@ -235,7 +239,11 @@ namespace Microsoft.AspNetCore.ResponseCompression headers.ContentMD5 = default; // Reset the MD5 because the content changed. headers.ContentLength = default; } + + return compressionProvider; } + + return null; } private void OnWrite() @@ -244,11 +252,11 @@ namespace Microsoft.AspNetCore.ResponseCompression { _compressionChecked = true; - InitializeCompressionHeaders(); + var compressionProvider = InitializeCompressionHeaders(); - if (_compressionProvider != null) + if (compressionProvider != null) { - _compressionStream = _compressionProvider.CreateStream(_innerStream); + _compressionStream = compressionProvider.CreateStream(_innerStream); } } } diff --git a/src/Middleware/ResponseCompression/test/ResponseCompressionMiddlewareTest.cs b/src/Middleware/ResponseCompression/test/ResponseCompressionMiddlewareTest.cs index 228ccf212da8ace664592b3c520062dbf732e2d3..d55262027d0d6981d7d8c128f968842666dee642 100644 --- a/src/Middleware/ResponseCompression/test/ResponseCompressionMiddlewareTest.cs +++ b/src/Middleware/ResponseCompression/test/ResponseCompressionMiddlewareTest.cs @@ -960,6 +960,73 @@ namespace Microsoft.AspNetCore.ResponseCompression.Tests } } + [Theory] + [MemberData(nameof(SupportedEncodings))] + public async Task UncompressedTrickleWriteAndFlushAsync_FlushesEachWrite(string encoding) + { + var responseReceived = new[] + { + new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously), + new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously), + new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously), + new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously), + new TaskCompletionSource<int>(TaskCreationOptions.RunContinuationsAsynchronously), + }; + + using var host = new HostBuilder() + .ConfigureWebHost(webHostBuilder => + { + webHostBuilder + .UseTestServer() + .ConfigureServices(services => + { + services.AddResponseCompression(); + }) + .Configure(app => + { + app.UseResponseCompression(); + app.Run(async context => + { + context.Response.Headers.ContentMD5 = "MD5"; + context.Response.ContentType = "Un/compressed"; + context.Features.Get<IHttpResponseBodyFeature>().DisableBuffering(); + + foreach (var signal in responseReceived) + { + await context.Response.WriteAsync("a"); + await context.Response.Body.FlushAsync(); + await signal.Task.TimeoutAfter(TimeSpan.FromSeconds(3)); + } + }); + }); + }).Build(); + + await host.StartAsync(); + + var server = host.GetTestServer(); + var client = server.CreateClient(); + + var request = new HttpRequestMessage(HttpMethod.Get, ""); + request.Headers.AcceptEncoding.ParseAdd(encoding); + + var response = await client.SendAsync(request, HttpCompletionOption.ResponseHeadersRead); + Assert.True(response.Content.Headers.TryGetValues(HeaderNames.ContentMD5, out var md5)); + Assert.Equal("MD5", md5.SingleOrDefault()); + Assert.Empty(response.Content.Headers.ContentEncoding); + + var body = await response.Content.ReadAsStreamAsync(); + + var data = new byte[100]; + foreach (var signal in responseReceived) + { + var read = await body.ReadAsync(data, 0, data.Length); + Assert.Equal(1, read); + Assert.Equal('a', (char)data[0]); + + signal.SetResult(0); + } + } + [Fact] public async Task SendFileAsync_DifferentContentType_NotBypassed() {