From 123bd0611ac6548bc4e483f8fb6af787dd45d4f7 Mon Sep 17 00:00:00 2001
From: David Fowler <davidfowl@gmail.com>
Date: Wed, 11 Aug 2021 16:01:00 -0700
Subject: [PATCH] Set IsRequired on ApiDescriptions for endpoints (#35233)

* Set isRequired on ApiDescriptions for endpoints
- Use the same logic we have in RequestDelegateFactory.Create to
determine if a method parameter is required or not. We then set the
IsRequired property on the ApiParameterDesciption.
---
 .../EndpointMetadataApiDescriptionProvider.cs | 28 ++++---
 ...pointMetadataApiDescriptionProviderTest.cs | 77 ++++++++++++++++---
 ...oft.AspNetCore.Mvc.ApiExplorer.Test.csproj |  1 +
 3 files changed, 83 insertions(+), 23 deletions(-)

diff --git a/src/Mvc/Mvc.ApiExplorer/src/EndpointMetadataApiDescriptionProvider.cs b/src/Mvc/Mvc.ApiExplorer/src/EndpointMetadataApiDescriptionProvider.cs
index c659bf540f5..fc3f69701a1 100644
--- a/src/Mvc/Mvc.ApiExplorer/src/EndpointMetadataApiDescriptionProvider.cs
+++ b/src/Mvc/Mvc.ApiExplorer/src/EndpointMetadataApiDescriptionProvider.cs
@@ -28,6 +28,7 @@ namespace Microsoft.AspNetCore.Mvc.ApiExplorer
         private readonly IHostEnvironment _environment;
         private readonly IServiceProviderIsService? _serviceProviderIsService;
         private readonly TryParseMethodCache TryParseMethodCache = new();
+        private readonly NullabilityInfoContext NullabilityContext = new();
 
         // Executes before MVC's DefaultApiDescriptionProvider and GrpcHttpApiDescriptionProvider for no particular reason.
         public int Order => -1100;
@@ -132,7 +133,7 @@ namespace Microsoft.AspNetCore.Mvc.ApiExplorer
 
         private ApiParameterDescription? CreateApiParameterDescription(ParameterInfo parameter, RoutePattern pattern)
         {
-            var (source, name) = GetBindingSourceAndName(parameter, pattern);
+            var (source, name, allowEmpty) = GetBindingSourceAndName(parameter, pattern);
 
             // Services are ignored because they are not request parameters.
             if (source == BindingSource.Services)
@@ -140,6 +141,10 @@ namespace Microsoft.AspNetCore.Mvc.ApiExplorer
                 return null;
             }
 
+            // Determine the "requiredness" based on nullability, default value or if allowEmpty is set
+            var nullability = NullabilityContext.Create(parameter);
+            var isOptional = parameter.HasDefaultValue || nullability.ReadState == NullabilityState.Nullable || allowEmpty;
+
             return new ApiParameterDescription
             {
                 Name = name,
@@ -147,30 +152,31 @@ namespace Microsoft.AspNetCore.Mvc.ApiExplorer
                 Source = source,
                 DefaultValue = parameter.DefaultValue,
                 Type = parameter.ParameterType,
+                IsRequired = !isOptional
             };
         }
 
         // TODO: Share more of this logic with RequestDelegateFactory.CreateArgument(...) using RequestDelegateFactoryUtilities
         // which is shared source.
-        private (BindingSource, string) GetBindingSourceAndName(ParameterInfo parameter, RoutePattern pattern)
+        private (BindingSource, string, bool) GetBindingSourceAndName(ParameterInfo parameter, RoutePattern pattern)
         {
             var attributes = parameter.GetCustomAttributes();
 
             if (attributes.OfType<IFromRouteMetadata>().FirstOrDefault() is { } routeAttribute)
             {
-                return (BindingSource.Path, routeAttribute.Name ?? parameter.Name ?? string.Empty);
+                return (BindingSource.Path, routeAttribute.Name ?? parameter.Name ?? string.Empty, false);
             }
             else if (attributes.OfType<IFromQueryMetadata>().FirstOrDefault() is { } queryAttribute)
             {
-                return (BindingSource.Query, queryAttribute.Name ?? parameter.Name ?? string.Empty);
+                return (BindingSource.Query, queryAttribute.Name ?? parameter.Name ?? string.Empty, false);
             }
             else if (attributes.OfType<IFromHeaderMetadata>().FirstOrDefault() is { } headerAttribute)
             {
-                return (BindingSource.Header, headerAttribute.Name ?? parameter.Name ?? string.Empty);
+                return (BindingSource.Header, headerAttribute.Name ?? parameter.Name ?? string.Empty, false);
             }
-            else if (parameter.CustomAttributes.Any(a => typeof(IFromBodyMetadata).IsAssignableFrom(a.AttributeType)))
+            else if (attributes.OfType<IFromBodyMetadata>().FirstOrDefault() is { } fromBodyAttribute)
             {
-                return (BindingSource.Body, parameter.Name ?? string.Empty);
+                return (BindingSource.Body, parameter.Name ?? string.Empty, fromBodyAttribute.AllowEmpty);
             }
             else if (parameter.CustomAttributes.Any(a => typeof(IFromServiceMetadata).IsAssignableFrom(a.AttributeType)) ||
                      parameter.ParameterType == typeof(HttpContext) ||
@@ -180,23 +186,23 @@ namespace Microsoft.AspNetCore.Mvc.ApiExplorer
                      parameter.ParameterType == typeof(CancellationToken) ||
                      _serviceProviderIsService?.IsService(parameter.ParameterType) == true)
             {
-                return (BindingSource.Services, parameter.Name ?? string.Empty);
+                return (BindingSource.Services, parameter.Name ?? string.Empty, false);
             }
             else if (parameter.ParameterType == typeof(string) || TryParseMethodCache.HasTryParseMethod(parameter))
             {
                 // Path vs query cannot be determined by RequestDelegateFactory at startup currently because of the layering, but can be done here.
                 if (parameter.Name is { } name && pattern.GetParameter(name) is not null)
                 {
-                    return (BindingSource.Path, name);
+                    return (BindingSource.Path, name, false);
                 }
                 else
                 {
-                    return (BindingSource.Query, parameter.Name ?? string.Empty);
+                    return (BindingSource.Query, parameter.Name ?? string.Empty, false);
                 }
             }
             else
             {
-                return (BindingSource.Body, parameter.Name ?? string.Empty);
+                return (BindingSource.Body, parameter.Name ?? string.Empty, false);
             }
         }
 
diff --git a/src/Mvc/Mvc.ApiExplorer/test/EndpointMetadataApiDescriptionProviderTest.cs b/src/Mvc/Mvc.ApiExplorer/test/EndpointMetadataApiDescriptionProviderTest.cs
index e390d4c85ef..3452974d8a2 100644
--- a/src/Mvc/Mvc.ApiExplorer/test/EndpointMetadataApiDescriptionProviderTest.cs
+++ b/src/Mvc/Mvc.ApiExplorer/test/EndpointMetadataApiDescriptionProviderTest.cs
@@ -76,20 +76,22 @@ namespace Microsoft.AspNetCore.Mvc.ApiExplorer
         [Fact]
         public void AddsRequestFormatFromMetadata()
         {
-            static void AssertustomRequestFormat(ApiDescription apiDescription)
+            static void AssertCustomRequestFormat(ApiDescription apiDescription)
             {
                 var requestFormat = Assert.Single(apiDescription.SupportedRequestFormats);
                 Assert.Equal("application/custom", requestFormat.MediaType);
                 Assert.Null(requestFormat.Formatter);
             }
 
-            AssertustomRequestFormat(GetApiDescription(
+            AssertCustomRequestFormat(GetApiDescription(
                 [Consumes("application/custom")]
-                (InferredJsonClass fromBody) => { }));
+            (InferredJsonClass fromBody) =>
+                { }));
 
-            AssertustomRequestFormat(GetApiDescription(
+            AssertCustomRequestFormat(GetApiDescription(
                 [Consumes("application/custom")]
-                ([FromBody] int fromBody) => { }));
+            ([FromBody] int fromBody) =>
+                { }));
         }
 
         [Fact]
@@ -97,7 +99,8 @@ namespace Microsoft.AspNetCore.Mvc.ApiExplorer
         {
             var apiDescription = GetApiDescription(
                 [Consumes("application/custom0", "application/custom1")]
-                (InferredJsonClass fromBody) => { });
+            (InferredJsonClass fromBody) =>
+                { });
 
             Assert.Equal(2, apiDescription.SupportedRequestFormats.Count);
 
@@ -167,8 +170,8 @@ namespace Microsoft.AspNetCore.Mvc.ApiExplorer
         {
             var apiDescription = GetApiDescription(
                 [ProducesResponseType(typeof(TimeSpan), StatusCodes.Status201Created)]
-                [Produces("application/custom")]
-                () => new InferredJsonClass());
+            [Produces("application/custom")]
+            () => new InferredJsonClass());
 
             var responseType = Assert.Single(apiDescription.SupportedResponseTypes);
 
@@ -185,8 +188,8 @@ namespace Microsoft.AspNetCore.Mvc.ApiExplorer
         {
             var apiDescription = GetApiDescription(
                 [ProducesResponseType(typeof(TimeSpan), StatusCodes.Status201Created)]
-                [ProducesResponseType(StatusCodes.Status400BadRequest)]
-                () => new InferredJsonClass());
+            [ProducesResponseType(StatusCodes.Status400BadRequest)]
+            () => new InferredJsonClass());
 
             Assert.Equal(2, apiDescription.SupportedResponseTypes.Count);
 
@@ -214,8 +217,8 @@ namespace Microsoft.AspNetCore.Mvc.ApiExplorer
         {
             var apiDescription = GetApiDescription(
                 [ProducesResponseType(typeof(InferredJsonClass), StatusCodes.Status201Created)]
-                [ProducesResponseType(StatusCodes.Status400BadRequest)]
-                () => Results.Ok(new InferredJsonClass()));
+            [ProducesResponseType(StatusCodes.Status400BadRequest)]
+            () => Results.Ok(new InferredJsonClass()));
 
             Assert.Equal(2, apiDescription.SupportedResponseTypes.Count);
 
@@ -324,18 +327,68 @@ namespace Microsoft.AspNetCore.Mvc.ApiExplorer
             Assert.Equal(typeof(int), fooParam.Type);
             Assert.Equal(typeof(int), fooParam.ModelMetadata.ModelType);
             Assert.Equal(BindingSource.Path, fooParam.Source);
+            Assert.True(fooParam.IsRequired);
 
             var barParam = apiDescription.ParameterDescriptions[1];
             Assert.Equal(typeof(int), barParam.Type);
             Assert.Equal(typeof(int), barParam.ModelMetadata.ModelType);
             Assert.Equal(BindingSource.Query, barParam.Source);
+            Assert.True(barParam.IsRequired);
 
             var fromBodyParam = apiDescription.ParameterDescriptions[2];
             Assert.Equal(typeof(InferredJsonClass), fromBodyParam.Type);
             Assert.Equal(typeof(InferredJsonClass), fromBodyParam.ModelMetadata.ModelType);
             Assert.Equal(BindingSource.Body, fromBodyParam.Source);
+            Assert.True(fromBodyParam.IsRequired);
         }
 
+        [Fact]
+        public void TestParameterIsRequired()
+        {
+            var apiDescription = GetApiDescription(([FromRoute] int foo, int? bar) => { });
+            Assert.Equal(2, apiDescription.ParameterDescriptions.Count);
+
+            var fooParam = apiDescription.ParameterDescriptions[0];
+            Assert.Equal(typeof(int), fooParam.Type);
+            Assert.Equal(typeof(int), fooParam.ModelMetadata.ModelType);
+            Assert.Equal(BindingSource.Path, fooParam.Source);
+            Assert.True(fooParam.IsRequired);
+
+            var barParam = apiDescription.ParameterDescriptions[1];
+            Assert.Equal(typeof(int?), barParam.Type);
+            Assert.Equal(typeof(int?), barParam.ModelMetadata.ModelType);
+            Assert.Equal(BindingSource.Query, barParam.Source);
+            Assert.False(barParam.IsRequired);
+        }
+
+#nullable enable
+
+        [Fact]
+        public void TestIsRequiredFromBody()
+        {
+            var apiDescription0 = GetApiDescription(([FromBody(EmptyBodyBehavior = EmptyBodyBehavior.Allow)] InferredJsonClass fromBody) => { });
+            var apiDescription1 = GetApiDescription((InferredJsonClass? fromBody) => { });
+            Assert.Equal(1, apiDescription0.ParameterDescriptions.Count);
+            Assert.Equal(1, apiDescription1.ParameterDescriptions.Count);
+
+            var fromBodyParam0 = apiDescription0.ParameterDescriptions[0];
+            Assert.Equal(typeof(InferredJsonClass), fromBodyParam0.Type);
+            Assert.Equal(typeof(InferredJsonClass), fromBodyParam0.ModelMetadata.ModelType);
+            Assert.Equal(BindingSource.Body, fromBodyParam0.Source);
+            Assert.False(fromBodyParam0.IsRequired);
+
+            var fromBodyParam1 = apiDescription1.ParameterDescriptions[0];
+            Assert.Equal(typeof(InferredJsonClass), fromBodyParam1.Type);
+            Assert.Equal(typeof(InferredJsonClass), fromBodyParam1.ModelMetadata.ModelType);
+            Assert.Equal(BindingSource.Body, fromBodyParam1.Source);
+            Assert.False(fromBodyParam1.IsRequired);
+        }
+
+        // This is necessary for TestIsRequiredFromBody to pass until https://github.com/dotnet/roslyn/issues/55254 is resolved.
+        private object RandomMethod() => throw new NotImplementedException();
+
+#nullable disable
+
         [Fact]
         public void AddsDisplayNameFromRouteEndpoint()
         {
diff --git a/src/Mvc/Mvc.ApiExplorer/test/Microsoft.AspNetCore.Mvc.ApiExplorer.Test.csproj b/src/Mvc/Mvc.ApiExplorer/test/Microsoft.AspNetCore.Mvc.ApiExplorer.Test.csproj
index 8eb953ba570..282f058453d 100644
--- a/src/Mvc/Mvc.ApiExplorer/test/Microsoft.AspNetCore.Mvc.ApiExplorer.Test.csproj
+++ b/src/Mvc/Mvc.ApiExplorer/test/Microsoft.AspNetCore.Mvc.ApiExplorer.Test.csproj
@@ -3,6 +3,7 @@
   <PropertyGroup>
     <TargetFramework>$(DefaultNetCoreTargetFramework)</TargetFramework>
     <LangVersion>Preview</LangVersion>
+    <Features>$(Features.Replace('nullablePublicOnly', '')</Features>
   </PropertyGroup>
 
   <ItemGroup>
-- 
GitLab