From c9eeba8ba1cbb5c6ed38764961976b916bcb0232 Mon Sep 17 00:00:00 2001
From: Aditya Mandaleeka <adityamandaleeka@users.noreply.github.com>
Date: Tue, 10 Aug 2021 10:26:59 -0700
Subject: [PATCH] Add feature and DiagnosticSource event for rejected request
 info (#34783)

Add feature and DiagnosticSource event for rejected request info.
---
 .../src/IBadRequestExceptionFeature.cs        | 16 ++++++
 .../Http.Features/src/PublicAPI.Unshipped.txt |  2 +
 .../Http/HttpProtocol.FeatureCollection.cs    |  5 ++
 .../Internal/Http/HttpProtocol.Generated.cs   | 25 ++++++++-
 .../Core/src/Internal/Http/HttpProtocol.cs    |  8 ++-
 .../Core/src/Internal/KestrelServerImpl.cs    | 19 +++++--
 .../Core/src/Internal/ServiceContext.cs       |  3 ++
 .../BadHttpRequestTests.cs                    | 54 +++++++++++++++++--
 .../HttpProtocolFeatureCollection.cs          |  4 +-
 9 files changed, 124 insertions(+), 12 deletions(-)
 create mode 100644 src/Http/Http.Features/src/IBadRequestExceptionFeature.cs

diff --git a/src/Http/Http.Features/src/IBadRequestExceptionFeature.cs b/src/Http/Http.Features/src/IBadRequestExceptionFeature.cs
new file mode 100644
index 00000000000..056696f1822
--- /dev/null
+++ b/src/Http/Http.Features/src/IBadRequestExceptionFeature.cs
@@ -0,0 +1,16 @@
+// Licensed to the .NET Foundation under one or more agreements.
+// The .NET Foundation licenses this file to you under the MIT license.
+
+namespace Microsoft.AspNetCore.Http.Features
+{
+    /// <summary>
+    /// Provides information about rejected HTTP requests.
+    /// </summary>
+    public interface IBadRequestExceptionFeature
+    {
+        /// <summary>
+        /// Synchronously retrieves the exception associated with the rejected HTTP request.
+        /// </summary>
+        Exception? Error { get; }
+    }
+}
diff --git a/src/Http/Http.Features/src/PublicAPI.Unshipped.txt b/src/Http/Http.Features/src/PublicAPI.Unshipped.txt
index a391a2f049a..cd3928f9937 100644
--- a/src/Http/Http.Features/src/PublicAPI.Unshipped.txt
+++ b/src/Http/Http.Features/src/PublicAPI.Unshipped.txt
@@ -36,6 +36,8 @@
 *REMOVED*Microsoft.AspNetCore.Http.Features.FeatureReferences<TCache>.Revision.get -> int
 *REMOVED*static readonly Microsoft.AspNetCore.Http.Features.FeatureReference<T>.Default -> Microsoft.AspNetCore.Http.Features.FeatureReference<T>
 *REMOVED*virtual Microsoft.AspNetCore.Http.Features.FeatureCollection.Revision.get -> int
+Microsoft.AspNetCore.Http.Features.IBadRequestExceptionFeature
+Microsoft.AspNetCore.Http.Features.IBadRequestExceptionFeature.Error.get -> System.Exception?
 Microsoft.AspNetCore.Http.Features.IServerVariablesFeature.this[string! variableName].get -> string?
 Microsoft.AspNetCore.Http.IHeaderDictionary.Accept.get -> Microsoft.Extensions.Primitives.StringValues
 Microsoft.AspNetCore.Http.IHeaderDictionary.Accept.set -> void
diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs
index 7a7e716e472..40dc6d89ef1 100644
--- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs
+++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.FeatureCollection.cs
@@ -241,6 +241,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
 
         Stream IHttpResponseBodyFeature.Stream => ResponseBody;
 
+        Exception? IBadRequestExceptionFeature.Error
+        {
+            get => _requestRejectedException;
+        }
+
         void IHttpResponseFeature.OnStarting(Func<object, Task> callback, object state)
         {
             OnStarting(callback, state);
diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs
index 7dcf8216f59..b8516497ba2 100644
--- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs
+++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.Generated.cs
@@ -29,7 +29,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
                                           IHttpRequestLifetimeFeature,
                                           IHttpBodyControlFeature,
                                           IHttpMaxRequestBodySizeFeature,
-                                          IHttpRequestBodyDetectionFeature
+                                          IHttpRequestBodyDetectionFeature,
+                                          IBadRequestExceptionFeature
     {
         // Implemented features
         internal protected IHttpRequestFeature? _currentIHttpRequestFeature;
@@ -46,6 +47,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
         internal protected IHttpBodyControlFeature? _currentIHttpBodyControlFeature;
         internal protected IHttpMaxRequestBodySizeFeature? _currentIHttpMaxRequestBodySizeFeature;
         internal protected IHttpRequestBodyDetectionFeature? _currentIHttpRequestBodyDetectionFeature;
+        internal protected IBadRequestExceptionFeature? _currentIBadRequestExceptionFeature;
 
         // Other reserved feature slots
         internal protected IServiceProvidersFeature? _currentIServiceProvidersFeature;
@@ -85,6 +87,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
             _currentIHttpBodyControlFeature = this;
             _currentIHttpMaxRequestBodySizeFeature = this;
             _currentIHttpRequestBodyDetectionFeature = this;
+            _currentIBadRequestExceptionFeature = this;
 
             _currentIServiceProvidersFeature = null;
             _currentIHttpActivityFeature = null;
@@ -257,6 +260,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
                 {
                     feature = _currentIHttpWebSocketFeature;
                 }
+                else if (key == typeof(IBadRequestExceptionFeature))
+                {
+                    feature = _currentIBadRequestExceptionFeature;
+                }
                 else if (key == typeof(IHttp2StreamIdFeature))
                 {
                     feature = _currentIHttp2StreamIdFeature;
@@ -389,6 +396,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
                 {
                     _currentIHttpWebSocketFeature = (IHttpWebSocketFeature?)value;
                 }
+                else if (key == typeof(IBadRequestExceptionFeature))
+                {
+                    _currentIBadRequestExceptionFeature = (IBadRequestExceptionFeature?)value;
+                }
                 else if (key == typeof(IHttp2StreamIdFeature))
                 {
                     _currentIHttp2StreamIdFeature = (IHttp2StreamIdFeature?)value;
@@ -523,6 +534,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
             {
                 feature = Unsafe.As<IHttpWebSocketFeature?, TFeature?>(ref _currentIHttpWebSocketFeature);
             }
+            else if (typeof(TFeature) == typeof(IBadRequestExceptionFeature))
+            {
+                feature = Unsafe.As<IBadRequestExceptionFeature?, TFeature?>(ref _currentIBadRequestExceptionFeature);
+            }
             else if (typeof(TFeature) == typeof(IHttp2StreamIdFeature))
             {
                 feature = Unsafe.As<IHttp2StreamIdFeature?, TFeature?>(ref _currentIHttp2StreamIdFeature);
@@ -663,6 +678,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
             {
                 _currentIHttpWebSocketFeature = Unsafe.As<TFeature?, IHttpWebSocketFeature?>(ref feature);
             }
+            else if (typeof(TFeature) == typeof(IBadRequestExceptionFeature))
+            {
+                _currentIBadRequestExceptionFeature = Unsafe.As<TFeature?, IBadRequestExceptionFeature?>(ref feature);
+            }
             else if (typeof(TFeature) == typeof(IHttp2StreamIdFeature))
             {
                 _currentIHttp2StreamIdFeature = Unsafe.As<TFeature?, IHttp2StreamIdFeature?>(ref feature);
@@ -791,6 +810,10 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
             {
                 yield return new KeyValuePair<Type, object>(typeof(IHttpWebSocketFeature), _currentIHttpWebSocketFeature);
             }
+            if (_currentIBadRequestExceptionFeature != null)
+            {
+                yield return new KeyValuePair<Type, object>(typeof(IBadRequestExceptionFeature), _currentIBadRequestExceptionFeature);
+            }
             if (_currentIHttp2StreamIdFeature != null)
             {
                 yield return new KeyValuePair<Type, object>(typeof(IHttp2StreamIdFeature), _currentIHttp2StreamIdFeature);
diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs
index b62e2186b18..c868b3223d7 100644
--- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs
+++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs
@@ -1324,14 +1324,20 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
         public void SetBadRequestState(BadHttpRequestException ex)
         {
             Log.ConnectionBadRequest(ConnectionId, ex);
+            _requestRejectedException = ex;
 
             if (!HasResponseStarted)
             {
                 SetErrorResponseException(ex);
             }
 
+            const string badRequestEventName = "Microsoft.AspNetCore.Server.Kestrel.BadRequest";
+            if (ServiceContext.DiagnosticSource?.IsEnabled(badRequestEventName) == true)
+            {
+                ServiceContext.DiagnosticSource.Write(badRequestEventName, this);
+            }
+
             _keepAlive = false;
-            _requestRejectedException = ex;
         }
 
         public void ReportApplicationError(Exception? ex)
diff --git a/src/Servers/Kestrel/Core/src/Internal/KestrelServerImpl.cs b/src/Servers/Kestrel/Core/src/Internal/KestrelServerImpl.cs
index 140d978f6d1..01e8d6a7edf 100644
--- a/src/Servers/Kestrel/Core/src/Internal/KestrelServerImpl.cs
+++ b/src/Servers/Kestrel/Core/src/Internal/KestrelServerImpl.cs
@@ -40,7 +40,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core
             IOptions<KestrelServerOptions> options,
             IEnumerable<IConnectionListenerFactory> transportFactories,
             ILoggerFactory loggerFactory)
-            : this(transportFactories, null, CreateServiceContext(options, loggerFactory))
+            : this(transportFactories, null, CreateServiceContext(options, loggerFactory, null))
         {
         }
 
@@ -49,7 +49,17 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core
             IEnumerable<IConnectionListenerFactory> transportFactories,
             IEnumerable<IMultiplexedConnectionListenerFactory> multiplexedFactories,
             ILoggerFactory loggerFactory)
-            : this(transportFactories, multiplexedFactories, CreateServiceContext(options, loggerFactory))
+            : this(transportFactories, multiplexedFactories, CreateServiceContext(options, loggerFactory, null))
+        {
+        }
+
+        public KestrelServerImpl(
+            IOptions<KestrelServerOptions> options,
+            IEnumerable<IConnectionListenerFactory> transportFactories,
+            IEnumerable<IMultiplexedConnectionListenerFactory> multiplexedFactories,
+            ILoggerFactory loggerFactory,
+            DiagnosticSource diagnosticSource)
+            : this(transportFactories, multiplexedFactories, CreateServiceContext(options, loggerFactory, diagnosticSource))
         {
         }
 
@@ -89,7 +99,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core
             HttpCharacters.Initialize();
         }
 
-        private static ServiceContext CreateServiceContext(IOptions<KestrelServerOptions> options, ILoggerFactory loggerFactory)
+        private static ServiceContext CreateServiceContext(IOptions<KestrelServerOptions> options, ILoggerFactory loggerFactory, DiagnosticSource? diagnosticSource)
         {
             if (options == null)
             {
@@ -124,7 +134,8 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core
                 DateHeaderValueManager = dateHeaderValueManager,
                 ConnectionManager = connectionManager,
                 Heartbeat = heartbeat,
-                ServerOptions = serverOptions
+                ServerOptions = serverOptions,
+                DiagnosticSource = diagnosticSource
             };
         }
 
diff --git a/src/Servers/Kestrel/Core/src/Internal/ServiceContext.cs b/src/Servers/Kestrel/Core/src/Internal/ServiceContext.cs
index 363b65080a3..6ceddc9c656 100644
--- a/src/Servers/Kestrel/Core/src/Internal/ServiceContext.cs
+++ b/src/Servers/Kestrel/Core/src/Internal/ServiceContext.cs
@@ -1,6 +1,7 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
+using System.Diagnostics;
 using System.IO.Pipelines;
 using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http;
 using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
@@ -27,5 +28,7 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal
         public Heartbeat Heartbeat { get; set; } = default!;
 
         public KestrelServerOptions ServerOptions { get; set; } = default!;
+
+        public DiagnosticSource? DiagnosticSource { get; set; }
     }
 }
diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/BadHttpRequestTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/BadHttpRequestTests.cs
index cef2434fdf3..fffb7b89eb1 100644
--- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/BadHttpRequestTests.cs
+++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/BadHttpRequestTests.cs
@@ -1,15 +1,13 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
-using System.Collections.Generic;
-using System.Linq;
-using System.Threading.Tasks;
+using System.Diagnostics;
+using Microsoft.AspNetCore.Http.Features;
 using Microsoft.AspNetCore.Server.Kestrel.Core;
 using Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure;
 using Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests.TestTransport;
 using Microsoft.AspNetCore.Testing;
 using Microsoft.Extensions.Logging;
-using Microsoft.Extensions.Logging.Testing;
 using Moq;
 using Xunit;
 using BadHttpRequestException = Microsoft.AspNetCore.Http.BadHttpRequestException;
@@ -195,6 +193,33 @@ namespace Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests
             }
         }
 
+        private class BadRequestEventListener : IObserver<KeyValuePair<string, object>>, IDisposable
+        {
+            private IDisposable _subscription;
+            private Action<KeyValuePair<string, object>> _callback;
+
+            public bool EventFired { get; set; }
+
+            public BadRequestEventListener(DiagnosticListener diagnosticListener, Action<KeyValuePair<string, object>> callback)
+            {
+                _subscription = diagnosticListener.Subscribe(this, IsEnabled);
+                _callback = callback;
+            }
+            private static readonly Predicate<string> IsEnabled = (provider) => provider switch
+            {
+                "Microsoft.AspNetCore.Server.Kestrel.BadRequest" => true,
+                _ => false
+            };
+            public void OnNext(KeyValuePair<string, object> pair)
+            {
+                EventFired = true;
+                _callback(pair);
+            }
+            public void OnError(Exception error) { }
+            public void OnCompleted() { }
+            public virtual void Dispose() => _subscription.Dispose();
+        }
+
         private async Task TestBadRequest(string request, string expectedResponseStatusCode, string expectedExceptionMessage, string expectedAllowHeader = null)
         {
             BadHttpRequestException loggedException = null;
@@ -207,7 +232,21 @@ namespace Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests
                 .Setup(trace => trace.ConnectionBadRequest(It.IsAny<string>(), It.IsAny<BadHttpRequestException>()))
                 .Callback<string, BadHttpRequestException>((connectionId, exception) => loggedException = exception);
 
-            await using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(LoggerFactory, mockKestrelTrace.Object)))
+            // Set up a listener to catch the BadRequest event
+            var diagListener = new DiagnosticListener("BadRequestTestsDiagListener");
+            string eventProviderName = "";
+            string exceptionString = "";
+            var badRequestEventListener = new BadRequestEventListener(diagListener, (pair) => {
+                eventProviderName = pair.Key;
+                var featureCollection = pair.Value as IFeatureCollection;
+                if (featureCollection is not null)
+                {
+                    var badRequestFeature = featureCollection.Get<IBadRequestExceptionFeature>();
+                    exceptionString = badRequestFeature.Error.ToString();
+                }
+            });
+
+            await using (var server = new TestServer(context => Task.CompletedTask, new TestServiceContext(LoggerFactory, mockKestrelTrace.Object) { DiagnosticSource = diagListener }))
             {
                 using (var connection = server.CreateConnection())
                 {
@@ -218,6 +257,11 @@ namespace Microsoft.AspNetCore.Server.Kestrel.InMemory.FunctionalTests
 
             mockKestrelTrace.Verify(trace => trace.ConnectionBadRequest(It.IsAny<string>(), It.IsAny<BadHttpRequestException>()));
             Assert.Equal(expectedExceptionMessage, loggedException.Message);
+
+            // Verify DiagnosticSource event for bad request
+            Assert.True(badRequestEventListener.EventFired);
+            Assert.Equal("Microsoft.AspNetCore.Server.Kestrel.BadRequest", eventProviderName);
+            Assert.Contains(expectedExceptionMessage, exceptionString);
         }
 
         private async Task ReceiveBadRequestResponse(InMemoryConnection connection, string expectedResponseStatusCode, string expectedDateHeaderValue, string expectedAllowHeader = null)
diff --git a/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs b/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs
index 9eca307738b..c048ebf3442 100644
--- a/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs
+++ b/src/Servers/Kestrel/tools/CodeGenerator/HttpProtocolFeatureCollection.cs
@@ -39,7 +39,8 @@ namespace CodeGenerator
                 "IHttpResponseTrailersFeature",
                 "ITlsConnectionFeature",
                 "IHttpUpgradeFeature",
-                "IHttpWebSocketFeature"
+                "IHttpWebSocketFeature",
+                "IBadRequestExceptionFeature"
             };
             var maybeFeatures = new[]
             {
@@ -78,6 +79,7 @@ namespace CodeGenerator
                 "IHttpBodyControlFeature",
                 "IHttpMaxRequestBodySizeFeature",
                 "IHttpRequestBodyDetectionFeature",
+                "IBadRequestExceptionFeature"
             };
 
             var usings = $@"
-- 
GitLab