From 6a8cc0d98d18de0fb390e92c8a98173cfa2da68f Mon Sep 17 00:00:00 2001
From: Damian Edwards <damian@damianedwards.com>
Date: Fri, 10 Jun 2022 16:33:40 -0700
Subject: [PATCH] Introduce IBindableFromHttpContext<TSelf> (#41100)

* Added IBindableFromHttpContext<TSelf> interface
* Discover BindAsync via IBindableFromHttpContext interface
* Added tests
Co-authored-by: Stephen Halter <halter73@gmail.com>
---
 .../src/IBindableFromHttpContextOfT.cs        |  22 +++
 .../src/PublicAPI.Unshipped.txt               |   2 +
 .../test/ParameterBindingMethodCacheTests.cs  | 146 +++++++++++++++++-
 src/Http/samples/MinimalSample/Program.cs     |  14 ++
 src/Shared/ParameterBindingMethodCache.cs     |  35 ++++-
 5 files changed, 214 insertions(+), 5 deletions(-)
 create mode 100644 src/Http/Http.Abstractions/src/IBindableFromHttpContextOfT.cs

diff --git a/src/Http/Http.Abstractions/src/IBindableFromHttpContextOfT.cs b/src/Http/Http.Abstractions/src/IBindableFromHttpContextOfT.cs
new file mode 100644
index 00000000000..0f414b849e5
--- /dev/null
+++ b/src/Http/Http.Abstractions/src/IBindableFromHttpContextOfT.cs
@@ -0,0 +1,22 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+using System.Reflection;
+
+namespace Microsoft.AspNetCore.Http;
+
+/// <summary>
+/// Defines a mechanism for creating an instance of a type from an <see cref="HttpContext"/> when binding parameters for an endpoint
+/// route handler delegate.
+/// </summary>
+/// <typeparam name="TSelf">The type that implements this interface.</typeparam>
+public interface IBindableFromHttpContext<TSelf> where TSelf : class, IBindableFromHttpContext<TSelf>
+{
+    /// <summary>
+    /// Creates an instance of <typeparamref name="TSelf"/> from the <see cref="HttpContext"/>.
+    /// </summary>
+    /// <param name="context">The <see cref="HttpContext"/> for the current request.</param>
+    /// <param name="parameter">The <see cref="ParameterInfo"/> for the parameter of the route handler delegate the returned instance will populate.</param>
+    /// <returns>The instance of <typeparamref name="TSelf"/>.</returns>
+    static abstract ValueTask<TSelf?> BindAsync(HttpContext context, ParameterInfo parameter);
+}
diff --git a/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt b/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt
index b8d0ab94d18..a307310dc1a 100644
--- a/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt
+++ b/src/Http/Http.Abstractions/src/PublicAPI.Unshipped.txt
@@ -9,6 +9,8 @@ Microsoft.AspNetCore.Http.DefaultRouteHandlerInvocationContext
 Microsoft.AspNetCore.Http.DefaultRouteHandlerInvocationContext.DefaultRouteHandlerInvocationContext(Microsoft.AspNetCore.Http.HttpContext! httpContext, params object![]! arguments) -> void
 Microsoft.AspNetCore.Http.EndpointMetadataCollection.Enumerator.Current.get -> object!
 Microsoft.AspNetCore.Http.EndpointMetadataCollection.GetRequiredMetadata<T>() -> T!
+Microsoft.AspNetCore.Http.IBindableFromHttpContext<TSelf>
+Microsoft.AspNetCore.Http.IBindableFromHttpContext<TSelf>.BindAsync(Microsoft.AspNetCore.Http.HttpContext! context, System.Reflection.ParameterInfo! parameter) -> System.Threading.Tasks.ValueTask<TSelf?>
 Microsoft.AspNetCore.Http.IRouteHandlerFilter.InvokeAsync(Microsoft.AspNetCore.Http.RouteHandlerInvocationContext! context, Microsoft.AspNetCore.Http.RouteHandlerFilterDelegate! next) -> System.Threading.Tasks.ValueTask<object?>
 Microsoft.AspNetCore.Http.Metadata.IFromFormMetadata
 Microsoft.AspNetCore.Http.Metadata.IFromFormMetadata.Name.get -> string?
diff --git a/src/Http/Http.Extensions/test/ParameterBindingMethodCacheTests.cs b/src/Http/Http.Extensions/test/ParameterBindingMethodCacheTests.cs
index 6a1085ca2c4..725fb958785 100644
--- a/src/Http/Http.Extensions/test/ParameterBindingMethodCacheTests.cs
+++ b/src/Http/Http.Extensions/test/ParameterBindingMethodCacheTests.cs
@@ -3,6 +3,7 @@
 
 #nullable enable
 
+using System.Diagnostics.CodeAnalysis;
 using System.Globalization;
 using System.Linq.Expressions;
 using System.Reflection;
@@ -302,6 +303,10 @@ public class ParameterBindingMethodCacheTests
                     {
                         GetFirstParameter((BindAsyncFromInterfaceWithParameterInfo arg) => BindAsyncFromInterfaceWithParameterInfoMethod(arg))
                     },
+                    new[]
+                    {
+                        GetFirstParameter((BindAsyncFromStaticAbstractInterfaceAndBindAsync arg) => BindAsyncFromImplicitStaticAbstractInterfaceMethodInsteadOfReflectionMatchedMethod(arg))
+                    },
                 };
         }
     }
@@ -320,6 +325,27 @@ public class ParameterBindingMethodCacheTests
         Assert.True(new ParameterBindingMethodCache().HasBindAsyncMethod(parameterInfo));
     }
 
+    [Fact]
+    public void HasBindAsyncMethod_ReturnsTrueForClassImplicitlyImplementingIBindableFromHttpContext()
+    {
+        var parameterInfo = GetFirstParameter((BindAsyncFromImplicitStaticAbstractInterface arg) => BindAsyncFromImplicitStaticAbstractInterfaceMethod(arg));
+        Assert.True(new ParameterBindingMethodCache().HasBindAsyncMethod(parameterInfo));
+    }
+
+    [Fact]
+    public void HasBindAsyncMethod_ReturnsTrueForClassExplicitlyImplementingIBindableFromHttpContext()
+    {
+        var parameterInfo = GetFirstParameter((BindAsyncFromExplicitStaticAbstractInterface arg) => BindAsyncFromExplicitStaticAbstractInterfaceMethod(arg));
+        Assert.True(new ParameterBindingMethodCache().HasBindAsyncMethod(parameterInfo));
+    }
+
+    [Fact]
+    public void HasBindAsyncMethod_ReturnsTrueForClassImplementingIBindableFromHttpContextAndNonInterfaceBindAsyncMethod()
+    {
+        var parameterInfo = GetFirstParameter((BindAsyncFromStaticAbstractInterfaceAndBindAsync arg) => BindAsyncFromImplicitStaticAbstractInterfaceMethodInsteadOfReflectionMatchedMethod(arg));
+        Assert.True(new ParameterBindingMethodCache().HasBindAsyncMethod(parameterInfo));
+    }
+
     [Fact]
     public void FindBindAsyncMethod_FindsNonNullableReturningBindAsyncMethodGivenNullableType()
     {
@@ -327,6 +353,42 @@ public class ParameterBindingMethodCacheTests
         Assert.True(new ParameterBindingMethodCache().HasBindAsyncMethod(parameterInfo));
     }
 
+    [Fact]
+    public async Task FindBindAsyncMethod_FindsForClassImplicitlyImplementingIBindableFromHttpContext()
+    {
+        var parameterInfo = GetFirstParameter((BindAsyncFromImplicitStaticAbstractInterface arg) => BindAsyncFromImplicitStaticAbstractInterfaceMethod(arg));
+        var cache = new ParameterBindingMethodCache();
+        Assert.True(cache.HasBindAsyncMethod(parameterInfo));
+        var methodFound = cache.FindBindAsyncMethod(parameterInfo);
+
+        var parseHttpContext = Expression.Lambda<Func<HttpContext, ValueTask<object?>>>(methodFound.Expression!,
+            ParameterBindingMethodCache.HttpContextExpr).Compile();
+
+        var httpContext = new DefaultHttpContext();
+
+        var result = await parseHttpContext(httpContext);
+        Assert.NotNull(result);
+        Assert.IsType<BindAsyncFromImplicitStaticAbstractInterface>(result);
+    }
+
+    [Fact]
+    public async Task FindBindAsyncMethod_FindsForClassExplicitlyImplementingIBindableFromHttpContext()
+    {
+        var parameterInfo = GetFirstParameter((BindAsyncFromExplicitStaticAbstractInterface arg) => BindAsyncFromExplicitStaticAbstractInterfaceMethod(arg));
+        var cache = new ParameterBindingMethodCache();
+        Assert.True(cache.HasBindAsyncMethod(parameterInfo));
+        var methodFound = cache.FindBindAsyncMethod(parameterInfo);
+
+        var parseHttpContext = Expression.Lambda<Func<HttpContext, ValueTask<object?>>>(methodFound.Expression!,
+            ParameterBindingMethodCache.HttpContextExpr).Compile();
+
+        var httpContext = new DefaultHttpContext();
+
+        var result = await parseHttpContext(httpContext);
+        Assert.NotNull(result);
+        Assert.IsType<BindAsyncFromExplicitStaticAbstractInterface>(result);
+    }
+
     [Fact]
     public async Task FindBindAsyncMethod_FindsFallbackMethodWhenPreferredMethodsReturnTypeIsWrong()
     {
@@ -359,6 +421,25 @@ public class ParameterBindingMethodCacheTests
         Assert.Null(await parseHttpContext(httpContext));
     }
 
+    [Fact]
+    public async Task FindBindAsyncMethod_FindsMethodFromStaticAbstractInterfaceWhenValidNonInterfaceMethodAlsoExists()
+    {
+        var parameterInfo = GetFirstParameter((BindAsyncFromStaticAbstractInterfaceAndBindAsync arg) => BindAsyncFromImplicitStaticAbstractInterfaceMethodInsteadOfReflectionMatchedMethod(arg));
+        var cache = new ParameterBindingMethodCache();
+        Assert.True(cache.HasBindAsyncMethod(parameterInfo));
+        var methodFound = cache.FindBindAsyncMethod(parameterInfo);
+
+        var parseHttpContext = Expression.Lambda<Func<HttpContext, ValueTask<object>>>(methodFound.Expression!,
+            ParameterBindingMethodCache.HttpContextExpr).Compile();
+
+        var httpContext = new DefaultHttpContext();
+        var result = await parseHttpContext(httpContext);
+
+        Assert.NotNull(result);
+        Assert.IsType<BindAsyncFromStaticAbstractInterfaceAndBindAsync>(result);
+        Assert.Equal(BindAsyncSource.InterfaceStaticAbstractImplicit, ((BindAsyncFromStaticAbstractInterfaceAndBindAsync)result).BoundFrom);
+    }
+
     [Theory]
     [InlineData(typeof(ClassWithParameterlessConstructor))]
     [InlineData(typeof(RecordClassParameterlessConstructor))]
@@ -499,6 +580,7 @@ public class ParameterBindingMethodCacheTests
                 typeof(BindAsyncWithParameterInfoWrongTypeInherit),
                 typeof(BindAsyncWrongTypeFromInterface),
                 typeof(BindAsyncBothBadMethods),
+                typeof(BindAsyncFromStaticAbstractInterfaceWrongType)
             };
         }
     }
@@ -627,7 +709,6 @@ public class ParameterBindingMethodCacheTests
     private static void BindAsyncStructMethod(BindAsyncStruct arg) { }
     private static void BindAsyncNullableStructMethod(BindAsyncStruct? arg) { }
     private static void NullableReturningBindAsyncStructMethod(NullableReturningBindAsyncStruct arg) { }
-
     private static void BindAsyncSingleArgRecordMethod(BindAsyncSingleArgRecord arg) { }
     private static void BindAsyncSingleArgStructMethod(BindAsyncSingleArgStruct arg) { }
     private static void InheritBindAsyncMethod(InheritBindAsync arg) { }
@@ -639,6 +720,10 @@ public class ParameterBindingMethodCacheTests
     private static void BindAsyncFromInterfaceWithParameterInfoMethod(BindAsyncFromInterfaceWithParameterInfo args) { }
     private static void BindAsyncFallbackMethod(BindAsyncFallsBack? arg) { }
     private static void BindAsyncBadMethodMethod(BindAsyncBadMethod? arg) { }
+    private static void BindAsyncFromImplicitStaticAbstractInterfaceMethod(BindAsyncFromImplicitStaticAbstractInterface arg) { }
+    private static void BindAsyncFromExplicitStaticAbstractInterfaceMethod(BindAsyncFromExplicitStaticAbstractInterface arg) { }
+    private static void BindAsyncFromImplicitStaticAbstractInterfaceMethodInsteadOfReflectionMatchedMethod(BindAsyncFromStaticAbstractInterfaceAndBindAsync arg) { }
+    private static void BindAsyncFromStaticAbstractInterfaceWrongTypeMethod(BindAsyncFromStaticAbstractInterfaceWrongType arg) { }
 
     private static ParameterInfo GetFirstParameter<T>(Expression<Action<T>> expr)
     {
@@ -646,6 +731,12 @@ public class ParameterBindingMethodCacheTests
         return mc.Method.GetParameters()[0];
     }
 
+    private static ParameterInfo GetParameterAtIndex<T>(Expression<Action<T>> expr, int paramIndex)
+    {
+        var mc = (MethodCallExpression)expr.Body;
+        return mc.Method.GetParameters()[paramIndex];
+    }
+
     private record TryParseStringRecord(int Value)
     {
         public static bool TryParse(string? value, IFormatProvider formatProvider, out TryParseStringRecord? result)
@@ -1347,6 +1438,59 @@ public class ParameterBindingMethodCacheTests
         }
     }
 
+    private class BindAsyncFromImplicitStaticAbstractInterface : IBindableFromHttpContext<BindAsyncFromImplicitStaticAbstractInterface>
+    {
+        public static ValueTask<BindAsyncFromImplicitStaticAbstractInterface?> BindAsync(HttpContext context, ParameterInfo parameter)
+        {
+            return ValueTask.FromResult<BindAsyncFromImplicitStaticAbstractInterface?>(new());
+        }
+    }
+
+    private class BindAsyncFromExplicitStaticAbstractInterface : IBindableFromHttpContext<BindAsyncFromExplicitStaticAbstractInterface>
+    {
+        static ValueTask<BindAsyncFromExplicitStaticAbstractInterface?> IBindableFromHttpContext<BindAsyncFromExplicitStaticAbstractInterface>.BindAsync(HttpContext context, ParameterInfo parameter)
+        {
+            return ValueTask.FromResult<BindAsyncFromExplicitStaticAbstractInterface?>(new());
+        }
+    }
+
+    private class BindAsyncFromStaticAbstractInterfaceAndBindAsync : IBindableFromHttpContext<BindAsyncFromStaticAbstractInterfaceAndBindAsync>
+    {
+        public BindAsyncFromStaticAbstractInterfaceAndBindAsync(BindAsyncSource boundFrom)
+        {
+            BoundFrom = boundFrom;
+        }
+
+        public BindAsyncSource BoundFrom { get; }
+
+        // Implicit interface implementation
+        public static ValueTask<BindAsyncFromStaticAbstractInterfaceAndBindAsync?> BindAsync(HttpContext context, ParameterInfo parameter)
+        {
+            return ValueTask.FromResult<BindAsyncFromStaticAbstractInterfaceAndBindAsync?>(new(BindAsyncSource.InterfaceStaticAbstractImplicit));
+        }
+
+        // Late-bound pattern based match in RequestDelegateFactory
+        public static ValueTask<BindAsyncFromStaticAbstractInterfaceAndBindAsync?> BindAsync(HttpContext context)
+        {
+            return ValueTask.FromResult<BindAsyncFromStaticAbstractInterfaceAndBindAsync?>(new(BindAsyncSource.Reflection));
+        }
+    }
+
+    private class BindAsyncFromStaticAbstractInterfaceWrongType : IBindableFromHttpContext<BindAsyncFromImplicitStaticAbstractInterface>
+    {
+        public static ValueTask<BindAsyncFromImplicitStaticAbstractInterface?> BindAsync(HttpContext context, ParameterInfo parameter)
+        {
+            return ValueTask.FromResult<BindAsyncFromImplicitStaticAbstractInterface?>(new());
+        }
+    }
+
+    private enum BindAsyncSource
+    {
+        Reflection,
+        InterfaceStaticAbstractImplicit,
+        InterfaceStaticAbstractExplicit
+    }
+
     private class MockParameterInfo : ParameterInfo
     {
         public MockParameterInfo(Type type, string name)
diff --git a/src/Http/samples/MinimalSample/Program.cs b/src/Http/samples/MinimalSample/Program.cs
index a15f6798608..27923bed88b 100644
--- a/src/Http/samples/MinimalSample/Program.cs
+++ b/src/Http/samples/MinimalSample/Program.cs
@@ -2,6 +2,7 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using Microsoft.AspNetCore.Http.HttpResults;
+using System.Reflection;
 using Microsoft.AspNetCore.Mvc;
 
 var builder = WebApplication.CreateBuilder(args);
@@ -63,6 +64,19 @@ app.MapGet("/problem/{problemType}", (string problemType) => problemType switch
 
     });
 
+app.MapPost("/todos", (TodoBindable todo) => todo);
+
 app.Run();
 
 internal record Todo(int Id, string Title);
+public class TodoBindable : IBindableFromHttpContext<TodoBindable>
+{
+    public int Id { get; set; }
+    public string Title { get; set; }
+    public bool IsComplete { get; set; }
+
+    public static ValueTask<TodoBindable> BindAsync(HttpContext context, ParameterInfo parameter)
+    {
+        return ValueTask.FromResult(new TodoBindable { Id = 1, Title = "I was bound from IBindableFromHttpContext<TodoBindable>.BindAsync!" });
+    }
+}
diff --git a/src/Shared/ParameterBindingMethodCache.cs b/src/Shared/ParameterBindingMethodCache.cs
index 99418b00261..eb5750a8fd3 100644
--- a/src/Shared/ParameterBindingMethodCache.cs
+++ b/src/Shared/ParameterBindingMethodCache.cs
@@ -24,6 +24,7 @@ internal sealed class ParameterBindingMethodCache
 {
     private static readonly MethodInfo ConvertValueTaskMethod = typeof(ParameterBindingMethodCache).GetMethod(nameof(ConvertValueTask), BindingFlags.NonPublic | BindingFlags.Static)!;
     private static readonly MethodInfo ConvertValueTaskOfNullableResultMethod = typeof(ParameterBindingMethodCache).GetMethod(nameof(ConvertValueTaskOfNullableResult), BindingFlags.NonPublic | BindingFlags.Static)!;
+    private static readonly MethodInfo BindAsyncMethod = typeof(ParameterBindingMethodCache).GetMethod(nameof(BindAsync), BindingFlags.NonPublic | BindingFlags.Static)!;
 
     internal static readonly ParameterExpression TempSourceStringExpr = Expression.Variable(typeof(string), "tempSourceString");
     internal static readonly ParameterExpression HttpContextExpr = Expression.Parameter(typeof(HttpContext), "httpContext");
@@ -185,12 +186,18 @@ internal sealed class ParameterBindingMethodCache
         (Func<ParameterInfo, Expression>?, int) Finder(Type nonNullableParameterType)
         {
             var hasParameterInfo = true;
-            // There should only be one BindAsync method with these parameters since C# does not allow overloading on return type.
-            var methodInfo = GetStaticMethodFromHierarchy(nonNullableParameterType, "BindAsync", new[] { typeof(HttpContext), typeof(ParameterInfo) }, ValidateReturnType);
+            var methodInfo = GetIBindableFromHttpContextMethod(nonNullableParameterType);
+
             if (methodInfo is null)
             {
-                hasParameterInfo = false;
-                methodInfo = GetStaticMethodFromHierarchy(nonNullableParameterType, "BindAsync", new[] { typeof(HttpContext) }, ValidateReturnType);
+                // There should only be one BindAsync method with these parameters since C# does not allow overloading on return type.
+                methodInfo = GetStaticMethodFromHierarchy(nonNullableParameterType, "BindAsync", new[] { typeof(HttpContext), typeof(ParameterInfo) }, ValidateReturnType);
+
+                if (methodInfo is null)
+                {
+                    hasParameterInfo = false;
+                    methodInfo = GetStaticMethodFromHierarchy(nonNullableParameterType, "BindAsync", new[] { typeof(HttpContext) }, ValidateReturnType);
+                }
             }
 
             // We're looking for a method with the following signatures:
@@ -373,6 +380,26 @@ internal sealed class ParameterBindingMethodCache
         throw new InvalidOperationException($"No public parameterless constructor found for type '{TypeNameHelper.GetTypeDisplayName(type, fullName: false)}'.");
     }
 
+    private static MethodInfo? GetIBindableFromHttpContextMethod(Type type)
+    {
+        // Check if parameter is bindable via static abstract method on IBindableFromHttpContext<TSelf>
+        foreach (var i in type.GetInterfaces())
+        {
+            if (i.IsGenericType && i.GetGenericTypeDefinition() == typeof(IBindableFromHttpContext<>) && i.GetGenericArguments()[0] == type)
+            {
+                return BindAsyncMethod.MakeGenericMethod(type);
+            }
+        }
+
+        return null;
+    }
+
+    private static ValueTask<TValue?> BindAsync<TValue>(HttpContext httpContext, ParameterInfo parameter)
+        where TValue : class?, IBindableFromHttpContext<TValue>
+    {
+        return TValue.BindAsync(httpContext, parameter);
+    }
+        
     private MethodInfo? GetStaticMethodFromHierarchy(Type type, string name, Type[] parameterTypes, Func<MethodInfo, bool> validateReturnType)
     {
         bool IsMatch(MethodInfo? method) => method is not null && !method.IsAbstract && validateReturnType(method);
-- 
GitLab