From d2ab01b41fe341c67b3f7da3f9dc623dac06ffd1 Mon Sep 17 00:00:00 2001 From: David Fowler <davidfowl@gmail.com> Date: Sat, 12 Jun 2021 18:16:12 -0700 Subject: [PATCH] Detect services based on service provider (#32737) * Detect services based on service provider - Use IServiceProviderIsService to detect if a parameter is a service. - As a final fallback, try to detect services from the DI container before falling back to body behavior. --- .../src/PublicAPI.Unshipped.txt | 6 +- .../src/RequestDelegateFactory.cs | 32 +++-- .../test/RequestDelegateFactoryTests.cs | 127 +++++++++++------- ...malActionEndpointRouteBuilderExtensions.cs | 2 +- ...ctionEndpointRouteBuilderExtensionsTest.cs | 35 ++++- 5 files changed, 140 insertions(+), 62 deletions(-) diff --git a/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt b/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt index 33854973715..c940f762c58 100644 --- a/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt +++ b/src/Http/Http.Extensions/src/PublicAPI.Unshipped.txt @@ -169,9 +169,9 @@ static Microsoft.AspNetCore.Http.HeaderDictionaryTypeExtensions.AppendList<T>(th static Microsoft.AspNetCore.Http.HeaderDictionaryTypeExtensions.GetTypedHeaders(this Microsoft.AspNetCore.Http.HttpRequest! request) -> Microsoft.AspNetCore.Http.Headers.RequestHeaders! static Microsoft.AspNetCore.Http.HeaderDictionaryTypeExtensions.GetTypedHeaders(this Microsoft.AspNetCore.Http.HttpResponse! response) -> Microsoft.AspNetCore.Http.Headers.ResponseHeaders! static Microsoft.AspNetCore.Http.HttpContextServerVariableExtensions.GetServerVariable(this Microsoft.AspNetCore.Http.HttpContext! context, string! variableName) -> string? -static Microsoft.AspNetCore.Http.RequestDelegateFactory.Create(System.Delegate! action) -> Microsoft.AspNetCore.Http.RequestDelegate! -static Microsoft.AspNetCore.Http.RequestDelegateFactory.Create(System.Reflection.MethodInfo! methodInfo) -> Microsoft.AspNetCore.Http.RequestDelegate! -static Microsoft.AspNetCore.Http.RequestDelegateFactory.Create(System.Reflection.MethodInfo! methodInfo, System.Func<Microsoft.AspNetCore.Http.HttpContext!, object!>! targetFactory) -> Microsoft.AspNetCore.Http.RequestDelegate! +static Microsoft.AspNetCore.Http.RequestDelegateFactory.Create(System.Delegate! action, System.IServiceProvider? serviceProvider) -> Microsoft.AspNetCore.Http.RequestDelegate! +static Microsoft.AspNetCore.Http.RequestDelegateFactory.Create(System.Reflection.MethodInfo! methodInfo, System.IServiceProvider? serviceProvider) -> Microsoft.AspNetCore.Http.RequestDelegate! +static Microsoft.AspNetCore.Http.RequestDelegateFactory.Create(System.Reflection.MethodInfo! methodInfo, System.IServiceProvider? serviceProvider, System.Func<Microsoft.AspNetCore.Http.HttpContext!, object!>! targetFactory) -> Microsoft.AspNetCore.Http.RequestDelegate! static Microsoft.AspNetCore.Http.ResponseExtensions.Clear(this Microsoft.AspNetCore.Http.HttpResponse! response) -> void static Microsoft.AspNetCore.Http.ResponseExtensions.Redirect(this Microsoft.AspNetCore.Http.HttpResponse! response, string! location, bool permanent, bool preserveMethod) -> void static Microsoft.AspNetCore.Http.SendFileResponseExtensions.SendFileAsync(this Microsoft.AspNetCore.Http.HttpResponse! response, Microsoft.Extensions.FileProviders.IFileInfo! file, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) -> System.Threading.Tasks.Task! diff --git a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs index 6232a5479a0..b14ef5fb3d4 100644 --- a/src/Http/Http.Extensions/src/RequestDelegateFactory.cs +++ b/src/Http/Http.Extensions/src/RequestDelegateFactory.cs @@ -62,8 +62,9 @@ namespace Microsoft.AspNetCore.Http /// Creates a <see cref="RequestDelegate"/> implementation for <paramref name="action"/>. /// </summary> /// <param name="action">A request handler with any number of custom parameters that often produces a response with its return value.</param> + /// <param name="serviceProvider">The <see cref="IServiceProvider"/> instance used to detect which parameters are services.</param> /// <returns>The <see cref="RequestDelegate"/>.</returns> - public static RequestDelegate Create(Delegate action) + public static RequestDelegate Create(Delegate action, IServiceProvider? serviceProvider) { if (action is null) { @@ -76,7 +77,7 @@ namespace Microsoft.AspNetCore.Http null => null, }; - var targetableRequestDelegate = CreateTargetableRequestDelegate(action.Method, targetExpression); + var targetableRequestDelegate = CreateTargetableRequestDelegate(action.Method, serviceProvider, targetExpression); return httpContext => { @@ -88,15 +89,16 @@ namespace Microsoft.AspNetCore.Http /// Creates a <see cref="RequestDelegate"/> implementation for <paramref name="methodInfo"/>. /// </summary> /// <param name="methodInfo">A static request handler with any number of custom parameters that often produces a response with its return value.</param> + /// <param name="serviceProvider">The <see cref="IServiceProvider"/> instance used to detect which parameters are services.</param> /// <returns>The <see cref="RequestDelegate"/>.</returns> - public static RequestDelegate Create(MethodInfo methodInfo) + public static RequestDelegate Create(MethodInfo methodInfo, IServiceProvider? serviceProvider) { if (methodInfo is null) { throw new ArgumentNullException(nameof(methodInfo)); } - var targetableRequestDelegate = CreateTargetableRequestDelegate(methodInfo, targetExpression: null); + var targetableRequestDelegate = CreateTargetableRequestDelegate(methodInfo, serviceProvider, targetExpression: null); return httpContext => { @@ -108,9 +110,10 @@ namespace Microsoft.AspNetCore.Http /// Creates a <see cref="RequestDelegate"/> implementation for <paramref name="methodInfo"/>. /// </summary> /// <param name="methodInfo">A request handler with any number of custom parameters that often produces a response with its return value.</param> + /// <param name="serviceProvider">The <see cref="IServiceProvider"/> instance used to detect which parameters are services.</param> /// <param name="targetFactory">Creates the <see langword="this"/> for the non-static method.</param> /// <returns>The <see cref="RequestDelegate"/>.</returns> - public static RequestDelegate Create(MethodInfo methodInfo, Func<HttpContext, object> targetFactory) + public static RequestDelegate Create(MethodInfo methodInfo, IServiceProvider? serviceProvider, Func<HttpContext, object> targetFactory) { if (methodInfo is null) { @@ -128,7 +131,7 @@ namespace Microsoft.AspNetCore.Http } var targetExpression = Expression.Convert(TargetExpr, methodInfo.DeclaringType); - var targetableRequestDelegate = CreateTargetableRequestDelegate(methodInfo, targetExpression); + var targetableRequestDelegate = CreateTargetableRequestDelegate(methodInfo, serviceProvider, targetExpression); return httpContext => { @@ -136,7 +139,7 @@ namespace Microsoft.AspNetCore.Http }; } - private static Func<object?, HttpContext, Task> CreateTargetableRequestDelegate(MethodInfo methodInfo, Expression? targetExpression) + private static Func<object?, HttpContext, Task> CreateTargetableRequestDelegate(MethodInfo methodInfo, IServiceProvider? serviceProvider, Expression? targetExpression) { // Non void return type @@ -154,7 +157,10 @@ namespace Microsoft.AspNetCore.Http // return default; // } - var factoryContext = new FactoryContext(); + var factoryContext = new FactoryContext() + { + ServiceProvider = serviceProvider + }; var arguments = CreateArguments(methodInfo.GetParameters(), factoryContext); @@ -234,6 +240,15 @@ namespace Microsoft.AspNetCore.Http } else { + if (factoryContext.ServiceProvider?.GetService<IServiceProviderIsService>() is IServiceProviderIsService serviceProviderIsService) + { + // If the parameter resolves as a service then get it from services + if (serviceProviderIsService.IsService(parameter.ParameterType)) + { + return Expression.Call(GetRequiredServiceMethod.MakeGenericMethod(parameter.ParameterType), RequestServicesExpr); + } + } + return BindParameterFromBody(parameter.ParameterType, allowEmpty: false, factoryContext); } } @@ -788,6 +803,7 @@ namespace Microsoft.AspNetCore.Http { public Type? JsonRequestBodyType { get; set; } public bool AllowEmptyRequestBody { get; set; } + public IServiceProvider? ServiceProvider { get; init; } public bool UsingTempSourceString { get; set; } public List<(ParameterExpression, Expression)> TryParseParams { get; } = new(); diff --git a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs index ae85b27e3ee..bac3c83201d 100644 --- a/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs +++ b/src/Http/Http.Extensions/test/RequestDelegateFactoryTests.cs @@ -92,7 +92,7 @@ namespace Microsoft.AspNetCore.Routing.Internal { var httpContext = new DefaultHttpContext(); - var requestDelegate = RequestDelegateFactory.Create(@delegate); + var requestDelegate = RequestDelegateFactory.Create(@delegate, new EmptyServiceProvdier()); await requestDelegate(httpContext); @@ -112,7 +112,7 @@ namespace Microsoft.AspNetCore.Routing.Internal BindingFlags.NonPublic | BindingFlags.Static, new[] { typeof(HttpContext) }); - var requestDelegate = RequestDelegateFactory.Create(methodInfo!); + var requestDelegate = RequestDelegateFactory.Create(methodInfo!, new EmptyServiceProvdier()); var httpContext = new DefaultHttpContext(); @@ -158,7 +158,7 @@ namespace Microsoft.AspNetCore.Routing.Internal return new TestNonStaticActionClass(2); } - var requestDelegate = RequestDelegateFactory.Create(methodInfo!, _ => GetTarget()); + var requestDelegate = RequestDelegateFactory.Create(methodInfo!, new EmptyServiceProvdier(), _ => GetTarget()); var httpContext = new DefaultHttpContext(); @@ -181,10 +181,12 @@ namespace Microsoft.AspNetCore.Routing.Internal BindingFlags.NonPublic | BindingFlags.Static, new[] { typeof(HttpContext) }); - var exNullAction = Assert.Throws<ArgumentNullException>(() => RequestDelegateFactory.Create(action: null!)); - var exNullMethodInfo1 = Assert.Throws<ArgumentNullException>(() => RequestDelegateFactory.Create(methodInfo: null!)); - var exNullMethodInfo2 = Assert.Throws<ArgumentNullException>(() => RequestDelegateFactory.Create(methodInfo: null!, _ => 0)); - var exNullTargetFactory = Assert.Throws<ArgumentNullException>(() => RequestDelegateFactory.Create(methodInfo!, targetFactory: null!)); + var serviceProvider = new EmptyServiceProvdier(); + + var exNullAction = Assert.Throws<ArgumentNullException>(() => RequestDelegateFactory.Create(action: null!, serviceProvider)); + var exNullMethodInfo1 = Assert.Throws<ArgumentNullException>(() => RequestDelegateFactory.Create(methodInfo: null!, serviceProvider)); + var exNullMethodInfo2 = Assert.Throws<ArgumentNullException>(() => RequestDelegateFactory.Create(methodInfo: null!, serviceProvider, _ => 0)); + var exNullTargetFactory = Assert.Throws<ArgumentNullException>(() => RequestDelegateFactory.Create(methodInfo!, serviceProvider, targetFactory: null!)); Assert.Equal("action", exNullAction.ParamName); Assert.Equal("methodInfo", exNullMethodInfo1.ParamName); @@ -206,7 +208,7 @@ namespace Microsoft.AspNetCore.Routing.Internal var httpContext = new DefaultHttpContext(); httpContext.Request.RouteValues[paramName] = originalRouteParam.ToString(NumberFormatInfo.InvariantInfo); - var requestDelegate = RequestDelegateFactory.Create((Action<HttpContext, int>)TestAction); + var requestDelegate = RequestDelegateFactory.Create((Action<HttpContext, int>)TestAction, new EmptyServiceProvdier()); await requestDelegate(httpContext); @@ -233,7 +235,7 @@ namespace Microsoft.AspNetCore.Routing.Internal { var httpContext = new DefaultHttpContext(); - var requestDelegate = RequestDelegateFactory.Create((Action<HttpContext, int>)TestOptional); + var requestDelegate = RequestDelegateFactory.Create((Action<HttpContext, int>)TestOptional, new EmptyServiceProvdier()); await requestDelegate(httpContext); @@ -245,7 +247,7 @@ namespace Microsoft.AspNetCore.Routing.Internal { var httpContext = new DefaultHttpContext(); - var requestDelegate = RequestDelegateFactory.Create((Action<HttpContext, int>)TestOptional); + var requestDelegate = RequestDelegateFactory.Create((Action<HttpContext, int>)TestOptional, new EmptyServiceProvdier()); await requestDelegate(httpContext); @@ -257,7 +259,7 @@ namespace Microsoft.AspNetCore.Routing.Internal { var httpContext = new DefaultHttpContext(); - var requestDelegate = RequestDelegateFactory.Create((Action<HttpContext, string>)TestOptionalString); + var requestDelegate = RequestDelegateFactory.Create((Action<HttpContext, string>)TestOptionalString, new EmptyServiceProvdier()); await requestDelegate(httpContext); @@ -274,7 +276,7 @@ namespace Microsoft.AspNetCore.Routing.Internal httpContext.Request.RouteValues[paramName] = originalRouteParam.ToString(NumberFormatInfo.InvariantInfo); - var requestDelegate = RequestDelegateFactory.Create((Action<HttpContext, int>)TestOptional); + var requestDelegate = RequestDelegateFactory.Create((Action<HttpContext, int>)TestOptional, new EmptyServiceProvdier()); await requestDelegate(httpContext); @@ -297,7 +299,7 @@ namespace Microsoft.AspNetCore.Routing.Internal var httpContext = new DefaultHttpContext(); httpContext.Request.RouteValues[specifiedName] = originalRouteParam.ToString(NumberFormatInfo.InvariantInfo); - var requestDelegate = RequestDelegateFactory.Create((Action<int>)TestAction); + var requestDelegate = RequestDelegateFactory.Create((Action<int>)TestAction, new EmptyServiceProvdier()); await requestDelegate(httpContext); @@ -320,7 +322,7 @@ namespace Microsoft.AspNetCore.Routing.Internal var httpContext = new DefaultHttpContext(); httpContext.Request.RouteValues[unmatchedName] = unmatchedRouteParam.ToString(NumberFormatInfo.InvariantInfo); - var requestDelegate = RequestDelegateFactory.Create((Action<int>)TestAction); + var requestDelegate = RequestDelegateFactory.Create((Action<int>)TestAction, new EmptyServiceProvdier()); await requestDelegate(httpContext); @@ -399,7 +401,7 @@ namespace Microsoft.AspNetCore.Routing.Internal var httpContext = new DefaultHttpContext(); httpContext.Request.RouteValues["tryParsable"] = routeValue; - var requestDelegate = RequestDelegateFactory.Create(action); + var requestDelegate = RequestDelegateFactory.Create(action, new EmptyServiceProvdier()); await requestDelegate(httpContext); @@ -416,7 +418,7 @@ namespace Microsoft.AspNetCore.Routing.Internal ["tryParsable"] = routeValue }); - var requestDelegate = RequestDelegateFactory.Create(action); + var requestDelegate = RequestDelegateFactory.Create(action, new EmptyServiceProvdier()); await requestDelegate(httpContext); @@ -437,8 +439,9 @@ namespace Microsoft.AspNetCore.Routing.Internal var requestDelegate = RequestDelegateFactory.Create((Action<HttpContext, int>)((httpContext, tryParsable) => { - httpContext.Items["tryParsable"] = tryParsable; - })); + httpContext.Items["tryParsable"] = tryParsable; + }), + new EmptyServiceProvdier()); await requestDelegate(httpContext); @@ -466,7 +469,7 @@ namespace Microsoft.AspNetCore.Routing.Internal [MemberData(nameof(DelegatesWithAttributesOnNotTryParsableParameters))] public void CreateThrowsInvalidOperationExceptionWhenAttributeRequiresTryParseMethodThatDoesNotExist(Delegate action) { - var ex = Assert.Throws<InvalidOperationException>(() => RequestDelegateFactory.Create(action)); + var ex = Assert.Throws<InvalidOperationException>(() => RequestDelegateFactory.Create(action, new EmptyServiceProvdier())); Assert.Equal("No public static bool Object.TryParse(string, out Object) method found for notTryParsable.", ex.Message); } @@ -475,7 +478,7 @@ namespace Microsoft.AspNetCore.Routing.Internal { var unnamedParameter = Expression.Parameter(typeof(int)); var lambda = Expression.Lambda(Expression.Block(), unnamedParameter); - var ex = Assert.Throws<InvalidOperationException>(() => RequestDelegateFactory.Create((Action<int>)lambda.Compile())); + var ex = Assert.Throws<InvalidOperationException>(() => RequestDelegateFactory.Create((Action<int>)lambda.Compile(), new EmptyServiceProvdier())); Assert.Equal("A parameter does not have a name! Was it genererated? All parameters must be named.", ex.Message); } @@ -498,7 +501,7 @@ namespace Microsoft.AspNetCore.Routing.Internal httpContext.Features.Set<IHttpRequestLifetimeFeature>(new TestHttpRequestLifetimeFeature()); httpContext.RequestServices = serviceCollection.BuildServiceProvider(); - var requestDelegate = RequestDelegateFactory.Create((Action<int, int>)TestAction); + var requestDelegate = RequestDelegateFactory.Create((Action<int, int>)TestAction, new EmptyServiceProvdier()); await requestDelegate(httpContext); @@ -540,7 +543,7 @@ namespace Microsoft.AspNetCore.Routing.Internal var httpContext = new DefaultHttpContext(); httpContext.Request.Query = query; - var requestDelegate = RequestDelegateFactory.Create((Action<int>)TestAction); + var requestDelegate = RequestDelegateFactory.Create((Action<int>)TestAction, new EmptyServiceProvdier()); await requestDelegate(httpContext); @@ -563,7 +566,7 @@ namespace Microsoft.AspNetCore.Routing.Internal var httpContext = new DefaultHttpContext(); httpContext.Request.Headers[customHeaderName] = originalHeaderParam.ToString(NumberFormatInfo.InvariantInfo); - var requestDelegate = RequestDelegateFactory.Create((Action<int>)TestAction); + var requestDelegate = RequestDelegateFactory.Create((Action<int>)TestAction, new EmptyServiceProvdier()); await requestDelegate(httpContext); @@ -607,7 +610,7 @@ namespace Microsoft.AspNetCore.Routing.Internal var requestBodyBytes = JsonSerializer.SerializeToUtf8Bytes(originalTodo); httpContext.Request.Body = new MemoryStream(requestBodyBytes); - var requestDelegate = RequestDelegateFactory.Create(action); + var requestDelegate = RequestDelegateFactory.Create(action, new EmptyServiceProvdier()); await requestDelegate(httpContext); @@ -624,7 +627,7 @@ namespace Microsoft.AspNetCore.Routing.Internal httpContext.Request.Headers["Content-Type"] = "application/json"; httpContext.Request.Headers["Content-Length"] = "0"; - var requestDelegate = RequestDelegateFactory.Create(action); + var requestDelegate = RequestDelegateFactory.Create(action, new EmptyServiceProvdier()); await Assert.ThrowsAsync<JsonException>(() => requestDelegate(httpContext)); } @@ -643,7 +646,7 @@ namespace Microsoft.AspNetCore.Routing.Internal httpContext.Request.Headers["Content-Type"] = "application/json"; httpContext.Request.Headers["Content-Length"] = "0"; - var requestDelegate = RequestDelegateFactory.Create((Action<Todo>)TestAction); + var requestDelegate = RequestDelegateFactory.Create((Action<Todo>)TestAction, new EmptyServiceProvdier()); await requestDelegate(httpContext); @@ -667,7 +670,7 @@ namespace Microsoft.AspNetCore.Routing.Internal httpContext.Request.Headers["Content-Type"] = "application/json"; httpContext.Request.Headers["Content-Length"] = "0"; - var requestDelegate = RequestDelegateFactory.Create((Action<BodyStruct>)TestAction); + var requestDelegate = RequestDelegateFactory.Create((Action<BodyStruct>)TestAction, new EmptyServiceProvdier()); await requestDelegate(httpContext); @@ -694,7 +697,7 @@ namespace Microsoft.AspNetCore.Routing.Internal httpContext.Features.Set<IHttpRequestLifetimeFeature>(new TestHttpRequestLifetimeFeature()); httpContext.RequestServices = serviceCollection.BuildServiceProvider(); - var requestDelegate = RequestDelegateFactory.Create((Action<Todo>)TestAction); + var requestDelegate = RequestDelegateFactory.Create((Action<Todo>)TestAction, new EmptyServiceProvdier()); await requestDelegate(httpContext); @@ -727,7 +730,7 @@ namespace Microsoft.AspNetCore.Routing.Internal httpContext.Features.Set<IHttpRequestLifetimeFeature>(new TestHttpRequestLifetimeFeature()); httpContext.RequestServices = serviceCollection.BuildServiceProvider(); - var requestDelegate = RequestDelegateFactory.Create((Action<Todo>)TestAction); + var requestDelegate = RequestDelegateFactory.Create((Action<Todo>)TestAction, new EmptyServiceProvdier()); await requestDelegate(httpContext); @@ -748,9 +751,9 @@ namespace Microsoft.AspNetCore.Routing.Internal void TestInferredInvalidAction(Todo value1, Todo value2) { } void TestBothInvalidAction(Todo value1, [FromBody] int value2) { } - Assert.Throws<InvalidOperationException>(() => RequestDelegateFactory.Create((Action<int, int>)TestAttributedInvalidAction)); - Assert.Throws<InvalidOperationException>(() => RequestDelegateFactory.Create((Action<Todo, Todo>)TestInferredInvalidAction)); - Assert.Throws<InvalidOperationException>(() => RequestDelegateFactory.Create((Action<Todo, int>)TestBothInvalidAction)); + Assert.Throws<InvalidOperationException>(() => RequestDelegateFactory.Create((Action<int, int>)TestAttributedInvalidAction, new EmptyServiceProvdier())); + Assert.Throws<InvalidOperationException>(() => RequestDelegateFactory.Create((Action<Todo, Todo>)TestInferredInvalidAction, new EmptyServiceProvdier())); + Assert.Throws<InvalidOperationException>(() => RequestDelegateFactory.Create((Action<Todo, int>)TestBothInvalidAction, new EmptyServiceProvdier())); } public static object[][] FromServiceActions @@ -777,12 +780,18 @@ namespace Microsoft.AspNetCore.Routing.Internal httpContext.Items.Add("service", myServices.Single()); } + void TestImpliedFromServiceBasedOnContainer(HttpContext httpContext, MyService myService) + { + httpContext.Items.Add("service", myService); + } + return new object[][] { new[] { (Action<HttpContext, MyService>)TestExplicitFromService }, new[] { (Action<HttpContext, IEnumerable<MyService>>)TestExplicitFromIEnumerableService }, new[] { (Action<HttpContext, IMyService>)TestImpliedFromService }, new[] { (Action<HttpContext, IEnumerable<MyService>>)TestImpliedIEnumerableFromService }, + new[] { (Action<HttpContext, MyService>)TestImpliedFromServiceBasedOnContainer }, }; } } @@ -797,10 +806,14 @@ namespace Microsoft.AspNetCore.Routing.Internal serviceCollection.AddSingleton(myOriginalService); serviceCollection.AddSingleton<IMyService>(myOriginalService); + var services = serviceCollection.BuildServiceProvider(); + + using var requestScoped = services.CreateScope(); + var httpContext = new DefaultHttpContext(); - httpContext.RequestServices = serviceCollection.BuildServiceProvider(); + httpContext.RequestServices = requestScoped.ServiceProvider; - var requestDelegate = RequestDelegateFactory.Create(action); + var requestDelegate = RequestDelegateFactory.Create(action, services); await requestDelegate(httpContext); @@ -812,9 +825,9 @@ namespace Microsoft.AspNetCore.Routing.Internal public async Task RequestDelegateRequiresServiceForAllFromServiceParameters(Delegate action) { var httpContext = new DefaultHttpContext(); - httpContext.RequestServices = new ServiceCollection().BuildServiceProvider(); + httpContext.RequestServices = new EmptyServiceProvdier(); - var requestDelegate = RequestDelegateFactory.Create(action); + var requestDelegate = RequestDelegateFactory.Create(action, new EmptyServiceProvdier()); await Assert.ThrowsAsync<InvalidOperationException>(() => requestDelegate(httpContext)); } @@ -831,7 +844,7 @@ namespace Microsoft.AspNetCore.Routing.Internal var httpContext = new DefaultHttpContext(); - var requestDelegate = RequestDelegateFactory.Create((Action<HttpContext>)TestAction); + var requestDelegate = RequestDelegateFactory.Create((Action<HttpContext>)TestAction, new EmptyServiceProvdier()); await requestDelegate(httpContext); @@ -854,7 +867,7 @@ namespace Microsoft.AspNetCore.Routing.Internal RequestAborted = cts.Token }; - var requestDelegate = RequestDelegateFactory.Create((Action<CancellationToken>)TestAction); + var requestDelegate = RequestDelegateFactory.Create((Action<CancellationToken>)TestAction, new EmptyServiceProvdier()); await requestDelegate(httpContext); @@ -898,7 +911,7 @@ namespace Microsoft.AspNetCore.Routing.Internal var responseBodyStream = new MemoryStream(); httpContext.Response.Body = responseBodyStream; - var requestDelegate = RequestDelegateFactory.Create(@delegate); + var requestDelegate = RequestDelegateFactory.Create(@delegate, new EmptyServiceProvdier()); await requestDelegate(httpContext); @@ -947,7 +960,7 @@ namespace Microsoft.AspNetCore.Routing.Internal var responseBodyStream = new MemoryStream(); httpContext.Response.Body = responseBodyStream; - var requestDelegate = RequestDelegateFactory.Create(@delegate); + var requestDelegate = RequestDelegateFactory.Create(@delegate, new EmptyServiceProvdier()); await requestDelegate(httpContext); @@ -990,7 +1003,7 @@ namespace Microsoft.AspNetCore.Routing.Internal var responseBodyStream = new MemoryStream(); httpContext.Response.Body = responseBodyStream; - var requestDelegate = RequestDelegateFactory.Create(@delegate); + var requestDelegate = RequestDelegateFactory.Create(@delegate, new EmptyServiceProvdier()); await requestDelegate(httpContext); @@ -1031,7 +1044,7 @@ namespace Microsoft.AspNetCore.Routing.Internal var responseBodyStream = new MemoryStream(); httpContext.Response.Body = responseBodyStream; - var requestDelegate = RequestDelegateFactory.Create(@delegate); + var requestDelegate = RequestDelegateFactory.Create(@delegate, new EmptyServiceProvdier()); await requestDelegate(httpContext); @@ -1072,7 +1085,7 @@ namespace Microsoft.AspNetCore.Routing.Internal var responseBodyStream = new MemoryStream(); httpContext.Response.Body = responseBodyStream; - var requestDelegate = RequestDelegateFactory.Create(@delegate); + var requestDelegate = RequestDelegateFactory.Create(@delegate, new EmptyServiceProvdier()); await requestDelegate(httpContext); @@ -1110,7 +1123,7 @@ namespace Microsoft.AspNetCore.Routing.Internal var responseBodyStream = new MemoryStream(); httpContext.Response.Body = responseBodyStream; - var requestDelegate = RequestDelegateFactory.Create(@delegate); + var requestDelegate = RequestDelegateFactory.Create(@delegate, serviceProvider: null); var exception = await Assert.ThrowsAnyAsync<InvalidOperationException>(async () => await requestDelegate(httpContext)); Assert.Contains(message, exception.Message); @@ -1155,7 +1168,7 @@ namespace Microsoft.AspNetCore.Routing.Internal var responseBodyStream = new MemoryStream(); httpContext.Response.Body = responseBodyStream; - var requestDelegate = RequestDelegateFactory.Create(@delegate); + var requestDelegate = RequestDelegateFactory.Create(@delegate, serviceProvider: null); await requestDelegate(httpContext); @@ -1268,6 +1281,30 @@ namespace Microsoft.AspNetCore.Routing.Internal } } + private class EmptyServiceProvdier : IServiceScope, IServiceProvider, IServiceScopeFactory + { + public IServiceProvider ServiceProvider => this; + + public IServiceScope CreateScope() + { + return new EmptyServiceProvdier(); + } + + public void Dispose() + { + + } + + public object? GetService(Type serviceType) + { + if (serviceType == typeof(IServiceScopeFactory)) + { + return this; + } + return null; + } + } + private class TestHttpRequestLifetimeFeature : IHttpRequestLifetimeFeature { private readonly CancellationTokenSource _requestAbortedCts = new CancellationTokenSource(); diff --git a/src/Http/Routing/src/Builder/MinimalActionEndpointRouteBuilderExtensions.cs b/src/Http/Routing/src/Builder/MinimalActionEndpointRouteBuilderExtensions.cs index 3db9fbfa98f..fb538f4d528 100644 --- a/src/Http/Routing/src/Builder/MinimalActionEndpointRouteBuilderExtensions.cs +++ b/src/Http/Routing/src/Builder/MinimalActionEndpointRouteBuilderExtensions.cs @@ -159,7 +159,7 @@ namespace Microsoft.AspNetCore.Builder const int defaultOrder = 0; var builder = new RouteEndpointBuilder( - RequestDelegateFactory.Create(action), + RequestDelegateFactory.Create(action, endpoints.ServiceProvider), pattern, defaultOrder) { diff --git a/src/Http/Routing/test/UnitTests/Builder/MinimalActionEndpointRouteBuilderExtensionsTest.cs b/src/Http/Routing/test/UnitTests/Builder/MinimalActionEndpointRouteBuilderExtensionsTest.cs index 02542f3e8eb..fec8da0ce7c 100644 --- a/src/Http/Routing/test/UnitTests/Builder/MinimalActionEndpointRouteBuilderExtensionsTest.cs +++ b/src/Http/Routing/test/UnitTests/Builder/MinimalActionEndpointRouteBuilderExtensionsTest.cs @@ -8,6 +8,7 @@ using System.Collections.Generic; using System.Linq; using System.Runtime.CompilerServices; using Microsoft.AspNetCore.Routing; +using Microsoft.Extensions.DependencyInjection; using Moq; using Xunit; @@ -28,7 +29,7 @@ namespace Microsoft.AspNetCore.Builder [Fact] public void MapEndpoint_PrecedenceOfMetadata_BuilderMetadataReturned() { - var builder = new DefaultEndpointRouteBuilder(Mock.Of<IApplicationBuilder>()); + var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(new EmptyServiceProvdier())); [HttpMethod("ATTRIBUTE")] void TestAction() @@ -60,7 +61,7 @@ namespace Microsoft.AspNetCore.Builder [Fact] public void MapGet_BuildsEndpointWithCorrectMethod() { - var builder = new DefaultEndpointRouteBuilder(Mock.Of<IApplicationBuilder>()); + var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(new EmptyServiceProvdier())); _ = builder.MapGet("/", (Action)(() => { })); var dataSource = GetBuilderEndpointDataSource(builder); @@ -80,7 +81,7 @@ namespace Microsoft.AspNetCore.Builder [Fact] public void MapPost_BuildsEndpointWithCorrectMethod() { - var builder = new DefaultEndpointRouteBuilder(Mock.Of<IApplicationBuilder>()); + var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(new EmptyServiceProvdier())); _ = builder.MapPost("/", (Action)(() => { })); var dataSource = GetBuilderEndpointDataSource(builder); @@ -100,7 +101,7 @@ namespace Microsoft.AspNetCore.Builder [Fact] public void MapPut_BuildsEndpointWithCorrectMethod() { - var builder = new DefaultEndpointRouteBuilder(Mock.Of<IApplicationBuilder>()); + var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(new EmptyServiceProvdier())); _ = builder.MapPut("/", (Action)(() => { })); var dataSource = GetBuilderEndpointDataSource(builder); @@ -120,7 +121,7 @@ namespace Microsoft.AspNetCore.Builder [Fact] public void MapDelete_BuildsEndpointWithCorrectMethod() { - var builder = new DefaultEndpointRouteBuilder(Mock.Of<IApplicationBuilder>()); + var builder = new DefaultEndpointRouteBuilder(new ApplicationBuilder(new EmptyServiceProvdier())); _ = builder.MapDelete("/", (Action)(() => { })); var dataSource = GetBuilderEndpointDataSource(builder); @@ -148,5 +149,29 @@ namespace Microsoft.AspNetCore.Builder HttpMethods = httpMethods; } } + + private class EmptyServiceProvdier : IServiceScope, IServiceProvider, IServiceScopeFactory + { + public IServiceProvider ServiceProvider => this; + + public IServiceScope CreateScope() + { + return new EmptyServiceProvdier(); + } + + public void Dispose() + { + + } + + public object? GetService(Type serviceType) + { + if (serviceType == typeof(IServiceScopeFactory)) + { + return this; + } + return null; + } + } } } -- GitLab