From c36a0e07a4abb6473fa230d69eec05af8f5c871d Mon Sep 17 00:00:00 2001
From: Stephen Halter <halter73@gmail.com>
Date: Tue, 6 Sep 2022 14:12:52 -0700
Subject: [PATCH] Don't cache Endpoints if a source throws (#43729)

---
 .../src/CompositeEndpointDataSource.cs        | 14 +++++---
 .../CompositeEndpointDataSourceTest.cs        | 36 ++++++++++++++++++-
 2 files changed, 45 insertions(+), 5 deletions(-)

diff --git a/src/Http/Routing/src/CompositeEndpointDataSource.cs b/src/Http/Routing/src/CompositeEndpointDataSource.cs
index bb4250d26bb..8c1ea4c5860 100644
--- a/src/Http/Routing/src/CompositeEndpointDataSource.cs
+++ b/src/Http/Routing/src/CompositeEndpointDataSource.cs
@@ -242,8 +242,7 @@ public sealed class CompositeEndpointDataSource : EndpointDataSource, IDisposabl
     [MemberNotNull(nameof(_consumerChangeToken))]
     private void CreateChangeTokenUnsynchronized(bool collectionChanged)
     {
-        _cts = new CancellationTokenSource();
-        _consumerChangeToken = new CancellationChangeToken(_cts.Token);
+        var cts = new CancellationTokenSource();
 
         if (collectionChanged)
         {
@@ -255,17 +254,24 @@ public sealed class CompositeEndpointDataSource : EndpointDataSource, IDisposabl
                     () => HandleChange(collectionChanged: false)));
             }
         }
+
+        _cts = cts;
+        _consumerChangeToken = new CancellationChangeToken(cts.Token);
     }
 
     [MemberNotNull(nameof(_endpoints))]
     private void CreateEndpointsUnsynchronized()
     {
-        _endpoints = new List<Endpoint>();
+        var endpoints = new List<Endpoint>();
 
         foreach (var dataSource in _dataSources)
         {
-            _endpoints.AddRange(dataSource.Endpoints);
+            endpoints.AddRange(dataSource.Endpoints);
         }
+
+        // Only cache _endpoints after everything succeeds without throwing.
+        // We don't want to create a negative cache which would cause 404s when there should be 500s.
+        _endpoints = endpoints;
     }
 
     // Use private variable '_endpoints' to avoid initialization
diff --git a/src/Http/Routing/test/UnitTests/CompositeEndpointDataSourceTest.cs b/src/Http/Routing/test/UnitTests/CompositeEndpointDataSourceTest.cs
index 3cab4ecef27..4e6b5677ea5 100644
--- a/src/Http/Routing/test/UnitTests/CompositeEndpointDataSourceTest.cs
+++ b/src/Http/Routing/test/UnitTests/CompositeEndpointDataSourceTest.cs
@@ -2,7 +2,6 @@
 // The .NET Foundation licenses this file to you under the MIT license.
 
 using System.Collections.ObjectModel;
-using System.Linq;
 using Microsoft.AspNetCore.Builder;
 using Microsoft.AspNetCore.Http;
 using Microsoft.AspNetCore.Routing.Patterns;
@@ -58,6 +57,28 @@ public class CompositeEndpointDataSourceTest
         Assert.Equal(groupedEndpoints, resolvedGroupEndpoints);
     }
 
+    [Fact]
+    public void RepeatedlyThrows_WhenChildDataSourcesThrow()
+    {
+        var ex = new Exception();
+        var compositeDataSource = new CompositeEndpointDataSource(new[]
+        {
+            new EndpointThrowingDataSource(ex),
+        });
+        var groupContext = new RouteGroupContext
+        {
+            Prefix = RoutePatternFactory.Parse(""),
+            Conventions = Array.Empty<Action<EndpointBuilder>>(),
+            FinallyConventions = Array.Empty<Action<EndpointBuilder>>(),
+            ApplicationServices = new ServiceCollection().BuildServiceProvider(),
+        };
+
+        Assert.Same(ex, Assert.Throws<Exception>(() => compositeDataSource.Endpoints));
+        Assert.Same(ex, Assert.Throws<Exception>(() => compositeDataSource.Endpoints));
+        Assert.Same(ex, Assert.Throws<Exception>(() => compositeDataSource.GetGroupedEndpoints(groupContext)));
+        Assert.Same(ex, Assert.Throws<Exception>(() => compositeDataSource.GetGroupedEndpoints(groupContext)));
+    }
+
     [Fact]
     public void Endpoints_ReturnsAllEndpoints_FromMultipleDataSources()
     {
@@ -502,4 +523,17 @@ public class CompositeEndpointDataSourceTest
 
         public override IChangeToken GetChangeToken() => NullChangeToken.Singleton;
     }
+
+    private class EndpointThrowingDataSource : EndpointDataSource
+    {
+        private readonly Exception _ex;
+
+        public EndpointThrowingDataSource(Exception ex)
+        {
+            _ex = ex;
+        }
+
+        public override IReadOnlyList<Endpoint> Endpoints => throw _ex;
+        public override IChangeToken GetChangeToken() => NullChangeToken.Singleton;
+    }
 }
-- 
GitLab