diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index e2ee1514ff438d075c0ec33d7be725b8f17e9ba2..7bd9fe9392d338dbf8216c9acfaa1a712c7bb5b6 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -36,6 +36,8 @@ public static partial class RequestDelegateFactory private static readonly MethodInfo ExecuteTaskWithEmptyResultMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTaskWithEmptyResult), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo ExecuteValueTaskWithEmptyResultMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteValueTaskWithEmptyResult), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo ExecuteTaskOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTaskOfT), BindingFlags.NonPublic | BindingFlags.Static)!; + private static readonly MethodInfo ExecuteTaskOfObjectMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTaskOfObject), BindingFlags.NonPublic | BindingFlags.Static)!; + private static readonly MethodInfo ExecuteValueTaskOfObjectMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteValueTaskOfObject), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo ExecuteTaskOfStringMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteTaskOfString), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo ExecuteValueTaskOfTMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteValueTaskOfT), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo ExecuteValueTaskMethod = typeof(RequestDelegateFactory).GetMethod(nameof(ExecuteValueTask), BindingFlags.NonPublic | BindingFlags.Static)!; @@ -749,14 +751,15 @@ public static partial class RequestDelegateFactory } else if (returnType == typeof(ValueTask<object>)) { - // REVIEW: We can avoid this box if it becomes a performance issue - var box = Expression.TypeAs(methodCall, typeof(object)); - return Expression.Call(ExecuteObjectReturnMethod, box, HttpContextExpr); + return Expression.Call(ExecuteValueTaskOfObjectMethod, + methodCall, + HttpContextExpr); } else if (returnType == typeof(Task<object>)) { - var convert = Expression.Convert(methodCall, typeof(object)); - return Expression.Call(ExecuteObjectReturnMethod, convert, HttpContextExpr); + return Expression.Call(ExecuteTaskOfObjectMethod, + methodCall, + HttpContextExpr); } else if (AwaitableInfo.IsTypeAwaitable(returnType, out _)) { @@ -1685,53 +1688,76 @@ public static partial class RequestDelegateFactory // if necessary and restart the cycle until we've reached a terminal state (unknown type). // We currently don't handle Task<unknown> or ValueTask<unknown>. We can support this later if this // ends up being a common scenario. - private static async Task ExecuteObjectReturn(object? obj, HttpContext httpContext) + private static Task ExecuteValueTaskOfObject(ValueTask<object> valueTask, HttpContext httpContext) + { + static async Task ExecuteAwaited(ValueTask<object> valueTask, HttpContext httpContext) + { + await ExecuteObjectReturn(await valueTask, httpContext); + } + + if (valueTask.IsCompletedSuccessfully) + { + return ExecuteObjectReturn(valueTask.GetAwaiter().GetResult(), httpContext); + } + + return ExecuteAwaited(valueTask, httpContext); + } + + private static Task ExecuteTaskOfObject(Task<object> task, HttpContext httpContext) + { + static async Task ExecuteAwaited(Task<object> task, HttpContext httpContext) + { + await ExecuteObjectReturn(await task, httpContext); + } + + if (task.IsCompletedSuccessfully) + { + return ExecuteObjectReturn(task.GetAwaiter().GetResult(), httpContext); + } + + return ExecuteAwaited(task, httpContext); + } + + private static Task ExecuteObjectReturn(object obj, HttpContext httpContext) { - // See if we need to unwrap Task<object> or ValueTask<object> if (obj is Task<object> taskObj) { - obj = await taskObj; + return ExecuteTaskOfObject(taskObj, httpContext); } else if (obj is ValueTask<object> valueTaskObj) { - obj = await valueTaskObj; + return ExecuteValueTaskOfObject(valueTaskObj, httpContext); } else if (obj is Task<IResult?> task) { - await ExecuteTaskResult(task, httpContext); - return; + return ExecuteTaskResult(task, httpContext); } else if (obj is ValueTask<IResult?> valueTask) { - await ExecuteValueTaskResult(valueTask, httpContext); - return; + return ExecuteValueTaskResult(valueTask, httpContext); } else if (obj is Task<string?> taskString) { - await ExecuteTaskOfString(taskString, httpContext); - return; - } + return ExecuteTaskOfString(taskString, httpContext); } else if (obj is ValueTask<string?> valueTaskString) { - await ExecuteValueTaskOfString(valueTaskString, httpContext); - return; + return ExecuteValueTaskOfString(valueTaskString, httpContext); } - // Terminal built ins - if (obj is IResult result) + else if (obj is IResult result) { - await ExecuteResultWriteResponse(result, httpContext); + return ExecuteResultWriteResponse(result, httpContext); } else if (obj is string stringValue) { SetPlaintextContentType(httpContext); - await httpContext.Response.WriteAsync(stringValue); + return httpContext.Response.WriteAsync(stringValue); } else { // Otherwise, we JSON serialize when we reach the terminal state // Call WriteAsJsonAsync<object?>() to serialize the runtime return type rather than the declared return type. - await httpContext.Response.WriteAsJsonAsync<object?>(obj); + return httpContext.Response.WriteAsJsonAsync<object?>(obj); } }