From 17b49eccb8044b54282014d2b2528306854b43ac Mon Sep 17 00:00:00 2001
From: Javier Calvarro Nelson <jacalvar@microsoft.com>
Date: Tue, 5 Jul 2022 19:58:58 +0200
Subject: [PATCH] [Blazor] Cleanup remote authentication service serialization
 (#42573)

---
 .../src/Interop/AuthenticationService.ts      |  12 +-
 .../src/Models/RemoteAuthenticationStatus.cs  |   3 +
 .../Services/RemoteAuthenticationService.cs   | 117 ++++--------------
 .../test/RemoteAuthenticationServiceTests.cs  |  58 ++++-----
 4 files changed, 63 insertions(+), 127 deletions(-)

diff --git a/src/Components/WebAssembly/WebAssembly.Authentication/src/Interop/AuthenticationService.ts b/src/Components/WebAssembly/WebAssembly.Authentication/src/Interop/AuthenticationService.ts
index 67b9bf6b6f9..f359b773253 100644
--- a/src/Components/WebAssembly/WebAssembly.Authentication/src/Interop/AuthenticationService.ts
+++ b/src/Components/WebAssembly/WebAssembly.Authentication/src/Interop/AuthenticationService.ts
@@ -35,15 +35,15 @@ export interface AccessToken {
 }
 
 export enum AccessTokenResultStatus {
-    Success = 'success',
-    RequiresRedirect = 'requiresRedirect'
+    Success = 'Success',
+    RequiresRedirect = 'RequiresRedirect'
 }
 
 export enum AuthenticationResultStatus {
-    Redirect = 'redirect',
-    Success = 'success',
-    Failure = 'failure',
-    OperationCompleted = 'operationCompleted'
+    Redirect = 'Redirect',
+    Success = 'Success',
+    Failure = 'Failure',
+    OperationCompleted = 'OperationCompleted'
 };
 
 export interface AuthenticationResult {
diff --git a/src/Components/WebAssembly/WebAssembly.Authentication/src/Models/RemoteAuthenticationStatus.cs b/src/Components/WebAssembly/WebAssembly.Authentication/src/Models/RemoteAuthenticationStatus.cs
index 00039b2b70a..9aaf64a98d7 100644
--- a/src/Components/WebAssembly/WebAssembly.Authentication/src/Models/RemoteAuthenticationStatus.cs
+++ b/src/Components/WebAssembly/WebAssembly.Authentication/src/Models/RemoteAuthenticationStatus.cs
@@ -1,11 +1,14 @@
 // Licensed to the .NET Foundation under one or more agreements.
 // The .NET Foundation licenses this file to you under the MIT license.
 
+using System.Text.Json.Serialization;
+
 namespace Microsoft.AspNetCore.Components.WebAssembly.Authentication;
 
 /// <summary>
 /// Represents the status of an authentication operation.
 /// </summary>
+[JsonConverter(typeof(JsonStringEnumConverter))]
 public enum RemoteAuthenticationStatus
 {
     /// <summary>
diff --git a/src/Components/WebAssembly/WebAssembly.Authentication/src/Services/RemoteAuthenticationService.cs b/src/Components/WebAssembly/WebAssembly.Authentication/src/Services/RemoteAuthenticationService.cs
index 4ffe310768f..a4909c42230 100644
--- a/src/Components/WebAssembly/WebAssembly.Authentication/src/Services/RemoteAuthenticationService.cs
+++ b/src/Components/WebAssembly/WebAssembly.Authentication/src/Services/RemoteAuthenticationService.cs
@@ -3,6 +3,7 @@
 
 using System.Diagnostics.CodeAnalysis;
 using System.Security.Claims;
+using System.Text.Json.Serialization;
 using Microsoft.AspNetCore.Components.Authorization;
 using Microsoft.Extensions.Options;
 using Microsoft.JSInterop;
@@ -81,14 +82,8 @@ public class RemoteAuthenticationService<
         RemoteAuthenticationContext<TRemoteAuthenticationState> context)
     {
         await EnsureAuthService();
-        var internalResult = await JsRuntime.InvokeAsync<InternalRemoteAuthenticationResult<TRemoteAuthenticationState>>("AuthenticationService.signIn", context.State);
-        var result = internalResult.Convert();
-        if (result.Status == RemoteAuthenticationStatus.Success)
-        {
-            var getUserTask = GetUser();
-            await getUserTask;
-            UpdateUser(getUserTask);
-        }
+        var result = await JsRuntime.InvokeAsync<RemoteAuthenticationResult<TRemoteAuthenticationState>>("AuthenticationService.signIn", context.State);
+        await UpdateUserOnSuccess(result);
 
         return result;
     }
@@ -98,14 +93,8 @@ public class RemoteAuthenticationService<
         RemoteAuthenticationContext<TRemoteAuthenticationState> context)
     {
         await EnsureAuthService();
-        var internalResult = await JsRuntime.InvokeAsync<InternalRemoteAuthenticationResult<TRemoteAuthenticationState>>("AuthenticationService.completeSignIn", context.Url);
-        var result = internalResult.Convert();
-        if (result.Status == RemoteAuthenticationStatus.Success)
-        {
-            var getUserTask = GetUser();
-            await getUserTask;
-            UpdateUser(getUserTask);
-        }
+        var result = await JsRuntime.InvokeAsync<RemoteAuthenticationResult<TRemoteAuthenticationState>>("AuthenticationService.completeSignIn", context.Url);
+        await UpdateUserOnSuccess(result);
 
         return result;
     }
@@ -115,14 +104,8 @@ public class RemoteAuthenticationService<
         RemoteAuthenticationContext<TRemoteAuthenticationState> context)
     {
         await EnsureAuthService();
-        var internalResult = await JsRuntime.InvokeAsync<InternalRemoteAuthenticationResult<TRemoteAuthenticationState>>("AuthenticationService.signOut", context.State);
-        var result = internalResult.Convert();
-        if (result.Status == RemoteAuthenticationStatus.Success)
-        {
-            var getUserTask = GetUser();
-            await getUserTask;
-            UpdateUser(getUserTask);
-        }
+        var result = await JsRuntime.InvokeAsync<RemoteAuthenticationResult<TRemoteAuthenticationState>>("AuthenticationService.signOut", context.State);
+        await UpdateUserOnSuccess(result);
 
         return result;
     }
@@ -132,14 +115,8 @@ public class RemoteAuthenticationService<
         RemoteAuthenticationContext<TRemoteAuthenticationState> context)
     {
         await EnsureAuthService();
-        var internalResult = await JsRuntime.InvokeAsync<InternalRemoteAuthenticationResult<TRemoteAuthenticationState>>("AuthenticationService.completeSignOut", context.Url);
-        var result = internalResult.Convert();
-        if (result.Status == RemoteAuthenticationStatus.Success)
-        {
-            var getUserTask = GetUser();
-            await getUserTask;
-            UpdateUser(getUserTask);
-        }
+        var result = await JsRuntime.InvokeAsync<RemoteAuthenticationResult<TRemoteAuthenticationState>>("AuthenticationService.completeSignOut", context.Url);
+        await UpdateUserOnSuccess(result);
 
         return result;
     }
@@ -150,18 +127,10 @@ public class RemoteAuthenticationService<
         await EnsureAuthService();
         var result = await JsRuntime.InvokeAsync<InternalAccessTokenResult>("AuthenticationService.getAccessToken");
 
-        if (!Enum.TryParse<AccessTokenResultStatus>(result.Status, ignoreCase: true, out var parsedStatus))
-        {
-            throw new InvalidOperationException($"Invalid access token result status '{result.Status ?? "(null)"}'");
-        }
-
-        if (parsedStatus == AccessTokenResultStatus.RequiresRedirect)
-        {
-            var redirectUrl = GetRedirectUrl(null);
-            result.RedirectUrl = redirectUrl.ToString();
-        }
-
-        return new AccessTokenResult(parsedStatus, result.Token, result.RedirectUrl);
+        return new AccessTokenResult(
+            result.Status,
+            result.Token,
+            result.Status == AccessTokenResultStatus.RequiresRedirect ? GetRedirectUrl(null).ToString() : null);
     }
 
     /// <inheritdoc />
@@ -177,18 +146,10 @@ public class RemoteAuthenticationService<
         await EnsureAuthService();
         var result = await JsRuntime.InvokeAsync<InternalAccessTokenResult>("AuthenticationService.getAccessToken", options);
 
-        if (!Enum.TryParse<AccessTokenResultStatus>(result.Status, ignoreCase: true, out var parsedStatus))
-        {
-            throw new InvalidOperationException($"Invalid access token result status '{result.Status ?? "(null)"}'");
-        }
-
-        if (parsedStatus == AccessTokenResultStatus.RequiresRedirect)
-        {
-            var redirectUrl = GetRedirectUrl(options.ReturnUrl);
-            result.RedirectUrl = redirectUrl.ToString();
-        }
-
-        return new AccessTokenResult(parsedStatus, result.Token, result.RedirectUrl);
+        return new AccessTokenResult(
+            result.Status,
+            result.Token,
+            result.Status == AccessTokenResultStatus.RequiresRedirect ? GetRedirectUrl(options.ReturnUrl).ToString() : null);
     }
 
     private Uri GetRedirectUrl(string customReturnUrl)
@@ -234,6 +195,15 @@ public class RemoteAuthenticationService<
             _initialized = true;
         }
     }
+    private async Task UpdateUserOnSuccess(RemoteAuthenticationResult<TRemoteAuthenticationState> result)
+    {
+        if (result.Status == RemoteAuthenticationStatus.Success)
+        {
+            var getUserTask = GetUser();
+            await getUserTask;
+            UpdateUser(getUserTask);
+        }
+    }
 
     private void UpdateUser(Task<ClaimsPrincipal> task)
     {
@@ -244,37 +214,4 @@ public class RemoteAuthenticationService<
 }
 
 // Internal for testing purposes
-internal struct InternalAccessTokenResult
-{
-    public string Status { get; set; }
-    public AccessToken Token { get; set; }
-    public string RedirectUrl { get; set; }
-}
-
-// Internal for testing purposes
-internal struct InternalRemoteAuthenticationResult<TRemoteAuthenticationState> where TRemoteAuthenticationState : RemoteAuthenticationState
-{
-    public string Status { get; set; }
-
-    public string ErrorMessage { get; set; }
-
-    public TRemoteAuthenticationState State { get; set; }
-
-    public RemoteAuthenticationResult<TRemoteAuthenticationState> Convert()
-    {
-        var result = new RemoteAuthenticationResult<TRemoteAuthenticationState>();
-        result.ErrorMessage = ErrorMessage;
-        result.State = State;
-
-        if (Status != null && Enum.TryParse<RemoteAuthenticationStatus>(Status, ignoreCase: true, out var status))
-        {
-            result.Status = status;
-        }
-        else
-        {
-            throw new InvalidOperationException($"Can't convert status '${Status ?? "(null)"}'.");
-        }
-
-        return result;
-    }
-}
+internal record struct InternalAccessTokenResult([property: JsonConverter(typeof(JsonStringEnumConverter))] AccessTokenResultStatus Status, AccessToken Token);
diff --git a/src/Components/WebAssembly/WebAssembly.Authentication/test/RemoteAuthenticationServiceTests.cs b/src/Components/WebAssembly/WebAssembly.Authentication/test/RemoteAuthenticationServiceTests.cs
index 595310ff1d5..fb9808828fe 100644
--- a/src/Components/WebAssembly/WebAssembly.Authentication/test/RemoteAuthenticationServiceTests.cs
+++ b/src/Components/WebAssembly/WebAssembly.Authentication/test/RemoteAuthenticationServiceTests.cs
@@ -25,10 +25,10 @@ public class RemoteAuthenticationServiceTests
             new AccountClaimsPrincipalFactory<RemoteUserAccount>(Mock.Of<IAccessTokenProviderAccessor>()));
 
         var state = new RemoteAuthenticationState();
-        testJsRuntime.SignInResult = new InternalRemoteAuthenticationResult<RemoteAuthenticationState>
+        testJsRuntime.SignInResult = new RemoteAuthenticationResult<RemoteAuthenticationState>
         {
             State = state,
-            Status = RemoteAuthenticationStatus.Success.ToString()
+            Status = RemoteAuthenticationStatus.Success
         };
 
         // Act
@@ -56,9 +56,9 @@ public class RemoteAuthenticationServiceTests
             new AccountClaimsPrincipalFactory<RemoteUserAccount>(Mock.Of<IAccessTokenProviderAccessor>()));
 
         var state = new RemoteAuthenticationState();
-        testJsRuntime.SignInResult = new InternalRemoteAuthenticationResult<RemoteAuthenticationState>
+        testJsRuntime.SignInResult = new RemoteAuthenticationResult<RemoteAuthenticationState>
         {
-            Status = value.ToString()
+            Status = value
         };
 
         // Act
@@ -83,10 +83,10 @@ public class RemoteAuthenticationServiceTests
             new AccountClaimsPrincipalFactory<RemoteUserAccount>(Mock.Of<IAccessTokenProviderAccessor>()));
 
         var state = new RemoteAuthenticationState();
-        testJsRuntime.CompleteSignInResult = new InternalRemoteAuthenticationResult<RemoteAuthenticationState>
+        testJsRuntime.CompleteSignInResult = new RemoteAuthenticationResult<RemoteAuthenticationState>
         {
             State = state,
-            Status = RemoteAuthenticationStatus.Success.ToString()
+            Status = RemoteAuthenticationStatus.Success
         };
 
         // Act
@@ -114,9 +114,9 @@ public class RemoteAuthenticationServiceTests
             new AccountClaimsPrincipalFactory<RemoteUserAccount>(Mock.Of<IAccessTokenProviderAccessor>()));
 
         var state = new RemoteAuthenticationState();
-        testJsRuntime.CompleteSignInResult = new InternalRemoteAuthenticationResult<RemoteAuthenticationState>
+        testJsRuntime.CompleteSignInResult = new RemoteAuthenticationResult<RemoteAuthenticationState>
         {
-            Status = value.ToString().ToString()
+            Status = value
         };
 
         // Act
@@ -141,10 +141,10 @@ public class RemoteAuthenticationServiceTests
             new AccountClaimsPrincipalFactory<RemoteUserAccount>(Mock.Of<IAccessTokenProviderAccessor>()));
 
         var state = new RemoteAuthenticationState();
-        testJsRuntime.SignOutResult = new InternalRemoteAuthenticationResult<RemoteAuthenticationState>
+        testJsRuntime.SignOutResult = new RemoteAuthenticationResult<RemoteAuthenticationState>
         {
             State = state,
-            Status = RemoteAuthenticationStatus.Success.ToString()
+            Status = RemoteAuthenticationStatus.Success
         };
 
         // Act
@@ -172,9 +172,9 @@ public class RemoteAuthenticationServiceTests
             new AccountClaimsPrincipalFactory<RemoteUserAccount>(Mock.Of<IAccessTokenProviderAccessor>()));
 
         var state = new RemoteAuthenticationState();
-        testJsRuntime.SignOutResult = new InternalRemoteAuthenticationResult<RemoteAuthenticationState>
+        testJsRuntime.SignOutResult = new RemoteAuthenticationResult<RemoteAuthenticationState>
         {
-            Status = value.ToString()
+            Status = value
         };
 
         // Act
@@ -199,10 +199,10 @@ public class RemoteAuthenticationServiceTests
             new AccountClaimsPrincipalFactory<RemoteUserAccount>(Mock.Of<IAccessTokenProviderAccessor>()));
 
         var state = new RemoteAuthenticationState();
-        testJsRuntime.CompleteSignOutResult = new InternalRemoteAuthenticationResult<RemoteAuthenticationState>
+        testJsRuntime.CompleteSignOutResult = new RemoteAuthenticationResult<RemoteAuthenticationState>
         {
             State = state,
-            Status = RemoteAuthenticationStatus.Success.ToString()
+            Status = RemoteAuthenticationStatus.Success
         };
 
         // Act
@@ -230,9 +230,9 @@ public class RemoteAuthenticationServiceTests
             new AccountClaimsPrincipalFactory<RemoteUserAccount>(Mock.Of<IAccessTokenProviderAccessor>()));
 
         var state = new RemoteAuthenticationState();
-        testJsRuntime.CompleteSignOutResult = new InternalRemoteAuthenticationResult<RemoteAuthenticationState>
+        testJsRuntime.CompleteSignOutResult = new RemoteAuthenticationResult<RemoteAuthenticationState>
         {
-            Status = value.ToString()
+            Status = value
         };
 
         // Act
@@ -259,7 +259,7 @@ public class RemoteAuthenticationServiceTests
         var state = new RemoteAuthenticationState();
         testJsRuntime.GetAccessTokenResult = new InternalAccessTokenResult
         {
-            Status = "success",
+            Status = AccessTokenResultStatus.Success,
             Token = new AccessToken
             {
                 Value = "1234",
@@ -277,8 +277,7 @@ public class RemoteAuthenticationServiceTests
             testJsRuntime.PastInvocations.Select(i => i.identifier).ToArray());
 
         Assert.True(result.TryGetToken(out var token));
-        Assert.Equal(result.Status, Enum.Parse<AccessTokenResultStatus>(testJsRuntime.GetAccessTokenResult.Status, ignoreCase: true));
-        Assert.Equal(result.RedirectUrl, testJsRuntime.GetAccessTokenResult.RedirectUrl);
+        Assert.Equal(result.Status, testJsRuntime.GetAccessTokenResult.Status);
         Assert.Equal(token, testJsRuntime.GetAccessTokenResult.Token);
     }
 
@@ -295,10 +294,7 @@ public class RemoteAuthenticationServiceTests
             new AccountClaimsPrincipalFactory<RemoteUserAccount>(Mock.Of<IAccessTokenProviderAccessor>()));
 
         var state = new RemoteAuthenticationState();
-        testJsRuntime.GetAccessTokenResult = new InternalAccessTokenResult
-        {
-            Status = "requiresRedirect",
-        };
+        testJsRuntime.GetAccessTokenResult = new InternalAccessTokenResult(AccessTokenResultStatus.RequiresRedirect, null);
 
         var tokenOptions = new AccessTokenRequestOptions
         {
@@ -317,7 +313,7 @@ public class RemoteAuthenticationServiceTests
 
         Assert.False(result.TryGetToken(out var token));
         Assert.Null(token);
-        Assert.Equal(result.Status, Enum.Parse<AccessTokenResultStatus>(testJsRuntime.GetAccessTokenResult.Status, ignoreCase: true));
+        Assert.Equal(result.Status, testJsRuntime.GetAccessTokenResult.Status);
         Assert.Equal(expectedRedirectUrl, result.RedirectUrl);
         Assert.Equal(tokenOptions, (AccessTokenRequestOptions)testJsRuntime.PastInvocations[^1].args[0]);
     }
@@ -337,7 +333,7 @@ public class RemoteAuthenticationServiceTests
         var state = new RemoteAuthenticationState();
         testJsRuntime.GetAccessTokenResult = new InternalAccessTokenResult
         {
-            Status = "requiresRedirect",
+            Status = AccessTokenResultStatus.RequiresRedirect,
         };
 
         var tokenOptions = new AccessTokenRequestOptions
@@ -358,7 +354,7 @@ public class RemoteAuthenticationServiceTests
 
         Assert.False(result.TryGetToken(out var token));
         Assert.Null(token);
-        Assert.Equal(result.Status, Enum.Parse<AccessTokenResultStatus>(testJsRuntime.GetAccessTokenResult.Status, ignoreCase: true));
+        Assert.Equal(result.Status, testJsRuntime.GetAccessTokenResult.Status);
         Assert.Equal(expectedRedirectUrl, result.RedirectUrl);
         Assert.Equal(tokenOptions, (AccessTokenRequestOptions)testJsRuntime.PastInvocations[^1].args[0]);
     }
@@ -450,7 +446,7 @@ public class RemoteAuthenticationServiceTests
         testJsRuntime.GetUserResult = account;
         testJsRuntime.GetAccessTokenResult = new InternalAccessTokenResult
         {
-            Status = "success",
+            Status = AccessTokenResultStatus.Success,
             Token = new AccessToken
             {
                 Value = "1234",
@@ -509,13 +505,13 @@ public class RemoteAuthenticationServiceTests
     {
         public IList<(string identifier, object[] args)> PastInvocations { get; set; } = new List<(string, object[])>();
 
-        public InternalRemoteAuthenticationResult<RemoteAuthenticationState> SignInResult { get; set; }
+        public RemoteAuthenticationResult<RemoteAuthenticationState> SignInResult { get; set; }
 
-        public InternalRemoteAuthenticationResult<RemoteAuthenticationState> CompleteSignInResult { get; set; }
+        public RemoteAuthenticationResult<RemoteAuthenticationState> CompleteSignInResult { get; set; }
 
-        public InternalRemoteAuthenticationResult<RemoteAuthenticationState> SignOutResult { get; set; }
+        public RemoteAuthenticationResult<RemoteAuthenticationState> SignOutResult { get; set; }
 
-        public InternalRemoteAuthenticationResult<RemoteAuthenticationState> CompleteSignOutResult { get; set; }
+        public RemoteAuthenticationResult<RemoteAuthenticationState> CompleteSignOutResult { get; set; }
 
         public InternalAccessTokenResult GetAccessTokenResult { get; set; }
 
-- 
GitLab