From fae3dd12aeba7c9995f69bfaa1c9b74d82307ef1 Mon Sep 17 00:00:00 2001 From: Hao Kung <HaoK@users.noreply.github.com> Date: Fri, 10 Jul 2020 17:56:18 -0700 Subject: [PATCH] Switch to new host apis (#23783) * Update tests * Switch to new host apis * Update host apis * Update CookieTests.cs * Update tests * PR feedback/cleanup * More cleanup --- .../test/InMemory.Test/FunctionalTest.cs | 205 +++--- .../test/AuthenticationMiddlewareTests.cs | 61 +- .../Authentication/test/CertificateTests.cs | 220 +++--- .../Authentication/test/CookieTests.cs | 644 ++++++++++-------- .../Authentication/test/DynamicSchemeTests.cs | 90 +-- .../Authentication/test/FacebookTests.cs | 64 +- .../Authentication/test/GoogleTests.cs | 273 ++++---- .../Authentication/test/JwtBearerTests.cs | 232 ++++--- .../test/MicrosoftAccountTests.cs | 153 +++-- .../Authentication/test/OAuthTests.cs | 75 +- .../OpenIdConnectConfigurationTests.cs | 24 +- .../OpenIdConnect/OpenIdConnectEventTests.cs | 65 +- .../test/OpenIdConnect/TestServerBuilder.cs | 105 +-- .../test/RemoteAuthenticationTests.cs | 48 +- .../Authentication/test/TwitterTests.cs | 114 ++-- .../test/WsFederation/WsFederationTest.cs | 250 +++---- 16 files changed, 1477 insertions(+), 1146 deletions(-) diff --git a/src/Identity/test/InMemory.Test/FunctionalTest.cs b/src/Identity/test/InMemory.Test/FunctionalTest.cs index a2609b4a435..56a07d761ed 100644 --- a/src/Identity/test/InMemory.Test/FunctionalTest.cs +++ b/src/Identity/test/InMemory.Test/FunctionalTest.cs @@ -18,6 +18,7 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Identity.Test; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; using Microsoft.Net.Http.Headers; using Xunit; @@ -31,7 +32,7 @@ namespace Microsoft.AspNetCore.Identity.InMemory public async Task CanChangePasswordOptions() { var clock = new TestClock(); - var server = CreateServer(services => services.Configure<IdentityOptions>(options => + var server = await CreateServer(services => services.Configure<IdentityOptions>(options => { options.Password.RequireUppercase = false; options.Password.RequireNonAlphanumeric = false; @@ -49,7 +50,7 @@ namespace Microsoft.AspNetCore.Identity.InMemory public async Task CookieContainsRoleClaim() { var clock = new TestClock(); - var server = CreateServer(null, null, null, testCore: true); + var server = await CreateServer(null, null, null, testCore: true); var transaction1 = await SendAsync(server, "http://example.com/createMe"); Assert.Equal(HttpStatusCode.OK, transaction1.Response.StatusCode); @@ -70,7 +71,7 @@ namespace Microsoft.AspNetCore.Identity.InMemory public async Task CanCreateMeLoginAndCookieStopsWorkingAfterExpiration() { var clock = new TestClock(); - var server = CreateServer(services => + var server = await CreateServer(services => { services.ConfigureApplicationCookie(options => { @@ -112,7 +113,7 @@ namespace Microsoft.AspNetCore.Identity.InMemory public async Task CanCreateMeLoginAndSecurityStampExtendsExpiration(bool rememberMe) { var clock = new TestClock(); - var server = CreateServer(services => services.AddSingleton<ISystemClock>(clock)); + var server = await CreateServer(services => services.AddSingleton<ISystemClock>(clock)); var transaction1 = await SendAsync(server, "http://example.com/createMe"); Assert.Equal(HttpStatusCode.OK, transaction1.Response.StatusCode); @@ -156,7 +157,7 @@ namespace Microsoft.AspNetCore.Identity.InMemory public async Task CanAccessOldPrincipalDuringSecurityStampReplacement() { var clock = new TestClock(); - var server = CreateServer(services => + var server = await CreateServer(services => { services.Configure<SecurityStampValidatorOptions>(options => { @@ -207,7 +208,7 @@ namespace Microsoft.AspNetCore.Identity.InMemory public async Task TwoFactorRememberCookieVerification() { var clock = new TestClock(); - var server = CreateServer(services => services.AddSingleton<ISystemClock>(clock)); + var server = await CreateServer(services => services.AddSingleton<ISystemClock>(clock)); var transaction1 = await SendAsync(server, "http://example.com/createMe"); Assert.Equal(HttpStatusCode.OK, transaction1.Response.StatusCode); @@ -234,7 +235,7 @@ namespace Microsoft.AspNetCore.Identity.InMemory public async Task TwoFactorRememberCookieClearedBySecurityStampChange() { var clock = new TestClock(); - var server = CreateServer(services => services.AddSingleton<ISystemClock>(clock)); + var server = await CreateServer(services => services.AddSingleton<ISystemClock>(clock)); var transaction1 = await SendAsync(server, "http://example.com/createMe"); Assert.Equal(HttpStatusCode.OK, transaction1.Response.StatusCode); @@ -285,111 +286,115 @@ namespace Microsoft.AspNetCore.Identity.InMemory return me; } - private static TestServer CreateServer(Action<IServiceCollection> configureServices = null, Func<HttpContext, Task> testpath = null, Uri baseAddress = null, bool testCore = false) + private static async Task<TestServer> CreateServer(Action<IServiceCollection> configureServices = null, Func<HttpContext, Task> testpath = null, Uri baseAddress = null, bool testCore = false) { - var builder = new WebHostBuilder() - .Configure(app => - { - app.UseAuthentication(); - app.Use(async (context, next) => + var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.Configure(app => { - var req = context.Request; - var res = context.Response; - var userManager = context.RequestServices.GetRequiredService<UserManager<PocoUser>>(); - var roleManager = context.RequestServices.GetRequiredService<RoleManager<PocoRole>>(); - var signInManager = context.RequestServices.GetRequiredService<SignInManager<PocoUser>>(); - PathString remainder; - if (req.Path == new PathString("/normal")) - { - res.StatusCode = 200; - } - else if (req.Path == new PathString("/createMe")) + app.UseAuthentication(); + app.Use(async (context, next) => { - var user = new PocoUser("hao"); - var result = await userManager.CreateAsync(user, TestPassword); - if (result.Succeeded) + var req = context.Request; + var res = context.Response; + var userManager = context.RequestServices.GetRequiredService<UserManager<PocoUser>>(); + var roleManager = context.RequestServices.GetRequiredService<RoleManager<PocoRole>>(); + var signInManager = context.RequestServices.GetRequiredService<SignInManager<PocoUser>>(); + PathString remainder; + if (req.Path == new PathString("/normal")) { - result = await roleManager.CreateAsync(new PocoRole("role")); + res.StatusCode = 200; } - if (result.Succeeded) + else if (req.Path == new PathString("/createMe")) { - result = await userManager.AddToRoleAsync(user, "role"); + var user = new PocoUser("hao"); + var result = await userManager.CreateAsync(user, TestPassword); + if (result.Succeeded) + { + result = await roleManager.CreateAsync(new PocoRole("role")); + } + if (result.Succeeded) + { + result = await userManager.AddToRoleAsync(user, "role"); + } + res.StatusCode = result.Succeeded ? 200 : 500; } - res.StatusCode = result.Succeeded ? 200 : 500; - } - else if (req.Path == new PathString("/createSimple")) - { - var result = await userManager.CreateAsync(new PocoUser("simple"), "aaaaaa"); - res.StatusCode = result.Succeeded ? 200 : 500; - } - else if (req.Path == new PathString("/signoutEverywhere")) - { - var user = await userManager.FindByNameAsync("hao"); - var result = await userManager.UpdateSecurityStampAsync(user); - res.StatusCode = result.Succeeded ? 200 : 500; - } - else if (req.Path.StartsWithSegments(new PathString("/pwdLogin"), out remainder)) - { - var isPersistent = bool.Parse(remainder.Value.Substring(1)); - var result = await signInManager.PasswordSignInAsync("hao", TestPassword, isPersistent, false); - res.StatusCode = result.Succeeded ? 200 : 500; - } - else if (req.Path == new PathString("/twofactorRememeber")) - { - var user = await userManager.FindByNameAsync("hao"); - await signInManager.RememberTwoFactorClientAsync(user); - res.StatusCode = 200; - } - else if (req.Path == new PathString("/isTwoFactorRememebered")) - { - var user = await userManager.FindByNameAsync("hao"); - var result = await signInManager.IsTwoFactorClientRememberedAsync(user); - res.StatusCode = result ? 200 : 500; - } - else if (req.Path == new PathString("/hasTwoFactorUserId")) - { - var result = await context.AuthenticateAsync(IdentityConstants.TwoFactorUserIdScheme); - res.StatusCode = result.Succeeded ? 200 : 500; - } - else if (req.Path == new PathString("/me")) - { - await DescribeAsync(res, AuthenticateResult.Success(new AuthenticationTicket(context.User, null, "Application"))); - } - else if (req.Path.StartsWithSegments(new PathString("/me"), out remainder)) - { - var auth = await context.AuthenticateAsync(remainder.Value.Substring(1)); - await DescribeAsync(res, auth); - } - else if (req.Path == new PathString("/testpath") && testpath != null) + else if (req.Path == new PathString("/createSimple")) + { + var result = await userManager.CreateAsync(new PocoUser("simple"), "aaaaaa"); + res.StatusCode = result.Succeeded ? 200 : 500; + } + else if (req.Path == new PathString("/signoutEverywhere")) + { + var user = await userManager.FindByNameAsync("hao"); + var result = await userManager.UpdateSecurityStampAsync(user); + res.StatusCode = result.Succeeded ? 200 : 500; + } + else if (req.Path.StartsWithSegments(new PathString("/pwdLogin"), out remainder)) + { + var isPersistent = bool.Parse(remainder.Value.Substring(1)); + var result = await signInManager.PasswordSignInAsync("hao", TestPassword, isPersistent, false); + res.StatusCode = result.Succeeded ? 200 : 500; + } + else if (req.Path == new PathString("/twofactorRememeber")) + { + var user = await userManager.FindByNameAsync("hao"); + await signInManager.RememberTwoFactorClientAsync(user); + res.StatusCode = 200; + } + else if (req.Path == new PathString("/isTwoFactorRememebered")) + { + var user = await userManager.FindByNameAsync("hao"); + var result = await signInManager.IsTwoFactorClientRememberedAsync(user); + res.StatusCode = result ? 200 : 500; + } + else if (req.Path == new PathString("/hasTwoFactorUserId")) + { + var result = await context.AuthenticateAsync(IdentityConstants.TwoFactorUserIdScheme); + res.StatusCode = result.Succeeded ? 200 : 500; + } + else if (req.Path == new PathString("/me")) + { + await DescribeAsync(res, AuthenticateResult.Success(new AuthenticationTicket(context.User, null, "Application"))); + } + else if (req.Path.StartsWithSegments(new PathString("/me"), out remainder)) + { + var auth = await context.AuthenticateAsync(remainder.Value.Substring(1)); + await DescribeAsync(res, auth); + } + else if (req.Path == new PathString("/testpath") && testpath != null) + { + await testpath(context); + } + else + { + await next(); + } + }); + }) + .ConfigureServices(services => + { + if (testCore) { - await testpath(context); + services.AddIdentityCore<PocoUser>() + .AddRoles<PocoRole>() + .AddSignInManager() + .AddDefaultTokenProviders(); + services.AddAuthentication(IdentityConstants.ApplicationScheme).AddIdentityCookies(); } else { - await next(); + services.AddIdentity<PocoUser, PocoRole>().AddDefaultTokenProviders(); } - }); - }) - .ConfigureServices(services => - { - if (testCore) - { - services.AddIdentityCore<PocoUser>() - .AddRoles<PocoRole>() - .AddSignInManager() - .AddDefaultTokenProviders(); - services.AddAuthentication(IdentityConstants.ApplicationScheme).AddIdentityCookies(); - } - else - { - services.AddIdentity<PocoUser, PocoRole>().AddDefaultTokenProviders(); - } - var store = new InMemoryStore<PocoUser, PocoRole>(); - services.AddSingleton<IUserStore<PocoUser>>(store); - services.AddSingleton<IRoleStore<PocoRole>>(store); - configureServices?.Invoke(services); - }); - var server = new TestServer(builder); + var store = new InMemoryStore<PocoUser, PocoRole>(); + services.AddSingleton<IUserStore<PocoUser>>(store); + services.AddSingleton<IRoleStore<PocoRole>>(store); + configureServices?.Invoke(services); + }) + .UseTestServer()) + .Build(); + await host.StartAsync(); + var server = host.GetTestServer(); server.BaseAddress = baseAddress; return server; } diff --git a/src/Security/Authentication/test/AuthenticationMiddlewareTests.cs b/src/Security/Authentication/test/AuthenticationMiddlewareTests.cs index b09f13cab9b..232b9fed0e8 100644 --- a/src/Security/Authentication/test/AuthenticationMiddlewareTests.cs +++ b/src/Security/Authentication/test/AuthenticationMiddlewareTests.cs @@ -8,6 +8,7 @@ using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; using Xunit; namespace Microsoft.AspNetCore.Authentication @@ -17,33 +18,38 @@ namespace Microsoft.AspNetCore.Authentication [Fact] public async Task OnlyInvokesCanHandleRequestHandlers() { - var builder = new WebHostBuilder() - .Configure(app => - { - app.UseAuthentication(); - }) - .ConfigureServices(services => services.AddAuthentication(o => - { - o.AddScheme("Skip", s => - { - s.HandlerType = typeof(SkipHandler); - }); - // Won't get hit since CanHandleRequests is false - o.AddScheme("throws", s => - { - s.HandlerType = typeof(ThrowsHandler); - }); - o.AddScheme("607", s => - { - s.HandlerType = typeof(SixOhSevenHandler); - }); - // Won't get run since 607 will finish - o.AddScheme("305", s => - { - s.HandlerType = typeof(ThreeOhFiveHandler); - }); - })); - var server = new TestServer(builder); + using var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .Configure(app => + { + app.UseAuthentication(); + }) + .ConfigureServices(services => services.AddAuthentication(o => + { + o.AddScheme("Skip", s => + { + s.HandlerType = typeof(SkipHandler); + }); + // Won't get hit since CanHandleRequests is false + o.AddScheme("throws", s => + { + s.HandlerType = typeof(ThrowsHandler); + }); + o.AddScheme("607", s => + { + s.HandlerType = typeof(SixOhSevenHandler); + }); + // Won't get run since 607 will finish + o.AddScheme("305", s => + { + s.HandlerType = typeof(ThreeOhFiveHandler); + }); + }))) + .Build(); + + await host.StartAsync(); + using var server = host.GetTestServer(); var response = await server.CreateClient().GetAsync("http://example.com/"); Assert.Equal(607, (int)response.StatusCode); } @@ -191,6 +197,5 @@ namespace Microsoft.AspNetCore.Authentication throw new NotImplementedException(); } } - } } diff --git a/src/Security/Authentication/test/CertificateTests.cs b/src/Security/Authentication/test/CertificateTests.cs index ca0e92024e5..2ff5d5b25ff 100644 --- a/src/Security/Authentication/test/CertificateTests.cs +++ b/src/Security/Authentication/test/CertificateTests.cs @@ -15,6 +15,7 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.AspNetCore.Testing; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; using Xunit; namespace Microsoft.AspNetCore.Authentication.Certificate.Test @@ -44,7 +45,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test [Fact] public async Task VerifyValidSelfSignedWithClientEkuAuthenticates() { - var server = CreateServer( + using var host = await CreateHost( new CertificateAuthenticationOptions { AllowedCertificateTypes = CertificateTypes.SelfSigned, @@ -52,6 +53,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test }, Certificates.SelfSignedValidWithClientEku); + using var server = host.GetTestServer(); var response = await server.CreateClient().GetAsync("https://example.com/"); Assert.Equal(HttpStatusCode.OK, response.StatusCode); } @@ -59,7 +61,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test [Fact] public async Task VerifyValidSelfSignedWithNoEkuAuthenticates() { - var server = CreateServer( + using var host = await CreateHost( new CertificateAuthenticationOptions { AllowedCertificateTypes = CertificateTypes.SelfSigned, @@ -67,6 +69,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test }, Certificates.SelfSignedValidWithNoEku); + using var server = host.GetTestServer(); var response = await server.CreateClient().GetAsync("https://example.com/"); Assert.Equal(HttpStatusCode.OK, response.StatusCode); } @@ -74,13 +77,14 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test [Fact] public async Task VerifyValidSelfSignedWithClientEkuFailsWhenSelfSignedCertsNotAllowed() { - var server = CreateServer( + using var host = await CreateHost( new CertificateAuthenticationOptions { AllowedCertificateTypes = CertificateTypes.Chained }, Certificates.SelfSignedValidWithClientEku); + using var server = host.GetTestServer(); var response = await server.CreateClient().GetAsync("https://example.com/"); Assert.Equal(HttpStatusCode.Forbidden, response.StatusCode); } @@ -88,7 +92,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test [Fact] public async Task VerifyValidSelfSignedWithNoEkuFailsWhenSelfSignedCertsNotAllowed() { - var server = CreateServer( + using var host = await CreateHost( new CertificateAuthenticationOptions { AllowedCertificateTypes = CertificateTypes.Chained, @@ -96,6 +100,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test }, Certificates.SelfSignedValidWithNoEku); + using var server = host.GetTestServer(); var response = await server.CreateClient().GetAsync("https://example.com/"); Assert.Equal(HttpStatusCode.Forbidden, response.StatusCode); } @@ -103,7 +108,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test [Fact] public async Task VerifyValidSelfSignedWithServerFailsEvenIfSelfSignedCertsAreAllowed() { - var server = CreateServer( + using var host = await CreateHost( new CertificateAuthenticationOptions { AllowedCertificateTypes = CertificateTypes.SelfSigned, @@ -111,6 +116,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test }, Certificates.SelfSignedValidWithServerEku); + using var server = host.GetTestServer(); var response = await server.CreateClient().GetAsync("https://example.com/"); Assert.Equal(HttpStatusCode.Forbidden, response.StatusCode); } @@ -118,7 +124,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test [Fact] public async Task VerifyValidSelfSignedWithServerPassesWhenSelfSignedCertsAreAllowedAndPurposeValidationIsOff() { - var server = CreateServer( + using var host = await CreateHost( new CertificateAuthenticationOptions { AllowedCertificateTypes = CertificateTypes.SelfSigned, @@ -127,6 +133,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test }, Certificates.SelfSignedValidWithServerEku); + using var server = host.GetTestServer(); var response = await server.CreateClient().GetAsync("https://example.com/"); Assert.Equal(HttpStatusCode.OK, response.StatusCode); } @@ -134,7 +141,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test [Fact] public async Task VerifyValidSelfSignedWithServerFailsPurposeValidationIsOffButSelfSignedCertsAreNotAllowed() { - var server = CreateServer( + using var host = await CreateHost( new CertificateAuthenticationOptions { AllowedCertificateTypes = CertificateTypes.Chained, @@ -143,6 +150,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test }, Certificates.SelfSignedValidWithServerEku); + using var server = host.GetTestServer(); var response = await server.CreateClient().GetAsync("https://example.com/"); Assert.Equal(HttpStatusCode.Forbidden, response.StatusCode); } @@ -150,7 +158,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test [Fact] public async Task VerifyExpiredSelfSignedFails() { - var server = CreateServer( + using var host = await CreateHost( new CertificateAuthenticationOptions { AllowedCertificateTypes = CertificateTypes.SelfSigned, @@ -159,6 +167,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test }, Certificates.SelfSignedExpired); + using var server = host.GetTestServer(); var response = await server.CreateClient().GetAsync("https://example.com/"); Assert.Equal(HttpStatusCode.Forbidden, response.StatusCode); } @@ -166,7 +175,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test [Fact] public async Task VerifyExpiredSelfSignedPassesIfDateRangeValidationIsDisabled() { - var server = CreateServer( + using var host = await CreateHost( new CertificateAuthenticationOptions { AllowedCertificateTypes = CertificateTypes.SelfSigned, @@ -175,6 +184,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test }, Certificates.SelfSignedExpired); + using var server = host.GetTestServer(); var response = await server.CreateClient().GetAsync("https://example.com/"); Assert.Equal(HttpStatusCode.OK, response.StatusCode); } @@ -182,7 +192,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test [Fact] public async Task VerifyNotYetValidSelfSignedFails() { - var server = CreateServer( + using var host = await CreateHost( new CertificateAuthenticationOptions { AllowedCertificateTypes = CertificateTypes.SelfSigned, @@ -191,6 +201,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test }, Certificates.SelfSignedNotYetValid); + using var server = host.GetTestServer(); var response = await server.CreateClient().GetAsync("https://example.com/"); Assert.Equal(HttpStatusCode.Forbidden, response.StatusCode); } @@ -198,7 +209,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test [Fact] public async Task VerifyNotYetValidSelfSignedPassesIfDateRangeValidationIsDisabled() { - var server = CreateServer( + using var host = await CreateHost( new CertificateAuthenticationOptions { AllowedCertificateTypes = CertificateTypes.SelfSigned, @@ -207,6 +218,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test }, Certificates.SelfSignedNotYetValid); + using var server = host.GetTestServer(); var response = await server.CreateClient().GetAsync("https://example.com/"); Assert.Equal(HttpStatusCode.OK, response.StatusCode); } @@ -214,7 +226,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test [Fact] public async Task VerifyFailingInTheValidationEventReturnsForbidden() { - var server = CreateServer( + using var host = await CreateHost( new CertificateAuthenticationOptions { ValidateCertificateUse = false, @@ -222,6 +234,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test }, Certificates.SelfSignedValidWithServerEku); + using var server = host.GetTestServer(); var response = await server.CreateClient().GetAsync("https://example.com/"); Assert.Equal(HttpStatusCode.Forbidden, response.StatusCode); } @@ -229,7 +242,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test [Fact] public async Task DoingNothingInTheValidationEventReturnsOK() { - var server = CreateServer( + using var host = await CreateHost( new CertificateAuthenticationOptions { AllowedCertificateTypes = CertificateTypes.SelfSigned, @@ -238,6 +251,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test }, Certificates.SelfSignedValidWithServerEku); + using var server = host.GetTestServer(); var response = await server.CreateClient().GetAsync("https://example.com/"); Assert.Equal(HttpStatusCode.OK, response.StatusCode); } @@ -245,12 +259,13 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test [Fact] public async Task VerifyNotSendingACertificateEndsUpInForbidden() { - var server = CreateServer( + using var host = await CreateHost( new CertificateAuthenticationOptions { Events = successfulValidationEvents }); + using var server = host.GetTestServer(); var response = await server.CreateClient().GetAsync("https://example.com/"); Assert.Equal(HttpStatusCode.Forbidden, response.StatusCode); } @@ -258,12 +273,13 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test [Fact] public async Task VerifyUntrustedClientCertEndsUpInForbidden() { - var server = CreateServer( + using var host = await CreateHost( new CertificateAuthenticationOptions { Events = successfulValidationEvents }, Certificates.SignedClient); + using var server = host.GetTestServer(); var response = await server.CreateClient().GetAsync("https://example.com/"); Assert.Equal(HttpStatusCode.Forbidden, response.StatusCode); } @@ -271,7 +287,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test [Fact] public async Task VerifyClientCertWithUntrustedRootAndTrustedChainEndsUpInForbidden() { - var server = CreateServer( + using var host = await CreateHost( new CertificateAuthenticationOptions { Events = successfulValidationEvents, @@ -280,6 +296,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test RevocationMode = X509RevocationMode.NoCheck }, Certificates.SignedClient); + using var server = host.GetTestServer(); var response = await server.CreateClient().GetAsync("https://example.com/"); Assert.Equal(HttpStatusCode.Forbidden, response.StatusCode); } @@ -287,7 +304,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test [Fact] public async Task VerifyValidClientCertWithTrustedChainAuthenticates() { - var server = CreateServer( + using var host = await CreateHost( new CertificateAuthenticationOptions { Events = successfulValidationEvents, @@ -296,6 +313,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test RevocationMode = X509RevocationMode.NoCheck }, Certificates.SignedClient); + using var server = host.GetTestServer(); var response = await server.CreateClient().GetAsync("https://example.com/"); Assert.Equal(HttpStatusCode.OK, response.StatusCode); } @@ -303,7 +321,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test [Fact] public async Task VerifyHeaderIsUsedIfCertIsNotPresent() { - var server = CreateServer( + using var host = await CreateHost( new CertificateAuthenticationOptions { AllowedCertificateTypes = CertificateTypes.SelfSigned, @@ -311,6 +329,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test }, wireUpHeaderMiddleware: true); + using var server = host.GetTestServer(); var client = server.CreateClient(); client.DefaultRequestHeaders.Add("X-Client-Cert", Convert.ToBase64String(Certificates.SelfSignedValidWithNoEku.RawData)); var response = await client.GetAsync("https://example.com/"); @@ -320,13 +339,14 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test [Fact] public async Task VerifyHeaderEncodedCertFailsOnBadEncoding() { - var server = CreateServer( + using var host = await CreateHost( new CertificateAuthenticationOptions { Events = successfulValidationEvents }, wireUpHeaderMiddleware: true); + using var server = host.GetTestServer(); var client = server.CreateClient(); client.DefaultRequestHeaders.Add("X-Client-Cert", "OOPS" + Convert.ToBase64String(Certificates.SelfSignedValidWithNoEku.RawData)); var response = await client.GetAsync("https://example.com/"); @@ -336,7 +356,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test [Fact] public async Task VerifySettingTheAzureHeaderOnTheForwarderOptionsWorks() { - var server = CreateServer( + using var host = await CreateHost( new CertificateAuthenticationOptions { AllowedCertificateTypes = CertificateTypes.SelfSigned, @@ -345,6 +365,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test wireUpHeaderMiddleware: true, headerName: "X-ARR-ClientCert"); + using var server = host.GetTestServer(); var client = server.CreateClient(); client.DefaultRequestHeaders.Add("X-ARR-ClientCert", Convert.ToBase64String(Certificates.SelfSignedValidWithNoEku.RawData)); var response = await client.GetAsync("https://example.com/"); @@ -354,7 +375,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test [Fact] public async Task VerifyACustomHeaderFailsIfTheHeaderIsNotPresent() { - var server = CreateServer( + using var host = await CreateHost( new CertificateAuthenticationOptions { Events = successfulValidationEvents @@ -362,6 +383,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test wireUpHeaderMiddleware: true, headerName: "X-ARR-ClientCert"); + using var server = host.GetTestServer(); var client = server.CreateClient(); client.DefaultRequestHeaders.Add("random-Weird-header", Convert.ToBase64String(Certificates.SelfSignedValidWithNoEku.RawData)); var response = await client.GetAsync("https://example.com/"); @@ -371,13 +393,14 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test [Fact] public async Task VerifyNoEventWireupWithAValidCertificateCreatesADefaultUser() { - var server = CreateServer( + using var host = await CreateHost( new CertificateAuthenticationOptions { AllowedCertificateTypes = CertificateTypes.SelfSigned }, Certificates.SelfSignedValidWithNoEku); + using var server = host.GetTestServer(); var response = await server.CreateClient().GetAsync("https://example.com/"); Assert.Equal(HttpStatusCode.OK, response.StatusCode); @@ -481,7 +504,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test const string Expected = "John Doe"; var validationCount = 0; - var server = CreateServer( + using var host = await CreateHost( new CertificateAuthenticationOptions { AllowedCertificateTypes = CertificateTypes.SelfSigned, @@ -508,6 +531,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test }, Certificates.SelfSignedValidWithNoEku, null, null, false, "", cache); + using var server = host.GetTestServer(); var response = await server.CreateClient().GetAsync("https://example.com/"); Assert.Equal(HttpStatusCode.OK, response.StatusCode); @@ -554,7 +578,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test { const string Expected = "John Doe"; - var server = CreateServer( + using var host = await CreateHost( new CertificateAuthenticationOptions { AllowedCertificateTypes = CertificateTypes.SelfSigned, @@ -577,6 +601,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test }, Certificates.SelfSignedValidWithNoEku); + using var server = host.GetTestServer(); var response = await server.CreateClient().GetAsync("https://example.com/"); Assert.Equal(HttpStatusCode.OK, response.StatusCode); @@ -596,7 +621,7 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test Assert.Single(responseAsXml.Elements("claim")); } - private static TestServer CreateServer( + private static async Task<IHost> CreateHost( CertificateAuthenticationOptions configureOptions, X509Certificate2 clientCertificate = null, Func<HttpContext, bool> handler = null, @@ -605,92 +630,95 @@ namespace Microsoft.AspNetCore.Authentication.Certificate.Test string headerName = "", bool useCache = false) { - var builder = new WebHostBuilder() - .Configure(app => - { - app.Use((context, next) => - { - if (clientCertificate != null) + var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .Configure(app => { - context.Connection.ClientCertificate = clientCertificate; - } - return next(); - }); + app.Use((context, next) => + { + if (clientCertificate != null) + { + context.Connection.ClientCertificate = clientCertificate; + } + return next(); + }); - if (wireUpHeaderMiddleware) - { - app.UseCertificateForwarding(); - } + if (wireUpHeaderMiddleware) + { + app.UseCertificateForwarding(); + } - app.UseAuthentication(); + app.UseAuthentication(); - app.Use(async (context, next) => + app.Use(async (context, next) => + { + var request = context.Request; + var response = context.Response; + + var authenticationResult = await context.AuthenticateAsync(); + + if (authenticationResult.Succeeded) + { + response.StatusCode = (int)HttpStatusCode.OK; + response.ContentType = "text/xml"; + + await response.WriteAsync("<claims>"); + foreach (Claim claim in context.User.Claims) + { + await response.WriteAsync($"<claim Type=\"{claim.Type}\" Issuer=\"{claim.Issuer}\">{claim.Value}</claim>"); + } + await response.WriteAsync("</claims>"); + } + else + { + await context.ChallengeAsync(); + } + }); + }) + .ConfigureServices(services => { - var request = context.Request; - var response = context.Response; - - var authenticationResult = await context.AuthenticateAsync(); - - if (authenticationResult.Succeeded) + AuthenticationBuilder authBuilder; + if (configureOptions != null) { - response.StatusCode = (int)HttpStatusCode.OK; - response.ContentType = "text/xml"; - - await response.WriteAsync("<claims>"); - foreach (Claim claim in context.User.Claims) + authBuilder = services.AddAuthentication(CertificateAuthenticationDefaults.AuthenticationScheme).AddCertificate(options => { - await response.WriteAsync($"<claim Type=\"{claim.Type}\" Issuer=\"{claim.Issuer}\">{claim.Value}</claim>"); - } - await response.WriteAsync("</claims>"); + options.CustomTrustStore = configureOptions.CustomTrustStore; + options.ChainTrustValidationMode = configureOptions.ChainTrustValidationMode; + options.AllowedCertificateTypes = configureOptions.AllowedCertificateTypes; + options.Events = configureOptions.Events; + options.ValidateCertificateUse = configureOptions.ValidateCertificateUse; + options.RevocationFlag = configureOptions.RevocationFlag; + options.RevocationMode = configureOptions.RevocationMode; + options.ValidateValidityPeriod = configureOptions.ValidateValidityPeriod; + }); } else { - await context.ChallengeAsync(); + authBuilder = services.AddAuthentication(CertificateAuthenticationDefaults.AuthenticationScheme).AddCertificate(); + } + if (useCache) + { + authBuilder.AddCertificateCache(); } - }); - }) - .ConfigureServices(services => - { - AuthenticationBuilder authBuilder; - if (configureOptions != null) - { - authBuilder = services.AddAuthentication(CertificateAuthenticationDefaults.AuthenticationScheme).AddCertificate(options => - { - options.CustomTrustStore = configureOptions.CustomTrustStore; - options.ChainTrustValidationMode = configureOptions.ChainTrustValidationMode; - options.AllowedCertificateTypes = configureOptions.AllowedCertificateTypes; - options.Events = configureOptions.Events; - options.ValidateCertificateUse = configureOptions.ValidateCertificateUse; - options.RevocationFlag = configureOptions.RevocationFlag; - options.RevocationMode = configureOptions.RevocationMode; - options.ValidateValidityPeriod = configureOptions.ValidateValidityPeriod; - }); - } - else - { - authBuilder = services.AddAuthentication(CertificateAuthenticationDefaults.AuthenticationScheme).AddCertificate(); - } - if (useCache) - { - authBuilder.AddCertificateCache(); - } - if (wireUpHeaderMiddleware && !string.IsNullOrEmpty(headerName)) - { - services.AddCertificateForwarding(options => - { - options.CertificateHeader = headerName; - }); - } - }); + if (wireUpHeaderMiddleware && !string.IsNullOrEmpty(headerName)) + { + services.AddCertificateForwarding(options => + { + options.CertificateHeader = headerName; + }); + } + })) + .Build(); + + await host.StartAsync(); - var server = new TestServer(builder) - { - BaseAddress = baseAddress - }; - return server; + var server = host.GetTestServer(); + server.BaseAddress = baseAddress; + return host; } private CertificateAuthenticationEvents successfulValidationEvents = new CertificateAuthenticationEvents() diff --git a/src/Security/Authentication/test/CookieTests.cs b/src/Security/Authentication/test/CookieTests.cs index e1a3840e193..56bea6c65a0 100644 --- a/src/Security/Authentication/test/CookieTests.cs +++ b/src/Security/Authentication/test/CookieTests.cs @@ -17,6 +17,7 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.AspNetCore.Testing; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Options; using Xunit; @@ -37,7 +38,8 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task NormalRequestPassesThrough() { - var server = CreateServer(s => { }); + using var host = await CreateHost(s => { }); + using var server = host.GetTestServer(); var response = await server.CreateClient().GetAsync("http://example.com/normal"); Assert.Equal(HttpStatusCode.OK, response.StatusCode); } @@ -45,7 +47,8 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task AjaxLoginRedirectToReturnUrlTurnsInto200WithLocationHeader() { - var server = CreateServer(o => o.LoginPath = "/login"); + using var host = await CreateHost(o => o.LoginPath = "/login"); + using var server = host.GetTestServer(); var transaction = await SendAsync(server, "http://example.com/challenge?X-Requested-With=XMLHttpRequest"); Assert.Equal(HttpStatusCode.Unauthorized, transaction.Response.StatusCode); var responded = transaction.Response.Headers.GetValues("Location"); @@ -56,7 +59,8 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task AjaxForbidTurnsInto403WithLocationHeader() { - var server = CreateServer(o => o.AccessDeniedPath = "/denied"); + using var host = await CreateHost(o => o.AccessDeniedPath = "/denied"); + using var server = host.GetTestServer(); var transaction = await SendAsync(server, "http://example.com/forbid?X-Requested-With=XMLHttpRequest"); Assert.Equal(HttpStatusCode.Forbidden, transaction.Response.StatusCode); var responded = transaction.Response.Headers.GetValues("Location"); @@ -67,7 +71,8 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task AjaxLogoutRedirectToReturnUrlTurnsInto200WithLocationHeader() { - var server = CreateServer(o => o.LogoutPath = "/signout"); + using var host = await CreateHost(o => o.LogoutPath = "/signout"); + using var server = host.GetTestServer(); var transaction = await SendAsync(server, "http://example.com/signout?X-Requested-With=XMLHttpRequest&ReturnUrl=/"); Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode); var responded = transaction.Response.Headers.GetValues("Location"); @@ -78,7 +83,8 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task AjaxChallengeRedirectTurnsInto200WithLocationHeader() { - var server = CreateServer(s => { }); + using var host = await CreateHost(s => { }); + using var server = host.GetTestServer(); var transaction = await SendAsync(server, "http://example.com/challenge?X-Requested-With=XMLHttpRequest&ReturnUrl=/"); Assert.Equal(HttpStatusCode.Unauthorized, transaction.Response.StatusCode); var responded = transaction.Response.Headers.GetValues("Location"); @@ -89,7 +95,8 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task ProtectedCustomRequestShouldRedirectToCustomRedirectUri() { - var server = CreateServer(s => { }); + using var host = await CreateHost(s => { }); + using var server = host.GetTestServer(); var transaction = await SendAsync(server, "http://example.com/protected/CustomRedirect"); @@ -122,12 +129,13 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task SignInCausesDefaultCookieToBeCreated() { - var server = CreateServerWithServices(s => s.AddAuthentication().AddCookie(o => + using var host = await CreateHostWithServices(s => s.AddAuthentication().AddCookie(o => { o.LoginPath = new PathString("/login"); o.Cookie.Name = "TestCookie"; }), SignInAsAlice); + using var server = host.GetTestServer(); var transaction = await SendAsync(server, "http://example.com/testpath"); var setCookie = transaction.SetCookie; @@ -158,11 +166,12 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task SignInWrongAuthTypeThrows() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.LoginPath = new PathString("/login"); o.Cookie.Name = "TestCookie"; }, SignInAsWrong); + using var server = host.GetTestServer(); await Assert.ThrowsAsync<InvalidOperationException>(async () => await SendAsync(server, "http://example.com/testpath")); } @@ -170,12 +179,13 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task SignOutWrongAuthTypeThrows() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.LoginPath = new PathString("/login"); o.Cookie.Name = "TestCookie"; }, SignOutAsWrong); + using var server = host.GetTestServer(); await Assert.ThrowsAsync<InvalidOperationException>(async () => await SendAsync(server, "http://example.com/testpath")); } @@ -191,13 +201,14 @@ namespace Microsoft.AspNetCore.Authentication.Cookies string requestUri, bool shouldBeSecureOnly) { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.LoginPath = new PathString("/login"); o.Cookie.Name = "TestCookie"; o.Cookie.SecurePolicy = cookieSecurePolicy; }, SignInAsAlice); + using var server = host.GetTestServer(); var transaction = await SendAsync(server, requestUri); var setCookie = transaction.SetCookie; @@ -214,7 +225,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task CookieOptionsAlterSetCookieHeader() { - var server1 = CreateServer(o => + using var host = await CreateHost(o => { o.Cookie.Name = "TestCookie"; o.Cookie.Path = "/foo"; @@ -224,6 +235,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies o.Cookie.HttpOnly = true; }, SignInAsAlice, baseAddress: new Uri("http://example.com/base")); + using var server1 = host.GetTestServer(); var transaction1 = await SendAsync(server1, "http://example.com/base/testpath"); var setCookie1 = transaction1.SetCookie; @@ -235,7 +247,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies Assert.Contains(" samesite=none", setCookie1); Assert.Contains(" httponly", setCookie1); - var server2 = CreateServer(o => + using var host2 = await CreateHost(o => { o.Cookie.Name = "SecondCookie"; o.Cookie.SecurePolicy = CookieSecurePolicy.None; @@ -243,6 +255,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies o.Cookie.HttpOnly = false; }, SignInAsAlice, baseAddress: new Uri("http://example.com/base")); + using var server2 = host2.GetTestServer(); var transaction2 = await SendAsync(server2, "http://example.com/base/testpath"); var setCookie2 = transaction2.SetCookie; @@ -258,7 +271,8 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task CookieContainsIdentity() { - var server = CreateServer(o => { }, SignInAsAlice); + using var host = await CreateHost(o => { }, SignInAsAlice); + using var server = host.GetTestServer(); var transaction1 = await SendAsync(server, "http://example.com/testpath"); @@ -270,11 +284,12 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task CookieAppliesClaimsTransform() { - var server = CreateServer(o => { }, + using var host = await CreateHost(o => { }, SignInAsAlice, baseAddress: null, claimsTransform: true); + using var server = host.GetTestServer(); var transaction1 = await SendAsync(server, "http://example.com/testpath"); var transaction2 = await SendAsync(server, "http://example.com/me/Cookies", transaction1.CookieNameValue); @@ -287,12 +302,13 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task CookieStopsWorkingAfterExpiration() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ExpireTimeSpan = TimeSpan.FromMinutes(10); o.SlidingExpiration = false; }, SignInAsAlice); + using var server = host.GetTestServer(); var transaction1 = await SendAsync(server, "http://example.com/testpath"); var transaction2 = await SendAsync(server, "http://example.com/me/Cookies", transaction1.CookieNameValue); @@ -316,7 +332,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task CookieExpirationCanBeOverridenInSignin() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ExpireTimeSpan = TimeSpan.FromMinutes(10); o.SlidingExpiration = false; @@ -326,6 +342,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies new ClaimsPrincipal(new ClaimsIdentity(new GenericIdentity("Alice", "Cookies"))), new AuthenticationProperties() { ExpiresUtc = _clock.UtcNow.Add(TimeSpan.FromMinutes(5)) })); + using var server = host.GetTestServer(); var transaction1 = await SendAsync(server, "http://example.com/testpath"); var transaction2 = await SendAsync(server, "http://example.com/me/Cookies", transaction1.CookieNameValue); @@ -349,7 +366,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task ExpiredCookieWithValidatorStillExpired() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ExpireTimeSpan = TimeSpan.FromMinutes(10); o.Events = new CookieAuthenticationEvents @@ -365,6 +382,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies context.SignInAsync("Cookies", new ClaimsPrincipal(new ClaimsIdentity(new GenericIdentity("Alice", "Cookies"))))); + using var server = host.GetTestServer(); var transaction1 = await SendAsync(server, "http://example.com/testpath"); _clock.Add(TimeSpan.FromMinutes(11)); @@ -377,7 +395,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task CookieCanBeRejectedAndSignedOutByValidator() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ExpireTimeSpan = TimeSpan.FromMinutes(10); o.SlidingExpiration = false; @@ -395,6 +413,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies context.SignInAsync("Cookies", new ClaimsPrincipal(new ClaimsIdentity(new GenericIdentity("Alice", "Cookies"))))); + using var server = host.GetTestServer(); var transaction1 = await SendAsync(server, "http://example.com/testpath"); var transaction2 = await SendAsync(server, "http://example.com/me/Cookies", transaction1.CookieNameValue); @@ -405,7 +424,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task CookieNotRenewedAfterSignOut() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ExpireTimeSpan = TimeSpan.FromMinutes(10); o.SlidingExpiration = false; @@ -422,6 +441,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies context.SignInAsync("Cookies", new ClaimsPrincipal(new ClaimsIdentity(new GenericIdentity("Alice", "Cookies"))))); + using var server = host.GetTestServer(); var transaction1 = await SendAsync(server, "http://example.com/testpath"); // renews on every request @@ -441,7 +461,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task CookieCanBeRenewedByValidator() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ExpireTimeSpan = TimeSpan.FromMinutes(10); o.SlidingExpiration = false; @@ -458,6 +478,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies context.SignInAsync("Cookies", new ClaimsPrincipal(new ClaimsIdentity(new GenericIdentity("Alice", "Cookies"))))); + using var server = host.GetTestServer(); var transaction1 = await SendAsync(server, "http://example.com/testpath"); var transaction2 = await SendAsync(server, "http://example.com/me/Cookies", transaction1.CookieNameValue); @@ -486,7 +507,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task CookieCanBeReplacedByValidator() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.Events = new CookieAuthenticationEvents { @@ -502,6 +523,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies context.SignInAsync("Cookies", new ClaimsPrincipal(new ClaimsIdentity(new GenericIdentity("Alice", "Cookies"))))); + using var server = host.GetTestServer(); var transaction1 = await SendAsync(server, "http://example.com/testpath"); var transaction2 = await SendAsync(server, "http://example.com/me/Cookies", transaction1.CookieNameValue); @@ -513,7 +535,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies public async Task CookieCanBeUpdatedByValidatorDuringRefresh() { var replace = false; - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ExpireTimeSpan = TimeSpan.FromMinutes(10); o.Events = new CookieAuthenticationEvents @@ -534,6 +556,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies context.SignInAsync("Cookies", new ClaimsPrincipal(new ClaimsIdentity(new GenericIdentity("Alice", "Cookies"))))); + using var server = host.GetTestServer(); var transaction1 = await SendAsync(server, "http://example.com/testpath"); var transaction2 = await SendAsync(server, "http://example.com/me/Cookies", transaction1.CookieNameValue); @@ -560,7 +583,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task CookieCanBeRenewedByValidatorWithSlidingExpiry() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ExpireTimeSpan = TimeSpan.FromMinutes(10); o.Events = new CookieAuthenticationEvents @@ -576,6 +599,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies context.SignInAsync("Cookies", new ClaimsPrincipal(new ClaimsIdentity(new GenericIdentity("Alice", "Cookies"))))); + using var server = host.GetTestServer(); var transaction1 = await SendAsync(server, "http://example.com/testpath"); var transaction2 = await SendAsync(server, "http://example.com/me/Cookies", transaction1.CookieNameValue); @@ -604,7 +628,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task CookieCanBeRenewedByValidatorWithModifiedProperties() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ExpireTimeSpan = TimeSpan.FromMinutes(10); o.Events = new CookieAuthenticationEvents @@ -631,6 +655,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies context.SignInAsync("Cookies", new ClaimsPrincipal(new ClaimsIdentity(new GenericIdentity("Alice", "Cookies"))))); + using var server = host.GetTestServer(); var transaction1 = await SendAsync(server, "http://example.com/testpath"); var transaction2 = await SendAsync(server, "http://example.com/me/Cookies", transaction1.CookieNameValue); @@ -659,7 +684,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task CookieValidatorOnlyCalledOnce() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ExpireTimeSpan = TimeSpan.FromMinutes(10); o.SlidingExpiration = false; @@ -676,6 +701,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies context.SignInAsync("Cookies", new ClaimsPrincipal(new ClaimsIdentity(new GenericIdentity("Alice", "Cookies"))))); + using var server = host.GetTestServer(); var transaction1 = await SendAsync(server, "http://example.com/testpath"); var transaction2 = await SendAsync(server, "http://example.com/me/Cookies", transaction1.CookieNameValue); @@ -708,7 +734,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies { DateTimeOffset? lastValidateIssuedDate = null; DateTimeOffset? lastExpiresDate = null; - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ExpireTimeSpan = TimeSpan.FromMinutes(10); o.SlidingExpiration = sliding; @@ -727,6 +753,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies context.SignInAsync("Cookies", new ClaimsPrincipal(new ClaimsIdentity(new GenericIdentity("Alice", "Cookies"))))); + using var server = host.GetTestServer(); var transaction1 = await SendAsync(server, "http://example.com/testpath"); var transaction2 = await SendAsync(server, "http://example.com/me/Cookies", transaction1.CookieNameValue); @@ -758,7 +785,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task CookieExpirationCanBeOverridenInEvent() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ExpireTimeSpan = TimeSpan.FromMinutes(10); o.SlidingExpiration = false; @@ -773,6 +800,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies }, SignInAsAlice); + using var server = host.GetTestServer(); var transaction1 = await SendAsync(server, "http://example.com/testpath"); var transaction2 = await SendAsync(server, "http://example.com/me/Cookies", transaction1.CookieNameValue); @@ -795,13 +823,14 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task CookieIsRenewedWithSlidingExpiration() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ExpireTimeSpan = TimeSpan.FromMinutes(10); o.SlidingExpiration = true; }, SignInAsAlice); + using var server = host.GetTestServer(); var transaction1 = await SendAsync(server, "http://example.com/testpath"); var transaction2 = await SendAsync(server, "http://example.com/me/Cookies", transaction1.CookieNameValue); @@ -831,7 +860,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task CookieIsRenewedWithSlidingExpirationWithoutTransformations() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ExpireTimeSpan = TimeSpan.FromMinutes(10); o.SlidingExpiration = true; @@ -847,6 +876,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies SignInAsAlice, claimsTransform: true); + using var server = host.GetTestServer(); var transaction1 = await SendAsync(server, "http://example.com/testpath"); var transaction2 = await SendAsync(server, "http://example.com/me/Cookies", transaction1.CookieNameValue); @@ -876,7 +906,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task CookieUsesPathBaseByDefault() { - var server = CreateServer(o => { }, + using var host = await CreateHost(o => { }, context => { Assert.Equal(new PathString("/base"), context.Request.PathBase); @@ -885,6 +915,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies }, new Uri("http://example.com/base")); + using var server = host.GetTestServer(); var transaction1 = await SendAsync(server, "http://example.com/base/testpath"); Assert.Contains("path=/base", transaction1.SetCookie); } @@ -892,9 +923,10 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task CookieChallengeRedirectsToLoginWithoutCookie() { - var server = CreateServer(o => { }, SignInAsAlice); + using var host = await CreateHost(o => { }, SignInAsAlice); var url = "http://example.com/challenge"; + using var server = host.GetTestServer(); var transaction = await SendAsync(server, url); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); @@ -905,9 +937,10 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task CookieForbidRedirectsWithoutCookie() { - var server = CreateServer(o => { }, SignInAsAlice); + using var host = await CreateHost(o => { }, SignInAsAlice); var url = "http://example.com/forbid"; + using var server = host.GetTestServer(); var transaction = await SendAsync(server, url); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); @@ -918,10 +951,11 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task CookieChallengeRedirectsWithLoginPath() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.LoginPath = new PathString("/page"); }); + using var server = host.GetTestServer(); var transaction1 = await SendAsync(server, "http://example.com/testpath"); @@ -933,10 +967,11 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task CookieChallengeWithUnauthorizedRedirectsToLoginIfNotAuthenticated() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.LoginPath = new PathString("/page"); }); + using var server = host.GetTestServer(); var transaction1 = await SendAsync(server, "http://example.com/testpath"); @@ -950,17 +985,21 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [InlineData(false)] public async Task MapWillAffectChallengeOnlyWithUseAuth(bool useAuth) { - var builder = new WebHostBuilder() - .Configure(app => - { - if (useAuth) - { - app.UseAuthentication(); - } - app.Map("/login", signoutApp => signoutApp.Run(context => context.ChallengeAsync("Cookies", new AuthenticationProperties() { RedirectUri = "/" }))); - }) - .ConfigureServices(s => s.AddAuthentication().AddCookie(o => o.LoginPath = new PathString("/page"))); - var server = new TestServer(builder); + using var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .Configure(app => + { + if (useAuth) + { + app.UseAuthentication(); + } + app.Map("/login", signoutApp => signoutApp.Run(context => context.ChallengeAsync("Cookies", new AuthenticationProperties() { RedirectUri = "/" }))); + }) + .ConfigureServices(s => s.AddAuthentication().AddCookie(o => o.LoginPath = new PathString("/page")))) + .Build(); + await host.StartAsync(); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com/login"); @@ -981,17 +1020,22 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [ConditionalFact(Skip = "Revisit, exception no longer thrown")] public async Task ChallengeDoesNotSet401OnUnauthorized() { - var builder = new WebHostBuilder() - .Configure(app => - { - app.UseAuthentication(); - app.Run(async context => + using var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .Configure(app => { - await Assert.ThrowsAsync<InvalidOperationException>(() => context.ChallengeAsync(CookieAuthenticationDefaults.AuthenticationScheme)); - }); - }) - .ConfigureServices(services => services.AddAuthentication().AddCookie()); - var server = new TestServer(builder); + app.UseAuthentication(); + app.Run(async context => + { + await Assert.ThrowsAsync<InvalidOperationException>(() => context.ChallengeAsync(CookieAuthenticationDefaults.AuthenticationScheme)); + }); + }) + .ConfigureServices(services => services.AddAuthentication().AddCookie())) + .Build(); + + await host.StartAsync(); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com"); Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode); @@ -1000,19 +1044,24 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task CanConfigureDefaultCookieInstance() { - var builder = new WebHostBuilder() - .Configure(app => - { - app.UseAuthentication(); - app.Run(context => context.SignInAsync(CookieAuthenticationDefaults.AuthenticationScheme, new ClaimsPrincipal(new ClaimsIdentity("whatever")))); - }) - .ConfigureServices(services => - { - services.AddAuthentication().AddCookie(); - services.Configure<CookieAuthenticationOptions>(CookieAuthenticationDefaults.AuthenticationScheme, - o => o.Cookie.Name = "One"); - }); - var server = new TestServer(builder); + using var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .Configure(app => + { + app.UseAuthentication(); + app.Run(context => context.SignInAsync(CookieAuthenticationDefaults.AuthenticationScheme, new ClaimsPrincipal(new ClaimsIdentity("whatever")))); + }) + .ConfigureServices(services => + { + services.AddAuthentication().AddCookie(); + services.Configure<CookieAuthenticationOptions>(CookieAuthenticationDefaults.AuthenticationScheme, + o => o.Cookie.Name = "One"); + })) + .Build(); + + await host.StartAsync(); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com"); @@ -1023,19 +1072,24 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task CanConfigureNamedCookieInstance() { - var builder = new WebHostBuilder() - .Configure(app => - { - app.UseAuthentication(); - app.Run(context => context.SignInAsync("Cookie1", new ClaimsPrincipal(new ClaimsIdentity("whatever")))); - }) - .ConfigureServices(services => - { - services.AddAuthentication().AddCookie("Cookie1"); - services.Configure<CookieAuthenticationOptions>("Cookie1", - o => o.Cookie.Name = "One"); - }); - var server = new TestServer(builder); + using var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .Configure(app => + { + app.UseAuthentication(); + app.Run(context => context.SignInAsync("Cookie1", new ClaimsPrincipal(new ClaimsIdentity("whatever")))); + }) + .ConfigureServices(services => + { + services.AddAuthentication().AddCookie("Cookie1"); + services.Configure<CookieAuthenticationOptions>("Cookie1", + o => o.Cookie.Name = "One"); + })) + .Build(); + + await host.StartAsync(); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com"); @@ -1046,15 +1100,20 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task MapWithSignInOnlyRedirectToReturnUrlOnLoginPath() { - var builder = new WebHostBuilder() - .Configure(app => - { - app.UseAuthentication(); - app.Map("/notlogin", signoutApp => signoutApp.Run(context => context.SignInAsync("Cookies", - new ClaimsPrincipal(new ClaimsIdentity("whatever"))))); - }) - .ConfigureServices(services => services.AddAuthentication().AddCookie(o => o.LoginPath = new PathString("/login"))); - var server = new TestServer(builder); + using var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .Configure(app => + { + app.UseAuthentication(); + app.Map("/notlogin", signoutApp => signoutApp.Run(context => context.SignInAsync("Cookies", + new ClaimsPrincipal(new ClaimsIdentity("whatever"))))); + }) + .ConfigureServices(services => services.AddAuthentication().AddCookie(o => o.LoginPath = new PathString("/login")))) + .Build(); + + await host.StartAsync(); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com/notlogin?ReturnUrl=%2Fpage"); Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode); @@ -1064,15 +1123,19 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task MapWillNotAffectSignInRedirectToReturnUrl() { - var builder = new WebHostBuilder() - .Configure(app => - { - app.UseAuthentication(); - app.Map("/login", signoutApp => signoutApp.Run(context => context.SignInAsync("Cookies", new ClaimsPrincipal(new ClaimsIdentity("whatever"))))); - }) - .ConfigureServices(services => services.AddAuthentication().AddCookie(o => o.LoginPath = new PathString("/login"))); + using var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .Configure(app => + { + app.UseAuthentication(); + app.Map("/login", signoutApp => signoutApp.Run(context => context.SignInAsync("Cookies", new ClaimsPrincipal(new ClaimsIdentity("whatever"))))); + }) + .ConfigureServices(services => services.AddAuthentication().AddCookie(o => o.LoginPath = new PathString("/login")))) + .Build(); - var server = new TestServer(builder); + await host.StartAsync(); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com/login?ReturnUrl=%2Fpage"); @@ -1086,14 +1149,19 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task MapWithSignOutOnlyRedirectToReturnUrlOnLogoutPath() { - var builder = new WebHostBuilder() - .Configure(app => - { - app.UseAuthentication(); - app.Map("/notlogout", signoutApp => signoutApp.Run(context => context.SignOutAsync("Cookies"))); - }) - .ConfigureServices(services => services.AddAuthentication().AddCookie(o => o.LogoutPath = new PathString("/logout"))); - var server = new TestServer(builder); + using var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .Configure(app => + { + app.UseAuthentication(); + app.Map("/notlogout", signoutApp => signoutApp.Run(context => context.SignOutAsync("Cookies"))); + }) + .ConfigureServices(services => services.AddAuthentication().AddCookie(o => o.LogoutPath = new PathString("/logout")))) + .Build(); + + await host.StartAsync(); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com/notlogout?ReturnUrl=%2Fpage"); Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode); @@ -1103,14 +1171,19 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task MapWillNotAffectSignOutRedirectToReturnUrl() { - var builder = new WebHostBuilder() - .Configure(app => - { - app.UseAuthentication(); - app.Map("/logout", signoutApp => signoutApp.Run(context => context.SignOutAsync("Cookies"))); - }) - .ConfigureServices(services => services.AddAuthentication().AddCookie(o => o.LogoutPath = new PathString("/logout"))); - var server = new TestServer(builder); + using var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .Configure(app => + { + app.UseAuthentication(); + app.Map("/logout", signoutApp => signoutApp.Run(context => context.SignOutAsync("Cookies"))); + }) + .ConfigureServices(services => services.AddAuthentication().AddCookie(o => o.LogoutPath = new PathString("/logout")))) + .Build(); + + await host.StartAsync(); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com/logout?ReturnUrl=%2Fpage"); @@ -1124,14 +1197,19 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task MapWillNotAffectAccessDenied() { - var builder = new WebHostBuilder() - .Configure(app => - { - app.UseAuthentication(); - app.Map("/forbid", signoutApp => signoutApp.Run(context => context.ForbidAsync("Cookies"))); - }) - .ConfigureServices(services => services.AddAuthentication().AddCookie(o => o.AccessDeniedPath = new PathString("/denied"))); - var server = new TestServer(builder); + using var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .Configure(app => + { + app.UseAuthentication(); + app.Map("/forbid", signoutApp => signoutApp.Run(context => context.ForbidAsync("Cookies"))); + }) + .ConfigureServices(services => services.AddAuthentication().AddCookie(o => o.AccessDeniedPath = new PathString("/denied")))) + .Build(); + + await host.StartAsync(); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com/forbid"); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); @@ -1143,15 +1221,20 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task NestedMapWillNotAffectLogin() { - var builder = new WebHostBuilder() - .Configure(app => - app.Map("/base", map => - { - map.UseAuthentication(); - map.Map("/login", signoutApp => signoutApp.Run(context => context.ChallengeAsync("Cookies", new AuthenticationProperties() { RedirectUri = "/" }))); - })) - .ConfigureServices(services => services.AddAuthentication().AddCookie(o => o.LoginPath = new PathString("/page"))); - var server = new TestServer(builder); + using var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .Configure(app => + app.Map("/base", map => + { + map.UseAuthentication(); + map.Map("/login", signoutApp => signoutApp.Run(context => context.ChallengeAsync("Cookies", new AuthenticationProperties() { RedirectUri = "/" }))); + })) + .ConfigureServices(services => services.AddAuthentication().AddCookie(o => o.LoginPath = new PathString("/page")))) + .Build(); + + await host.StartAsync(); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com/base/login"); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); @@ -1166,7 +1249,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [InlineData("http://example.com/redirect_to")] public async Task RedirectUriIsHoneredAfterSignin(string redirectUrl) { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.LoginPath = "/testpath"; o.Cookie.Name = "TestCookie"; @@ -1177,6 +1260,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies new ClaimsPrincipal(new ClaimsIdentity(new GenericIdentity("Alice", CookieAuthenticationDefaults.AuthenticationScheme))), new AuthenticationProperties { RedirectUri = redirectUrl }) ); + using var server = host.GetTestServer(); var transaction = await SendAsync(server, "http://example.com/testpath"); Assert.NotEmpty(transaction.SetCookie); @@ -1187,7 +1271,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task RedirectUriInQueryIsHoneredAfterSignin() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.LoginPath = "/testpath"; o.ReturnUrlParameter = "return"; @@ -1199,6 +1283,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies CookieAuthenticationDefaults.AuthenticationScheme, new ClaimsPrincipal(new ClaimsIdentity(new GenericIdentity("Alice", CookieAuthenticationDefaults.AuthenticationScheme)))); }); + using var server = host.GetTestServer(); var transaction = await SendAsync(server, "http://example.com/testpath?return=%2Fret_path_2"); Assert.NotEmpty(transaction.SetCookie); @@ -1209,7 +1294,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task AbsoluteRedirectUriInQueryStringIsRejected() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.LoginPath = "/testpath"; o.ReturnUrlParameter = "return"; @@ -1221,6 +1306,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies CookieAuthenticationDefaults.AuthenticationScheme, new ClaimsPrincipal(new ClaimsIdentity(new GenericIdentity("Alice", CookieAuthenticationDefaults.AuthenticationScheme)))); }); + using var server = host.GetTestServer(); var transaction = await SendAsync(server, "http://example.com/testpath?return=http%3A%2F%2Fexample.com%2Fredirect_to"); Assert.NotEmpty(transaction.SetCookie); @@ -1230,7 +1316,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task EnsurePrecedenceOfRedirectUriAfterSignin() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.LoginPath = "/testpath"; o.ReturnUrlParameter = "return"; @@ -1243,6 +1329,7 @@ namespace Microsoft.AspNetCore.Authentication.Cookies new ClaimsPrincipal(new ClaimsIdentity(new GenericIdentity("Alice", CookieAuthenticationDefaults.AuthenticationScheme))), new AuthenticationProperties { RedirectUri = "/redirect_test" }); }); + using var server = host.GetTestServer(); var transaction = await SendAsync(server, "http://example.com/testpath?return=%2Fret_path_2"); Assert.NotEmpty(transaction.SetCookie); @@ -1253,15 +1340,19 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task NestedMapWillNotAffectAccessDenied() { - var builder = new WebHostBuilder() - .Configure(app => - app.Map("/base", map => - { - map.UseAuthentication(); - map.Map("/forbid", signoutApp => signoutApp.Run(context => context.ForbidAsync("Cookies"))); - })) - .ConfigureServices(services => services.AddAuthentication().AddCookie(o => o.AccessDeniedPath = new PathString("/denied"))); - var server = new TestServer(builder); + using var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .Configure(app => + app.Map("/base", map => + { + map.UseAuthentication(); + map.Map("/forbid", signoutApp => signoutApp.Run(context => context.ForbidAsync("Cookies"))); + })) + .ConfigureServices(services => services.AddAuthentication().AddCookie(o => o.AccessDeniedPath = new PathString("/denied")))) + .Build(); + await host.StartAsync(); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com/base/forbid"); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); @@ -1273,43 +1364,50 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task CanSpecifyAndShareDataProtector() { - var dp = new NoOpDataProtector(); - var builder1 = new WebHostBuilder() - .Configure(app => - { - app.UseAuthentication(); - app.Use((context, next) => - context.SignInAsync("Cookies", - new ClaimsPrincipal(new ClaimsIdentity(new GenericIdentity("Alice", "Cookies"))), - new AuthenticationProperties())); - }) - .ConfigureServices(services => services.AddAuthentication().AddCookie(o => - { - o.TicketDataFormat = new TicketDataFormat(dp); - o.Cookie.Name = "Cookie"; - })); - var server1 = new TestServer(builder1); + using var host1 = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .Configure(app => + { + app.UseAuthentication(); + app.Use((context, next) => + context.SignInAsync("Cookies", + new ClaimsPrincipal(new ClaimsIdentity(new GenericIdentity("Alice", "Cookies"))), + new AuthenticationProperties())); + }) + .ConfigureServices(services => services.AddAuthentication().AddCookie(o => + { + o.TicketDataFormat = new TicketDataFormat(dp); + o.Cookie.Name = "Cookie"; + }))) + .Build(); + await host1.StartAsync(); + using var server1 = host1.GetTestServer(); ; var transaction = await SendAsync(server1, "http://example.com/stuff"); Assert.NotNull(transaction.SetCookie); - var builder2 = new WebHostBuilder() - .Configure(app => - { - app.UseAuthentication(); - app.Use(async (context, next) => + using var host2 = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .Configure(app => { - var result = await context.AuthenticateAsync("Cookies"); - await DescribeAsync(context.Response, result); - }); - }) - .ConfigureServices(services => services.AddAuthentication().AddCookie("Cookies", o => - { - o.Cookie.Name = "Cookie"; - o.TicketDataFormat = new TicketDataFormat(dp); - })); - var server2 = new TestServer(builder2); + app.UseAuthentication(); + app.Use(async (context, next) => + { + var result = await context.AuthenticateAsync("Cookies"); + await DescribeAsync(context.Response, result); + }); + }) + .ConfigureServices(services => services.AddAuthentication().AddCookie("Cookies", o => + { + o.Cookie.Name = "Cookie"; + o.TicketDataFormat = new TicketDataFormat(dp); + }))) + .Build(); + await host2.StartAsync(); + using var server2 = host2.GetTestServer(); var transaction2 = await SendAsync(server2, "http://example.com/stuff", transaction.CookieNameValue); Assert.Equal("Alice", FindClaimValue(transaction2, ClaimTypes.Name)); } @@ -1318,39 +1416,43 @@ namespace Microsoft.AspNetCore.Authentication.Cookies [Fact] public async Task NullExpiresUtcPropertyIsGuarded() { - var builder = new WebHostBuilder() - .ConfigureServices(services => services.AddAuthentication().AddCookie(o => - { - o.Events = new CookieAuthenticationEvents + using var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .ConfigureServices(services => services.AddAuthentication().AddCookie(o => { - OnValidatePrincipal = context => + o.Events = new CookieAuthenticationEvents { - context.Properties.ExpiresUtc = null; - context.ShouldRenew = true; - return Task.FromResult(0); - } - }; - })) - .Configure(app => - { - app.UseAuthentication(); - - app.Run(async context => + OnValidatePrincipal = context => + { + context.Properties.ExpiresUtc = null; + context.ShouldRenew = true; + return Task.FromResult(0); + } + }; + })) + .Configure(app => { - if (context.Request.Path == "/signin") - { - await context.SignInAsync( - CookieAuthenticationDefaults.AuthenticationScheme, - new ClaimsPrincipal(new ClaimsIdentity(new GenericIdentity("Alice", "Cookies")))); - } - else + app.UseAuthentication(); + + app.Run(async context => { - await context.Response.WriteAsync("ha+1"); - } - }); - }); + if (context.Request.Path == "/signin") + { + await context.SignInAsync( + CookieAuthenticationDefaults.AuthenticationScheme, + new ClaimsPrincipal(new ClaimsIdentity(new GenericIdentity("Alice", "Cookies")))); + } + else + { + await context.Response.WriteAsync("ha+1"); + } + }); + })) + .Build(); - var server = new TestServer(builder); + await host.StartAsync(); + using var server = host.GetTestServer(); var cookie = (await server.SendAsync("http://www.example.com/signin")).SetCookie.FirstOrDefault(); Assert.NotNull(cookie); @@ -1429,8 +1531,8 @@ namespace Microsoft.AspNetCore.Authentication.Cookies } } - private TestServer CreateServer(Action<CookieAuthenticationOptions> configureOptions, Func<HttpContext, Task> testpath = null, Uri baseAddress = null, bool claimsTransform = false) - => CreateServerWithServices(s => + private Task<IHost> CreateHost(Action<CookieAuthenticationOptions> configureOptions, Func<HttpContext, Task> testpath = null, Uri baseAddress = null, bool claimsTransform = false) + => CreateHostWithServices(s => { s.AddSingleton<ISystemClock>(_clock); s.AddAuthentication(CookieAuthenticationDefaults.AuthenticationScheme).AddCookie(configureOptions); @@ -1440,73 +1542,79 @@ namespace Microsoft.AspNetCore.Authentication.Cookies } }, testpath, baseAddress); - private static TestServer CreateServerWithServices(Action<IServiceCollection> configureServices, Func<HttpContext, Task> testpath = null, Uri baseAddress = null) + private static async Task<IHost> CreateHostWithServices(Action<IServiceCollection> configureServices, Func<HttpContext, Task> testpath = null, Uri baseAddress = null) { - var builder = new WebHostBuilder() - .Configure(app => - { - app.UseAuthentication(); - app.Use(async (context, next) => - { - var req = context.Request; - var res = context.Response; - PathString remainder; - if (req.Path == new PathString("/normal")) + var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .Configure(app => { - res.StatusCode = 200; - } - else if (req.Path == new PathString("/forbid")) // Simulate forbidden - { - await context.ForbidAsync(CookieAuthenticationDefaults.AuthenticationScheme); - } - else if (req.Path == new PathString("/challenge")) - { - await context.ChallengeAsync(CookieAuthenticationDefaults.AuthenticationScheme); - } - else if (req.Path == new PathString("/signout")) - { - await context.SignOutAsync(CookieAuthenticationDefaults.AuthenticationScheme); - } - else if (req.Path == new PathString("/unauthorized")) - { - await context.ChallengeAsync(CookieAuthenticationDefaults.AuthenticationScheme, new AuthenticationProperties()); - } - else if (req.Path == new PathString("/protected/CustomRedirect")) - { - await context.ChallengeAsync(CookieAuthenticationDefaults.AuthenticationScheme, new AuthenticationProperties() { RedirectUri = "/CustomRedirect" }); - } - else if (req.Path == new PathString("/me")) - { - await DescribeAsync(res, AuthenticateResult.Success(new AuthenticationTicket(context.User, new AuthenticationProperties(), CookieAuthenticationDefaults.AuthenticationScheme))); - } - else if (req.Path.StartsWithSegments(new PathString("/me"), out remainder)) - { - var ticket = await context.AuthenticateAsync(remainder.Value.Substring(1)); - await DescribeAsync(res, ticket); - } - else if (req.Path == new PathString("/testpath") && testpath != null) - { - await testpath(context); - } - else if (req.Path == new PathString("/checkforerrors")) - { - var result = await context.AuthenticateAsync(CookieAuthenticationDefaults.AuthenticationScheme); // this used to be "Automatic" - if (result.Failure != null) + app.UseAuthentication(); + app.Use(async (context, next) => { - throw new Exception("Failed to authenticate", result.Failure); - } - return; - } - else - { - await next(); - } - }); - }) - .ConfigureServices(configureServices); - var server = new TestServer(builder); + var req = context.Request; + var res = context.Response; + PathString remainder; + if (req.Path == new PathString("/normal")) + { + res.StatusCode = 200; + } + else if (req.Path == new PathString("/forbid")) // Simulate forbidden + { + await context.ForbidAsync(CookieAuthenticationDefaults.AuthenticationScheme); + } + else if (req.Path == new PathString("/challenge")) + { + await context.ChallengeAsync(CookieAuthenticationDefaults.AuthenticationScheme); + } + else if (req.Path == new PathString("/signout")) + { + await context.SignOutAsync(CookieAuthenticationDefaults.AuthenticationScheme); + } + else if (req.Path == new PathString("/unauthorized")) + { + await context.ChallengeAsync(CookieAuthenticationDefaults.AuthenticationScheme, new AuthenticationProperties()); + } + else if (req.Path == new PathString("/protected/CustomRedirect")) + { + await context.ChallengeAsync(CookieAuthenticationDefaults.AuthenticationScheme, new AuthenticationProperties() { RedirectUri = "/CustomRedirect" }); + } + else if (req.Path == new PathString("/me")) + { + await DescribeAsync(res, AuthenticateResult.Success(new AuthenticationTicket(context.User, new AuthenticationProperties(), CookieAuthenticationDefaults.AuthenticationScheme))); + } + else if (req.Path.StartsWithSegments(new PathString("/me"), out remainder)) + { + var ticket = await context.AuthenticateAsync(remainder.Value.Substring(1)); + await DescribeAsync(res, ticket); + } + else if (req.Path == new PathString("/testpath") && testpath != null) + { + await testpath(context); + } + else if (req.Path == new PathString("/checkforerrors")) + { + var result = await context.AuthenticateAsync(CookieAuthenticationDefaults.AuthenticationScheme); // this used to be "Automatic" + if (result.Failure != null) + { + throw new Exception("Failed to authenticate", result.Failure); + } + return; + } + else + { + await next(); + } + }); + }) + .ConfigureServices(configureServices)) + .Build(); + + await host.StartAsync(); + + var server = host.GetTestServer(); server.BaseAddress = baseAddress; - return server; + return host; } private static Task DescribeAsync(HttpResponse res, AuthenticateResult result) diff --git a/src/Security/Authentication/test/DynamicSchemeTests.cs b/src/Security/Authentication/test/DynamicSchemeTests.cs index 6df65f66adf..62f250f369b 100644 --- a/src/Security/Authentication/test/DynamicSchemeTests.cs +++ b/src/Security/Authentication/test/DynamicSchemeTests.cs @@ -10,6 +10,7 @@ using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Xunit; @@ -21,12 +22,13 @@ namespace Microsoft.AspNetCore.Authentication [Fact] public async Task OptionsAreConfiguredOnce() { - var server = CreateServer(s => + using var host = await CreateHost(s => { s.Configure<TestOptions>("One", o => o.Instance = new Singleton()); s.Configure<TestOptions>("Two", o => o.Instance = new Singleton()); }); // Add One scheme + using var server = host.GetTestServer(); var response = await server.CreateClient().GetAsync("http://example.com/add/One"); Assert.Equal(HttpStatusCode.OK, response.StatusCode); var transaction = await server.SendAsync("http://example.com/auth/One"); @@ -57,7 +59,8 @@ namespace Microsoft.AspNetCore.Authentication [Fact] public async Task CanAddAndRemoveSchemes() { - var server = CreateServer(); + using var host = await CreateHost(); + using var server = host.GetTestServer(); await Assert.ThrowsAsync<InvalidOperationException>(() => server.SendAsync("http://example.com/auth/One")); // Add One scheme @@ -124,47 +127,52 @@ namespace Microsoft.AspNetCore.Authentication } } - private static TestServer CreateServer(Action<IServiceCollection> configureServices = null) + private static async Task<IHost> CreateHost(Action<IServiceCollection> configureServices = null) { - var builder = new WebHostBuilder() - .Configure(app => - { - app.UseAuthentication(); - app.Use(async (context, next) => - { - var req = context.Request; - var res = context.Response; - if (req.Path.StartsWithSegments(new PathString("/add"), out var remainder)) - { - var name = remainder.Value.Substring(1); - var auth = context.RequestServices.GetRequiredService<IAuthenticationSchemeProvider>(); - var scheme = new AuthenticationScheme(name, name, typeof(TestHandler)); - auth.AddScheme(scheme); - } - else if (req.Path.StartsWithSegments(new PathString("/auth"), out remainder)) - { - var name = (remainder.Value.Length > 0) ? remainder.Value.Substring(1) : null; - var result = await context.AuthenticateAsync(name); - await res.DescribeAsync(result?.Ticket?.Principal); - } - else if (req.Path.StartsWithSegments(new PathString("/remove"), out remainder)) + var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .Configure(app => + { + app.UseAuthentication(); + app.Use(async (context, next) => + { + var req = context.Request; + var res = context.Response; + if (req.Path.StartsWithSegments(new PathString("/add"), out var remainder)) + { + var name = remainder.Value.Substring(1); + var auth = context.RequestServices.GetRequiredService<IAuthenticationSchemeProvider>(); + var scheme = new AuthenticationScheme(name, name, typeof(TestHandler)); + auth.AddScheme(scheme); + } + else if (req.Path.StartsWithSegments(new PathString("/auth"), out remainder)) + { + var name = (remainder.Value.Length > 0) ? remainder.Value.Substring(1) : null; + var result = await context.AuthenticateAsync(name); + await res.DescribeAsync(result?.Ticket?.Principal); + } + else if (req.Path.StartsWithSegments(new PathString("/remove"), out remainder)) + { + var name = remainder.Value.Substring(1); + var auth = context.RequestServices.GetRequiredService<IAuthenticationSchemeProvider>(); + auth.RemoveScheme(name); + } + else + { + await next(); + } + }); + }) + .ConfigureServices(services => { - var name = remainder.Value.Substring(1); - var auth = context.RequestServices.GetRequiredService<IAuthenticationSchemeProvider>(); - auth.RemoveScheme(name); - } - else - { - await next(); - } - }); - }) - .ConfigureServices(services => - { - configureServices?.Invoke(services); - services.AddAuthentication(); - }); - return new TestServer(builder); + configureServices?.Invoke(services); + services.AddAuthentication(); + })) + .Build(); + + await host.StartAsync(); + return host; } } } diff --git a/src/Security/Authentication/test/FacebookTests.cs b/src/Security/Authentication/test/FacebookTests.cs index 66a36057579..d2c33019e4c 100644 --- a/src/Security/Authentication/test/FacebookTests.cs +++ b/src/Security/Authentication/test/FacebookTests.cs @@ -9,6 +9,7 @@ using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging.Abstractions; using System; using System.Linq; @@ -47,7 +48,7 @@ namespace Microsoft.AspNetCore.Authentication.Facebook [Fact] public async Task ThrowsIfAppIdMissing() { - var server = CreateServer( + using var host = await CreateHost( app => { }, services => services.AddAuthentication().AddFacebook(o => o.SignInScheme = "Whatever"), async context => @@ -55,6 +56,7 @@ namespace Microsoft.AspNetCore.Authentication.Facebook await Assert.ThrowsAsync<ArgumentException>("AppId", () => context.ChallengeAsync("Facebook")); return true; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com/challenge"); Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode); } @@ -62,7 +64,7 @@ namespace Microsoft.AspNetCore.Authentication.Facebook [Fact] public async Task ThrowsIfAppSecretMissing() { - var server = CreateServer( + using var host = await CreateHost( app => { }, services => services.AddAuthentication().AddFacebook(o => o.AppId = "Whatever"), async context => @@ -70,6 +72,7 @@ namespace Microsoft.AspNetCore.Authentication.Facebook await Assert.ThrowsAsync<ArgumentException>("AppSecret", () => context.ChallengeAsync("Facebook")); return true; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com/challenge"); Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode); } @@ -77,7 +80,7 @@ namespace Microsoft.AspNetCore.Authentication.Facebook [Fact] public async Task ChallengeWillTriggerApplyRedirectEvent() { - var server = CreateServer( + using var host = await CreateHost( app => { app.UseAuthentication(); @@ -105,6 +108,7 @@ namespace Microsoft.AspNetCore.Authentication.Facebook await context.ChallengeAsync("Facebook"); return true; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com/challenge"); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); var query = transaction.Response.Headers.Location.Query; @@ -114,7 +118,7 @@ namespace Microsoft.AspNetCore.Authentication.Facebook [Fact] public async Task ChallengeWillIncludeScopeAsConfigured() { - var server = CreateServer( + using var host = await CreateHost( app => app.UseAuthentication(), services => { @@ -133,6 +137,7 @@ namespace Microsoft.AspNetCore.Authentication.Facebook return true; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com/challenge"); var res = transaction.Response; @@ -143,7 +148,7 @@ namespace Microsoft.AspNetCore.Authentication.Facebook [Fact] public async Task ChallengeWillIncludeScopeAsOverwritten() { - var server = CreateServer( + using var host = await CreateHost( app => app.UseAuthentication(), services => { @@ -164,6 +169,7 @@ namespace Microsoft.AspNetCore.Authentication.Facebook return true; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com/challenge"); var res = transaction.Response; @@ -174,7 +180,7 @@ namespace Microsoft.AspNetCore.Authentication.Facebook [Fact] public async Task ChallengeWillIncludeScopeAsOverwrittenWithBaseAuthenticationProperties() { - var server = CreateServer( + using var host = await CreateHost( app => app.UseAuthentication(), services => { @@ -195,6 +201,7 @@ namespace Microsoft.AspNetCore.Authentication.Facebook return true; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com/challenge"); var res = transaction.Response; @@ -205,7 +212,7 @@ namespace Microsoft.AspNetCore.Authentication.Facebook [Fact] public async Task NestedMapWillNotAffectRedirect() { - var server = CreateServer(app => app.Map("/base", map => + using var host = await CreateHost(app => app.Map("/base", map => { map.UseAuthentication(); map.Map("/login", signoutApp => signoutApp.Run(context => context.ChallengeAsync("Facebook", new AuthenticationProperties() { RedirectUri = "/" }))); @@ -222,6 +229,7 @@ namespace Microsoft.AspNetCore.Authentication.Facebook }, handler: null); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com/base/login"); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); var location = transaction.Response.Headers.Location.AbsoluteUri; @@ -236,7 +244,7 @@ namespace Microsoft.AspNetCore.Authentication.Facebook [Fact] public async Task MapWillNotAffectRedirect() { - var server = CreateServer( + using var host = await CreateHost( app => { app.UseAuthentication(); @@ -254,6 +262,7 @@ namespace Microsoft.AspNetCore.Authentication.Facebook }); }, handler: null); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com/login"); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); var location = transaction.Response.Headers.Location.AbsoluteUri; @@ -268,7 +277,7 @@ namespace Microsoft.AspNetCore.Authentication.Facebook [Fact] public async Task ChallengeWillTriggerRedirection() { - var server = CreateServer( + using var host = await CreateHost( app => app.UseAuthentication(), services => { @@ -288,6 +297,7 @@ namespace Microsoft.AspNetCore.Authentication.Facebook await context.ChallengeAsync("Facebook"); return true; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com/challenge"); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); var location = transaction.Response.Headers.Location.AbsoluteUri; @@ -305,7 +315,7 @@ namespace Microsoft.AspNetCore.Authentication.Facebook var customUserInfoEndpoint = "https://graph.facebook.com/me?fields=email,timezone,picture"; var finalUserInfoEndpoint = string.Empty; var stateFormat = new PropertiesDataFormat(new EphemeralDataProtectionProvider(NullLoggerFactory.Instance).CreateProtector("FacebookTest")); - var server = CreateServer( + using var host = await CreateHost( app => app.UseAuthentication(), services => { @@ -350,6 +360,7 @@ namespace Microsoft.AspNetCore.Authentication.Facebook properties.Items.Add(correlationKey, correlationValue); properties.RedirectUri = "/me"; var state = stateFormat.Protect(properties); + using var server = host.GetTestServer(); var transaction = await server.SendAsync( "https://example.com/signin-facebook?code=TestCode&state=" + UrlEncoder.Default.Encode(state), $".AspNetCore.Correlation.Facebook.{correlationValue}=N"); @@ -360,22 +371,27 @@ namespace Microsoft.AspNetCore.Authentication.Facebook Assert.Contains("&access_token=", finalUserInfoEndpoint); } - private static TestServer CreateServer(Action<IApplicationBuilder> configure, Action<IServiceCollection> configureServices, Func<HttpContext, Task<bool>> handler) + private static async Task<IHost> CreateHost(Action<IApplicationBuilder> configure, Action<IServiceCollection> configureServices, Func<HttpContext, Task<bool>> handler) { - var builder = new WebHostBuilder() - .Configure(app => - { - configure?.Invoke(app); - app.Use(async (context, next) => - { - if (handler == null || !await handler(context)) + var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .Configure(app => { - await next(); - } - }); - }) - .ConfigureServices(configureServices); - return new TestServer(builder); + configure?.Invoke(app); + app.Use(async (context, next) => + { + if (handler == null || !await handler(context)) + { + await next(); + } + }); + }) + .ConfigureServices(configureServices)) + .Build(); + + await host.StartAsync(); + return host; } } } diff --git a/src/Security/Authentication/test/GoogleTests.cs b/src/Security/Authentication/test/GoogleTests.cs index 98e001d7d4f..27d90adaadc 100644 --- a/src/Security/Authentication/test/GoogleTests.cs +++ b/src/Security/Authentication/test/GoogleTests.cs @@ -8,6 +8,7 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.AspNetCore.WebUtilities; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging.Abstractions; using System; using System.Collections.Generic; @@ -48,11 +49,12 @@ namespace Microsoft.AspNetCore.Authentication.Google [Fact] public async Task ChallengeWillTriggerRedirection() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://example.com/challenge"); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); var location = transaction.Response.Headers.Location.ToString(); @@ -72,11 +74,12 @@ namespace Microsoft.AspNetCore.Authentication.Google [Fact] public async Task SignInThrows() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://example.com/signIn"); Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode); } @@ -84,11 +87,12 @@ namespace Microsoft.AspNetCore.Authentication.Google [Fact] public async Task SignOutThrows() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://example.com/signOut"); Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode); } @@ -96,11 +100,12 @@ namespace Microsoft.AspNetCore.Authentication.Google [Fact] public async Task ForbidThrows() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://example.com/signOut"); Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode); } @@ -108,11 +113,12 @@ namespace Microsoft.AspNetCore.Authentication.Google [Fact] public async Task Challenge401WillNotTriggerRedirection() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://example.com/401"); Assert.Equal(HttpStatusCode.Unauthorized, transaction.Response.StatusCode); } @@ -120,11 +126,12 @@ namespace Microsoft.AspNetCore.Authentication.Google [Fact] public async Task ChallengeWillSetCorrelationCookie() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://example.com/challenge"); Assert.Contains(transaction.SetCookie, cookie => cookie.StartsWith(".AspNetCore.Correlation.Google.")); } @@ -132,11 +139,12 @@ namespace Microsoft.AspNetCore.Authentication.Google [Fact] public async Task ChallengeWillSetDefaultScope() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://example.com/challenge"); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); var query = transaction.Response.Headers.Location.Query; @@ -147,7 +155,7 @@ namespace Microsoft.AspNetCore.Authentication.Google public async Task ChallengeWillUseAuthenticationPropertiesParametersAsQueryArguments() { var stateFormat = new PropertiesDataFormat(new EphemeralDataProtectionProvider(NullLoggerFactory.Instance).CreateProtector("GoogleTest")); - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; @@ -172,6 +180,7 @@ namespace Microsoft.AspNetCore.Authentication.Google return Task.FromResult<object>(null); }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://example.com/challenge2"); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); @@ -198,7 +207,7 @@ namespace Microsoft.AspNetCore.Authentication.Google public async Task ChallengeWillUseAuthenticationPropertiesItemsAsParameters() { var stateFormat = new PropertiesDataFormat(new EphemeralDataProtectionProvider(NullLoggerFactory.Instance).CreateProtector("GoogleTest")); - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; @@ -223,6 +232,7 @@ namespace Microsoft.AspNetCore.Authentication.Google return Task.FromResult<object>(null); }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://example.com/challenge2"); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); @@ -249,7 +259,7 @@ namespace Microsoft.AspNetCore.Authentication.Google public async Task ChallengeWillUseAuthenticationPropertiesItemsAsQueryArgumentsButParametersWillOverwrite() { var stateFormat = new PropertiesDataFormat(new EphemeralDataProtectionProvider(NullLoggerFactory.Instance).CreateProtector("GoogleTest")); - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; @@ -278,6 +288,7 @@ namespace Microsoft.AspNetCore.Authentication.Google return Task.FromResult<object>(null); }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://example.com/challenge2"); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); @@ -303,7 +314,7 @@ namespace Microsoft.AspNetCore.Authentication.Google [Fact] public async Task ChallengeWillTriggerApplyRedirectEvent() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; @@ -316,6 +327,7 @@ namespace Microsoft.AspNetCore.Authentication.Google } }; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://example.com/challenge"); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); var query = transaction.Response.Headers.Location.Query; @@ -325,7 +337,7 @@ namespace Microsoft.AspNetCore.Authentication.Google [Fact] public async Task AuthenticateWithoutCookieWillFail() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; @@ -340,6 +352,7 @@ namespace Microsoft.AspNetCore.Authentication.Google Assert.NotNull(result.Failure); } }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://example.com/auth"); Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode); } @@ -347,11 +360,12 @@ namespace Microsoft.AspNetCore.Authentication.Google [Fact] public async Task ReplyPathWithoutStateQueryStringWillBeRejected() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; }); + using var server = host.GetTestServer(); var error = await Assert.ThrowsAnyAsync<Exception>(() => server.SendAsync("https://example.com/signin-google?code=TestCode")); Assert.Equal("The oauth state was missing or invalid.", error.GetBaseException().Message); } @@ -361,7 +375,7 @@ namespace Microsoft.AspNetCore.Authentication.Google [InlineData(false)] public async Task ReplyPathWithAccessDeniedErrorFails(bool redirect) { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; @@ -376,6 +390,7 @@ namespace Microsoft.AspNetCore.Authentication.Google } } : new OAuthEvents(); }); + using var server = host.GetTestServer(); var sendTask = server.SendAsync("https://example.com/signin-google?error=access_denied&error_description=SoBad&error_uri=foobar&state=protected_state", ".AspNetCore.Correlation.Google.correlationId=N"); if (redirect) @@ -394,7 +409,7 @@ namespace Microsoft.AspNetCore.Authentication.Google [Fact] public async Task ReplyPathWithAccessDeniedError_AllowsCustomizingPath() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; @@ -414,6 +429,7 @@ namespace Microsoft.AspNetCore.Authentication.Google } }; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://example.com/signin-google?error=access_denied&error_description=SoBad&error_uri=foobar&state=protected_state", ".AspNetCore.Correlation.Google.correlationId=N"); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); @@ -425,7 +441,7 @@ namespace Microsoft.AspNetCore.Authentication.Google { var accessDeniedCalled = false; var remoteFailureCalled = false; - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; @@ -456,6 +472,7 @@ namespace Microsoft.AspNetCore.Authentication.Google } }; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://example.com/signin-google?error=access_denied&error_description=whyitfailed&error_uri=https://example.com/fail&state=protected_state", ".AspNetCore.Correlation.Google.correlationId=N"); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); @@ -469,7 +486,7 @@ namespace Microsoft.AspNetCore.Authentication.Google [InlineData(false)] public async Task ReplyPathWithErrorFails(bool redirect) { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; @@ -491,6 +508,7 @@ namespace Microsoft.AspNetCore.Authentication.Google } } : new OAuthEvents(); }); + using var server = host.GetTestServer(); var sendTask = server.SendAsync("https://example.com/signin-google?error=itfailed&error_description=whyitfailed&error_uri=https://example.com/fail&state=protected_state", ".AspNetCore.Correlation.Google.correlationId=N"); if (redirect) @@ -512,7 +530,7 @@ namespace Microsoft.AspNetCore.Authentication.Google public async Task ReplyPathWillAuthenticateValidAuthorizeCodeAndState(string claimsIssuer) { var stateFormat = new PropertiesDataFormat(new EphemeralDataProtectionProvider(NullLoggerFactory.Instance).CreateProtector("GoogleTest")); - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; @@ -531,6 +549,7 @@ namespace Microsoft.AspNetCore.Authentication.Google properties.Items.Add(correlationKey, correlationValue); properties.RedirectUri = "/me"; var state = stateFormat.Protect(properties); + using var server = host.GetTestServer(); var transaction = await server.SendAsync( "https://example.com/signin-google?code=TestCode&state=" + UrlEncoder.Default.Encode(state), $".AspNetCore.Correlation.Google.{correlationValue}=N"); @@ -567,7 +586,7 @@ namespace Microsoft.AspNetCore.Authentication.Google public async Task ReplyPathWillThrowIfCodeIsInvalid(bool redirect) { var stateFormat = new PropertiesDataFormat(new EphemeralDataProtectionProvider(NullLoggerFactory.Instance).CreateProtector("GoogleTest")); - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; @@ -597,6 +616,7 @@ namespace Microsoft.AspNetCore.Authentication.Google properties.RedirectUri = "/me"; var state = stateFormat.Protect(properties); + using var server = host.GetTestServer(); var sendTask = server.SendAsync( "https://example.com/signin-google?code=TestCode&state=" + UrlEncoder.Default.Encode(state), $".AspNetCore.Correlation.Google.{correlationValue}=N"); @@ -620,7 +640,7 @@ namespace Microsoft.AspNetCore.Authentication.Google public async Task ReplyPathWillRejectIfAccessTokenIsMissing(bool redirect) { var stateFormat = new PropertiesDataFormat(new EphemeralDataProtectionProvider(NullLoggerFactory.Instance).CreateProtector("GoogleTest")); - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; @@ -648,6 +668,7 @@ namespace Microsoft.AspNetCore.Authentication.Google properties.Items.Add(correlationKey, correlationValue); properties.RedirectUri = "/me"; var state = stateFormat.Protect(properties); + using var server = host.GetTestServer(); var sendTask = server.SendAsync( "https://example.com/signin-google?code=TestCode&state=" + UrlEncoder.Default.Encode(state), $".AspNetCore.Correlation.Google.{correlationValue}=N"); @@ -669,7 +690,7 @@ namespace Microsoft.AspNetCore.Authentication.Google public async Task AuthenticatedEventCanGetRefreshToken() { var stateFormat = new PropertiesDataFormat(new EphemeralDataProtectionProvider(NullLoggerFactory.Instance).CreateProtector("GoogleTest")); - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; @@ -691,6 +712,7 @@ namespace Microsoft.AspNetCore.Authentication.Google properties.Items.Add(correlationKey, correlationValue); properties.RedirectUri = "/me"; var state = stateFormat.Protect(properties); + using var server = host.GetTestServer(); var transaction = await server.SendAsync( "https://example.com/signin-google?code=TestCode&state=" + UrlEncoder.Default.Encode(state), $".AspNetCore.Correlation.Google.{correlationValue}=N"); @@ -710,7 +732,7 @@ namespace Microsoft.AspNetCore.Authentication.Google public async Task NullRedirectUriWillRedirectToSlash() { var stateFormat = new PropertiesDataFormat(new EphemeralDataProtectionProvider(NullLoggerFactory.Instance).CreateProtector("GoogleTest")); - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; @@ -730,6 +752,7 @@ namespace Microsoft.AspNetCore.Authentication.Google var correlationValue = "TestCorrelationId"; properties.Items.Add(correlationKey, correlationValue); var state = stateFormat.Protect(properties); + using var server = host.GetTestServer(); var transaction = await server.SendAsync( "https://example.com/signin-google?code=TestCode&state=" + UrlEncoder.Default.Encode(state), $".AspNetCore.Correlation.Google.{correlationValue}=N"); @@ -744,7 +767,7 @@ namespace Microsoft.AspNetCore.Authentication.Google public async Task ValidateAuthenticatedContext() { var stateFormat = new PropertiesDataFormat(new EphemeralDataProtectionProvider(NullLoggerFactory.Instance).CreateProtector("GoogleTest")); - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; @@ -776,6 +799,7 @@ namespace Microsoft.AspNetCore.Authentication.Google var state = stateFormat.Protect(properties); //Post a message to the Google middleware + using var server = host.GetTestServer(); var transaction = await server.SendAsync( "https://example.com/signin-google?code=TestCode&state=" + UrlEncoder.Default.Encode(state), $".AspNetCore.Correlation.Google.{correlationValue}=N"); @@ -787,13 +811,14 @@ namespace Microsoft.AspNetCore.Authentication.Google [Fact] public async Task NoStateCausesException() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; }); //Post a message to the Google middleware + using var server = host.GetTestServer(); var error = await Assert.ThrowsAnyAsync<Exception>(() => server.SendAsync("https://example.com/signin-google?code=TestCode")); Assert.Equal("The oauth state was missing or invalid.", error.GetBaseException().Message); } @@ -802,7 +827,7 @@ namespace Microsoft.AspNetCore.Authentication.Google public async Task CanRedirectOnError() { var stateFormat = new PropertiesDataFormat(new EphemeralDataProtectionProvider(NullLoggerFactory.Instance).CreateProtector("GoogleTest")); - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; @@ -818,6 +843,7 @@ namespace Microsoft.AspNetCore.Authentication.Google }); //Post a message to the Google middleware + using var server = host.GetTestServer(); var transaction = await server.SendAsync( "https://example.com/signin-google?code=TestCode"); @@ -830,7 +856,7 @@ namespace Microsoft.AspNetCore.Authentication.Google public async Task AuthenticateAutomaticWhenAlreadySignedInSucceeds() { var stateFormat = new PropertiesDataFormat(new EphemeralDataProtectionProvider(NullLoggerFactory.Instance).CreateProtector("GoogleTest")); - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; @@ -847,6 +873,7 @@ namespace Microsoft.AspNetCore.Authentication.Google properties.Items.Add(correlationKey, correlationValue); properties.RedirectUri = "/me"; var state = stateFormat.Protect(properties); + using var server = host.GetTestServer(); var transaction = await server.SendAsync( "https://example.com/signin-google?code=TestCode&state=" + UrlEncoder.Default.Encode(state), $".AspNetCore.Correlation.Google.{correlationValue}=N"); @@ -873,7 +900,7 @@ namespace Microsoft.AspNetCore.Authentication.Google public async Task AuthenticateGoogleWhenAlreadySignedInSucceeds() { var stateFormat = new PropertiesDataFormat(new EphemeralDataProtectionProvider(NullLoggerFactory.Instance).CreateProtector("GoogleTest")); - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; @@ -890,6 +917,7 @@ namespace Microsoft.AspNetCore.Authentication.Google properties.Items.Add(correlationKey, correlationValue); properties.RedirectUri = "/me"; var state = stateFormat.Protect(properties); + using var server = host.GetTestServer(); var transaction = await server.SendAsync( "https://example.com/signin-google?code=TestCode&state=" + UrlEncoder.Default.Encode(state), $".AspNetCore.Correlation.Google.{correlationValue}=N"); @@ -916,7 +944,7 @@ namespace Microsoft.AspNetCore.Authentication.Google public async Task AuthenticateFacebookWhenAlreadySignedWithGoogleReturnsNull() { var stateFormat = new PropertiesDataFormat(new EphemeralDataProtectionProvider(NullLoggerFactory.Instance).CreateProtector("GoogleTest")); - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; @@ -933,6 +961,7 @@ namespace Microsoft.AspNetCore.Authentication.Google properties.Items.Add(correlationKey, correlationValue); properties.RedirectUri = "/me"; var state = stateFormat.Protect(properties); + using var server = host.GetTestServer(); var transaction = await server.SendAsync( "https://example.com/signin-google?code=TestCode&state=" + UrlEncoder.Default.Encode(state), $".AspNetCore.Correlation.Google.{correlationValue}=N"); @@ -952,7 +981,7 @@ namespace Microsoft.AspNetCore.Authentication.Google public async Task ChallengeFacebookWhenAlreadySignedWithGoogleSucceeds() { var stateFormat = new PropertiesDataFormat(new EphemeralDataProtectionProvider(NullLoggerFactory.Instance).CreateProtector("GoogleTest")); - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; @@ -969,6 +998,7 @@ namespace Microsoft.AspNetCore.Authentication.Google properties.Items.Add(correlationKey, correlationValue); properties.RedirectUri = "/me"; var state = stateFormat.Protect(properties); + using var server = host.GetTestServer(); var transaction = await server.SendAsync( "https://example.com/signin-google?code=TestCode&state=" + UrlEncoder.Default.Encode(state), $".AspNetCore.Correlation.Google.{correlationValue}=N"); @@ -1040,99 +1070,104 @@ namespace Microsoft.AspNetCore.Authentication.Google } } - private static TestServer CreateServer(Action<GoogleOptions> configureOptions, Func<HttpContext, Task> testpath = null) + private static async Task<IHost> CreateHost(Action<GoogleOptions> configureOptions, Func<HttpContext, Task> testpath = null) { - var builder = new WebHostBuilder() - .Configure(app => - { - app.UseAuthentication(); - app.Use(async (context, next) => - { - var req = context.Request; - var res = context.Response; - if (req.Path == new PathString("/challenge")) - { - await context.ChallengeAsync(); - } - else if (req.Path == new PathString("/challengeFacebook")) - { - await context.ChallengeAsync("Facebook"); - } - else if (req.Path == new PathString("/tokens")) - { - var result = await context.AuthenticateAsync(TestExtensions.CookieAuthenticationScheme); - var tokens = result.Properties.GetTokens(); - await res.DescribeAsync(tokens); - } - else if (req.Path == new PathString("/me")) - { - await res.DescribeAsync(context.User); - } - else if (req.Path == new PathString("/authenticate")) - { - var result = await context.AuthenticateAsync(TestExtensions.CookieAuthenticationScheme); - await res.DescribeAsync(result.Principal); - } - else if (req.Path == new PathString("/authenticateGoogle")) - { - var result = await context.AuthenticateAsync("Google"); - await res.DescribeAsync(result?.Principal); - } - else if (req.Path == new PathString("/authenticateFacebook")) - { - var result = await context.AuthenticateAsync("Facebook"); - await res.DescribeAsync(result?.Principal); - } - else if (req.Path == new PathString("/unauthorized")) - { - // Simulate Authorization failure - var result = await context.AuthenticateAsync("Google"); - await context.ChallengeAsync("Google"); - } - else if (req.Path == new PathString("/unauthorizedAuto")) + var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .Configure(app => { - var result = await context.AuthenticateAsync("Google"); - await context.ChallengeAsync("Google"); - } - else if (req.Path == new PathString("/401")) + app.UseAuthentication(); + app.Use(async (context, next) => + { + var req = context.Request; + var res = context.Response; + if (req.Path == new PathString("/challenge")) + { + await context.ChallengeAsync(); + } + else if (req.Path == new PathString("/challengeFacebook")) + { + await context.ChallengeAsync("Facebook"); + } + else if (req.Path == new PathString("/tokens")) + { + var result = await context.AuthenticateAsync(TestExtensions.CookieAuthenticationScheme); + var tokens = result.Properties.GetTokens(); + await res.DescribeAsync(tokens); + } + else if (req.Path == new PathString("/me")) + { + await res.DescribeAsync(context.User); + } + else if (req.Path == new PathString("/authenticate")) + { + var result = await context.AuthenticateAsync(TestExtensions.CookieAuthenticationScheme); + await res.DescribeAsync(result.Principal); + } + else if (req.Path == new PathString("/authenticateGoogle")) + { + var result = await context.AuthenticateAsync("Google"); + await res.DescribeAsync(result?.Principal); + } + else if (req.Path == new PathString("/authenticateFacebook")) + { + var result = await context.AuthenticateAsync("Facebook"); + await res.DescribeAsync(result?.Principal); + } + else if (req.Path == new PathString("/unauthorized")) + { + // Simulate Authorization failure + var result = await context.AuthenticateAsync("Google"); + await context.ChallengeAsync("Google"); + } + else if (req.Path == new PathString("/unauthorizedAuto")) + { + var result = await context.AuthenticateAsync("Google"); + await context.ChallengeAsync("Google"); + } + else if (req.Path == new PathString("/401")) + { + res.StatusCode = 401; + } + else if (req.Path == new PathString("/signIn")) + { + await Assert.ThrowsAsync<InvalidOperationException>(() => context.SignInAsync("Google", new ClaimsPrincipal())); + } + else if (req.Path == new PathString("/signOut")) + { + await Assert.ThrowsAsync<InvalidOperationException>(() => context.SignOutAsync("Google")); + } + else if (req.Path == new PathString("/forbid")) + { + await Assert.ThrowsAsync<InvalidOperationException>(() => context.ForbidAsync("Google")); + } + else if (testpath != null) + { + await testpath(context); + } + else + { + await next(); + } + }); + }) + .ConfigureServices(services => { - res.StatusCode = 401; - } - else if (req.Path == new PathString("/signIn")) - { - await Assert.ThrowsAsync<InvalidOperationException>(() => context.SignInAsync("Google", new ClaimsPrincipal())); - } - else if (req.Path == new PathString("/signOut")) - { - await Assert.ThrowsAsync<InvalidOperationException>(() => context.SignOutAsync("Google")); - } - else if (req.Path == new PathString("/forbid")) - { - await Assert.ThrowsAsync<InvalidOperationException>(() => context.ForbidAsync("Google")); - } - else if (testpath != null) - { - await testpath(context); - } - else - { - await next(); - } - }); - }) - .ConfigureServices(services => - { - services.AddTransient<IClaimsTransformation, ClaimsTransformer>(); - services.AddAuthentication(TestExtensions.CookieAuthenticationScheme) - .AddCookie(TestExtensions.CookieAuthenticationScheme, o => o.ForwardChallenge = GoogleDefaults.AuthenticationScheme) - .AddGoogle(configureOptions) - .AddFacebook(o => - { - o.ClientId = "Test ClientId"; - o.ClientSecret = "Test AppSecrent"; - }); - }); - return new TestServer(builder); + services.AddTransient<IClaimsTransformation, ClaimsTransformer>(); + services.AddAuthentication(TestExtensions.CookieAuthenticationScheme) + .AddCookie(TestExtensions.CookieAuthenticationScheme, o => o.ForwardChallenge = GoogleDefaults.AuthenticationScheme) + .AddGoogle(configureOptions) + .AddFacebook(o => + { + o.ClientId = "Test ClientId"; + o.ClientSecret = "Test AppSecrent"; + }); + })) + .Build(); + + await host.StartAsync(); + return host; } private class TestStateDataFormat : ISecureDataFormat<AuthenticationProperties> diff --git a/src/Security/Authentication/test/JwtBearerTests.cs b/src/Security/Authentication/test/JwtBearerTests.cs index 0fe3e85cd70..5075f1c888a 100755 --- a/src/Security/Authentication/test/JwtBearerTests.cs +++ b/src/Security/Authentication/test/JwtBearerTests.cs @@ -6,6 +6,7 @@ using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; using Microsoft.IdentityModel.Tokens; using System; using System.IdentityModel.Tokens.Jwt; @@ -61,7 +62,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer var tokenText = new JwtSecurityTokenHandler().WriteToken(token); - var server = CreateServer(o => + using var host = await CreateHost(o => { o.TokenValidationParameters = new TokenValidationParameters() { @@ -72,6 +73,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer }); var newBearerToken = "Bearer " + tokenText; + using var server = host.GetTestServer(); var response = await SendAsync(server, "http://example.com/oauth", newBearerToken); Assert.Equal(HttpStatusCode.OK, response.Response.StatusCode); } @@ -96,7 +98,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer var tokenText = new JwtSecurityTokenHandler().WriteToken(token); - var server = CreateServer(o => + using var host = await CreateHost(o => { o.SaveToken = true; o.TokenValidationParameters = new TokenValidationParameters() @@ -108,6 +110,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer }); var newBearerToken = "Bearer " + tokenText; + using var server = host.GetTestServer(); var response = await SendAsync(server, "http://example.com/token", newBearerToken); Assert.Equal(HttpStatusCode.OK, response.Response.StatusCode); Assert.Equal(tokenText, await response.Response.Content.ReadAsStringAsync()); @@ -116,7 +119,8 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer [Fact] public async Task SignInThrows() { - var server = CreateServer(); + using var host = await CreateHost(); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://example.com/signIn"); Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode); } @@ -124,7 +128,8 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer [Fact] public async Task SignOutThrows() { - var server = CreateServer(); + using var host = await CreateHost(); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://example.com/signOut"); Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode); } @@ -132,7 +137,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer [Fact] public async Task ThrowAtAuthenticationFailedEvent() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.Events = new JwtBearerEvents { @@ -164,6 +169,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer } }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://example.com/signIn"); Assert.Equal(HttpStatusCode.Unauthorized, transaction.Response.StatusCode); @@ -172,7 +178,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer [Fact] public async Task CustomHeaderReceived() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.Events = new JwtBearerEvents() { @@ -193,6 +199,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer }; }); + using var server = host.GetTestServer(); var response = await SendAsync(server, "http://example.com/oauth", "someHeader someblob"); Assert.Equal(HttpStatusCode.OK, response.Response.StatusCode); Assert.Equal("Bob le Magnifique", response.ResponseText); @@ -201,7 +208,8 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer [Fact] public async Task NoHeaderReceived() { - var server = CreateServer(); + using var host = await CreateHost(); + using var server = host.GetTestServer(); var response = await SendAsync(server, "http://example.com/oauth"); Assert.Equal(HttpStatusCode.Unauthorized, response.Response.StatusCode); } @@ -209,7 +217,8 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer [Fact] public async Task HeaderWithoutBearerReceived() { - var server = CreateServer(); + using var host = await CreateHost(); + using var server = host.GetTestServer(); var response = await SendAsync(server, "http://example.com/oauth", "Token"); Assert.Equal(HttpStatusCode.Unauthorized, response.Response.StatusCode); } @@ -217,7 +226,8 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer [Fact] public async Task UnrecognizedTokenReceived() { - var server = CreateServer(); + using var host = await CreateHost(); + using var server = host.GetTestServer(); var response = await SendAsync(server, "http://example.com/oauth", "Bearer someblob"); Assert.Equal(HttpStatusCode.Unauthorized, response.Response.StatusCode); Assert.Equal("", response.ResponseText); @@ -226,12 +236,13 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer [Fact] public async Task InvalidTokenReceived() { - var server = CreateServer(options => + using var host = await CreateHost(options => { options.SecurityTokenValidators.Clear(); options.SecurityTokenValidators.Add(new InvalidTokenValidator()); }); + using var server = host.GetTestServer(); var response = await SendAsync(server, "http://example.com/oauth", "Bearer someblob"); Assert.Equal(HttpStatusCode.Unauthorized, response.Response.StatusCode); Assert.Equal("Bearer error=\"invalid_token\"", response.Response.Headers.WwwAuthenticate.First().ToString()); @@ -249,12 +260,13 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer [InlineData(typeof(SecurityTokenSignatureKeyNotFoundException), "The signature key was not found")] public async Task ExceptionReportedInHeaderForAuthenticationFailures(Type errorType, string message) { - var server = CreateServer(options => + using var host = await CreateHost(options => { options.SecurityTokenValidators.Clear(); options.SecurityTokenValidators.Add(new InvalidTokenValidator(errorType)); }); + using var server = host.GetTestServer(); var response = await SendAsync(server, "http://example.com/oauth", "Bearer someblob"); Assert.Equal(HttpStatusCode.Unauthorized, response.Response.StatusCode); Assert.Equal($"Bearer error=\"invalid_token\", error_description=\"{message}\"", response.Response.Headers.WwwAuthenticate.First().ToString()); @@ -269,12 +281,13 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer [InlineData(typeof(SecurityTokenExpiredException), "The token expired at '02/20/2000 00:00:00'")] public async Task ExceptionReportedInHeaderWithDetailsForAuthenticationFailures(Type errorType, string message) { - var server = CreateServer(options => + using var host = await CreateHost(options => { options.SecurityTokenValidators.Clear(); options.SecurityTokenValidators.Add(new DetailedInvalidTokenValidator(errorType)); }); + using var server = host.GetTestServer(); var response = await SendAsync(server, "http://example.com/oauth", "Bearer someblob"); Assert.Equal(HttpStatusCode.Unauthorized, response.Response.StatusCode); Assert.Equal($"Bearer error=\"invalid_token\", error_description=\"{message}\"", response.Response.Headers.WwwAuthenticate.First().ToString()); @@ -285,12 +298,13 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer [InlineData(typeof(ArgumentException))] public async Task ExceptionNotReportedInHeaderForOtherFailures(Type errorType) { - var server = CreateServer(options => + using var host = await CreateHost(options => { options.SecurityTokenValidators.Clear(); options.SecurityTokenValidators.Add(new InvalidTokenValidator(errorType)); }); + using var server = host.GetTestServer(); var response = await SendAsync(server, "http://example.com/oauth", "Bearer someblob"); Assert.Equal(HttpStatusCode.Unauthorized, response.Response.StatusCode); Assert.Equal("Bearer error=\"invalid_token\"", response.Response.Headers.WwwAuthenticate.First().ToString()); @@ -300,13 +314,14 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer [Fact] public async Task ExceptionsReportedInHeaderForMultipleAuthenticationFailures() { - var server = CreateServer(options => + using var host = await CreateHost(options => { options.SecurityTokenValidators.Clear(); options.SecurityTokenValidators.Add(new InvalidTokenValidator(typeof(SecurityTokenInvalidAudienceException))); options.SecurityTokenValidators.Add(new InvalidTokenValidator(typeof(SecurityTokenSignatureKeyNotFoundException))); }); + using var server = host.GetTestServer(); var response = await SendAsync(server, "http://example.com/oauth", "Bearer someblob"); Assert.Equal(HttpStatusCode.Unauthorized, response.Response.StatusCode); Assert.Equal("Bearer error=\"invalid_token\", error_description=\"The audience '(null)' is invalid; The signature key was not found\"", @@ -323,7 +338,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer [InlineData(null, null, "custom_uri")] public async Task ExceptionsReportedInHeaderExposesUserDefinedError(string error, string description, string uri) { - var server = CreateServer(options => + using var host = await CreateHost(options => { options.Events = new JwtBearerEvents { @@ -338,6 +353,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer }; }); + using var server = host.GetTestServer(); var response = await SendAsync(server, "http://example.com/oauth", "Bearer someblob"); Assert.Equal(HttpStatusCode.Unauthorized, response.Response.StatusCode); Assert.Equal("", response.ResponseText); @@ -380,11 +396,12 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer [Fact] public async Task ExceptionNotReportedInHeaderWhenIncludeErrorDetailsIsFalse() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.IncludeErrorDetails = false; }); + using var server = host.GetTestServer(); var response = await SendAsync(server, "http://example.com/oauth", "Bearer someblob"); Assert.Equal(HttpStatusCode.Unauthorized, response.Response.StatusCode); Assert.Equal("Bearer", response.Response.Headers.WwwAuthenticate.First().ToString()); @@ -394,8 +411,9 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer [Fact] public async Task ExceptionNotReportedInHeaderWhenTokenWasMissing() { - var server = CreateServer(); + using var host = await CreateHost(); + using var server = host.GetTestServer(); var response = await SendAsync(server, "http://example.com/oauth"); Assert.Equal(HttpStatusCode.Unauthorized, response.Response.StatusCode); Assert.Equal("Bearer", response.Response.Headers.WwwAuthenticate.First().ToString()); @@ -405,7 +423,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer [Fact] public async Task CustomTokenValidated() { - var server = CreateServer(options => + using var host = await CreateHost(options => { options.Events = new JwtBearerEvents() { @@ -432,6 +450,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer options.SecurityTokenValidators.Add(new BlobTokenValidator(JwtBearerDefaults.AuthenticationScheme)); }); + using var server = host.GetTestServer(); var response = await SendAsync(server, "http://example.com/oauth", "Bearer someblob"); Assert.Equal(HttpStatusCode.OK, response.Response.StatusCode); Assert.Equal("Bob le Magnifique", response.ResponseText); @@ -440,7 +459,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer [Fact] public async Task RetrievingTokenFromAlternateLocation() { - var server = CreateServer(options => + using var host = await CreateHost(options => { options.Events = new JwtBearerEvents() { @@ -457,6 +476,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer })); }); + using var server = host.GetTestServer(); var response = await SendAsync(server, "http://example.com/oauth", "Bearer Token"); Assert.Equal(HttpStatusCode.OK, response.Response.StatusCode); Assert.Equal("Bob le Tout Puissant", response.ResponseText); @@ -465,7 +485,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer [Fact] public async Task EventOnMessageReceivedSkip_NoMoreEventsExecuted() { - var server = CreateServer(options => + using var host = await CreateHost(options => { options.Events = new JwtBearerEvents() { @@ -489,6 +509,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer }; }); + using var server = host.GetTestServer(); var response = await SendAsync(server, "http://example.com/checkforerrors", "Bearer Token"); Assert.Equal(HttpStatusCode.OK, response.Response.StatusCode); Assert.Equal(string.Empty, response.ResponseText); @@ -497,7 +518,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer [Fact] public async Task EventOnMessageReceivedReject_NoMoreEventsExecuted() { - var server = CreateServer(options => + using var host = await CreateHost(options => { options.Events = new JwtBearerEvents() { @@ -522,6 +543,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer }; }); + using var server = host.GetTestServer(); var exception = await Assert.ThrowsAsync<Exception>(delegate { return SendAsync(server, "http://example.com/checkforerrors", "Bearer Token"); @@ -533,7 +555,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer [Fact] public async Task EventOnTokenValidatedSkip_NoMoreEventsExecuted() { - var server = CreateServer(options => + using var host = await CreateHost(options => { options.Events = new JwtBearerEvents() { @@ -555,6 +577,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer options.SecurityTokenValidators.Add(new BlobTokenValidator("JWT")); }); + using var server = host.GetTestServer(); var response = await SendAsync(server, "http://example.com/checkforerrors", "Bearer Token"); Assert.Equal(HttpStatusCode.OK, response.Response.StatusCode); Assert.Equal(string.Empty, response.ResponseText); @@ -563,7 +586,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer [Fact] public async Task EventOnTokenValidatedReject_NoMoreEventsExecuted() { - var server = CreateServer(options => + using var host = await CreateHost(options => { options.Events = new JwtBearerEvents() { @@ -586,6 +609,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer options.SecurityTokenValidators.Add(new BlobTokenValidator("JWT")); }); + using var server = host.GetTestServer(); var exception = await Assert.ThrowsAsync<Exception>(delegate { return SendAsync(server, "http://example.com/checkforerrors", "Bearer Token"); @@ -597,7 +621,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer [Fact] public async Task EventOnAuthenticationFailedSkip_NoMoreEventsExecuted() { - var server = CreateServer(options => + using var host = await CreateHost(options => { options.Events = new JwtBearerEvents() { @@ -619,6 +643,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer options.SecurityTokenValidators.Add(new BlobTokenValidator("JWT")); }); + using var server = host.GetTestServer(); var response = await SendAsync(server, "http://example.com/checkforerrors", "Bearer Token"); Assert.Equal(HttpStatusCode.OK, response.Response.StatusCode); Assert.Equal(string.Empty, response.ResponseText); @@ -627,7 +652,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer [Fact] public async Task EventOnAuthenticationFailedReject_NoMoreEventsExecuted() { - var server = CreateServer(options => + using var host = await CreateHost(options => { options.Events = new JwtBearerEvents() { @@ -650,6 +675,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer options.SecurityTokenValidators.Add(new BlobTokenValidator("JWT")); }); + using var server = host.GetTestServer(); var exception = await Assert.ThrowsAsync<Exception>(delegate { return SendAsync(server, "http://example.com/checkforerrors", "Bearer Token"); @@ -661,7 +687,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer [Fact] public async Task EventOnChallengeSkip_ResponseNotModified() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.Events = new JwtBearerEvents() { @@ -673,6 +699,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer }; }); + using var server = host.GetTestServer(); var response = await SendAsync(server, "http://example.com/unauthorized", "Bearer Token"); Assert.Equal(HttpStatusCode.OK, response.Response.StatusCode); Assert.Empty(response.Response.Headers.WwwAuthenticate); @@ -684,7 +711,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer { var tokenData = CreateStandardTokenAndKey(); - var server = CreateServer(o => + using var host = await CreateHost(o => { o.TokenValidationParameters = new TokenValidationParameters() { @@ -694,6 +721,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer }; }); var newBearerToken = "Bearer " + tokenData.tokenText; + using var server = host.GetTestServer(); var response = await SendAsync(server, "http://example.com/forbidden", newBearerToken); Assert.Equal(HttpStatusCode.Forbidden, response.Response.StatusCode); } @@ -702,7 +730,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer public async Task EventOnForbiddenSkip_ResponseNotModified() { var tokenData = CreateStandardTokenAndKey(); - var server = CreateServer(o => + using var host = await CreateHost(o => { o.TokenValidationParameters = new TokenValidationParameters() { @@ -719,6 +747,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer }; }); var newBearerToken = "Bearer " + tokenData.tokenText; + using var server = host.GetTestServer(); var response = await SendAsync(server, "http://example.com/forbidden", newBearerToken); Assert.Equal(HttpStatusCode.Forbidden, response.Response.StatusCode); } @@ -727,7 +756,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer public async Task EventOnForbidden_ResponseModified() { var tokenData = CreateStandardTokenAndKey(); - var server = CreateServer(o => + using var host = await CreateHost(o => { o.TokenValidationParameters = new TokenValidationParameters() { @@ -745,6 +774,7 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer }; }); var newBearerToken = "Bearer " + tokenData.tokenText; + using var server = host.GetTestServer(); var response = await SendAsync(server, "http://example.com/forbidden", newBearerToken); Assert.Equal(418, (int)response.Response.StatusCode); Assert.Equal("You Shall Not Pass", await response.Response.Content.ReadAsStringAsync()); @@ -896,82 +926,86 @@ namespace Microsoft.AspNetCore.Authentication.JwtBearer } } - private static TestServer CreateServer(Action<JwtBearerOptions> options = null, Func<HttpContext, Func<Task>, Task> handlerBeforeAuth = null) + private static async Task<IHost> CreateHost(Action<JwtBearerOptions> options = null, Func<HttpContext, Func<Task>, Task> handlerBeforeAuth = null) { - var builder = new WebHostBuilder() - .Configure(app => - { - if (handlerBeforeAuth != null) - { - app.Use(handlerBeforeAuth); - } - - app.UseAuthentication(); - app.Use(async (context, next) => - { - if (context.Request.Path == new PathString("/checkforerrors")) + var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .Configure(app => { - var result = await context.AuthenticateAsync(JwtBearerDefaults.AuthenticationScheme); // this used to be "Automatic" - if (result.Failure != null) + if (handlerBeforeAuth != null) { - throw new Exception("Failed to authenticate", result.Failure); - } - return; - } - else if (context.Request.Path == new PathString("/oauth")) - { - if (context.User == null || - context.User.Identity == null || - !context.User.Identity.IsAuthenticated) - { - context.Response.StatusCode = 401; - // REVIEW: no more automatic challenge - await context.ChallengeAsync(JwtBearerDefaults.AuthenticationScheme); - return; + app.Use(handlerBeforeAuth); } - var identifier = context.User.FindFirst(ClaimTypes.NameIdentifier); - if (identifier == null) + app.UseAuthentication(); + app.Use(async (context, next) => { - context.Response.StatusCode = 500; - return; - } - - await context.Response.WriteAsync(identifier.Value); - } - else if (context.Request.Path == new PathString("/token")) - { - var token = await context.GetTokenAsync("access_token"); - await context.Response.WriteAsync(token); - } - else if (context.Request.Path == new PathString("/unauthorized")) - { - // Simulate Authorization failure - var result = await context.AuthenticateAsync(JwtBearerDefaults.AuthenticationScheme); - await context.ChallengeAsync(JwtBearerDefaults.AuthenticationScheme); - } - else if (context.Request.Path == new PathString("/forbidden")) - { - // Simulate Forbidden - await context.ForbidAsync(JwtBearerDefaults.AuthenticationScheme); - } - else if (context.Request.Path == new PathString("/signIn")) - { - await Assert.ThrowsAsync<InvalidOperationException>(() => context.SignInAsync(JwtBearerDefaults.AuthenticationScheme, new ClaimsPrincipal())); - } - else if (context.Request.Path == new PathString("/signOut")) - { - await Assert.ThrowsAsync<InvalidOperationException>(() => context.SignOutAsync(JwtBearerDefaults.AuthenticationScheme)); - } - else - { - await next(); - } - }); - }) - .ConfigureServices(services => services.AddAuthentication(JwtBearerDefaults.AuthenticationScheme).AddJwtBearer(options)); - - return new TestServer(builder); + if (context.Request.Path == new PathString("/checkforerrors")) + { + var result = await context.AuthenticateAsync(JwtBearerDefaults.AuthenticationScheme); // this used to be "Automatic" + if (result.Failure != null) + { + throw new Exception("Failed to authenticate", result.Failure); + } + return; + } + else if (context.Request.Path == new PathString("/oauth")) + { + if (context.User == null || + context.User.Identity == null || + !context.User.Identity.IsAuthenticated) + { + context.Response.StatusCode = 401; + // REVIEW: no more automatic challenge + await context.ChallengeAsync(JwtBearerDefaults.AuthenticationScheme); + return; + } + + var identifier = context.User.FindFirst(ClaimTypes.NameIdentifier); + if (identifier == null) + { + context.Response.StatusCode = 500; + return; + } + + await context.Response.WriteAsync(identifier.Value); + } + else if (context.Request.Path == new PathString("/token")) + { + var token = await context.GetTokenAsync("access_token"); + await context.Response.WriteAsync(token); + } + else if (context.Request.Path == new PathString("/unauthorized")) + { + // Simulate Authorization failure + var result = await context.AuthenticateAsync(JwtBearerDefaults.AuthenticationScheme); + await context.ChallengeAsync(JwtBearerDefaults.AuthenticationScheme); + } + else if (context.Request.Path == new PathString("/forbidden")) + { + // Simulate Forbidden + await context.ForbidAsync(JwtBearerDefaults.AuthenticationScheme); + } + else if (context.Request.Path == new PathString("/signIn")) + { + await Assert.ThrowsAsync<InvalidOperationException>(() => context.SignInAsync(JwtBearerDefaults.AuthenticationScheme, new ClaimsPrincipal())); + } + else if (context.Request.Path == new PathString("/signOut")) + { + await Assert.ThrowsAsync<InvalidOperationException>(() => context.SignOutAsync(JwtBearerDefaults.AuthenticationScheme)); + } + else + { + await next(); + } + }); + }) + .ConfigureServices(services => services.AddAuthentication(JwtBearerDefaults.AuthenticationScheme).AddJwtBearer(options))) + .Build(); + + await host.StartAsync(); + return host; } // TODO: see if we can share the TestExtensions SendAsync method (only diff is auth header) diff --git a/src/Security/Authentication/test/MicrosoftAccountTests.cs b/src/Security/Authentication/test/MicrosoftAccountTests.cs index 21105cedd9d..8797357b461 100644 --- a/src/Security/Authentication/test/MicrosoftAccountTests.cs +++ b/src/Security/Authentication/test/MicrosoftAccountTests.cs @@ -9,6 +9,7 @@ using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.AspNetCore.WebUtilities; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Logging.Abstractions; using System; using System.Linq; @@ -48,7 +49,7 @@ namespace Microsoft.AspNetCore.Authentication.Tests.MicrosoftAccount [Fact] public async Task ChallengeWillTriggerApplyRedirectEvent() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Client Id"; o.ClientSecret = "Test Client Secret"; @@ -61,6 +62,8 @@ namespace Microsoft.AspNetCore.Authentication.Tests.MicrosoftAccount } }; }); + using var server = host.GetTestServer(); + var transaction = await server.SendAsync("http://example.com/challenge"); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); var query = transaction.Response.Headers.Location.Query; @@ -70,11 +73,12 @@ namespace Microsoft.AspNetCore.Authentication.Tests.MicrosoftAccount [Fact] public async Task SignInThrows() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://example.com/signIn"); Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode); } @@ -82,11 +86,12 @@ namespace Microsoft.AspNetCore.Authentication.Tests.MicrosoftAccount [Fact] public async Task SignOutThrows() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://example.com/signOut"); Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode); } @@ -94,11 +99,12 @@ namespace Microsoft.AspNetCore.Authentication.Tests.MicrosoftAccount [Fact] public async Task ForbidThrows() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://example.com/signOut"); Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode); } @@ -106,11 +112,12 @@ namespace Microsoft.AspNetCore.Authentication.Tests.MicrosoftAccount [Fact] public async Task ChallengeWillTriggerRedirection() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com/challenge"); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); var location = transaction.Response.Headers.Location.AbsoluteUri; @@ -127,7 +134,7 @@ namespace Microsoft.AspNetCore.Authentication.Tests.MicrosoftAccount [Fact] public async Task ChallengeWillIncludeScopeAsConfigured() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; @@ -135,6 +142,7 @@ namespace Microsoft.AspNetCore.Authentication.Tests.MicrosoftAccount o.Scope.Add("foo"); o.Scope.Add("bar"); }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com/challenge"); var res = transaction.Response; Assert.Equal(HttpStatusCode.Redirect, res.StatusCode); @@ -144,7 +152,7 @@ namespace Microsoft.AspNetCore.Authentication.Tests.MicrosoftAccount [Fact] public async Task ChallengeWillIncludeScopeAsOverwritten() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; @@ -152,6 +160,7 @@ namespace Microsoft.AspNetCore.Authentication.Tests.MicrosoftAccount o.Scope.Add("foo"); o.Scope.Add("bar"); }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com/challengeWithOtherScope"); var res = transaction.Response; Assert.Equal(HttpStatusCode.Redirect, res.StatusCode); @@ -161,7 +170,7 @@ namespace Microsoft.AspNetCore.Authentication.Tests.MicrosoftAccount [Fact] public async Task ChallengeWillIncludeScopeAsOverwrittenWithBaseAuthenticationProperties() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; @@ -169,6 +178,7 @@ namespace Microsoft.AspNetCore.Authentication.Tests.MicrosoftAccount o.Scope.Add("foo"); o.Scope.Add("bar"); }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com/challengeWithOtherScopeWithBaseAuthenticationProperties"); var res = transaction.Response; Assert.Equal(HttpStatusCode.Redirect, res.StatusCode); @@ -179,7 +189,7 @@ namespace Microsoft.AspNetCore.Authentication.Tests.MicrosoftAccount public async Task AuthenticatedEventCanGetRefreshToken() { var stateFormat = new PropertiesDataFormat(new EphemeralDataProtectionProvider(NullLoggerFactory.Instance).CreateProtector("MsftTest")); - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Client Id"; o.ClientSecret = "Test Client Secret"; @@ -229,6 +239,7 @@ namespace Microsoft.AspNetCore.Authentication.Tests.MicrosoftAccount properties.Items.Add(correlationKey, correlationValue); properties.RedirectUri = "/me"; var state = stateFormat.Protect(properties); + using var server = host.GetTestServer(); var transaction = await server.SendAsync( "https://example.com/signin-microsoft?code=TestCode&state=" + UrlEncoder.Default.Encode(state), $".AspNetCore.Correlation.Microsoft.{correlationValue}=N"); @@ -248,12 +259,13 @@ namespace Microsoft.AspNetCore.Authentication.Tests.MicrosoftAccount public async Task ChallengeWillUseAuthenticationPropertiesParametersAsQueryArguments() { var stateFormat = new PropertiesDataFormat(new EphemeralDataProtectionProvider(NullLoggerFactory.Instance).CreateProtector("MicrosoftTest")); - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Id"; o.ClientSecret = "Test Secret"; o.StateDataFormat = stateFormat; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://example.com/challenge"); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); @@ -277,7 +289,7 @@ namespace Microsoft.AspNetCore.Authentication.Tests.MicrosoftAccount [Fact] public async Task PkceSentToTokenEndpoint() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ClientId = "Test Client Id"; o.ClientSecret = "Test Client Secret"; @@ -320,6 +332,7 @@ namespace Microsoft.AspNetCore.Authentication.Tests.MicrosoftAccount } }; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://example.com/challenge"); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); var locationUri = transaction.Response.Headers.Location; @@ -342,68 +355,72 @@ namespace Microsoft.AspNetCore.Authentication.Tests.MicrosoftAccount Assert.StartsWith(".AspNetCore." + TestExtensions.CookieAuthenticationScheme, transaction.SetCookie[1]); } - private static TestServer CreateServer(Action<MicrosoftAccountOptions> configureOptions) + private static async Task<IHost> CreateHost(Action<MicrosoftAccountOptions> configureOptions) { - var builder = new WebHostBuilder() - .Configure(app => - { - app.UseAuthentication(); - app.Use(async (context, next) => - { - var req = context.Request; - var res = context.Response; - if (req.Path == new PathString("/challenge")) + var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .Configure(app => { - await context.ChallengeAsync("Microsoft", new MicrosoftChallengeProperties + app.UseAuthentication(); + app.Use(async (context, next) => { - Prompt = "select_account", - LoginHint = "username", - DomainHint = "consumers", - ResponseMode = "query", - RedirectUri = "/me" + var req = context.Request; + var res = context.Response; + if (req.Path == new PathString("/challenge")) + { + await context.ChallengeAsync("Microsoft", new MicrosoftChallengeProperties + { + Prompt = "select_account", + LoginHint = "username", + DomainHint = "consumers", + ResponseMode = "query", + RedirectUri = "/me" + }); + } + else if (req.Path == new PathString("/challengeWithOtherScope")) + { + var properties = new OAuthChallengeProperties(); + properties.SetScope("baz", "qux"); + await context.ChallengeAsync("Microsoft", properties); + } + else if (req.Path == new PathString("/challengeWithOtherScopeWithBaseAuthenticationProperties")) + { + var properties = new AuthenticationProperties(); + properties.SetParameter(OAuthChallengeProperties.ScopeKey, new string[] { "baz", "qux" }); + await context.ChallengeAsync("Microsoft", properties); + } + else if (req.Path == new PathString("/me")) + { + await res.DescribeAsync(context.User); + } + else if (req.Path == new PathString("/signIn")) + { + await Assert.ThrowsAsync<InvalidOperationException>(() => context.SignInAsync("Microsoft", new ClaimsPrincipal())); + } + else if (req.Path == new PathString("/signOut")) + { + await Assert.ThrowsAsync<InvalidOperationException>(() => context.SignOutAsync("Microsoft")); + } + else if (req.Path == new PathString("/forbid")) + { + await Assert.ThrowsAsync<InvalidOperationException>(() => context.ForbidAsync("Microsoft")); + } + else + { + await next(); + } }); - } - else if (req.Path == new PathString("/challengeWithOtherScope")) - { - var properties = new OAuthChallengeProperties(); - properties.SetScope("baz", "qux"); - await context.ChallengeAsync("Microsoft", properties); - } - else if (req.Path == new PathString("/challengeWithOtherScopeWithBaseAuthenticationProperties")) - { - var properties = new AuthenticationProperties(); - properties.SetParameter(OAuthChallengeProperties.ScopeKey, new string[] { "baz", "qux" }); - await context.ChallengeAsync("Microsoft", properties); - } - else if (req.Path == new PathString("/me")) - { - await res.DescribeAsync(context.User); - } - else if (req.Path == new PathString("/signIn")) - { - await Assert.ThrowsAsync<InvalidOperationException>(() => context.SignInAsync("Microsoft", new ClaimsPrincipal())); - } - else if (req.Path == new PathString("/signOut")) - { - await Assert.ThrowsAsync<InvalidOperationException>(() => context.SignOutAsync("Microsoft")); - } - else if (req.Path == new PathString("/forbid")) + }) + .ConfigureServices(services => { - await Assert.ThrowsAsync<InvalidOperationException>(() => context.ForbidAsync("Microsoft")); - } - else - { - await next(); - } - }); - }) - .ConfigureServices(services => - { - services.AddAuthentication(TestExtensions.CookieAuthenticationScheme) - .AddCookie(TestExtensions.CookieAuthenticationScheme, o => { }) - .AddMicrosoftAccount(configureOptions); - }); - return new TestServer(builder); + services.AddAuthentication(TestExtensions.CookieAuthenticationScheme) + .AddCookie(TestExtensions.CookieAuthenticationScheme, o => { }) + .AddMicrosoftAccount(configureOptions); + })) + .Build(); + await host.StartAsync(); + return host; } private static HttpResponseMessage ReturnJsonResponse(object content) diff --git a/src/Security/Authentication/test/OAuthTests.cs b/src/Security/Authentication/test/OAuthTests.cs index 1a8e6bb4f2b..5b3ab2eef95 100644 --- a/src/Security/Authentication/test/OAuthTests.cs +++ b/src/Security/Authentication/test/OAuthTests.cs @@ -7,6 +7,7 @@ using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; using System; using System.Collections.Generic; using System.Net; @@ -34,7 +35,7 @@ namespace Microsoft.AspNetCore.Authentication.OAuth [Fact] public async Task ThrowsIfClientIdMissing() { - var server = CreateServer( + using var host = await CreateHost( services => services.AddAuthentication().AddOAuth("weeblie", o => { o.SignInScheme = "whatever"; @@ -43,13 +44,14 @@ namespace Microsoft.AspNetCore.Authentication.OAuth o.TokenEndpoint = "/"; o.AuthorizationEndpoint = "/"; })); + using var server = host.GetTestServer(); await Assert.ThrowsAsync<ArgumentException>("ClientId", () => server.SendAsync("http://example.com/")); } [Fact] public async Task ThrowsIfClientSecretMissing() { - var server = CreateServer( + using var host = await CreateHost( services => services.AddAuthentication().AddOAuth("weeblie", o => { o.SignInScheme = "whatever"; @@ -58,13 +60,14 @@ namespace Microsoft.AspNetCore.Authentication.OAuth o.TokenEndpoint = "/"; o.AuthorizationEndpoint = "/"; })); + using var server = host.GetTestServer(); await Assert.ThrowsAsync<ArgumentException>("ClientSecret", () => server.SendAsync("http://example.com/")); } [Fact] public async Task ThrowsIfCallbackPathMissing() { - var server = CreateServer( + using var host = await CreateHost( services => services.AddAuthentication().AddOAuth("weeblie", o => { o.ClientId = "Whatever;"; @@ -73,13 +76,14 @@ namespace Microsoft.AspNetCore.Authentication.OAuth o.AuthorizationEndpoint = "/"; o.SignInScheme = "eh"; })); + using var server = host.GetTestServer(); await Assert.ThrowsAsync<ArgumentException>("CallbackPath", () => server.SendAsync("http://example.com/")); } [Fact] public async Task ThrowsIfTokenEndpointMissing() { - var server = CreateServer( + using var host = await CreateHost( services => services.AddAuthentication().AddOAuth("weeblie", o => { o.ClientId = "Whatever;"; @@ -88,13 +92,14 @@ namespace Microsoft.AspNetCore.Authentication.OAuth o.AuthorizationEndpoint = "/"; o.SignInScheme = "eh"; })); + using var server = host.GetTestServer(); await Assert.ThrowsAsync<ArgumentException>("TokenEndpoint", () => server.SendAsync("http://example.com/")); } [Fact] public async Task ThrowsIfAuthorizationEndpointMissing() { - var server = CreateServer( + using var host = await CreateHost( services => services.AddAuthentication().AddOAuth("weeblie", o => { o.ClientId = "Whatever;"; @@ -103,13 +108,14 @@ namespace Microsoft.AspNetCore.Authentication.OAuth o.TokenEndpoint = "/"; o.SignInScheme = "eh"; })); + using var server = host.GetTestServer(); await Assert.ThrowsAsync<ArgumentException>("AuthorizationEndpoint", () => server.SendAsync("http://example.com/")); } [Fact] public async Task RedirectToIdentityProvider_SetsCorrelationIdCookiePath_ToCallBackPath() { - var server = CreateServer( + using var host = await CreateHost( s => s.AddAuthentication().AddOAuth( "Weblie", opt => @@ -122,6 +128,7 @@ namespace Microsoft.AspNetCore.Authentication.OAuth return true; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://www.example.com/challenge"); var res = transaction.Response; @@ -135,7 +142,7 @@ namespace Microsoft.AspNetCore.Authentication.OAuth [Fact] public async Task RedirectToAuthorizeEndpoint_CorrelationIdCookieOptions_CanBeOverriden() { - var server = CreateServer( + using var host = await CreateHost( s => s.AddAuthentication().AddOAuth( "Weblie", opt => @@ -149,6 +156,7 @@ namespace Microsoft.AspNetCore.Authentication.OAuth return true; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://www.example.com/challenge"); var res = transaction.Response; @@ -162,7 +170,7 @@ namespace Microsoft.AspNetCore.Authentication.OAuth [Fact] public async Task RedirectToAuthorizeEndpoint_HasScopeAsConfigured() { - var server = CreateServer( + using var host = await CreateHost( s => s.AddAuthentication().AddOAuth( "Weblie", opt => @@ -178,6 +186,7 @@ namespace Microsoft.AspNetCore.Authentication.OAuth return true; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://www.example.com/challenge"); var res = transaction.Response; @@ -188,7 +197,7 @@ namespace Microsoft.AspNetCore.Authentication.OAuth [Fact] public async Task RedirectToAuthorizeEndpoint_HasScopeAsOverwritten() { - var server = CreateServer( + using var host = await CreateHost( s => s.AddAuthentication().AddOAuth( "Weblie", opt => @@ -206,6 +215,7 @@ namespace Microsoft.AspNetCore.Authentication.OAuth return true; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://www.example.com/challenge"); var res = transaction.Response; @@ -216,7 +226,7 @@ namespace Microsoft.AspNetCore.Authentication.OAuth [Fact] public async Task RedirectToAuthorizeEndpoint_HasScopeAsOverwrittenWithBaseAuthenticationProperties() { - var server = CreateServer( + using var host = await CreateHost( s => s.AddAuthentication().AddOAuth( "Weblie", opt => @@ -234,6 +244,7 @@ namespace Microsoft.AspNetCore.Authentication.OAuth return true; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://www.example.com/challenge"); var res = transaction.Response; @@ -254,7 +265,7 @@ namespace Microsoft.AspNetCore.Authentication.OAuth [Fact] public async Task HandleRequestAsync_RedirectsToAccessDeniedPathWhenExplicitlySet() { - var server = CreateServer( + using var host = await CreateHost( s => s.AddAuthentication().AddOAuth( "Weblie", opt => @@ -270,6 +281,7 @@ namespace Microsoft.AspNetCore.Authentication.OAuth opt.Events.OnRemoteFailure = context => throw new InvalidOperationException("This event should not be called."); })); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://www.example.com/oauth-callback?error=access_denied&state=protected_state", ".AspNetCore.Correlation.Weblie.correlationId=N"); @@ -280,7 +292,7 @@ namespace Microsoft.AspNetCore.Authentication.OAuth [Fact] public async Task HandleRequestAsync_InvokesAccessDeniedEvent() { - var server = CreateServer( + using var host = await CreateHost( s => s.AddAuthentication().AddOAuth( "Weblie", opt => @@ -304,6 +316,7 @@ namespace Microsoft.AspNetCore.Authentication.OAuth }; })); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://www.example.com/oauth-callback?error=access_denied&state=protected_state", ".AspNetCore.Correlation.Weblie.correlationId=N"); @@ -314,7 +327,7 @@ namespace Microsoft.AspNetCore.Authentication.OAuth [Fact] public async Task HandleRequestAsync_InvokesRemoteFailureEventWhenAccessDeniedPathIsNotExplicitlySet() { - var server = CreateServer( + using var host = await CreateHost( s => s.AddAuthentication().AddOAuth( "Weblie", opt => @@ -339,6 +352,7 @@ namespace Microsoft.AspNetCore.Authentication.OAuth }; })); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://www.example.com/oauth-callback?error=access_denied&state=protected_state", ".AspNetCore.Correlation.Weblie.correlationId=N"); @@ -349,7 +363,7 @@ namespace Microsoft.AspNetCore.Authentication.OAuth [Fact] public async Task RemoteAuthenticationFailed_OAuthError_IncludesProperties() { - var server = CreateServer( + using var host = await CreateHost( s => s.AddAuthentication().AddOAuth( "Weblie", opt => @@ -374,6 +388,7 @@ namespace Microsoft.AspNetCore.Authentication.OAuth }; })); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://www.example.com/oauth-callback?error=custom_error&state=protected_state", ".AspNetCore.Correlation.Weblie.correlationId=N"); @@ -381,22 +396,26 @@ namespace Microsoft.AspNetCore.Authentication.OAuth Assert.Null(transaction.Response.Headers.Location); } - private static TestServer CreateServer(Action<IServiceCollection> configureServices, Func<HttpContext, Task<bool>> handler = null) + private static async Task<IHost> CreateHost(Action<IServiceCollection> configureServices, Func<HttpContext, Task<bool>> handler = null) { - var builder = new WebHostBuilder() - .Configure(app => - { - app.UseAuthentication(); - app.Use(async (context, next) => - { - if (handler == null || ! await handler(context)) + var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .Configure(app => { - await next(); - } - }); - }) - .ConfigureServices(configureServices); - return new TestServer(builder); + app.UseAuthentication(); + app.Use(async (context, next) => + { + if (handler == null || ! await handler(context)) + { + await next(); + } + }); + }) + .ConfigureServices(configureServices)) + .Build(); + await host.StartAsync(); + return host; } private class TestStateDataFormat : ISecureDataFormat<AuthenticationProperties> diff --git a/src/Security/Authentication/test/OpenIdConnect/OpenIdConnectConfigurationTests.cs b/src/Security/Authentication/test/OpenIdConnect/OpenIdConnectConfigurationTests.cs index ed368c1ef7f..af159488743 100644 --- a/src/Security/Authentication/test/OpenIdConnect/OpenIdConnectConfigurationTests.cs +++ b/src/Security/Authentication/test/OpenIdConnect/OpenIdConnectConfigurationTests.cs @@ -12,6 +12,7 @@ using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; using Xunit; namespace Microsoft.AspNetCore.Authentication.Test.OpenIdConnect @@ -550,16 +551,19 @@ namespace Microsoft.AspNetCore.Authentication.Test.OpenIdConnect private TestServer BuildTestServer(Action<OpenIdConnectOptions> options) { - var builder = new WebHostBuilder() - .ConfigureServices(services => - { - services.AddAuthentication() - .AddCookie() - .AddOpenIdConnect(options); - }) - .Configure(app => app.UseAuthentication()); - - return new TestServer(builder); + var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .ConfigureServices(services => + { + services.AddAuthentication() + .AddCookie() + .AddOpenIdConnect(options); + }) + .Configure(app => app.UseAuthentication())) + .Build(); + host.Start(); + return host.GetTestServer(); } private async Task TestConfigurationException<T>( diff --git a/src/Security/Authentication/test/OpenIdConnect/OpenIdConnectEventTests.cs b/src/Security/Authentication/test/OpenIdConnect/OpenIdConnectEventTests.cs index 090bf3dec4c..691bd9f7301 100644 --- a/src/Security/Authentication/test/OpenIdConnect/OpenIdConnectEventTests.cs +++ b/src/Security/Authentication/test/OpenIdConnect/OpenIdConnectEventTests.cs @@ -1,4 +1,4 @@ -// Copyright (c) .NET Foundation. All rights reserved. +// Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; @@ -18,6 +18,7 @@ using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; using Microsoft.Extensions.Primitives; using Microsoft.IdentityModel.Protocols.OpenIdConnect; using Microsoft.IdentityModel.Tokens; @@ -1266,39 +1267,43 @@ namespace Microsoft.AspNetCore.Authentication.Test.OpenIdConnect private TestServer CreateServer(OpenIdConnectEvents events, RequestDelegate appCode) { - var builder = new WebHostBuilder() - .ConfigureServices(services => - { - services.AddAuthentication(auth => + var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .ConfigureServices(services => { - auth.DefaultScheme = CookieAuthenticationDefaults.AuthenticationScheme; - auth.DefaultChallengeScheme = OpenIdConnectDefaults.AuthenticationScheme; + services.AddAuthentication(auth => + { + auth.DefaultScheme = CookieAuthenticationDefaults.AuthenticationScheme; + auth.DefaultChallengeScheme = OpenIdConnectDefaults.AuthenticationScheme; + }) + .AddCookie() + .AddOpenIdConnect(o => + { + o.Events = events; + o.ClientId = "ClientId"; + o.GetClaimsFromUserInfoEndpoint = true; + o.Configuration = new OpenIdConnectConfiguration() + { + TokenEndpoint = "http://testhost/tokens", + UserInfoEndpoint = "http://testhost/user", + EndSessionEndpoint = "http://testhost/end" + }; + o.StateDataFormat = new TestStateDataFormat(); + o.SecurityTokenValidator = new TestTokenValidator(); + o.ProtocolValidator = new TestProtocolValidator(); + o.BackchannelHttpHandler = new TestBackchannel(); + }); }) - .AddCookie() - .AddOpenIdConnect(o => + .Configure(app => { - o.Events = events; - o.ClientId = "ClientId"; - o.GetClaimsFromUserInfoEndpoint = true; - o.Configuration = new OpenIdConnectConfiguration() - { - TokenEndpoint = "http://testhost/tokens", - UserInfoEndpoint = "http://testhost/user", - EndSessionEndpoint = "http://testhost/end" - }; - o.StateDataFormat = new TestStateDataFormat(); - o.SecurityTokenValidator = new TestTokenValidator(); - o.ProtocolValidator = new TestProtocolValidator(); - o.BackchannelHttpHandler = new TestBackchannel(); - }); - }) - .Configure(app => - { - app.UseAuthentication(); - app.Run(appCode); - }); + app.UseAuthentication(); + app.Run(appCode); + })) + .Build(); - return new TestServer(builder); + host.Start(); + return host.GetTestServer(); } private Task<HttpResponseMessage> PostAsync(TestServer server, string path, string form) diff --git a/src/Security/Authentication/test/OpenIdConnect/TestServerBuilder.cs b/src/Security/Authentication/test/OpenIdConnect/TestServerBuilder.cs index c37da8c0438..41fd484b733 100644 --- a/src/Security/Authentication/test/OpenIdConnect/TestServerBuilder.cs +++ b/src/Security/Authentication/test/OpenIdConnect/TestServerBuilder.cs @@ -11,6 +11,7 @@ using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; using Microsoft.IdentityModel.Protocols; using Microsoft.IdentityModel.Protocols.OpenIdConnect; @@ -62,59 +63,63 @@ namespace Microsoft.AspNetCore.Authentication.Test.OpenIdConnect Func<HttpContext, Task> handler, AuthenticationProperties properties) { - var builder = new WebHostBuilder() - .Configure(app => - { - app.UseAuthentication(); - app.Use(async (context, next) => - { - var req = context.Request; - var res = context.Response; - - if (req.Path == new PathString(Challenge)) - { - await context.ChallengeAsync(OpenIdConnectDefaults.AuthenticationScheme); - } - else if (req.Path == new PathString(ChallengeWithProperties)) - { - await context.ChallengeAsync(OpenIdConnectDefaults.AuthenticationScheme, properties); - } - else if (req.Path == new PathString(ChallengeWithOutContext)) - { - res.StatusCode = 401; - } - else if (req.Path == new PathString(Signin)) + var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .Configure(app => { - await context.SignInAsync(OpenIdConnectDefaults.AuthenticationScheme, new ClaimsPrincipal()); - } - else if (req.Path == new PathString(Signout)) - { - await context.SignOutAsync(OpenIdConnectDefaults.AuthenticationScheme); - } - else if (req.Path == new PathString("/signout_with_specific_redirect_uri")) - { - await context.SignOutAsync( - OpenIdConnectDefaults.AuthenticationScheme, - new AuthenticationProperties() { RedirectUri = "http://www.example.com/specific_redirect_uri" }); - } - else if (handler != null) - { - await handler(context); - } - else + app.UseAuthentication(); + app.Use(async (context, next) => + { + var req = context.Request; + var res = context.Response; + + if (req.Path == new PathString(Challenge)) + { + await context.ChallengeAsync(OpenIdConnectDefaults.AuthenticationScheme); + } + else if (req.Path == new PathString(ChallengeWithProperties)) + { + await context.ChallengeAsync(OpenIdConnectDefaults.AuthenticationScheme, properties); + } + else if (req.Path == new PathString(ChallengeWithOutContext)) + { + res.StatusCode = 401; + } + else if (req.Path == new PathString(Signin)) + { + await context.SignInAsync(OpenIdConnectDefaults.AuthenticationScheme, new ClaimsPrincipal()); + } + else if (req.Path == new PathString(Signout)) + { + await context.SignOutAsync(OpenIdConnectDefaults.AuthenticationScheme); + } + else if (req.Path == new PathString("/signout_with_specific_redirect_uri")) + { + await context.SignOutAsync( + OpenIdConnectDefaults.AuthenticationScheme, + new AuthenticationProperties() { RedirectUri = "http://www.example.com/specific_redirect_uri" }); + } + else if (handler != null) + { + await handler(context); + } + else + { + await next(); + } + }); + }) + .ConfigureServices(services => { - await next(); - } - }); - }) - .ConfigureServices(services => - { - services.AddAuthentication(CookieAuthenticationDefaults.AuthenticationScheme) - .AddCookie() - .AddOpenIdConnect(options); - }); + services.AddAuthentication(CookieAuthenticationDefaults.AuthenticationScheme) + .AddCookie() + .AddOpenIdConnect(options); + })) + .Build(); - return new TestServer(builder); + host.Start(); + return host.GetTestServer(); } } } diff --git a/src/Security/Authentication/test/RemoteAuthenticationTests.cs b/src/Security/Authentication/test/RemoteAuthenticationTests.cs index d477e75347c..7222a759028 100644 --- a/src/Security/Authentication/test/RemoteAuthenticationTests.cs +++ b/src/Security/Authentication/test/RemoteAuthenticationTests.cs @@ -6,6 +6,7 @@ using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; using System; using System.Threading.Tasks; using Xunit; @@ -16,8 +17,8 @@ namespace Microsoft.AspNetCore.Authentication { protected override string DisplayName => DefaultScheme; - private TestServer CreateServer(Action<TOptions> configureOptions, Func<HttpContext, Task> testpath = null, bool isDefault = true) - => CreateServerWithServices(s => + private Task<IHost> CreateHost(Action<TOptions> configureOptions, Func<HttpContext, Task> testpath = null, bool isDefault = true) + => CreateHostWithServices(s => { var builder = s.AddAuthentication(); if (isDefault) @@ -29,23 +30,26 @@ namespace Microsoft.AspNetCore.Authentication }, testpath); - protected virtual TestServer CreateServerWithServices(Action<IServiceCollection> configureServices, Func<HttpContext, Task> testpath = null) + protected virtual async Task<IHost> CreateHostWithServices(Action<IServiceCollection> configureServices, Func<HttpContext, Task> testpath = null) { - //private static TestServer CreateServer(Action<IApplicationBuilder> configure, Action<IServiceCollection> configureServices, Func<HttpContext, Task<bool>> handler) - var builder = new WebHostBuilder() - .Configure(app => - { - app.Use(async (context, next) => - { - if (testpath != null) + var host = new HostBuilder() + .ConfigureWebHost(webHostBuilder => + webHostBuilder.UseTestServer() + .Configure(app => { - await testpath(context); - } - await next(); - }); - }) - .ConfigureServices(configureServices); - return new TestServer(builder); + app.Use(async (context, next) => + { + if (testpath != null) + { + await testpath(context); + } + await next(); + }); + }) + .ConfigureServices(configureServices)) + .Build(); + await host.StartAsync(); + return host; } protected abstract void ConfigureDefaults(TOptions o); @@ -53,13 +57,14 @@ namespace Microsoft.AspNetCore.Authentication [Fact] public async Task VerifySignInSchemeCannotBeSetToSelf() { - var server = CreateServer( + using var host = await CreateHost( o => { ConfigureDefaults(o); o.SignInScheme = DefaultScheme; }, context => context.ChallengeAsync(DefaultScheme)); + using var server = host.GetTestServer(); var error = await Assert.ThrowsAsync<InvalidOperationException>(() => server.SendAsync("https://example.com/challenge")); Assert.Contains("cannot be set to itself", error.Message); } @@ -67,10 +72,12 @@ namespace Microsoft.AspNetCore.Authentication [Fact] public async Task VerifySignInSchemeCannotBeSetToSelfUsingDefaultScheme() { - var server = CreateServer( + + using var host = await CreateHost( o => o.SignInScheme = null, context => context.ChallengeAsync(DefaultScheme), isDefault: true); + using var server = host.GetTestServer(); var error = await Assert.ThrowsAsync<InvalidOperationException>(() => server.SendAsync("https://example.com/challenge")); Assert.Contains("cannot be set to itself", error.Message); } @@ -78,13 +85,14 @@ namespace Microsoft.AspNetCore.Authentication [Fact] public async Task VerifySignInSchemeCannotBeSetToSelfUsingDefaultSignInScheme() { - var server = CreateServerWithServices( + using var host = await CreateHostWithServices( services => { var builder = services.AddAuthentication(o => o.DefaultSignInScheme = DefaultScheme); RegisterAuth(builder, o => o.SignInScheme = null); }, context => context.ChallengeAsync(DefaultScheme)); + using var server = host.GetTestServer(); var error = await Assert.ThrowsAsync<InvalidOperationException>(() => server.SendAsync("https://example.com/challenge")); Assert.Contains("cannot be set to itself", error.Message); } diff --git a/src/Security/Authentication/test/TwitterTests.cs b/src/Security/Authentication/test/TwitterTests.cs index 7d028396d22..6958cabfa88 100644 --- a/src/Security/Authentication/test/TwitterTests.cs +++ b/src/Security/Authentication/test/TwitterTests.cs @@ -5,6 +5,7 @@ using Microsoft.AspNetCore.Hosting; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.TestHost; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; using Microsoft.Net.Http.Headers; using System; using System.Linq; @@ -43,7 +44,7 @@ namespace Microsoft.AspNetCore.Authentication.Twitter [Fact] public async Task ChallengeWillTriggerApplyRedirectEvent() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ConsumerKey = "Test Consumer Key"; o.ConsumerSecret = "Test Consumer Secret"; @@ -65,6 +66,7 @@ namespace Microsoft.AspNetCore.Authentication.Twitter await context.ChallengeAsync("Twitter"); return true; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com/challenge"); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); var query = transaction.Response.Headers.Location.Query; @@ -78,11 +80,12 @@ namespace Microsoft.AspNetCore.Authentication.Twitter [Fact] public async Task ThrowsIfClientIdMissing() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ConsumerSecret = "Test Consumer Secret"; }); + using var server = host.GetTestServer(); await Assert.ThrowsAsync<ArgumentException>("ConsumerKey", async () => await server.SendAsync("http://example.com/challenge")); } @@ -93,24 +96,26 @@ namespace Microsoft.AspNetCore.Authentication.Twitter [Fact] public async Task ThrowsIfClientSecretMissing() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ConsumerKey = "Test Consumer Key"; }); + using var server = host.GetTestServer(); await Assert.ThrowsAsync<ArgumentException>("ConsumerSecret", async () => await server.SendAsync("http://example.com/challenge")); } [Fact] public async Task BadSignInWillThrow() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ConsumerKey = "Test Consumer Key"; o.ConsumerSecret = "Test Consumer Secret"; }); // Send a bogus sign in + using var server = host.GetTestServer(); var error = await Assert.ThrowsAnyAsync<Exception>(() => server.SendAsync("https://example.com/signin-twitter")); Assert.Equal("Invalid state cookie.", error.GetBaseException().Message); } @@ -118,11 +123,12 @@ namespace Microsoft.AspNetCore.Authentication.Twitter [Fact] public async Task SignInThrows() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ConsumerKey = "Test Consumer Key"; o.ConsumerSecret = "Test Consumer Secret"; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://example.com/signIn"); Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode); } @@ -130,11 +136,12 @@ namespace Microsoft.AspNetCore.Authentication.Twitter [Fact] public async Task SignOutThrows() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ConsumerKey = "Test Consumer Key"; o.ConsumerSecret = "Test Consumer Secret"; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://example.com/signOut"); Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode); } @@ -142,11 +149,12 @@ namespace Microsoft.AspNetCore.Authentication.Twitter [Fact] public async Task ForbidThrows() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ConsumerKey = "Test Consumer Key"; o.ConsumerSecret = "Test Consumer Secret"; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("https://example.com/signOut"); Assert.Equal(HttpStatusCode.OK, transaction.Response.StatusCode); } @@ -154,7 +162,7 @@ namespace Microsoft.AspNetCore.Authentication.Twitter [Fact] public async Task ChallengeWillTriggerRedirection() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ConsumerKey = "Test Consumer Key"; o.ConsumerSecret = "Test Consumer Secret"; @@ -168,6 +176,7 @@ namespace Microsoft.AspNetCore.Authentication.Twitter await context.ChallengeAsync("Twitter"); return true; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com/challenge"); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); var location = transaction.Response.Headers.Location.AbsoluteUri; @@ -177,7 +186,7 @@ namespace Microsoft.AspNetCore.Authentication.Twitter [Fact] public async Task HandleRequestAsync_RedirectsToAccessDeniedPathWhenExplicitlySet() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ConsumerKey = "Test Consumer Key"; o.ConsumerSecret = "Test Consumer Secret"; @@ -195,6 +204,7 @@ namespace Microsoft.AspNetCore.Authentication.Twitter await context.ChallengeAsync("Twitter", properties); return true; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com/challenge"); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); var location = transaction.Response.Headers.Location.AbsoluteUri; @@ -217,7 +227,7 @@ namespace Microsoft.AspNetCore.Authentication.Twitter [Fact] public async Task BadCallbackCallsAccessDeniedWithState() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ConsumerKey = "Test Consumer Key"; o.ConsumerSecret = "Test Consumer Secret"; @@ -244,6 +254,7 @@ namespace Microsoft.AspNetCore.Authentication.Twitter await context.ChallengeAsync("Twitter", properties); return true; }); + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com/challenge"); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); var location = transaction.Response.Headers.Location.AbsoluteUri; @@ -265,7 +276,7 @@ namespace Microsoft.AspNetCore.Authentication.Twitter [Fact] public async Task BadCallbackCallsRemoteAuthFailedWithState() { - var server = CreateServer(o => + using var host = await CreateHost(o => { o.ConsumerKey = "Test Consumer Key"; o.ConsumerSecret = "Test Consumer Secret"; @@ -294,6 +305,8 @@ namespace Microsoft.AspNetCore.Authentication.Twitter await context.ChallengeAsync("Twitter", properties); return true; }); + + using var server = host.GetTestServer(); var transaction = await server.SendAsync("http://example.com/challenge"); Assert.Equal(HttpStatusCode.Redirect, transaction.Response.StatusCode); var location = transaction.Response.Headers.Location.AbsoluteUri; @@ -312,46 +325,51 @@ namespace Microsoft.AspNetCore.Authentication.Twitter Assert.Equal(HttpStatusCode.NotAcceptable, response.StatusCode); } - private static TestServer CreateServer(Action<TwitterOptions> options, Func<HttpContext, Task<bool>> handler = null) + private static async Task<IHost> CreateHost(Action<TwitterOptions> options, Func<HttpContext, Task<bool>> handler = null) { - var builder = new WebHostBuilder() - .Configure(app => - { - app.UseAuthentication(); - app.Use(async (context, next) => - { - var req = context.Request; - var res = context.Response; - if (req.Path == new PathString("/signIn")) - { - await Assert.ThrowsAsync<InvalidOperationException>(() => context.SignInAsync("Twitter", new ClaimsPrincipal())); - } - else if (req.Path == new PathString("/signOut")) + var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .Configure(app => { - await Assert.ThrowsAsync<InvalidOperationException>(() => context.SignOutAsync("Twitter")); - } - else if (req.Path == new PathString("/forbid")) + app.UseAuthentication(); + app.Use(async (context, next) => + { + var req = context.Request; + var res = context.Response; + if (req.Path == new PathString("/signIn")) + { + await Assert.ThrowsAsync<InvalidOperationException>(() => context.SignInAsync("Twitter", new ClaimsPrincipal())); + } + else if (req.Path == new PathString("/signOut")) + { + await Assert.ThrowsAsync<InvalidOperationException>(() => context.SignOutAsync("Twitter")); + } + else if (req.Path == new PathString("/forbid")) + { + await Assert.ThrowsAsync<InvalidOperationException>(() => context.ForbidAsync("Twitter")); + } + else if (handler == null || !await handler(context)) + { + await next(); + } + }); + }) + .ConfigureServices(services => { - await Assert.ThrowsAsync<InvalidOperationException>(() => context.ForbidAsync("Twitter")); - } - else if (handler == null || ! await handler(context)) - { - await next(); - } - }); - }) - .ConfigureServices(services => - { - Action<TwitterOptions> wrapOptions = o => - { - o.SignInScheme = "External"; - options(o); - }; - services.AddAuthentication() - .AddCookie("External", _ => { }) - .AddTwitter(wrapOptions); - }); - return new TestServer(builder); + Action<TwitterOptions> wrapOptions = o => + { + o.SignInScheme = "External"; + options(o); + }; + services.AddAuthentication() + .AddCookie("External", _ => { }) + .AddTwitter(wrapOptions); + })) + .Build(); + + await host.StartAsync(); + return host; } private HttpResponseMessage BackchannelRequestToken(HttpRequestMessage req) diff --git a/src/Security/Authentication/test/WsFederation/WsFederationTest.cs b/src/Security/Authentication/test/WsFederation/WsFederationTest.cs index 1c4fc9e94d6..eff5985347e 100644 --- a/src/Security/Authentication/test/WsFederation/WsFederationTest.cs +++ b/src/Security/Authentication/test/WsFederation/WsFederationTest.cs @@ -19,6 +19,7 @@ using Microsoft.AspNetCore.Http.Extensions; using Microsoft.AspNetCore.TestHost; using Microsoft.AspNetCore.WebUtilities; using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; using Microsoft.IdentityModel.Tokens; using Microsoft.Net.Http.Headers; using Xunit; @@ -43,20 +44,25 @@ namespace Microsoft.AspNetCore.Authentication.WsFederation [Fact] public async Task MissingConfigurationThrows() { - var builder = new WebHostBuilder() - .Configure(ConfigureApp) - .ConfigureServices(services => - { - services.AddAuthentication(sharedOptions => - { - sharedOptions.DefaultScheme = CookieAuthenticationDefaults.AuthenticationScheme; - sharedOptions.DefaultSignInScheme = CookieAuthenticationDefaults.AuthenticationScheme; - sharedOptions.DefaultChallengeScheme = WsFederationDefaults.AuthenticationScheme; - }) - .AddCookie() - .AddWsFederation(); - }); - var server = new TestServer(builder); + using var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .Configure(ConfigureApp) + .ConfigureServices(services => + { + services.AddAuthentication(sharedOptions => + { + sharedOptions.DefaultScheme = CookieAuthenticationDefaults.AuthenticationScheme; + sharedOptions.DefaultSignInScheme = CookieAuthenticationDefaults.AuthenticationScheme; + sharedOptions.DefaultChallengeScheme = WsFederationDefaults.AuthenticationScheme; + }) + .AddCookie() + .AddWsFederation(); + })) + .Build(); + + await host.StartAsync(); + using var server = host.GetTestServer(); var httpClient = server.CreateClient(); // Verify if the request is redirected to STS with right parameters @@ -67,7 +73,7 @@ namespace Microsoft.AspNetCore.Authentication.WsFederation [Fact] public async Task ChallengeRedirects() { - var httpClient = CreateClient(); + var httpClient = await CreateClient(); // Verify if the request is redirected to STS with right parameters var response = await httpClient.GetAsync("/"); @@ -83,7 +89,7 @@ namespace Microsoft.AspNetCore.Authentication.WsFederation [Fact] public async Task MapWillNotAffectRedirect() { - var httpClient = CreateClient(); + var httpClient = await CreateClient(); // Verify if the request is redirected to STS with right parameters var response = await httpClient.GetAsync("/mapped-challenge"); @@ -99,7 +105,7 @@ namespace Microsoft.AspNetCore.Authentication.WsFederation [Fact] public async Task PreMappedWillAffectRedirect() { - var httpClient = CreateClient(); + var httpClient = await CreateClient(); // Verify if the request is redirected to STS with right parameters var response = await httpClient.GetAsync("/premapped-challenge"); @@ -115,7 +121,7 @@ namespace Microsoft.AspNetCore.Authentication.WsFederation [Fact] public async Task ValidTokenIsAccepted() { - var httpClient = CreateClient(); + var httpClient = await CreateClient(); // Verify if the request is redirected to STS with right parameters var response = await httpClient.GetAsync("/"); @@ -139,7 +145,7 @@ namespace Microsoft.AspNetCore.Authentication.WsFederation [Fact] public async Task ValidUnsolicitedTokenIsRefused() { - var httpClient = CreateClient(); + var httpClient = await CreateClient(); var form = CreateSignInContent("WsFederation/ValidToken.xml", suppressWctx: true); var exception = await Assert.ThrowsAsync<Exception>(() => httpClient.PostAsync(httpClient.BaseAddress + "signin-wsfed", form)); Assert.Contains("Unsolicited logins are not allowed.", exception.InnerException.Message); @@ -148,7 +154,7 @@ namespace Microsoft.AspNetCore.Authentication.WsFederation [Fact] public async Task ValidUnsolicitedTokenIsAcceptedWhenAllowed() { - var httpClient = CreateClient(allowUnsolicited: true); + var httpClient = await CreateClient(allowUnsolicited: true); var form = CreateSignInContent("WsFederation/ValidToken.xml", suppressWctx: true); var response = await httpClient.PostAsync(httpClient.BaseAddress + "signin-wsfed", form); @@ -166,7 +172,7 @@ namespace Microsoft.AspNetCore.Authentication.WsFederation [Fact] public async Task InvalidTokenIsRejected() { - var httpClient = CreateClient(); + var httpClient = await CreateClient(); // Verify if the request is redirected to STS with right parameters var response = await httpClient.GetAsync("/"); @@ -184,7 +190,7 @@ namespace Microsoft.AspNetCore.Authentication.WsFederation [Fact] public async Task RemoteSignoutRequestTriggersSignout() { - var httpClient = CreateClient(); + var httpClient = await CreateClient(); var response = await httpClient.GetAsync("/signin-wsfed?wa=wsignoutcleanup1.0"); response.EnsureSuccessStatusCode(); @@ -198,30 +204,35 @@ namespace Microsoft.AspNetCore.Authentication.WsFederation [Fact] public async Task EventsResolvedFromDI() { - var builder = new WebHostBuilder() - .ConfigureServices(services => - { - services.AddSingleton<MyWsFedEvents>(); - services.AddAuthentication(sharedOptions => - { - sharedOptions.DefaultScheme = CookieAuthenticationDefaults.AuthenticationScheme; - sharedOptions.DefaultSignInScheme = CookieAuthenticationDefaults.AuthenticationScheme; - sharedOptions.DefaultChallengeScheme = WsFederationDefaults.AuthenticationScheme; - }) - .AddCookie() - .AddWsFederation(options => - { - options.Wtrealm = "http://Automation1"; - options.MetadataAddress = "https://login.windows.net/4afbc689-805b-48cf-a24c-d4aa3248a248/federationmetadata/2007-06/federationmetadata.xml"; - options.BackchannelHttpHandler = new WaadMetadataDocumentHandler(); - options.EventsType = typeof(MyWsFedEvents); - }); - }) - .Configure(app => - { - app.Run(context => context.ChallengeAsync()); - }); - var server = new TestServer(builder); + using var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .ConfigureServices(services => + { + services.AddSingleton<MyWsFedEvents>(); + services.AddAuthentication(sharedOptions => + { + sharedOptions.DefaultScheme = CookieAuthenticationDefaults.AuthenticationScheme; + sharedOptions.DefaultSignInScheme = CookieAuthenticationDefaults.AuthenticationScheme; + sharedOptions.DefaultChallengeScheme = WsFederationDefaults.AuthenticationScheme; + }) + .AddCookie() + .AddWsFederation(options => + { + options.Wtrealm = "http://Automation1"; + options.MetadataAddress = "https://login.windows.net/4afbc689-805b-48cf-a24c-d4aa3248a248/federationmetadata/2007-06/federationmetadata.xml"; + options.BackchannelHttpHandler = new WaadMetadataDocumentHandler(); + options.EventsType = typeof(MyWsFedEvents); + }); + }) + .Configure(app => + { + app.Run(context => context.ChallengeAsync()); + })) + .Build(); + + await host.StartAsync(); + using var server = host.GetTestServer(); var result = await server.CreateClient().GetAsync(""); Assert.Contains("CustomKey=CustomValue", result.Headers.Location.Query); @@ -264,86 +275,91 @@ namespace Microsoft.AspNetCore.Authentication.WsFederation } } - private HttpClient CreateClient(bool allowUnsolicited = false) + private async Task<HttpClient> CreateClient(bool allowUnsolicited = false) { - var builder = new WebHostBuilder() - .Configure(ConfigureApp) - .ConfigureServices(services => - { - services.AddAuthentication(sharedOptions => - { - sharedOptions.DefaultScheme = CookieAuthenticationDefaults.AuthenticationScheme; - sharedOptions.DefaultSignInScheme = CookieAuthenticationDefaults.AuthenticationScheme; - sharedOptions.DefaultChallengeScheme = WsFederationDefaults.AuthenticationScheme; - }) - .AddCookie() - .AddWsFederation(options => + var host = new HostBuilder() + .ConfigureWebHost(builder => + builder.UseTestServer() + .Configure(ConfigureApp) + .ConfigureServices(services => { - options.Wtrealm = "http://Automation1"; - options.MetadataAddress = "https://login.windows.net/4afbc689-805b-48cf-a24c-d4aa3248a248/federationmetadata/2007-06/federationmetadata.xml"; - options.BackchannelHttpHandler = new WaadMetadataDocumentHandler(); - options.StateDataFormat = new CustomStateDataFormat(); - options.SecurityTokenHandlers = new List<ISecurityTokenValidator>() { new TestSecurityTokenValidator() }; - options.UseTokenLifetime = false; - options.AllowUnsolicitedLogins = allowUnsolicited; - options.Events = new WsFederationEvents() + services.AddAuthentication(sharedOptions => { - OnMessageReceived = context => + sharedOptions.DefaultScheme = CookieAuthenticationDefaults.AuthenticationScheme; + sharedOptions.DefaultSignInScheme = CookieAuthenticationDefaults.AuthenticationScheme; + sharedOptions.DefaultChallengeScheme = WsFederationDefaults.AuthenticationScheme; + }) + .AddCookie() + .AddWsFederation(options => + { + options.Wtrealm = "http://Automation1"; + options.MetadataAddress = "https://login.windows.net/4afbc689-805b-48cf-a24c-d4aa3248a248/federationmetadata/2007-06/federationmetadata.xml"; + options.BackchannelHttpHandler = new WaadMetadataDocumentHandler(); + options.StateDataFormat = new CustomStateDataFormat(); + options.SecurityTokenHandlers = new List<ISecurityTokenValidator>() { new TestSecurityTokenValidator() }; + options.UseTokenLifetime = false; + options.AllowUnsolicitedLogins = allowUnsolicited; + options.Events = new WsFederationEvents() { - if (!context.ProtocolMessage.Parameters.TryGetValue("suppressWctx", out var suppress)) + OnMessageReceived = context => { - Assert.True(context.ProtocolMessage.Wctx.Equals("customValue"), "wctx is not my custom value"); - } - context.HttpContext.Items["MessageReceived"] = true; - return Task.FromResult(0); - }, - OnRedirectToIdentityProvider = context => - { - if (context.ProtocolMessage.IsSignInMessage) + if (!context.ProtocolMessage.Parameters.TryGetValue("suppressWctx", out var suppress)) + { + Assert.True(context.ProtocolMessage.Wctx.Equals("customValue"), "wctx is not my custom value"); + } + context.HttpContext.Items["MessageReceived"] = true; + return Task.FromResult(0); + }, + OnRedirectToIdentityProvider = context => { - // Sign in message - context.ProtocolMessage.Wctx = "customValue"; - } - - return Task.FromResult(0); - }, - OnSecurityTokenReceived = context => - { - context.HttpContext.Items["SecurityTokenReceived"] = true; - return Task.FromResult(0); - }, - OnSecurityTokenValidated = context => - { - Assert.True((bool)context.HttpContext.Items["MessageReceived"], "MessageReceived notification not invoked"); - Assert.True((bool)context.HttpContext.Items["SecurityTokenReceived"], "SecurityTokenReceived notification not invoked"); - - if (context.Principal != null) + if (context.ProtocolMessage.IsSignInMessage) + { + // Sign in message + context.ProtocolMessage.Wctx = "customValue"; + } + + return Task.FromResult(0); + }, + OnSecurityTokenReceived = context => { - var identity = context.Principal.Identities.Single(); - identity.AddClaim(new Claim("ReturnEndpoint", "true")); - identity.AddClaim(new Claim("Authenticated", "true")); - identity.AddClaim(new Claim(identity.RoleClaimType, "Guest", ClaimValueTypes.String)); + context.HttpContext.Items["SecurityTokenReceived"] = true; + return Task.FromResult(0); + }, + OnSecurityTokenValidated = context => + { + Assert.True((bool)context.HttpContext.Items["MessageReceived"], "MessageReceived notification not invoked"); + Assert.True((bool)context.HttpContext.Items["SecurityTokenReceived"], "SecurityTokenReceived notification not invoked"); + + if (context.Principal != null) + { + var identity = context.Principal.Identities.Single(); + identity.AddClaim(new Claim("ReturnEndpoint", "true")); + identity.AddClaim(new Claim("Authenticated", "true")); + identity.AddClaim(new Claim(identity.RoleClaimType, "Guest", ClaimValueTypes.String)); + } + + return Task.FromResult(0); + }, + OnAuthenticationFailed = context => + { + context.HttpContext.Items["AuthenticationFailed"] = true; + //Change the request url to something different and skip Wsfed. This new url will handle the request and let us know if this notification was invoked. + context.HttpContext.Request.Path = new PathString("/AuthenticationFailed"); + context.SkipHandler(); + return Task.FromResult(0); + }, + OnRemoteSignOut = context => + { + context.Response.Headers["EventHeader"] = "OnRemoteSignOut"; + return Task.FromResult(0); } + }; + }); + })) + .Build(); - return Task.FromResult(0); - }, - OnAuthenticationFailed = context => - { - context.HttpContext.Items["AuthenticationFailed"] = true; - //Change the request url to something different and skip Wsfed. This new url will handle the request and let us know if this notification was invoked. - context.HttpContext.Request.Path = new PathString("/AuthenticationFailed"); - context.SkipHandler(); - return Task.FromResult(0); - }, - OnRemoteSignOut = context => - { - context.Response.Headers["EventHeader"] = "OnRemoteSignOut"; - return Task.FromResult(0); - } - }; - }); - }); - var server = new TestServer(builder); + await host.StartAsync(); + var server = host.GetTestServer(); return server.CreateClient(); } -- GitLab