diff --git a/src/Http/Routing/src/Internal/MapActionExpressionTreeBuilder.cs b/src/Http/Routing/src/Internal/MapActionExpressionTreeBuilder.cs index 54b931cd9875a876ee771bf5f766df7519b64fc8..1e83d92c3497c9608f763d82b5e2e155e51a88f3 100644 --- a/src/Http/Routing/src/Internal/MapActionExpressionTreeBuilder.cs +++ b/src/Http/Routing/src/Internal/MapActionExpressionTreeBuilder.cs @@ -24,7 +24,8 @@ namespace Microsoft.AspNetCore.Routing.Internal private static readonly MethodInfo ChangeTypeMethodInfo = GetMethodInfo<Func<object, Type, object>>((value, type) => Convert.ChangeType(value, type, CultureInfo.InvariantCulture)); private static readonly MethodInfo ExecuteTaskOfTMethodInfo = typeof(MapActionExpressionTreeBuilder).GetMethod(nameof(ExecuteTask), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo ExecuteTaskOfStringMethodInfo = typeof(MapActionExpressionTreeBuilder).GetMethod(nameof(ExecuteTaskOfString), BindingFlags.NonPublic | BindingFlags.Static)!; - private static readonly MethodInfo ExecuteValueTaskOfTMethodInfo = typeof(MapActionExpressionTreeBuilder).GetMethod(nameof(ExecuteValueTask), BindingFlags.NonPublic | BindingFlags.Static)!; + private static readonly MethodInfo ExecuteValueTaskOfTMethodInfo = typeof(MapActionExpressionTreeBuilder).GetMethod(nameof(ExecuteValueTaskOfT), BindingFlags.NonPublic | BindingFlags.Static)!; + private static readonly MethodInfo ExecuteValueTaskMethodInfo = typeof(MapActionExpressionTreeBuilder).GetMethod(nameof(ExecuteValueTask), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo ExecuteValueTaskOfStringMethodInfo = typeof(MapActionExpressionTreeBuilder).GetMethod(nameof(ExecuteValueTaskOfString), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo ExecuteTaskResultOfTMethodInfo = typeof(MapActionExpressionTreeBuilder).GetMethod(nameof(ExecuteTaskResult), BindingFlags.NonPublic | BindingFlags.Static)!; private static readonly MethodInfo ExecuteValueResultTaskOfTMethodInfo = typeof(MapActionExpressionTreeBuilder).GetMethod(nameof(ExecuteValueTaskResult), BindingFlags.NonPublic | BindingFlags.Static)!; @@ -71,28 +72,31 @@ namespace Microsoft.AspNetCore.Routing.Internal // This argument represents the deserialized body returned from IHttpRequestReader // when the method has a FromBody attribute declared - var args = new List<Expression>(); + var methodParameters = method.GetParameters(); + var args = new List<Expression>(methodParameters.Length); - foreach (var parameter in method.GetParameters()) + foreach (var parameter in methodParameters) { Expression paramterExpression = Expression.Default(parameter.ParameterType); - if (parameter.GetCustomAttributes().OfType<IFromRouteMetadata>().FirstOrDefault() is { } routeAttribute) + var parameterCustomAttributes = parameter.GetCustomAttributes(); + + if (parameterCustomAttributes.OfType<IFromRouteMetadata>().FirstOrDefault() is { } routeAttribute) { var routeValuesProperty = Expression.Property(HttpRequestExpr, nameof(HttpRequest.RouteValues)); paramterExpression = BindParamenter(routeValuesProperty, parameter, routeAttribute.Name); } - else if (parameter.GetCustomAttributes().OfType<IFromQueryMetadata>().FirstOrDefault() is { } queryAttribute) + else if (parameterCustomAttributes.OfType<IFromQueryMetadata>().FirstOrDefault() is { } queryAttribute) { var queryProperty = Expression.Property(HttpRequestExpr, nameof(HttpRequest.Query)); paramterExpression = BindParamenter(queryProperty, parameter, queryAttribute.Name); } - else if (parameter.GetCustomAttributes().OfType<IFromHeaderMetadata>().FirstOrDefault() is { } headerAttribute) + else if (parameterCustomAttributes.OfType<IFromHeaderMetadata>().FirstOrDefault() is { } headerAttribute) { var headersProperty = Expression.Property(HttpRequestExpr, nameof(HttpRequest.Headers)); paramterExpression = BindParamenter(headersProperty, parameter, headerAttribute.Name); } - else if (parameter.GetCustomAttributes().OfType<IFromBodyMetadata>().FirstOrDefault() is { } bodyAttribute) + else if (parameterCustomAttributes.OfType<IFromBodyMetadata>().FirstOrDefault() is { } bodyAttribute) { if (consumeBodyDirectly) { @@ -109,7 +113,7 @@ namespace Microsoft.AspNetCore.Routing.Internal bodyType = parameter.ParameterType; paramterExpression = Expression.Convert(DeserializedBodyArg, bodyType); } - else if (parameter.GetCustomAttributes().OfType<IFromFormMetadata>().FirstOrDefault() is { } formAttribute) + else if (parameterCustomAttributes.OfType<IFromFormMetadata>().FirstOrDefault() is { } formAttribute) { if (consumeBodyDirectly) { @@ -125,27 +129,24 @@ namespace Microsoft.AspNetCore.Routing.Internal { paramterExpression = Expression.Call(GetRequiredServiceMethodInfo.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr); } - else + else if (parameter.ParameterType == typeof(IFormCollection)) { - if (parameter.ParameterType == typeof(IFormCollection)) + if (consumeBodyDirectly) { - if (consumeBodyDirectly) - { - ThrowCannotReadBodyDirectlyAndAsForm(); - } + ThrowCannotReadBodyDirectlyAndAsForm(); + } - consumeBodyAsForm = true; + consumeBodyAsForm = true; - paramterExpression = Expression.Property(HttpRequestExpr, nameof(HttpRequest.Form)); - } - else if (parameter.ParameterType == typeof(HttpContext)) - { - paramterExpression = HttpContextParameter; - } - else if (parameter.ParameterType == typeof(CancellationToken)) - { - paramterExpression = RequestAbortedExpr; - } + paramterExpression = Expression.Property(HttpRequestExpr, nameof(HttpRequest.Form)); + } + else if (parameter.ParameterType == typeof(HttpContext)) + { + paramterExpression = HttpContextParameter; + } + else if (parameter.ParameterType == typeof(CancellationToken)) + { + paramterExpression = RequestAbortedExpr; } args.Add(paramterExpression); @@ -182,6 +183,12 @@ namespace Microsoft.AspNetCore.Routing.Internal { body = methodCall; } + else if (method.ReturnType == typeof(ValueTask)) + { + body = Expression.Call( + ExecuteValueTaskMethodInfo, + methodCall); + } else if (method.ReturnType.IsGenericType && method.ReturnType.GetGenericTypeDefinition() == typeof(Task<>)) { @@ -263,7 +270,7 @@ namespace Microsoft.AspNetCore.Routing.Internal var box = Expression.TypeAs(methodCall, typeof(object)); body = Expression.Call(JsonResultWriteResponseAsync, HttpResponseExpr, box, Expression.Constant(CancellationToken.None)); } - else + else { body = Expression.Call(JsonResultWriteResponseAsync, HttpResponseExpr, methodCall, Expression.Constant(CancellationToken.None)); } @@ -398,10 +405,20 @@ namespace Microsoft.AspNetCore.Routing.Internal expr = Expression.Convert(expr, parameter.ParameterType); } + Expression defaultExpression; + if (parameter.HasDefaultValue) + { + defaultExpression = Expression.Constant(parameter.DefaultValue); + } + else + { + defaultExpression = Expression.Default(parameter.ParameterType); + } + // property[key] == null ? default : (ParameterType){Type}.Parse(property[key]); expr = Expression.Condition( Expression.Equal(valueArg, Expression.Constant(null)), - Expression.Default(parameter.ParameterType), + defaultExpression, expr); return expr; @@ -449,7 +466,22 @@ namespace Microsoft.AspNetCore.Routing.Internal return ExecuteAwaited(task, httpContext); } - private static Task ExecuteValueTask<T>(ValueTask<T> task, HttpContext httpContext) + private static Task ExecuteValueTask(ValueTask task) + { + static async Task ExecuteAwaited(ValueTask task) + { + await task; + } + + if (task.IsCompletedSuccessfully) + { + task.GetAwaiter().GetResult(); + } + + return ExecuteAwaited(task); + } + + private static Task ExecuteValueTaskOfT<T>(ValueTask<T> task, HttpContext httpContext) { static async Task ExecuteAwaited(ValueTask<T> task, HttpContext httpContext) { diff --git a/src/Http/Routing/test/UnitTests/Internal/MapActionExpressionTreeBuilderTest.cs b/src/Http/Routing/test/UnitTests/Internal/MapActionExpressionTreeBuilderTest.cs index d9c7754327e5cb24dfe33833165c5f2dfb50db6c..543e8a17d808c728e56f1e1561b1224e9652e78e 100644 --- a/src/Http/Routing/test/UnitTests/Internal/MapActionExpressionTreeBuilderTest.cs +++ b/src/Http/Routing/test/UnitTests/Internal/MapActionExpressionTreeBuilderTest.cs @@ -24,44 +24,186 @@ namespace Microsoft.AspNetCore.Routing.Internal { public class MapActionExpressionTreeBuilderTest { - [Fact] - public async Task RequestDelegateInvokesAction() + public static IEnumerable<object[]> NoResult { - var invoked = false; - - void TestAction() + get { - invoked = true; + void TestAction(HttpContext httpContext) + { + MarkAsInvoked(httpContext); + } + + Task TaskTestAction(HttpContext httpContext) + { + MarkAsInvoked(httpContext); + return Task.CompletedTask; + } + + ValueTask ValueTaskTestAction(HttpContext httpContext) + { + MarkAsInvoked(httpContext); + return ValueTask.CompletedTask; + } + + void StaticTestAction(HttpContext httpContext) + { + MarkAsInvoked(httpContext); + } + + Task StaticTaskTestAction(HttpContext httpContext) + { + MarkAsInvoked(httpContext); + return Task.CompletedTask; + } + + ValueTask StaticValueTaskTestAction(HttpContext httpContext) + { + MarkAsInvoked(httpContext); + return ValueTask.CompletedTask; + } + + void MarkAsInvoked(HttpContext httpContext) + { + httpContext.Items.Add("invoked", true); + } + + return new List<object[]> + { + new object[] { (Action<HttpContext>)TestAction }, + new object[] { (Func<HttpContext, Task>)TaskTestAction }, + new object[] { (Func<HttpContext, ValueTask>)ValueTaskTestAction }, + new object[] { (Action<HttpContext>)StaticTestAction }, + new object[] { (Func<HttpContext, Task>)StaticTaskTestAction }, + new object[] { (Func<HttpContext, ValueTask>)StaticValueTaskTestAction }, + }; } + } - var requestDelegate = MapActionExpressionTreeBuilder.BuildRequestDelegate((Action)TestAction); + [Theory] + [MemberData(nameof(NoResult))] + public async Task RequestDelegateInvokesAction(Delegate @delegate) + { + var httpContext = new DefaultHttpContext(); + + var requestDelegate = MapActionExpressionTreeBuilder.BuildRequestDelegate(@delegate); - await requestDelegate(null!); + await requestDelegate(httpContext); - Assert.True(invoked); + Assert.True(httpContext.Items["invoked"] as bool?); } - [Fact] - public async Task RequestDelegatePopulatesFromRouteParameterBasedOnParameterName() + public static IEnumerable<object[]> FromRouteResult + { + get + { + void TestAction(HttpContext httpContext, [FromRoute] int value) + { + StoreInput(httpContext, value); + }; + + Task TaskTestAction(HttpContext httpContext, [FromRoute] int value) + { + StoreInput(httpContext, value); + return Task.CompletedTask; + } + + ValueTask ValueTaskTestAction(HttpContext httpContext, [FromRoute] int value) + { + StoreInput(httpContext, value); + return ValueTask.CompletedTask; + } + + + + return new List<object[]> + { + new object[] { (Action<HttpContext, int>)TestAction }, + new object[] { (Func<HttpContext, int, Task>)TaskTestAction }, + new object[] { (Func<HttpContext, int, ValueTask>)ValueTaskTestAction }, + }; + } + } + private static void StoreInput(HttpContext httpContext, object value) + { + httpContext.Items.Add("input", value); + } + + [Theory] + [MemberData(nameof(FromRouteResult))] + public async Task RequestDelegatePopulatesFromRouteParameterBasedOnParameterName(Delegate @delegate) { const string paramName = "value"; const int originalRouteParam = 42; - int? deserializedRouteParam = null; + var httpContext = new DefaultHttpContext(); + httpContext.Request.RouteValues[paramName] = originalRouteParam.ToString(NumberFormatInfo.InvariantInfo); + + var requestDelegate = MapActionExpressionTreeBuilder.BuildRequestDelegate(@delegate); + + await requestDelegate(httpContext); + + Assert.Equal(originalRouteParam, httpContext.Items["input"] as int?); + } - void TestAction([FromRoute] int value) + public static IEnumerable<object[]> FromRouteOptionalResult + { + get { - deserializedRouteParam = value; + return new List<object[]> + { + new object[] { (Action<HttpContext, int>)TestAction }, + new object[] { (Func<HttpContext, int, Task>)TaskTestAction }, + new object[] { (Func<HttpContext, int, ValueTask>)ValueTaskTestAction } + }; } + } + + private static void TestAction(HttpContext httpContext, [FromRoute] int value = 42) + { + StoreInput(httpContext, value); + } + + private static Task TaskTestAction(HttpContext httpContext, [FromRoute] int value = 42) + { + StoreInput(httpContext, value); + return Task.CompletedTask; + } + + private static ValueTask ValueTaskTestAction(HttpContext httpContext, [FromRoute] int value = 42) + { + StoreInput(httpContext, value); + return ValueTask.CompletedTask; + } + [Theory] + [MemberData(nameof(FromRouteOptionalResult))] + public async Task RequestDelegatePopulatesFromRouteOptionalParameter(Delegate @delegate) + { var httpContext = new DefaultHttpContext(); + + var requestDelegate = MapActionExpressionTreeBuilder.BuildRequestDelegate(@delegate); + + await requestDelegate(httpContext); + + Assert.Equal(42, httpContext.Items["input"] as int?); + } + + [Theory] + [MemberData(nameof(FromRouteOptionalResult))] + public async Task RequestDelegatePopulatesFromRouteOptionalParameterBasedOnParameterName(Delegate @delegate) + { + const string paramName = "value"; + const int originalRouteParam = 47; + + var httpContext = new DefaultHttpContext(); + httpContext.Request.RouteValues[paramName] = originalRouteParam.ToString(NumberFormatInfo.InvariantInfo); - var requestDelegate = MapActionExpressionTreeBuilder.BuildRequestDelegate((Action<int>)TestAction); + var requestDelegate = MapActionExpressionTreeBuilder.BuildRequestDelegate(@delegate); await requestDelegate(httpContext); - Assert.Equal(originalRouteParam, deserializedRouteParam); + Assert.Equal(47, httpContext.Items["input"] as int?); } [Fact]