diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index 7762e2e1467f597c0cf0657fde272545404dda60..6232a5479a00e59489ec2f4554710e045d9f35c7 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -4,6 +4,7 @@ using System; using System.Collections.Concurrent; using System.Collections.Generic; +using System.Globalization; using System.IO; using System.Linq; using System.Linq.Expressions; @@ -30,7 +31,7 @@ namespace Microsoft.AspNetCore.Http private static readonly MethodInfo ExecuteTaskResultOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTaskResult), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo ExecuteValueResultTaskOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteValueTaskResult), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo GetRequiredServiceMethod = typeof(ServiceProviderServiceExtensions).GetMethod(nameof(ServiceProviderServiceExtensions.GetRequiredService), BindingFlags.Public | BindingFlags.Static, new Type[] { typeof(IServiceProvider) })!; - private static readonly MethodInfo ResultWriteResponseAsyncMethod = typeof(IResult).GetMethod(nameof(IResult.ExecuteAsync), BindingFlags.Public | BindingFlags.Instance)!; + private static readonly MethodInfo ResultWriteResponseAsyncMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteResultWriteResponse), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo StringResultWriteResponseAsyncMethod = GetMethodInfo<Func<HttpResponse, string, Task>>((response, text) => HttpResponseWritingExtensions.WriteAsync(response, text, default)); private static readonly MethodInfo JsonResultWriteResponseAsyncMethod = GetMethodInfo<Func<HttpResponse, object, Task>>((response, value) => HttpResponseJsonExtensions.WriteAsJsonAsync(response, value, default)); private static readonly MethodInfo EnumTryParseMethod = GetEnumTryParseMethod(); @@ -393,7 +394,7 @@ namespace Microsoft.AspNetCore.Http } else if (typeof(IResult).IsAssignableFrom(returnType)) { - return Expression.Call(methodCall, ResultWriteResponseAsyncMethod, HttpContextExpr); + return Expression.Call(ResultWriteResponseAsyncMethod, methodCall, HttpContextExpr); } else if (returnType == typeof(string)) { @@ -679,6 +680,8 @@ namespace Microsoft.AspNetCore.Http private static Task ExecuteTask<T>(Task<T> task, HttpContext httpContext) { + EnsureRequestTaskNotNull(task); + static async Task ExecuteAwaited(Task<T> task, HttpContext httpContext) { await httpContext.Response.WriteAsJsonAsync(await task); @@ -692,8 +695,10 @@ namespace Microsoft.AspNetCore.Http return ExecuteAwaited(task, httpContext); } - private static Task ExecuteTaskOfString(Task<string> task, HttpContext httpContext) + private static Task ExecuteTaskOfString(Task<string?> task, HttpContext httpContext) { + EnsureRequestTaskNotNull(task); + static async Task ExecuteAwaited(Task<string> task, HttpContext httpContext) { await httpContext.Response.WriteAsync(await task); @@ -701,10 +706,10 @@ namespace Microsoft.AspNetCore.Http if (task.IsCompletedSuccessfully) { - return httpContext.Response.WriteAsync(task.GetAwaiter().GetResult()); + return httpContext.Response.WriteAsync(task.GetAwaiter().GetResult()!); } - return ExecuteAwaited(task, httpContext); + return ExecuteAwaited(task!, httpContext); } private static Task ExecuteValueTask(ValueTask task) @@ -737,7 +742,7 @@ namespace Microsoft.AspNetCore.Http return ExecuteAwaited(task, httpContext); } - private static Task ExecuteValueTaskOfString(ValueTask<string> task, HttpContext httpContext) + private static Task ExecuteValueTaskOfString(ValueTask<string?> task, HttpContext httpContext) { static async Task ExecuteAwaited(ValueTask<string> task, HttpContext httpContext) { @@ -746,30 +751,37 @@ namespace Microsoft.AspNetCore.Http if (task.IsCompletedSuccessfully) { - return httpContext.Response.WriteAsync(task.GetAwaiter().GetResult()); + return httpContext.Response.WriteAsync(task.GetAwaiter().GetResult()!); } - return ExecuteAwaited(task, httpContext); + return ExecuteAwaited(task!, httpContext); } - private static Task ExecuteValueTaskResult<T>(ValueTask<T> task, HttpContext httpContext) where T : IResult + private static Task ExecuteValueTaskResult<T>(ValueTask<T?> task, HttpContext httpContext) where T : IResult { static async Task ExecuteAwaited(ValueTask<T> task, HttpContext httpContext) { - await (await task).ExecuteAsync(httpContext); + await EnsureRequestResultNotNull(await task)!.ExecuteAsync(httpContext); } if (task.IsCompletedSuccessfully) { - return task.GetAwaiter().GetResult().ExecuteAsync(httpContext); + return EnsureRequestResultNotNull(task.GetAwaiter().GetResult())!.ExecuteAsync(httpContext); } - return ExecuteAwaited(task, httpContext); + return ExecuteAwaited(task!, httpContext); } - private static async Task ExecuteTaskResult<T>(Task<T> task, HttpContext httpContext) where T : IResult + private static async Task ExecuteTaskResult<T>(Task<T?> task, HttpContext httpContext) where T : IResult { - await (await task).ExecuteAsync(httpContext); + EnsureRequestTaskOfNotNull(task); + + await EnsureRequestResultNotNull(await task)!.ExecuteAsync(httpContext); + } + + private static async Task ExecuteResultWriteResponse(IResult result, HttpContext httpContext) + { + await EnsureRequestResultNotNull(result)!.ExecuteAsync(httpContext); } private class FactoryContext @@ -819,5 +831,31 @@ namespace Microsoft.AspNetCore.Http return loggerFactory.CreateLogger(typeof(RequestDelegateFactory)); } } + + private static void EnsureRequestTaskOfNotNull<T>(Task<T?> task) where T : IResult + { + if (task is null) + { + throw new InvalidOperationException("The IResult in Task<IResult> response must not be null."); + } + } + + private static void EnsureRequestTaskNotNull(Task? task) + { + if (task is null) + { + throw new InvalidOperationException("The Task returned by the Delegate must not be null."); + } + } + + private static IResult EnsureRequestResultNotNull(IResult? result) + { + if (result is null) + { + throw new InvalidOperationException("The IResult returned by the Delegate must not be null."); + } + + return result; + } } } diff --git a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs index 32e115a80038b26b344c4ea8574c931754a9fdde..ae85b27e3ee351c4e4b40749e83f262054a2c5d7 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs @@ -1081,6 +1081,89 @@ namespace Microsoft.AspNetCore.Routing.Internal Assert.Equal("true", responseBody); } + public static IEnumerable<object[]> NullResult + { + get + { + IResult? TestAction() => null; + Task<bool?>? TaskBoolAction() => null; + Task<IResult?>? TaskNullAction() => null; + Task<IResult?> TaskTestAction() => Task.FromResult<IResult?>(null); + ValueTask<IResult?> ValueTaskTestAction() => ValueTask.FromResult<IResult?>(null); + + return new List<object[]> + { + new object[] { (Func<IResult?>)TestAction, "The IResult returned by the Delegate must not be null." }, + new object[] { (Func<Task<IResult?>?>)TaskNullAction, "The IResult in Task<IResult> response must not be null." }, + new object[] { (Func<Task<bool?>?>)TaskBoolAction, "The Task returned by the Delegate must not be null." }, + new object[] { (Func<Task<IResult?>>)TaskTestAction, "The IResult returned by the Delegate must not be null." }, + new object[] { (Func<ValueTask<IResult?>>)ValueTaskTestAction, "The IResult returned by the Delegate must not be null." }, + }; + } + } + + [Theory] + [MemberData(nameof(NullResult))] + public async Task RequestDelegateThrowsInvalidOperationExceptionOnNullDelegate(Delegate @delegate, string message) + { + var httpContext = new DefaultHttpContext(); + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + + var requestDelegate = RequestDelegateFactory.Create(@delegate); + + var exception = await Assert.ThrowsAnyAsync<InvalidOperationException>(async () => await requestDelegate(httpContext)); + Assert.Contains(message, exception.Message); + } + + public static IEnumerable<object[]> NullContentResult + { + get + { + bool? TestBoolAction() => null; + Task<bool?> TaskTestBoolAction() => Task.FromResult<bool?>(null); + ValueTask<bool?> ValueTaskTestBoolAction() => ValueTask.FromResult<bool?>(null); + + int? TestIntAction() => null; + Task<int?> TaskTestIntAction() => Task.FromResult<int?>(null); + ValueTask<int?> ValueTaskTestIntAction() => ValueTask.FromResult<int?>(null); + + Todo? TestTodoAction() => null; + Task<Todo?> TaskTestTodoAction() => Task.FromResult<Todo?>(null); + ValueTask<Todo?> ValueTaskTestTodoAction() => ValueTask.FromResult<Todo?>(null); + + return new List<object[]> + { + new object[] { (Func<bool?>)TestBoolAction }, + new object[] { (Func<Task<bool?>>)TaskTestBoolAction }, + new object[] { (Func<ValueTask<bool?>>)ValueTaskTestBoolAction }, + new object[] { (Func<int?>)TestIntAction }, + new object[] { (Func<Task<int?>>)TaskTestIntAction }, + new object[] { (Func<ValueTask<int?>>)ValueTaskTestIntAction }, + new object[] { (Func<Todo?>)TestTodoAction }, + new object[] { (Func<Task<Todo?>>)TaskTestTodoAction }, + new object[] { (Func<ValueTask<Todo?>>)ValueTaskTestTodoAction }, + }; + } + } + + [Theory] + [MemberData(nameof(NullContentResult))] + public async Task RequestDelegateWritesNullReturnNullValue(Delegate @delegate) + { + var httpContext = new DefaultHttpContext(); + var responseBodyStream = new MemoryStream(); + httpContext.Response.Body = responseBodyStream; + + var requestDelegate = RequestDelegateFactory.Create(@delegate); + + await requestDelegate(httpContext); + + var responseBody = Encoding.UTF8.GetString(responseBodyStream.ToArray()); + + Assert.Equal("null", responseBody); + } + private class Todo { public int Id { get; set; }