From 375f5a4ea060b54aec6352339237b0a12caccedb Mon Sep 17 00:00:00 2001
From: Stephen Halter <halter73@gmail.com>
Date: Fri, 17 Jun 2022 00:04:01 -0700
Subject: [PATCH] Fix nullability of TryUpdateModelAsync (#41959)

---
 src/Mvc/Mvc.Core/src/ControllerBase.cs        |  4 +-
 .../DefaultPropertyFilterProvider.cs          |  4 +-
 .../src/ModelBinding/ModelBindingHelper.cs    |  6 +-
 src/Mvc/Mvc.Core/src/PublicAPI.Unshipped.txt  |  6 ++
 src/Mvc/Mvc.Core/test/ControllerBaseTest.cs   | 63 ++++++++++++++++++-
 src/Mvc/Mvc.RazorPages/src/PageBase.cs        |  4 +-
 src/Mvc/Mvc.RazorPages/src/PageModel.cs       |  4 +-
 .../src/PublicAPI.Unshipped.txt               |  8 +++
 8 files changed, 86 insertions(+), 13 deletions(-)

diff --git a/src/Mvc/Mvc.Core/src/ControllerBase.cs b/src/Mvc/Mvc.Core/src/ControllerBase.cs
index d1c9a840d75..5fddac183df 100644
--- a/src/Mvc/Mvc.Core/src/ControllerBase.cs
+++ b/src/Mvc/Mvc.Core/src/ControllerBase.cs
@@ -2619,7 +2619,7 @@ public abstract class ControllerBase
     public async Task<bool> TryUpdateModelAsync<TModel>(
         TModel model,
         string prefix,
-        params Expression<Func<TModel, object>>[] includeExpressions)
+        params Expression<Func<TModel, object?>>[] includeExpressions)
        where TModel : class
     {
         if (model == null)
@@ -2710,7 +2710,7 @@ public abstract class ControllerBase
         TModel model,
         string prefix,
         IValueProvider valueProvider,
-        params Expression<Func<TModel, object>>[] includeExpressions)
+        params Expression<Func<TModel, object?>>[] includeExpressions)
        where TModel : class
     {
         if (model == null)
diff --git a/src/Mvc/Mvc.Core/src/ModelBinding/DefaultPropertyFilterProvider.cs b/src/Mvc/Mvc.Core/src/ModelBinding/DefaultPropertyFilterProvider.cs
index 95324dc1809..d30748e7de8 100644
--- a/src/Mvc/Mvc.Core/src/ModelBinding/DefaultPropertyFilterProvider.cs
+++ b/src/Mvc/Mvc.Core/src/ModelBinding/DefaultPropertyFilterProvider.cs
@@ -27,7 +27,7 @@ public class DefaultPropertyFilterProvider<TModel> : IPropertyFilterProvider
     /// Expressions which can be used to generate property filter which can filter model
     /// properties.
     /// </summary>
-    public virtual IEnumerable<Expression<Func<TModel, object>>>? PropertyIncludeExpressions => null;
+    public virtual IEnumerable<Expression<Func<TModel, object?>>>? PropertyIncludeExpressions => null;
 
     /// <inheritdoc />
     public virtual Func<ModelMetadata, bool> PropertyFilter
@@ -45,7 +45,7 @@ public class DefaultPropertyFilterProvider<TModel> : IPropertyFilterProvider
     }
 
     private static Func<ModelMetadata, bool> GetPropertyFilterFromExpression(
-        IEnumerable<Expression<Func<TModel, object>>> includeExpressions)
+        IEnumerable<Expression<Func<TModel, object?>>> includeExpressions)
     {
         var expression = ModelBindingHelper.GetPropertyFilterExpression(includeExpressions.ToArray());
         return expression.Compile();
diff --git a/src/Mvc/Mvc.Core/src/ModelBinding/ModelBindingHelper.cs b/src/Mvc/Mvc.Core/src/ModelBinding/ModelBindingHelper.cs
index 236a243d7b8..34dc03f9c3b 100644
--- a/src/Mvc/Mvc.Core/src/ModelBinding/ModelBindingHelper.cs
+++ b/src/Mvc/Mvc.Core/src/ModelBinding/ModelBindingHelper.cs
@@ -82,7 +82,7 @@ internal static class ModelBindingHelper
         IModelBinderFactory modelBinderFactory,
         IValueProvider valueProvider,
         IObjectModelValidator objectModelValidator,
-        params Expression<Func<TModel, object>>[] includeExpressions)
+        params Expression<Func<TModel, object?>>[] includeExpressions)
        where TModel : class
     {
         if (includeExpressions == null)
@@ -363,7 +363,7 @@ internal static class ModelBindingHelper
     /// <param name="expressions">Expressions identifying the properties to allow for binding.</param>
     /// <returns>An expression which can be used with <see cref="IPropertyFilterProvider"/>.</returns>
     public static Expression<Func<ModelMetadata, bool>> GetPropertyFilterExpression<TModel>(
-        Expression<Func<TModel, object>>[] expressions)
+        Expression<Func<TModel, object?>>[] expressions)
     {
         if (expressions.Length == 0)
         {
@@ -385,7 +385,7 @@ internal static class ModelBindingHelper
     }
 
     private static Expression<Func<ModelMetadata, bool>> GetPredicateExpression<TModel>(
-        Expression<Func<TModel, object>> expression)
+        Expression<Func<TModel, object?>> expression)
     {
         var propertyName = GetPropertyName(expression.Body);
 
diff --git a/src/Mvc/Mvc.Core/src/PublicAPI.Unshipped.txt b/src/Mvc/Mvc.Core/src/PublicAPI.Unshipped.txt
index 820ab9896cb..dc8c6f2136b 100644
--- a/src/Mvc/Mvc.Core/src/PublicAPI.Unshipped.txt
+++ b/src/Mvc/Mvc.Core/src/PublicAPI.Unshipped.txt
@@ -3,6 +3,10 @@
 Microsoft.AspNetCore.Mvc.ApiBehaviorOptions.DisableImplicitFromServicesParameters.get -> bool
 Microsoft.AspNetCore.Mvc.ApiBehaviorOptions.DisableImplicitFromServicesParameters.set -> void
 Microsoft.AspNetCore.Mvc.ApplicationModels.InferParameterBindingInfoConvention.InferParameterBindingInfoConvention(Microsoft.AspNetCore.Mvc.ModelBinding.IModelMetadataProvider! modelMetadataProvider, Microsoft.Extensions.DependencyInjection.IServiceProviderIsService! serviceProviderIsService) -> void
+*REMOVED*Microsoft.AspNetCore.Mvc.ControllerBase.TryUpdateModelAsync<TModel>(TModel! model, string! prefix, Microsoft.AspNetCore.Mvc.ModelBinding.IValueProvider! valueProvider, params System.Linq.Expressions.Expression<System.Func<TModel!, object!>!>![]! includeExpressions) -> System.Threading.Tasks.Task<bool>!
+*REMOVED*Microsoft.AspNetCore.Mvc.ControllerBase.TryUpdateModelAsync<TModel>(TModel! model, string! prefix, params System.Linq.Expressions.Expression<System.Func<TModel!, object!>!>![]! includeExpressions) -> System.Threading.Tasks.Task<bool>!
+Microsoft.AspNetCore.Mvc.ControllerBase.TryUpdateModelAsync<TModel>(TModel! model, string! prefix, Microsoft.AspNetCore.Mvc.ModelBinding.IValueProvider! valueProvider, params System.Linq.Expressions.Expression<System.Func<TModel!, object?>!>![]! includeExpressions) -> System.Threading.Tasks.Task<bool>!
+Microsoft.AspNetCore.Mvc.ControllerBase.TryUpdateModelAsync<TModel>(TModel! model, string! prefix, params System.Linq.Expressions.Expression<System.Func<TModel!, object?>!>![]! includeExpressions) -> System.Threading.Tasks.Task<bool>!
 Microsoft.AspNetCore.Mvc.ModelBinding.Binders.TryParseModelBinderProvider
 Microsoft.AspNetCore.Mvc.ModelBinding.Binders.TryParseModelBinderProvider.GetBinder(Microsoft.AspNetCore.Mvc.ModelBinding.ModelBinderProviderContext! context) -> Microsoft.AspNetCore.Mvc.ModelBinding.IModelBinder?
 Microsoft.AspNetCore.Mvc.ModelBinding.Binders.TryParseModelBinderProvider.TryParseModelBinderProvider() -> void
@@ -15,3 +19,5 @@ Microsoft.AspNetCore.Mvc.ModelBinding.Metadata.ValidationMetadata.ValidationMode
 Microsoft.AspNetCore.Mvc.ModelBinding.Metadata.ValidationMetadata.ValidationModelName.set -> void
 virtual Microsoft.AspNetCore.Mvc.Infrastructure.ConfigureCompatibilityOptions<TOptions>.PostConfigure(string? name, TOptions! options) -> void
 static Microsoft.AspNetCore.Mvc.ControllerBase.Empty.get -> Microsoft.AspNetCore.Mvc.EmptyResult!
+*REMOVED*virtual Microsoft.AspNetCore.Mvc.ModelBinding.DefaultPropertyFilterProvider<TModel>.PropertyIncludeExpressions.get -> System.Collections.Generic.IEnumerable<System.Linq.Expressions.Expression<System.Func<TModel!, object!>!>!>?
+virtual Microsoft.AspNetCore.Mvc.ModelBinding.DefaultPropertyFilterProvider<TModel>.PropertyIncludeExpressions.get -> System.Collections.Generic.IEnumerable<System.Linq.Expressions.Expression<System.Func<TModel!, object?>!>!>?
diff --git a/src/Mvc/Mvc.Core/test/ControllerBaseTest.cs b/src/Mvc/Mvc.Core/test/ControllerBaseTest.cs
index deeda17dd25..3b6020a52dc 100644
--- a/src/Mvc/Mvc.Core/test/ControllerBaseTest.cs
+++ b/src/Mvc/Mvc.Core/test/ControllerBaseTest.cs
@@ -2728,8 +2728,7 @@ public class ControllerBaseTest
     [Theory]
     [InlineData("")]
     [InlineData("prefix")]
-    public async Task
-        TryUpdateModel_IncludeExpressionWithValueProviderOverload_UsesPassedArguments(string prefix)
+    public async Task TryUpdateModel_IncludeExpressionWithValueProviderOverload_UsesPassedArguments(string prefix)
     {
         // Arrange
         var valueProvider = new Mock<IValueProvider>();
@@ -2758,6 +2757,57 @@ public class ControllerBaseTest
         Assert.NotEqual(0, binder.BindModelCount);
     }
 
+#nullable enable
+    [Fact]
+    public async Task TryUpdateModel_SupportsNullableExpressions()
+    {
+        // Arrange
+        var valueProvider = new Mock<IValueProvider>();
+        valueProvider.Setup(v => v.ContainsPrefix(""))
+            .Returns(true);
+
+        StubModelBinder CreateBinder() => new StubModelBinder(context =>
+        {
+            Assert.Same(
+                valueProvider.Object,
+                Assert.IsType<CompositeValueProvider>(context.ValueProvider)[0]);
+
+            Assert.NotNull(context.PropertyFilter);
+
+            bool InvokePropertyFilter(string propertyName)
+            {
+                var modelMetadata = context.ModelMetadata.Properties[propertyName];
+                Assert.NotNull(modelMetadata);
+                return context.PropertyFilter!(modelMetadata!);
+            }
+
+            Assert.True(InvokePropertyFilter("Include"));
+            Assert.False(InvokePropertyFilter("Exclude"));
+        });
+
+        var binder1 = CreateBinder();
+        var controller1 = GetController(binder1, valueProvider.Object);
+        var model1 = new MyNullableModel();
+
+        // Act
+        await controller1.TryUpdateModelAsync(model1, prefix: "", m => m.Include);
+
+        // Assert
+        Assert.NotEqual(0, binder1.BindModelCount);
+
+        // Arrange (IModelBinder overload)
+        var binder2 = CreateBinder();
+        var controller2 = GetController(binder2, valueProvider.Object);
+        var model2 = new MyNullableModel();
+
+        // Act (IModelBinder overload)
+        await controller2.TryUpdateModelAsync(model2, prefix: "", m => m.Include);
+
+        // Assert (IModelBinder overload)
+        Assert.NotEqual(0, binder2.BindModelCount);
+    }
+#nullable restore
+
     [Fact]
     public async Task TryUpdateModelNonGeneric_PropertyFilterWithValueProviderOverload_UsesPassedArguments()
     {
@@ -3114,6 +3164,15 @@ public class ControllerBaseTest
         public string Property3 { get; set; }
     }
 
+#nullable enable
+    private class MyNullableModel
+    {
+        public string? Include { get; set; }
+
+        public string? Exclude { get; set; }
+    }
+#nullable restore
+
     private class TryValidateModelModel
     {
         public int IntegerProperty { get; set; }
diff --git a/src/Mvc/Mvc.RazorPages/src/PageBase.cs b/src/Mvc/Mvc.RazorPages/src/PageBase.cs
index e6a57f73487..24c116012f1 100644
--- a/src/Mvc/Mvc.RazorPages/src/PageBase.cs
+++ b/src/Mvc/Mvc.RazorPages/src/PageBase.cs
@@ -1397,7 +1397,7 @@ public abstract class PageBase : RazorPageBase
     public async Task<bool> TryUpdateModelAsync<TModel>(
         TModel model,
         string prefix,
-        params Expression<Func<TModel, object>>[] includeExpressions)
+        params Expression<Func<TModel, object?>>[] includeExpressions)
        where TModel : class
     {
         if (model == null)
@@ -1486,7 +1486,7 @@ public abstract class PageBase : RazorPageBase
         TModel model,
         string prefix,
         IValueProvider valueProvider,
-        params Expression<Func<TModel, object>>[] includeExpressions)
+        params Expression<Func<TModel, object?>>[] includeExpressions)
        where TModel : class
     {
         if (model == null)
diff --git a/src/Mvc/Mvc.RazorPages/src/PageModel.cs b/src/Mvc/Mvc.RazorPages/src/PageModel.cs
index e9876a78a54..a7f37c85af3 100644
--- a/src/Mvc/Mvc.RazorPages/src/PageModel.cs
+++ b/src/Mvc/Mvc.RazorPages/src/PageModel.cs
@@ -286,7 +286,7 @@ public abstract class PageModel : IAsyncPageFilter, IPageFilter
     protected internal async Task<bool> TryUpdateModelAsync<TModel>(
         TModel model,
         string name,
-        params Expression<Func<TModel, object>>[] includeExpressions)
+        params Expression<Func<TModel, object?>>[] includeExpressions)
        where TModel : class
     {
         if (model == null)
@@ -375,7 +375,7 @@ public abstract class PageModel : IAsyncPageFilter, IPageFilter
         TModel model,
         string name,
         IValueProvider valueProvider,
-        params Expression<Func<TModel, object>>[] includeExpressions)
+        params Expression<Func<TModel, object?>>[] includeExpressions)
        where TModel : class
     {
         if (model == null)
diff --git a/src/Mvc/Mvc.RazorPages/src/PublicAPI.Unshipped.txt b/src/Mvc/Mvc.RazorPages/src/PublicAPI.Unshipped.txt
index 7dc5c58110b..ed7d713068c 100644
--- a/src/Mvc/Mvc.RazorPages/src/PublicAPI.Unshipped.txt
+++ b/src/Mvc/Mvc.RazorPages/src/PublicAPI.Unshipped.txt
@@ -1 +1,9 @@
 #nullable enable
+*REMOVED*Microsoft.AspNetCore.Mvc.RazorPages.PageBase.TryUpdateModelAsync<TModel>(TModel! model, string! prefix, Microsoft.AspNetCore.Mvc.ModelBinding.IValueProvider! valueProvider, params System.Linq.Expressions.Expression<System.Func<TModel!, object!>!>![]! includeExpressions) -> System.Threading.Tasks.Task<bool>!
+*REMOVED*Microsoft.AspNetCore.Mvc.RazorPages.PageBase.TryUpdateModelAsync<TModel>(TModel! model, string! prefix, params System.Linq.Expressions.Expression<System.Func<TModel!, object!>!>![]! includeExpressions) -> System.Threading.Tasks.Task<bool>!
+*REMOVED*Microsoft.AspNetCore.Mvc.RazorPages.PageModel.TryUpdateModelAsync<TModel>(TModel! model, string! name, Microsoft.AspNetCore.Mvc.ModelBinding.IValueProvider! valueProvider, params System.Linq.Expressions.Expression<System.Func<TModel!, object!>!>![]! includeExpressions) -> System.Threading.Tasks.Task<bool>!
+*REMOVED*Microsoft.AspNetCore.Mvc.RazorPages.PageModel.TryUpdateModelAsync<TModel>(TModel! model, string! name, params System.Linq.Expressions.Expression<System.Func<TModel!, object!>!>![]! includeExpressions) -> System.Threading.Tasks.Task<bool>!
+Microsoft.AspNetCore.Mvc.RazorPages.PageBase.TryUpdateModelAsync<TModel>(TModel! model, string! prefix, Microsoft.AspNetCore.Mvc.ModelBinding.IValueProvider! valueProvider, params System.Linq.Expressions.Expression<System.Func<TModel!, object?>!>![]! includeExpressions) -> System.Threading.Tasks.Task<bool>!
+Microsoft.AspNetCore.Mvc.RazorPages.PageBase.TryUpdateModelAsync<TModel>(TModel! model, string! prefix, params System.Linq.Expressions.Expression<System.Func<TModel!, object?>!>![]! includeExpressions) -> System.Threading.Tasks.Task<bool>!
+Microsoft.AspNetCore.Mvc.RazorPages.PageModel.TryUpdateModelAsync<TModel>(TModel! model, string! name, Microsoft.AspNetCore.Mvc.ModelBinding.IValueProvider! valueProvider, params System.Linq.Expressions.Expression<System.Func<TModel!, object?>!>![]! includeExpressions) -> System.Threading.Tasks.Task<bool>!
+Microsoft.AspNetCore.Mvc.RazorPages.PageModel.TryUpdateModelAsync<TModel>(TModel! model, string! name, params System.Linq.Expressions.Expression<System.Func<TModel!, object?>!>![]! includeExpressions) -> System.Threading.Tasks.Task<bool>!
-- 
GitLab