Skip to content
代码片段 群组 项目
未验证 提交 8ae29f4c 编辑于 作者: Stephen Halter's avatar Stephen Halter 提交者: GitHub
浏览文件

Produce more efficient TryParse code gen for minimal actions (#31739)

* Produce more efficient TryParse code gen.

* Fix optional string parameter handling

* Remove innerTempSourceString

* Update comments

* Add test cases for IEnumerable<TService>

* fix tests
上级 5e71c23a
No related branches found
No related tags found
无相关合并请求
......@@ -3,6 +3,7 @@
using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Linq.Expressions;
......@@ -53,6 +54,8 @@ namespace Microsoft.AspNetCore.Http
private static readonly MemberExpression StatusCodeExpr = Expression.Property(HttpResponseExpr, nameof(HttpResponse.StatusCode));
private static readonly MemberExpression CompletedTaskExpr = Expression.Property(null, (PropertyInfo)GetMemberInfo<Func<Task>>(() => Task.CompletedTask));
private static readonly BinaryExpression TempSourceStringNotNullExpr = Expression.NotEqual(TempSourceStringExpr, Expression.Constant(null));
private static readonly ConcurrentDictionary<Type, MethodInfo?> TryParseMethodCache = new();
/// <summary>
......@@ -155,10 +158,15 @@ namespace Microsoft.AspNetCore.Http
var arguments = CreateArguments(methodInfo.GetParameters(), factoryContext);
var responseWritingMethodCall = factoryContext.CheckForTryParseFailure ?
CreateTryParseCheckingResponseWritingMethodCall(methodInfo, targetExpression, arguments) :
var responseWritingMethodCall = factoryContext.TryParseParams.Count > 0 ?
CreateTryParseCheckingResponseWritingMethodCall(methodInfo, targetExpression, arguments, factoryContext) :
CreateResponseWritingMethodCall(methodInfo, targetExpression, arguments);
if (factoryContext.UsingTempSourceString)
{
responseWritingMethodCall = Expression.Block(new[] { TempSourceStringExpr }, responseWritingMethodCall);
}
return HandleRequestBodyAndCompileRequestDelegate(responseWritingMethodCall, factoryContext);
}
......@@ -242,30 +250,30 @@ namespace Microsoft.AspNetCore.Http
}
// If we're calling TryParse and wasTryParseFailure indicates it failed, set a 400 StatusCode instead of calling the method.
private static Expression CreateTryParseCheckingResponseWritingMethodCall(MethodInfo methodInfo, Expression? target, Expression[] arguments)
private static Expression CreateTryParseCheckingResponseWritingMethodCall(
MethodInfo methodInfo, Expression? target, Expression[] arguments, FactoryContext factoryContext)
{
// {
// bool wasTryParseFailure = false;
// string tempSourceString;
// bool wasTryParseFailure = false;
//
// // Assume "[FromRoute] int id" is the first parameter.
//
// tempSourceString = httpContext.RequestValue["id"];
// int param1 = tempSourceString == null ? default :
// {
// int parsedValue = default;
// // Assume "int param1" is the first parameter, "[FromRoute] int? param2 = 42" is the second parameter ...
// int param1_local;
// int? param2_local;
// // ...
//
// if (!int.TryParse(tempSourceString, out parsedValue))
// {
// wasTryParseFailure = true;
// Log.ParameterBindingFailed(httpContext, "Int32", "id", tempSourceString)
// }
// tempSourceString = httpContext.RouteValue["param1"] ?? httpContext.Query["param1"];
//
// return parsedValue;
// };
// if (tempSourceString != null)
// {
// if (!int.TryParse(tempSourceString, out param1_local))
// {
// wasTryParseFailure = true;
// Log.ParameterBindingFailed(httpContext, "Int32", "id", tempSourceString)
// }
// }
//
// tempSourceString = httpContext.RequestValue["param2"];
// int param2 = tempSourceString == null ? default :
// tempSourceString = httpContext.RouteValue["param2"];
// // ...
//
// return wasTryParseFailure ?
......@@ -274,41 +282,33 @@ namespace Microsoft.AspNetCore.Http
// return Task.CompletedTask;
// } :
// {
// // Logic generated by AddResponseWritingToMethodCall() that calls action(param1, parm2, ...)
// // Logic generated by AddResponseWritingToMethodCall() that calls action(param1_local, param2_local, ...)
// };
// }
var parameters = methodInfo.GetParameters();
var storedArguments = new ParameterExpression[parameters.Length];
var localVariables = new ParameterExpression[parameters.Length + 2];
var localVariables = new ParameterExpression[factoryContext.TryParseParams.Count + 1];
var tryParseAndCallMethod = new Expression[factoryContext.TryParseParams.Count + 1];
for (var i = 0; i < parameters.Length; i++)
for (var i = 0; i < factoryContext.TryParseParams.Count; i++)
{
storedArguments[i] = localVariables[i] = Expression.Parameter(parameters[i].ParameterType);
(localVariables[i], tryParseAndCallMethod[i]) = factoryContext.TryParseParams[i];
}
localVariables[parameters.Length] = WasTryParseFailureExpr;
localVariables[parameters.Length + 1] = TempSourceStringExpr;
var assignAndCall = new Expression[parameters.Length + 1];
for (var i = 0; i < parameters.Length; i++)
{
assignAndCall[i] = Expression.Assign(localVariables[i], arguments[i]);
}
localVariables[factoryContext.TryParseParams.Count] = WasTryParseFailureExpr;
var set400StatusAndReturnCompletedTask = Expression.Block(
Expression.Assign(StatusCodeExpr, Expression.Constant(400)),
CompletedTaskExpr);
var methodCall = CreateMethodCall(methodInfo, target, storedArguments);
var methodCall = CreateMethodCall(methodInfo, target, arguments);
var checkWasTryParseFailure = Expression.Condition(WasTryParseFailureExpr,
set400StatusAndReturnCompletedTask,
AddResponseWritingToMethodCall(methodCall, methodInfo.ReturnType));
assignAndCall[parameters.Length] = checkWasTryParseFailure;
tryParseAndCallMethod[factoryContext.TryParseParams.Count] = checkWasTryParseFailure;
return Expression.Block(localVariables, assignAndCall);
return Expression.Block(localVariables, tryParseAndCallMethod);
}
private static Expression AddResponseWritingToMethodCall(Expression methodCall, Type returnType)
......@@ -540,10 +540,24 @@ namespace Microsoft.AspNetCore.Http
{
if (parameter.ParameterType == typeof(string))
{
return valueExpression;
if (!parameter.HasDefaultValue)
{
return valueExpression;
}
factoryContext.UsingTempSourceString = true;
return Expression.Block(
Expression.Assign(TempSourceStringExpr, valueExpression),
Expression.Condition(TempSourceStringNotNullExpr,
TempSourceStringExpr,
Expression.Constant(parameter.DefaultValue)));
}
factoryContext.UsingTempSourceString = true;
var underlyingNullableType = Nullable.GetUnderlyingType(parameter.ParameterType);
var isNotNullable = underlyingNullableType is null;
var nonNullableParameterType = underlyingNullableType ?? parameter.ParameterType;
var tryParseMethod = FindTryParseMethod(nonNullableParameterType);
......@@ -552,28 +566,47 @@ namespace Microsoft.AspNetCore.Http
throw new InvalidOperationException($"No public static bool {parameter.ParameterType.Name}.TryParse(string, out {parameter.ParameterType.Name}) method found for {parameter.Name}.");
}
// bool wasTryParseFailure = false;
// string tempSourceString;
// bool wasTryParseFailure = false;
//
// // Assume "[FromRoute] int id" is the first parameter.
// tempSourceString = httpContext.RequestValue["id"];
// // Assume "int param1" is the first parameter and "[FromRoute] int? param2 = 42" is the second parameter.
// int param1_local;
// int? param2_local;
//
// int param1 = tempSourceString == null ? default :
// tempSourceString = httpContext.RouteValue["param1"] ?? httpContext.Query["param1"];
//
// if (tempSourceString != null)
// {
// int parsedValue = default;
// if (!int.TryParse(tempSourceString, out param1_local))
// {
// wasTryParseFailure = true;
// Log.ParameterBindingFailed(httpContext, "Int32", "id", tempSourceString)
// }
// }
//
// if (!int.TryParse(tempSourceString, out parsedValue))
// {
// wasTryParseFailure = true;
// Log.ParameterBindingFailed(httpContext, "Int32", "id", tempSourceString)
// }
// tempSourceString = httpContext.RouteValue["param2"];
//
// return parsedValue;
// };
// if (tempSourceString != null)
// {
// if (int.TryParse(tempSourceString, out int parsedValue))
// {
// param2_local = parsedValue;
// }
// else
// {
// wasTryParseFailure = true;
// Log.ParameterBindingFailed(httpContext, "Int32", "id", tempSourceString)
// }
// }
// else
// {
// param2_local = 42;
// }
factoryContext.CheckForTryParseFailure = true;
var argument = Expression.Variable(parameter.ParameterType, $"{parameter.Name}_local");
var parsedValue = Expression.Variable(nonNullableParameterType);
// If the parameter is nullable, create a "parsedValue" local to TryParse into since we cannot the parameter directly.
var parsedValue = isNotNullable ? argument : Expression.Variable(nonNullableParameterType, "parsedValue");
var parameterTypeNameConstant = Expression.Constant(parameter.ParameterType.Name);
var parameterNameConstant = Expression.Constant(parameter.Name);
......@@ -584,32 +617,30 @@ namespace Microsoft.AspNetCore.Http
HttpContextExpr, parameterTypeNameConstant, parameterNameConstant, TempSourceStringExpr));
var tryParseCall = Expression.Call(tryParseMethod, TempSourceStringExpr, parsedValue);
var ifFailExpression = Expression.IfThen(Expression.Not(tryParseCall), failBlock);
Expression parsedValueExpression = Expression.Block(new[] { parsedValue },
ifFailExpression,
parsedValue);
// Convert back to nullable if necessary.
if (underlyingNullableType is not null)
{
parsedValueExpression = Expression.Convert(parsedValueExpression, parameter.ParameterType);
}
Expression defaultExpression = parameter.HasDefaultValue ?
Expression.Constant(parameter.DefaultValue) :
Expression.Default(parameter.ParameterType);
// tempSourceString = httpContext.RequestValue["id"];
var storeValueToTemp = Expression.Assign(TempSourceStringExpr, valueExpression);
// int param1 = tempSourcString == null ? default : ...
var ternary = Expression.Condition(
Expression.Equal(TempSourceStringExpr, Expression.Constant(null)),
defaultExpression,
parsedValueExpression);
return Expression.Block(storeValueToTemp, ternary);
// If the parameter is nullable, we need to assign the "parsedValue" local to the nullable parameter on success.
Expression tryParseExpression = isNotNullable ?
Expression.IfThen(Expression.Not(tryParseCall), failBlock) :
Expression.Block(new[] { parsedValue },
Expression.IfThenElse(tryParseCall,
Expression.Assign(argument, Expression.Convert(parsedValue, parameter.ParameterType)),
failBlock));
var ifNotNullTryParse = !parameter.HasDefaultValue ?
Expression.IfThen(TempSourceStringNotNullExpr, tryParseExpression) :
Expression.IfThenElse(TempSourceStringNotNullExpr,
tryParseExpression,
Expression.Assign(argument, Expression.Constant(parameter.DefaultValue)));
var fullTryParseBlock = Expression.Block(
// tempSourceString = httpContext.RequestValue["id"];
Expression.Assign(TempSourceStringExpr, valueExpression),
// if (tempSourceString != null) { ... }
ifNotNullTryParse);
factoryContext.TryParseParams.Add((argument, fullTryParseBlock));
return argument;
}
private static Expression BindParameterFromProperty(ParameterInfo parameter, MemberExpression property, string key, FactoryContext factoryContext) =>
......@@ -747,7 +778,8 @@ namespace Microsoft.AspNetCore.Http
public Type? JsonRequestBodyType { get; set; }
public bool AllowEmptyRequestBody { get; set; }
public bool CheckForTryParseFailure { get; set; }
public bool UsingTempSourceString { get; set; }
public List<(ParameterExpression, Expression)> TryParseParams { get; } = new();
}
private static class Log
......
......@@ -7,6 +7,7 @@ using System;
using System.Collections.Generic;
using System.Globalization;
using System.IO;
using System.Linq;
using System.Linq.Expressions;
using System.Net;
using System.Net.Sockets;
......@@ -212,7 +213,17 @@ namespace Microsoft.AspNetCore.Routing.Internal
Assert.Equal(originalRouteParam, httpContext.Items["input"]);
}
private static void TestAction(HttpContext httpContext, [FromRoute] int value = 42)
private static void TestOptional(HttpContext httpContext, [FromRoute] int value = 42)
{
httpContext.Items.Add("input", value);
}
private static void TestOptionalNullable(HttpContext httpContext, int? value = 42)
{
httpContext.Items.Add("input", value);
}
private static void TestOptionalString(HttpContext httpContext, string value = "default")
{
httpContext.Items.Add("input", value);
}
......@@ -222,13 +233,37 @@ namespace Microsoft.AspNetCore.Routing.Internal
{
var httpContext = new DefaultHttpContext();
var requestDelegate = RequestDelegateFactory.Create((Action<HttpContext, int>)TestAction);
var requestDelegate = RequestDelegateFactory.Create((Action<HttpContext, int>)TestOptional);
await requestDelegate(httpContext);
Assert.Equal(42, httpContext.Items["input"]);
}
[Fact]
public async Task RequestDelegatePopulatesFromNullableOptionalParameter()
{
var httpContext = new DefaultHttpContext();
var requestDelegate = RequestDelegateFactory.Create((Action<HttpContext, int>)TestOptional);
await requestDelegate(httpContext);
Assert.Equal(42, httpContext.Items["input"]);
}
[Fact]
public async Task RequestDelegatePopulatesFromOptionalStringParameter()
{
var httpContext = new DefaultHttpContext();
var requestDelegate = RequestDelegateFactory.Create((Action<HttpContext, string>)TestOptionalString);
await requestDelegate(httpContext);
Assert.Equal("default", httpContext.Items["input"]);
}
[Fact]
public async Task RequestDelegatePopulatesFromRouteOptionalParameterBasedOnParameterName()
{
......@@ -239,7 +274,7 @@ namespace Microsoft.AspNetCore.Routing.Internal
httpContext.Request.RouteValues[paramName] = originalRouteParam.ToString(NumberFormatInfo.InvariantInfo);
var requestDelegate = RequestDelegateFactory.Create((Action<HttpContext, int>)TestAction);
var requestDelegate = RequestDelegateFactory.Create((Action<HttpContext, int>)TestOptional);
await requestDelegate(httpContext);
......@@ -727,15 +762,27 @@ namespace Microsoft.AspNetCore.Routing.Internal
httpContext.Items.Add("service", myService);
}
void TestExplicitFromIEnumerableService(HttpContext httpContext, [FromService] IEnumerable<MyService> myServices)
{
httpContext.Items.Add("service", myServices.Single());
}
void TestImpliedFromService(HttpContext httpContext, IMyService myService)
{
httpContext.Items.Add("service", myService);
}
return new[]
void TestImpliedIEnumerableFromService(HttpContext httpContext, IEnumerable<MyService> myServices)
{
httpContext.Items.Add("service", myServices.Single());
}
return new object[][]
{
new[] { (Action<HttpContext, MyService>)TestExplicitFromService },
new[] { (Action<HttpContext, MyService>)TestImpliedFromService },
new[] { (Action<HttpContext, IEnumerable<MyService>>)TestExplicitFromIEnumerableService },
new[] { (Action<HttpContext, IMyService>)TestImpliedFromService },
new[] { (Action<HttpContext, IEnumerable<MyService>>)TestImpliedIEnumerableFromService },
};
}
}
......@@ -753,7 +800,7 @@ namespace Microsoft.AspNetCore.Routing.Internal
var httpContext = new DefaultHttpContext();
httpContext.RequestServices = serviceCollection.BuildServiceProvider();
var requestDelegate = RequestDelegateFactory.Create((Action<HttpContext, MyService>)action);
var requestDelegate = RequestDelegateFactory.Create(action);
await requestDelegate(httpContext);
......@@ -767,7 +814,7 @@ namespace Microsoft.AspNetCore.Routing.Internal
var httpContext = new DefaultHttpContext();
httpContext.RequestServices = new ServiceCollection().BuildServiceProvider();
var requestDelegate = RequestDelegateFactory.Create((Action<HttpContext, MyService>)action);
var requestDelegate = RequestDelegateFactory.Create(action);
await Assert.ThrowsAsync<InvalidOperationException>(() => requestDelegate(httpContext));
}
......
0% 加载中 .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册