diff --git a/src/SignalR/common/Shared/ClientResultsManager.cs b/src/SignalR/common/Shared/ClientResultsManager.cs index 12544fb649dddb1dd2075ca38bc4713f8271430d..f8c9e25d9fad7b853b7a0246d8d30f218294295b 100644 --- a/src/SignalR/common/Shared/ClientResultsManager.cs +++ b/src/SignalR/common/Shared/ClientResultsManager.cs @@ -42,6 +42,11 @@ internal sealed class ClientResultsManager : IInvocationBinder { var result = _pendingInvocations.TryAdd(invocationId, invocationInfo); Debug.Assert(result); + // Should have a 50% chance of happening once every 2.71 quintillion invocations (see UUID in Wikipedia) + if (!result) + { + invocationInfo.Complete(invocationInfo.Tcs, CompletionMessage.WithError(invocationId, "ID collision occurred when using client results. This is likely a bug in SignalR.")); + } } public void TryCompleteResult(string connectionId, CompletionMessage message) diff --git a/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs b/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs index fcaad6c86345d6575da31193a64eb3c9390aec16..3e26bc244f5879c09fc63c02332613dd514fcbdf 100644 --- a/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs +++ b/src/SignalR/server/Specification.Tests/src/ScaleoutHubLifetimeManagerTests.cs @@ -588,25 +588,22 @@ public abstract class ScaleoutHubLifetimeManagerTests<TBackplane> : HubLifetimeM var manager1 = CreateNewHubLifetimeManager(backplane); var manager2 = CreateNewHubLifetimeManager(backplane); - using (var client1 = new TestClient()) - using (var client2 = new TestClient()) + using (var client = new TestClient()) { - var connection1 = HubConnectionContextUtils.Create(client1.Connection); - var connection2 = HubConnectionContextUtils.Create(client2.Connection); + var connection = HubConnectionContextUtils.Create(client.Connection); - await manager1.OnConnectedAsync(connection1).DefaultTimeout(); - await manager2.OnConnectedAsync(connection2).DefaultTimeout(); + await manager1.OnConnectedAsync(connection).DefaultTimeout(); - var invoke1 = manager1.InvokeConnectionAsync<int>(connection2.ConnectionId, "Result", new object[] { "test" }, cancellationToken: default); - var invocation2 = Assert.IsType<InvocationMessage>(await client2.ReadAsync().DefaultTimeout()); + var invoke1 = manager1.InvokeConnectionAsync<int>(connection.ConnectionId, "Result", new object[] { "test" }, cancellationToken: default); + var invocation2 = Assert.IsType<InvocationMessage>(await client.ReadAsync().DefaultTimeout()); - var invoke2 = manager2.InvokeConnectionAsync<int>(connection1.ConnectionId, "Result", new object[] { "test" }, cancellationToken: default); - var invocation1 = Assert.IsType<InvocationMessage>(await client1.ReadAsync().DefaultTimeout()); + var invoke2 = manager2.InvokeConnectionAsync<int>(connection.ConnectionId, "Result", new object[] { "test" }, cancellationToken: default); + var invocation1 = Assert.IsType<InvocationMessage>(await client.ReadAsync().DefaultTimeout()); Assert.NotEqual(invocation1.InvocationId, invocation2.InvocationId); - await manager1.SetConnectionResultAsync(connection2.ConnectionId, CompletionMessage.WithResult(invocation2.InvocationId, 2)).DefaultTimeout(); - await manager2.SetConnectionResultAsync(connection1.ConnectionId, CompletionMessage.WithResult(invocation1.InvocationId, 5)).DefaultTimeout(); + await manager1.SetConnectionResultAsync(connection.ConnectionId, CompletionMessage.WithResult(invocation2.InvocationId, 2)).DefaultTimeout(); + await manager2.SetConnectionResultAsync(connection.ConnectionId, CompletionMessage.WithResult(invocation1.InvocationId, 5)).DefaultTimeout(); var res = await invoke1.DefaultTimeout(); Assert.Equal(2, res); diff --git a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisChannels.cs b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisChannels.cs index a0793091382e36f7263bd2311bc85912ff4cc9a8..676e07bc7f1a10f44f69cb3eea4c08b191071630 100644 --- a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisChannels.cs +++ b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisChannels.cs @@ -23,12 +23,18 @@ internal sealed class RedisChannels /// </summary> public string GroupManagement { get; } - public RedisChannels(string prefix) + /// <summary> + /// Gets the name of the internal channel for receiving client results. + /// </summary> + public string ReturnResults { get; } + + public RedisChannels(string prefix, string serverName) { _prefix = prefix; All = prefix + ":all"; GroupManagement = prefix + ":internal:groups"; + ReturnResults = _prefix + ":internal:return:" + serverName; } /// <summary> @@ -71,15 +77,4 @@ internal sealed class RedisChannels { return _prefix + ":internal:ack:" + serverName; } - - /// <summary> - /// Gets the name of the client return results channel for the specified server. - /// </summary> - /// <param name="serverName">The name of the server to get the client return results channel for.</param> - /// <returns></returns> - [MethodImpl(MethodImplOptions.AggressiveInlining)] - public string ReturnResults(string serverName) - { - return _prefix + ":internal:return:" + serverName; - } } diff --git a/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs b/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs index 1d8397272ab4f2c3a92b396c7a0c04163c46f8fa..6748720a7e103df5f47112039c6f19383b16a79d 100644 --- a/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs +++ b/src/SignalR/server/StackExchangeRedis/src/RedisHubLifetimeManager.cs @@ -39,7 +39,6 @@ public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposab private readonly AckHandler _ackHandler; private int _internalAckId; - private ulong _lastInvocationId; /// <summary> /// Constructs the <see cref="RedisHubLifetimeManager{THub}"/> with types from Dependency Injection. @@ -72,7 +71,7 @@ public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposab _logger = logger; _options = options.Value; _ackHandler = new AckHandler(); - _channels = new RedisChannels(typeof(THub).FullName!); + _channels = new RedisChannels(typeof(THub).FullName!, _serverName); if (globalHubOptions != null && hubOptions != null) { _protocol = new RedisProtocol(new DefaultHubMessageSerializer(hubProtocolResolver, globalHubOptions.Value.SupportedProtocols, hubOptions.Value.SupportedProtocols)); @@ -416,8 +415,8 @@ public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposab var connection = _connections[connectionId]; - // Needs to be unique across servers, easiest way to do that is prefix with connection ID. - var invocationId = $"{connectionId}{Interlocked.Increment(ref _lastInvocationId)}"; + // ID needs to be unique for each invocation and across servers, we generate a GUID every time, that should provide enough uniqueness guarantees. + var invocationId = GenerateInvocationId(); using var _ = CancellationTokenUtils.CreateLinkedToken(cancellationToken, connection?.ConnectionAborted ?? default, out var linkedToken); @@ -428,7 +427,7 @@ public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposab if (connection == null) { // TODO: Need to handle other server going away while waiting for connection result - var messageBytes = _protocol.WriteInvocation(methodName, args, invocationId, returnChannel: _channels.ReturnResults(_serverName)); + var messageBytes = _protocol.WriteInvocation(methodName, args, invocationId, returnChannel: _channels.ReturnResults); var received = await PublishAsync(_channels.Connection(connectionId), messageBytes); if (received < 1) { @@ -674,7 +673,7 @@ public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposab private async Task SubscribeToReturnResultsAsync() { - var channel = await _bus!.SubscribeAsync(_channels.ReturnResults(_serverName)); + var channel = await _bus!.SubscribeAsync(_channels.ReturnResults); channel.OnMessage((channelMessage) => { var completion = RedisProtocol.ReadCompletion(channelMessage.Message); @@ -700,6 +699,7 @@ public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposab Debug.Assert(parseSuccess); var invocationInfo = _clientResultsManager.RemoveInvocation(((CompletionMessage)hubMessage!).InvocationId!); + invocationInfo?.Completion(invocationInfo?.Tcs!, (CompletionMessage)hubMessage!); }); } @@ -784,6 +784,21 @@ public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposab return $"{Environment.MachineName}_{Guid.NewGuid():N}"; } + private static string GenerateInvocationId() + { + Span<byte> buffer = stackalloc byte[16]; + var success = Guid.NewGuid().TryWriteBytes(buffer); + Debug.Assert(success); + // 16 * 4/3 = 21.333 which means base64 encoding will use 22 characters of actual data and 2 characters of padding ('=') + Span<char> base64 = stackalloc char[24]; + success = Convert.TryToBase64Chars(buffer, base64, out var written); + Debug.Assert(success); + Debug.Assert(written == 24); + // Trim the two '==' + Debug.Assert(base64.EndsWith("==")); + return new string(base64[..^2]); + } + private sealed class LoggerTextWriter : TextWriter { private readonly ILogger _logger;