diff --git a/samples/ChatSample/Hubs/Chat.cs b/samples/ChatSample/Hubs/Chat.cs index 4ace1387c50a3b260868b7ded7e76a2cd155ae9c..ee41f2c1a77dbd4daf30c886e355767ff24f57c0 100644 --- a/samples/ChatSample/Hubs/Chat.cs +++ b/samples/ChatSample/Hubs/Chat.cs @@ -16,7 +16,7 @@ namespace ChatSample.Hubs { if (!Context.User.Identity.IsAuthenticated) { - Context.Connection.Transport.Dispose(); + Context.Connection.Dispose(); } return Task.CompletedTask; diff --git a/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs b/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs index e48fa16229f3518b4f6563b715af79a070520ed3..39bc889042c830f4acb8f48263a9d40033081676 100644 --- a/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs +++ b/samples/SocialWeather/PersistentConnectionLifeTimeManager.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.IO; using System.IO.Pipelines; using System.Linq; using System.Threading.Tasks; @@ -13,19 +14,19 @@ namespace SocialWeather public class PersistentConnectionLifeTimeManager { private readonly FormatterResolver _formatterResolver; - private readonly ConnectionList<StreamingConnection> _connectionList = new ConnectionList<StreamingConnection>(); + private readonly ConnectionList _connectionList = new ConnectionList(); public PersistentConnectionLifeTimeManager(FormatterResolver formatterResolver) { _formatterResolver = formatterResolver; } - public void OnConnectedAsync(StreamingConnection connection) + public void OnConnectedAsync(Connection connection) { _connectionList.Add(connection); } - public void OnDisconnectedAsync(StreamingConnection connection) + public void OnDisconnectedAsync(Connection connection) { _connectionList.Remove(connection); } @@ -35,7 +36,10 @@ namespace SocialWeather foreach (var connection in _connectionList) { var formatter = _formatterResolver.GetFormatter<T>(connection.Metadata.Get<string>("formatType")); - await formatter.WriteAsync(data, connection.Transport.GetStream()); + var ms = new MemoryStream(); + await formatter.WriteAsync(data, ms); + var buffer = ReadableBuffer.Create(ms.ToArray()).Preserve(); + await connection.Transport.Output.WriteAsync(new Message(buffer, Format.Binary, endOfMessage: true)); } } @@ -54,7 +58,7 @@ namespace SocialWeather throw new NotImplementedException(); } - public void AddGroupAsync(StreamingConnection connection, string groupName) + public void AddGroupAsync(Connection connection, string groupName) { var groups = connection.Metadata.GetOrAdd("groups", _ => new HashSet<string>()); lock (groups) @@ -63,7 +67,7 @@ namespace SocialWeather } } - public void RemoveGroupAsync(StreamingConnection connection, string groupName) + public void RemoveGroupAsync(Connection connection, string groupName) { var groups = connection.Metadata.Get<HashSet<string>>("groups"); if (groups != null) diff --git a/samples/SocialWeather/SocialWeatherEndPoint.cs b/samples/SocialWeather/SocialWeatherEndPoint.cs index d838e66d6602e4b54067c46cdcaea729c982653f..659f2aea7f8e7f0795e88f8ade1d16b95c971f66 100644 --- a/samples/SocialWeather/SocialWeatherEndPoint.cs +++ b/samples/SocialWeather/SocialWeatherEndPoint.cs @@ -1,6 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.IO; using System.IO.Pipelines; using System.Threading.Tasks; using Microsoft.AspNetCore.Sockets; @@ -8,7 +9,7 @@ using Microsoft.Extensions.Logging; namespace SocialWeather { - public class SocialWeatherEndPoint : StreamingEndPoint + public class SocialWeatherEndPoint : EndPoint { private readonly PersistentConnectionLifeTimeManager _lifetimeManager; private readonly FormatterResolver _formatterResolver; @@ -22,22 +23,24 @@ namespace SocialWeather _logger = logger; } - public async override Task OnConnectedAsync(StreamingConnection connection) + public async override Task OnConnectedAsync(Connection connection) { _lifetimeManager.OnConnectedAsync(connection); await ProcessRequests(connection); _lifetimeManager.OnDisconnectedAsync(connection); } - public async Task ProcessRequests(StreamingConnection connection) + public async Task ProcessRequests(Connection connection) { - var stream = connection.Transport.GetStream(); var formatter = _formatterResolver.GetFormatter<WeatherReport>( connection.Metadata.Get<string>("formatType")); - WeatherReport weatherReport; - while ((weatherReport = await formatter.ReadAsync(stream)) != null) + while (true) { + Message message = await connection.Transport.Input.ReadAsync(); + var stream = new MemoryStream(); + await message.Payload.Buffer.CopyToAsync(stream); + WeatherReport weatherReport = await formatter.ReadAsync(stream); await _lifetimeManager.SendToAllAsync(weatherReport); } } diff --git a/samples/SocketsSample/EndPoints/ChatEndPoint.cs b/samples/SocketsSample/EndPoints/ChatEndPoint.cs deleted file mode 100644 index 39117504f053ce810bf430afd27ef97edde1720b..0000000000000000000000000000000000000000 --- a/samples/SocketsSample/EndPoints/ChatEndPoint.cs +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Collections.Generic; -using System.IO.Pipelines; -using System.Linq; -using System.Text; -using System.Threading.Tasks; -using Microsoft.AspNetCore.Sockets; - -namespace SocketsSample -{ - public class ChatEndPoint : StreamingEndPoint - { - public ConnectionList<StreamingConnection> Connections { get; } = new ConnectionList<StreamingConnection>(); - - public override async Task OnConnectedAsync(StreamingConnection connection) - { - Connections.Add(connection); - - await Broadcast($"{connection.ConnectionId} connected ({connection.Metadata["transport"]})"); - - while (true) - { - var result = await connection.Transport.Input.ReadAsync(); - var input = result.Buffer; - try - { - if (input.IsEmpty && result.IsCompleted) - { - break; - } - - // We can avoid the copy here but we'll deal with that later - await Broadcast(input.ToArray()); - } - finally - { - connection.Transport.Input.Advance(input.End); - } - } - - Connections.Remove(connection); - - await Broadcast($"{connection.ConnectionId} disconnected ({connection.Metadata["transport"]})"); - } - - private Task Broadcast(string text) - { - return Broadcast(Encoding.UTF8.GetBytes(text)); - } - - private Task Broadcast(byte[] payload) - { - var tasks = new List<Task>(Connections.Count); - - foreach (var c in Connections) - { - tasks.Add(c.Transport.Output.WriteAsync(payload)); - } - - return Task.WhenAll(tasks); - } - } - -} diff --git a/samples/SocketsSample/EndPoints/MessagesEndPoint.cs b/samples/SocketsSample/EndPoints/MessagesEndPoint.cs index d0afa43d54649f0f294e961b73083583a0b1d59c..1ef2a6d811d29c3706cc0dd41f3d9f813321e960 100644 --- a/samples/SocketsSample/EndPoints/MessagesEndPoint.cs +++ b/samples/SocketsSample/EndPoints/MessagesEndPoint.cs @@ -12,11 +12,11 @@ using Microsoft.AspNetCore.Sockets; namespace SocketsSample.EndPoints { - public class MessagesEndPoint : MessagingEndPoint + public class MessagesEndPoint : EndPoint { - public ConnectionList<MessagingConnection> Connections { get; } = new ConnectionList<MessagingConnection>(); + public ConnectionList Connections { get; } = new ConnectionList(); - public override async Task OnConnectedAsync(MessagingConnection connection) + public override async Task OnConnectedAsync(Connection connection) { Connections.Add(connection); @@ -24,19 +24,19 @@ namespace SocketsSample.EndPoints try { - while (true) + while (await connection.Transport.Input.WaitToReadAsync()) { - using (var message = await connection.Transport.Input.ReadAsync()) + Message message; + if (connection.Transport.Input.TryRead(out message)) { - // We can avoid the copy here but we'll deal with that later - await Broadcast(message.Payload.Buffer, message.MessageFormat, message.EndOfMessage); + using (message) + { + // We can avoid the copy here but we'll deal with that later + await Broadcast(message.Payload.Buffer, message.MessageFormat, message.EndOfMessage); + } } } } - catch (Exception ex) when (ex.GetType().IsNested && ex.GetType().DeclaringType == typeof(Channel)) - { - // Gross that we have to catch this this way. See https://github.com/dotnet/corefxlab/issues/1068 - } finally { Connections.Remove(connection); diff --git a/samples/SocketsSample/Startup.cs b/samples/SocketsSample/Startup.cs index a71b54d8bd61df87ee1ee66c61f01591369e07b3..7ff36396811f89138da79f1c645b397540902ada 100644 --- a/samples/SocketsSample/Startup.cs +++ b/samples/SocketsSample/Startup.cs @@ -29,7 +29,6 @@ namespace SocketsSample }); // .AddRedis(); - services.AddSingleton<ChatEndPoint>(); services.AddSingleton<MessagesEndPoint>(); services.AddSingleton<ProtobufSerializer>(); } @@ -53,8 +52,7 @@ namespace SocketsSample app.UseSockets(routes => { - routes.MapEndpoint<ChatEndPoint>("/chat"); - routes.MapEndpoint<MessagesEndPoint>("/msgs"); + routes.MapEndpoint<MessagesEndPoint>("/chat"); }); } } diff --git a/samples/SocketsSample/wwwroot/index.html b/samples/SocketsSample/wwwroot/index.html index bf8ed6f0ccfd2febd1aff2cc0ff17604b5d08ebc..6bf0d73183d8fe242e7e659f360c7a371a02f7a7 100644 --- a/samples/SocketsSample/wwwroot/index.html +++ b/samples/SocketsSample/wwwroot/index.html @@ -6,18 +6,12 @@ </head> <body> <h1>ASP.NET Sockets</h1> - <h2>Streaming</h2> + <h2>Messaging</h2> <ul> <li><a href="sse.html#/chat">Server Sent Events</a></li> <li><a href="polling.html#/chat">Long polling</a></li> <li><a href="ws.html#/chat">Web Sockets</a></li> </ul> - <h2>Messaging</h2> - <ul> - <li><a href="sse.html#/msgs">Server Sent Events</a></li> - <li><a href="polling.html#/msgs">Long polling</a></li> - <li><a href="ws.html#/msgs">Web Sockets</a></li> - </ul> <h1>ASP.NET SignalR</h1> <ul> <li><a href="hubs.html">Hubs</a></li> diff --git a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs index 5f218b09863553034bbce20bbae1fbf17d85d086..44c6d24922e0ce53d6ee631ff34b9120270373d2 100644 --- a/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR.Redis/RedisHubLifetimeManager.cs @@ -20,7 +20,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis { public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposable { - private readonly ConnectionList<StreamingConnection> _connections = new ConnectionList<StreamingConnection>(); + private readonly ConnectionList _connections = new ConnectionList(); // TODO: Investigate "memory leak" entries never get removed private readonly ConcurrentDictionary<string, GroupData> _groups = new ConcurrentDictionary<string, GroupData>(); private readonly InvocationAdapterRegistry _registry; @@ -51,7 +51,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis foreach (var connection in _connections) { - tasks.Add(connection.Transport.Output.WriteAsync((byte[])data)); + tasks.Add(WriteAsync(connection, data)); } previousBroadcastTask = Task.WhenAll(tasks); @@ -116,7 +116,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis } } - public override Task OnConnectedAsync(StreamingConnection connection) + public override Task OnConnectedAsync(Connection connection) { var redisSubscriptions = connection.Metadata.GetOrAdd("redis_subscriptions", _ => new HashSet<string>()); var connectionTask = TaskCache.CompletedTask; @@ -133,7 +133,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis { await previousConnectionTask; - previousConnectionTask = connection.Transport.Output.WriteAsync((byte[])data); + previousConnectionTask = WriteAsync(connection, data); }); @@ -149,14 +149,14 @@ namespace Microsoft.AspNetCore.SignalR.Redis { await previousUserTask; - previousUserTask = connection.Transport.Output.WriteAsync((byte[])data); + previousUserTask = WriteAsync(connection, data); }); } return Task.WhenAll(connectionTask, userTask); } - public override Task OnDisconnectedAsync(StreamingConnection connection) + public override Task OnDisconnectedAsync(Connection connection) { _connections.Remove(connection); @@ -186,7 +186,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis return Task.WhenAll(tasks); } - public override async Task AddGroupAsync(StreamingConnection connection, string groupName) + public override async Task AddGroupAsync(Connection connection, string groupName) { var groupChannel = typeof(THub).FullName + ".group." + groupName; @@ -220,9 +220,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis await previousTask; var tasks = new List<Task>(group.Connections.Count); - foreach (var groupConnection in group.Connections.Cast<StreamingConnection>()) + foreach (var groupConnection in group.Connections.Cast<Connection>()) { - tasks.Add(groupConnection.Transport.Output.WriteAsync((byte[])data)); + tasks.Add(WriteAsync(groupConnection, data)); } previousTask = Task.WhenAll(tasks); @@ -234,7 +234,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis } } - public override async Task RemoveGroupAsync(StreamingConnection connection, string groupName) + public override async Task RemoveGroupAsync(Connection connection, string groupName) { var groupChannel = typeof(THub).FullName + ".group." + groupName; @@ -275,6 +275,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis _redisServerConnection.Dispose(); } + private Task WriteAsync(Connection connection, byte[] data) + { + var buffer = ReadableBuffer.Create(data).Preserve(); + return connection.Transport.Output.WriteAsync(new Message(buffer, Format.Binary, endOfMessage: true)); + } + private class LoggerTextWriter : TextWriter { private readonly ILogger _logger; @@ -300,7 +306,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis private class GroupData { public SemaphoreSlim Lock = new SemaphoreSlim(1, 1); - public ConnectionList<StreamingConnection> Connections = new ConnectionList<StreamingConnection>(); + public ConnectionList Connections = new ConnectionList(); } } } diff --git a/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs index 50cd69fcb01c17e992bfbb1ec31ad072cd80caeb..91cb6e13e64d136f1ea10e695bfb180a6df65407 100644 --- a/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR/DefaultHubLifetimeManager.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.IO; using System.IO.Pipelines; using System.Threading.Tasks; using Microsoft.AspNetCore.Sockets; @@ -12,7 +13,7 @@ namespace Microsoft.AspNetCore.SignalR { public class DefaultHubLifetimeManager<THub> : HubLifetimeManager<THub> { - private readonly ConnectionList<StreamingConnection> _connections = new ConnectionList<StreamingConnection>(); + private readonly ConnectionList _connections = new ConnectionList(); private readonly InvocationAdapterRegistry _registry; public DefaultHubLifetimeManager(InvocationAdapterRegistry registry) @@ -20,7 +21,7 @@ namespace Microsoft.AspNetCore.SignalR _registry = registry; } - public override Task AddGroupAsync(StreamingConnection connection, string groupName) + public override Task AddGroupAsync(Connection connection, string groupName) { var groups = connection.Metadata.GetOrAdd("groups", _ => new HashSet<string>()); @@ -32,7 +33,7 @@ namespace Microsoft.AspNetCore.SignalR return TaskCache.CompletedTask; } - public override Task RemoveGroupAsync(StreamingConnection connection, string groupName) + public override Task RemoveGroupAsync(Connection connection, string groupName) { var groups = connection.Metadata.Get<HashSet<string>>("groups"); @@ -54,7 +55,7 @@ namespace Microsoft.AspNetCore.SignalR return InvokeAllWhere(methodName, args, c => true); } - private Task InvokeAllWhere(string methodName, object[] args, Func<StreamingConnection, bool> include) + private Task InvokeAllWhere(string methodName, object[] args, Func<Connection, bool> include) { var tasks = new List<Task>(_connections.Count); var message = new InvocationDescriptor @@ -73,7 +74,7 @@ namespace Microsoft.AspNetCore.SignalR var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get<string>("formatType")); - tasks.Add(invocationAdapter.WriteMessageAsync(message, connection.Transport.GetStream())); + tasks.Add(WriteAsync(connection, invocationAdapter, message)); } return Task.WhenAll(tasks); @@ -91,7 +92,7 @@ namespace Microsoft.AspNetCore.SignalR Arguments = args }; - return invocationAdapter.WriteMessageAsync(message, connection.Transport.GetStream()); + return WriteAsync(connection, invocationAdapter, message); } public override Task InvokeGroupAsync(string groupName, string methodName, object[] args) @@ -111,17 +112,24 @@ namespace Microsoft.AspNetCore.SignalR }); } - public override Task OnConnectedAsync(StreamingConnection connection) + public override Task OnConnectedAsync(Connection connection) { _connections.Add(connection); return TaskCache.CompletedTask; } - public override Task OnDisconnectedAsync(StreamingConnection connection) + public override Task OnDisconnectedAsync(Connection connection) { _connections.Remove(connection); return TaskCache.CompletedTask; } - } + private static Task WriteAsync(Connection connection, IInvocationAdapter invocationAdapter, InvocationDescriptor message) + { + var stream = new MemoryStream(); + invocationAdapter.WriteMessageAsync(message, stream); + var buffer = ReadableBuffer.Create(stream.ToArray()).Preserve(); + return connection.Transport.Output.WriteAsync(new Message(buffer, Format.Binary, endOfMessage: true)); + } + } } diff --git a/src/Microsoft.AspNetCore.SignalR/HubCallerContext.cs b/src/Microsoft.AspNetCore.SignalR/HubCallerContext.cs index 6d01a68ae7d6d7e0c7d5161740255962ec9fa194..50df41dc6acb5b0283e65023fff3c66182fead98 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubCallerContext.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubCallerContext.cs @@ -8,12 +8,12 @@ namespace Microsoft.AspNetCore.SignalR { public class HubCallerContext { - public HubCallerContext(StreamingConnection connection) + public HubCallerContext(Connection connection) { Connection = connection; } - public StreamingConnection Connection { get; } + public Connection Connection { get; } public ClaimsPrincipal User => Connection.User; diff --git a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs index b0c184bfba0ef3e474ebf7e0234a83265c7bdef0..28a52668bee0372832c164e900cd2044b615299f 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubEndPoint.cs @@ -3,6 +3,7 @@ using System; using System.Collections.Generic; +using System.IO; using System.IO.Pipelines; using System.Linq; using System.Reflection; @@ -25,10 +26,10 @@ namespace Microsoft.AspNetCore.SignalR } } - public class HubEndPoint<THub, TClient> : StreamingEndPoint, IInvocationBinder where THub : Hub<TClient> + public class HubEndPoint<THub, TClient> : EndPoint, IInvocationBinder where THub : Hub<TClient> { - private readonly Dictionary<string, Func<StreamingConnection, InvocationDescriptor, Task<InvocationResultDescriptor>>> _callbacks - = new Dictionary<string, Func<StreamingConnection, InvocationDescriptor, Task<InvocationResultDescriptor>>>(StringComparer.OrdinalIgnoreCase); + private readonly Dictionary<string, Func<Connection, InvocationDescriptor, Task<InvocationResultDescriptor>>> _callbacks + = new Dictionary<string, Func<Connection, InvocationDescriptor, Task<InvocationResultDescriptor>>>(StringComparer.OrdinalIgnoreCase); private readonly Dictionary<string, Type[]> _paramTypes = new Dictionary<string, Type[]>(); private readonly HubLifetimeManager<THub> _lifetimeManager; @@ -52,7 +53,7 @@ namespace Microsoft.AspNetCore.SignalR DiscoverHubMethods(); } - public override async Task OnConnectedAsync(StreamingConnection connection) + public override async Task OnConnectedAsync(Connection connection) { // TODO: Dispatch from the caller await Task.Yield(); @@ -68,7 +69,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task RunHubAsync(StreamingConnection connection) + private async Task RunHubAsync(Connection connection) { await HubOnConnectedAsync(connection); @@ -86,7 +87,7 @@ namespace Microsoft.AspNetCore.SignalR await HubOnDisconnectedAsync(connection, null); } - private async Task HubOnConnectedAsync(StreamingConnection connection) + private async Task HubOnConnectedAsync(Connection connection) { try { @@ -112,7 +113,7 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task HubOnDisconnectedAsync(StreamingConnection connection, Exception exception) + private async Task HubOnDisconnectedAsync(Connection connection, Exception exception) { try { @@ -138,15 +139,26 @@ namespace Microsoft.AspNetCore.SignalR } } - private async Task DispatchMessagesAsync(StreamingConnection connection) + private async Task DispatchMessagesAsync(Connection connection) { - var stream = connection.Transport.GetStream(); var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get<string>("formatType")); - while (true) + while (await connection.Transport.Input.WaitToReadAsync()) { - // TODO: Handle receiving InvocationResultDescriptor - var invocationDescriptor = await invocationAdapter.ReadMessageAsync(stream, this) as InvocationDescriptor; + Message message; + if (!connection.Transport.Input.TryRead(out message)) + { + continue; + } + + InvocationDescriptor invocationDescriptor; + using (message) + { + var inputStream = new MemoryStream(message.Payload.Buffer.ToArray()); + + // TODO: Handle receiving InvocationResultDescriptor + invocationDescriptor = await invocationAdapter.ReadMessageAsync(inputStream, this) as InvocationDescriptor; + } // Is there a better way of detecting that a connection was closed? if (invocationDescriptor == null) @@ -160,7 +172,7 @@ namespace Microsoft.AspNetCore.SignalR } InvocationResultDescriptor result; - Func<StreamingConnection, InvocationDescriptor, Task<InvocationResultDescriptor>> callback; + Func<Connection, InvocationDescriptor, Task<InvocationResultDescriptor>> callback; if (_callbacks.TryGetValue(invocationDescriptor.Method, out callback)) { result = await callback(connection, invocationDescriptor); @@ -177,11 +189,19 @@ namespace Microsoft.AspNetCore.SignalR _logger.LogError("Unknown hub method '{method}'", invocationDescriptor.Method); } - await invocationAdapter.WriteMessageAsync(result, stream); + // TODO: Pool memory + var outStream = new MemoryStream(); + await invocationAdapter.WriteMessageAsync(result, outStream); + + var buffer = ReadableBuffer.Create(outStream.ToArray()).Preserve(); + if (await connection.Transport.Output.WaitToWriteAsync()) + { + connection.Transport.Output.TryWrite(new Message(buffer, Format.Binary, endOfMessage: true)); + } } } - private void InitializeHub(THub hub, StreamingConnection connection) + private void InitializeHub(THub hub, Connection connection) { hub.Clients = _hubContext.Clients; hub.Context = new HubCallerContext(connection); @@ -290,7 +310,7 @@ namespace Microsoft.AspNetCore.SignalR Type[] types; if (!_paramTypes.TryGetValue(methodName, out types)) { - throw new InvalidOperationException($"The hub method '{methodName}' could not be resolved."); + return Type.EmptyTypes; } return types; } diff --git a/src/Microsoft.AspNetCore.SignalR/HubLifetimeManager.cs b/src/Microsoft.AspNetCore.SignalR/HubLifetimeManager.cs index 7fa770dfc8184472f1c1904432b03e29982305dd..51a2b4d4f7f437686aaa2638c2b19829e9e6636d 100644 --- a/src/Microsoft.AspNetCore.SignalR/HubLifetimeManager.cs +++ b/src/Microsoft.AspNetCore.SignalR/HubLifetimeManager.cs @@ -8,9 +8,9 @@ namespace Microsoft.AspNetCore.SignalR { public abstract class HubLifetimeManager<THub> { - public abstract Task OnConnectedAsync(StreamingConnection connection); + public abstract Task OnConnectedAsync(Connection connection); - public abstract Task OnDisconnectedAsync(StreamingConnection connection); + public abstract Task OnDisconnectedAsync(Connection connection); public abstract Task InvokeAllAsync(string methodName, object[] args); @@ -20,9 +20,9 @@ namespace Microsoft.AspNetCore.SignalR public abstract Task InvokeUserAsync(string userId, string methodName, object[] args); - public abstract Task AddGroupAsync(StreamingConnection connection, string groupName); + public abstract Task AddGroupAsync(Connection connection, string groupName); - public abstract Task RemoveGroupAsync(StreamingConnection connection, string groupName); + public abstract Task RemoveGroupAsync(Connection connection, string groupName); } } diff --git a/src/Microsoft.AspNetCore.SignalR/Proxies.cs b/src/Microsoft.AspNetCore.SignalR/Proxies.cs index c2dc2675ab33515abc55be259b227d7cdbb412ba..dfed4a80ff52cdb5b7016b649cb7558971373366 100644 --- a/src/Microsoft.AspNetCore.SignalR/Proxies.cs +++ b/src/Microsoft.AspNetCore.SignalR/Proxies.cs @@ -75,10 +75,10 @@ namespace Microsoft.AspNetCore.SignalR public class GroupManager<THub> : IGroupManager { - private readonly StreamingConnection _connection; + private readonly Connection _connection; private readonly HubLifetimeManager<THub> _lifetimeManager; - public GroupManager(StreamingConnection connection, HubLifetimeManager<THub> lifetimeManager) + public GroupManager(Connection connection, HubLifetimeManager<THub> lifetimeManager) { _connection = connection; _lifetimeManager = lifetimeManager; diff --git a/src/Microsoft.AspNetCore.Sockets/Connection.cs b/src/Microsoft.AspNetCore.Sockets/Connection.cs index dffd4fe2d9b50df4ab98c7ee748e2e4c4ab3faa2..ab2f0dad5af72cdfbb480153db87e42335cb0a6f 100644 --- a/src/Microsoft.AspNetCore.Sockets/Connection.cs +++ b/src/Microsoft.AspNetCore.Sockets/Connection.cs @@ -3,24 +3,28 @@ using System; using System.Security.Claims; +using System.Threading.Tasks; namespace Microsoft.AspNetCore.Sockets { - public abstract class Connection : IDisposable + public class Connection : IDisposable { - public abstract ConnectionMode Mode { get; } public string ConnectionId { get; } public ClaimsPrincipal User { get; set; } public ConnectionMetadata Metadata { get; } = new ConnectionMetadata(); - protected Connection(string id) + public IChannelConnection<Message> Transport { get; } + + public Connection(string id, IChannelConnection<Message> transport) { + Transport = transport; ConnectionId = id; } - public virtual void Dispose() + public void Dispose() { + Transport.Dispose(); } } } diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionList.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionList.cs index bc5016a3bbdde877865a3dbb83f68b6350408e93..5f5efaadb8ff7e492e10d74df24917aae8fa2842 100644 --- a/src/Microsoft.AspNetCore.Sockets/ConnectionList.cs +++ b/src/Microsoft.AspNetCore.Sockets/ConnectionList.cs @@ -8,15 +8,15 @@ using System.Collections.Generic; namespace Microsoft.AspNetCore.Sockets { - public class ConnectionList<T> : IReadOnlyCollection<T> where T: Connection + public class ConnectionList : IReadOnlyCollection<Connection> { - private readonly ConcurrentDictionary<string, T> _connections = new ConcurrentDictionary<string, T>(); + private readonly ConcurrentDictionary<string, Connection> _connections = new ConcurrentDictionary<string, Connection>(); - public T this[string connectionId] + public Connection this[string connectionId] { get { - T connection; + Connection connection; if (_connections.TryGetValue(connectionId, out connection)) { return connection; @@ -27,18 +27,18 @@ namespace Microsoft.AspNetCore.Sockets public int Count => _connections.Count; - public void Add(T connection) + public void Add(Connection connection) { _connections.TryAdd(connection.ConnectionId, connection); } - public void Remove(T connection) + public void Remove(Connection connection) { - T dummy; + Connection dummy; _connections.TryRemove(connection.ConnectionId, out dummy); } - public IEnumerator<T> GetEnumerator() + public IEnumerator<Connection> GetEnumerator() { foreach (var item in _connections) { diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs index 24f45c1949edb3f7edb3c6be958cb45f872e92b6..e22a171cc2d2e0518f08f3755e34f1461bb660d9 100644 --- a/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs +++ b/src/Microsoft.AspNetCore.Sockets/ConnectionManager.cs @@ -3,8 +3,6 @@ using System; using System.Collections.Concurrent; -using System.Diagnostics; -using System.IO.Pipelines; using System.Threading; using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.Sockets.Internal; @@ -15,11 +13,9 @@ namespace Microsoft.AspNetCore.Sockets { private readonly ConcurrentDictionary<string, ConnectionState> _connections = new ConcurrentDictionary<string, ConnectionState>(); private readonly Timer _timer; - private readonly PipelineFactory _pipelineFactory; - public ConnectionManager(PipelineFactory pipelineFactory) + public ConnectionManager() { - _pipelineFactory = pipelineFactory; _timer = new Timer(Scan, this, 0, 1000); } @@ -28,8 +24,23 @@ namespace Microsoft.AspNetCore.Sockets return _connections.TryGetValue(id, out state); } - public ConnectionState CreateConnection(ConnectionMode mode) => - mode == ConnectionMode.Streaming ? CreateStreamingConnection() : CreateMessagingConnection(); + public ConnectionState CreateConnection() + { + var id = MakeNewConnectionId(); + + var transportToApplication = Channel.CreateUnbounded<Message>(); + var applicationToTransport = Channel.CreateUnbounded<Message>(); + + var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication); + var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport); + + var state = new ConnectionState( + new Connection(id, applicationSide), + transportSide); + + _connections.TryAdd(id, state); + return state; + } public void RemoveConnection(string id) { @@ -92,41 +103,5 @@ namespace Microsoft.AspNetCore.Sockets } } } - - private ConnectionState CreateMessagingConnection() - { - var id = MakeNewConnectionId(); - - var transportToApplication = Channel.Create<Message>(); - var applicationToTransport = Channel.Create<Message>(); - - var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication); - var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport); - - var state = new MessagingConnectionState( - new MessagingConnection(id, applicationSide), - transportSide); - - _connections.TryAdd(id, state); - return state; - } - - private ConnectionState CreateStreamingConnection() - { - var id = MakeNewConnectionId(); - - var transportToApplication = _pipelineFactory.Create(); - var applicationToTransport = _pipelineFactory.Create(); - - var transportSide = new PipelineConnection(applicationToTransport, transportToApplication); - var applicationSide = new PipelineConnection(transportToApplication, applicationToTransport); - - var state = new StreamingConnectionState( - new StreamingConnection(id, applicationSide), - transportSide); - - _connections.TryAdd(id, state); - return state; - } } } diff --git a/src/Microsoft.AspNetCore.Sockets/ConnectionMode.cs b/src/Microsoft.AspNetCore.Sockets/ConnectionMode.cs deleted file mode 100644 index 91644f4159c53c181cc6f43769b24773ea6ead79..0000000000000000000000000000000000000000 --- a/src/Microsoft.AspNetCore.Sockets/ConnectionMode.cs +++ /dev/null @@ -1,11 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -namespace Microsoft.AspNetCore.Sockets -{ - public enum ConnectionMode - { - Streaming, - Messaging - } -} diff --git a/src/Microsoft.AspNetCore.Sockets/EndPoint.cs b/src/Microsoft.AspNetCore.Sockets/EndPoint.cs index 256edfb4beb322bbd550ebc1b65a7633fff115bd..a975839e3481d34048920daad2e304efc1a98038 100644 --- a/src/Microsoft.AspNetCore.Sockets/EndPoint.cs +++ b/src/Microsoft.AspNetCore.Sockets/EndPoint.cs @@ -11,14 +11,6 @@ namespace Microsoft.AspNetCore.Sockets // REVIEW: This doesn't have any members any more... marker interface? Still even necessary? public abstract class EndPoint { - /// <summary> - /// Gets the connection mode supported by this endpoint. - /// </summary> - /// <remarks> - /// This maps directly to whichever of <see cref="MessagingEndPoint"/> or <see cref="StreamingEndPoint"/> the end point subclasses. - /// </remarks> - public abstract ConnectionMode Mode { get; } - /// <summary> /// Called when a new connection is accepted to the endpoint /// </summary> diff --git a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs index c83dac4851774fd2d57cc99f0db516bb5278a716..e4fb33650bf4b654a898d98ee353e622f577d3dd 100644 --- a/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs +++ b/src/Microsoft.AspNetCore.Sockets/HttpConnectionDispatcher.cs @@ -18,14 +18,12 @@ namespace Microsoft.AspNetCore.Sockets public class HttpConnectionDispatcher { private readonly ConnectionManager _manager; - private readonly PipelineFactory _pipelineFactory; private readonly ILoggerFactory _loggerFactory; private readonly ILogger _logger; - public HttpConnectionDispatcher(ConnectionManager manager, PipelineFactory factory, ILoggerFactory loggerFactory) + public HttpConnectionDispatcher(ConnectionManager manager, ILoggerFactory loggerFactory) { _manager = manager; - _pipelineFactory = factory; _loggerFactory = loggerFactory; _logger = _loggerFactory.CreateLogger<HttpConnectionDispatcher>(); } @@ -37,7 +35,7 @@ namespace Microsoft.AspNetCore.Sockets if (context.Request.Path.StartsWithSegments(path + "/getid")) { - await ProcessGetId(context, endpoint.Mode); + await ProcessGetId(context); } else if (context.Request.Path.StartsWithSegments(path + "/send")) { @@ -56,10 +54,10 @@ namespace Microsoft.AspNetCore.Sockets ? Format.Binary : Format.Text; - var state = GetOrCreateConnection(context, endpoint.Mode); + var state = GetOrCreateConnection(context); // Adapt the connection to a message-based transport if necessary, since all the HTTP transports are message-based. - var application = GetMessagingChannel(state, format); + var application = state.Application; // Server sent events transport if (context.Request.Path.StartsWithSegments(path + "/sse")) @@ -137,7 +135,7 @@ namespace Microsoft.AspNetCore.Sockets // Notify the long polling transport to end if (endpointTask.IsFaulted) { - state.TerminateTransport(endpointTask.Exception.InnerException); + state.Connection.Transport.Output.TryComplete(endpointTask.Exception.InnerException); } state.Connection.Dispose(); @@ -151,19 +149,6 @@ namespace Microsoft.AspNetCore.Sockets } } - private static IChannelConnection<Message> GetMessagingChannel(ConnectionState state, Format format) - { - if (state.Connection.Mode == ConnectionMode.Messaging) - { - return ((MessagingConnectionState)state).Application; - } - else - { - // We need to build an adapter - return new FramingChannel(((StreamingConnectionState)state).Application, format); - } - } - private ConnectionState InitializePersistentConnection(ConnectionState state, string transport, HttpContext context, EndPoint endpoint, Format format) { state.Connection.User = context.User; @@ -197,10 +182,10 @@ namespace Microsoft.AspNetCore.Sockets await Task.WhenAll(endpointTask, transportTask); } - private Task ProcessGetId(HttpContext context, ConnectionMode mode) + private Task ProcessGetId(HttpContext context) { // Establish the connection - var state = _manager.CreateConnection(mode); + var state = _manager.CreateConnection(); // Get the bytes for the connection id var connectionIdBuffer = Encoding.UTF8.GetBytes(state.Connection.ConnectionId); @@ -221,34 +206,27 @@ namespace Microsoft.AspNetCore.Sockets ConnectionState state; if (_manager.TryGetConnection(connectionId, out state)) { - if (state.Connection.Mode == ConnectionMode.Streaming) + // Collect the message and write it to the channel + // TODO: Need to use some kind of pooled memory here. + byte[] buffer; + using (var stream = new MemoryStream()) { - var streamingState = (StreamingConnectionState)state; - - await context.Request.Body.CopyToAsync(streamingState.Application.Output); + await context.Request.Body.CopyToAsync(stream); + buffer = stream.ToArray(); } - else - { - // Collect the message and write it to the channel - // TODO: Need to use some kind of pooled memory here. - byte[] buffer; - using (var strm = new MemoryStream()) - { - await context.Request.Body.CopyToAsync(strm); - await strm.FlushAsync(); - buffer = strm.ToArray(); - } - var format = - string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase) - ? Format.Binary - : Format.Text; - var message = new Message( - ReadableBuffer.Create(buffer).Preserve(), - format, - endOfMessage: true); - await ((MessagingConnectionState)state).Application.Output.WriteAsync(message); - } + var format = + string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase) + ? Format.Binary + : Format.Text; + + var message = new Message( + ReadableBuffer.Create(buffer).Preserve(), + format, + endOfMessage: true); + + await state.Application.Output.WriteAsync(message); + } else { @@ -256,7 +234,7 @@ namespace Microsoft.AspNetCore.Sockets } } - private ConnectionState GetOrCreateConnection(HttpContext context, ConnectionMode mode) + private ConnectionState GetOrCreateConnection(HttpContext context) { var connectionId = context.Request.Query["id"]; ConnectionState connectionState; @@ -264,7 +242,7 @@ namespace Microsoft.AspNetCore.Sockets // There's no connection id so this is a brand new connection if (StringValues.IsNullOrEmpty(connectionId)) { - connectionState = _manager.CreateConnection(mode); + connectionState = _manager.CreateConnection(); } else if (!_manager.TryGetConnection(connectionId, out connectionState)) { diff --git a/src/Microsoft.AspNetCore.Sockets/Internal/ChannelConnection.cs b/src/Microsoft.AspNetCore.Sockets/Internal/ChannelConnection.cs index facc91a2c5b43d8daa7722e1114f5cc3d7e4a1ce..07eb6260ff4f7f20020a40744bfae3cb5603a522 100644 --- a/src/Microsoft.AspNetCore.Sockets/Internal/ChannelConnection.cs +++ b/src/Microsoft.AspNetCore.Sockets/Internal/ChannelConnection.cs @@ -2,9 +2,6 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; -using System.Collections.Generic; -using System.Linq; -using System.Threading.Tasks; using System.Threading.Tasks.Channels; namespace Microsoft.AspNetCore.Sockets.Internal @@ -15,7 +12,6 @@ namespace Microsoft.AspNetCore.Sockets.Internal public IChannel<T> Output { get; } IReadableChannel<T> IChannelConnection<T>.Input => Input; - IWritableChannel<T> IChannelConnection<T>.Output => Output; public ChannelConnection(IChannel<T> input, IChannel<T> output) diff --git a/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs b/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs index ff6ad5aad8952057c1274b0630278a4bbf78b168..d360cb2deb61077c816eb349b3054d34f4dbaf6c 100644 --- a/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs +++ b/src/Microsoft.AspNetCore.Sockets/Internal/ConnectionState.cs @@ -5,24 +5,27 @@ using System; namespace Microsoft.AspNetCore.Sockets.Internal { - public abstract class ConnectionState : IDisposable + public class ConnectionState : IDisposable { public Connection Connection { get; set; } - public ConnectionMode Mode => Connection.Mode; + public IChannelConnection<Message> Application { get; } // These are used for long polling mostly public Action Close { get; set; } public DateTime LastSeenUtc { get; set; } public bool Active { get; set; } = true; - protected ConnectionState(Connection connection) + public ConnectionState(Connection connection, IChannelConnection<Message> application) { Connection = connection; + Application = application; LastSeenUtc = DateTime.UtcNow; } - public abstract void Dispose(); - - public abstract void TerminateTransport(Exception innerException); + public void Dispose() + { + Connection.Dispose(); + Application.Dispose(); + } } } diff --git a/src/Microsoft.AspNetCore.Sockets/Internal/FramingChannel.cs b/src/Microsoft.AspNetCore.Sockets/Internal/FramingChannel.cs deleted file mode 100644 index fa791075c2d1bf685276f934594cf0e21afe8214..0000000000000000000000000000000000000000 --- a/src/Microsoft.AspNetCore.Sockets/Internal/FramingChannel.cs +++ /dev/null @@ -1,142 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.IO.Pipelines; -using System.Threading; -using System.Threading.Tasks; -using System.Threading.Tasks.Channels; - -namespace Microsoft.AspNetCore.Sockets.Internal -{ - /// <summary> - /// Creates a <see cref="IChannelConnection{Message}"/> out of a <see cref="IPipelineConnection"/> by framing data - /// read out of the Pipeline and flattening out frames to write them to the Pipeline when received. - /// </summary> - public class FramingChannel : IChannelConnection<Message>, IReadableChannel<Message>, IWritableChannel<Message> - { - private readonly IPipelineConnection _connection; - private readonly TaskCompletionSource<object> _tcs = new TaskCompletionSource<object>(); - private readonly Format _format; - - Task IReadableChannel<Message>.Completion => _tcs.Task; - - public IReadableChannel<Message> Input => this; - public IWritableChannel<Message> Output => this; - - public FramingChannel(IPipelineConnection connection, Format format) - { - _connection = connection; - _format = format; - } - - ValueTask<Message> IReadableChannel<Message>.ReadAsync(CancellationToken cancellationToken) - { - var awaiter = _connection.Input.ReadAsync(); - if (awaiter.IsCompleted) - { - return new ValueTask<Message>(ReadSync(awaiter.GetResult(), cancellationToken)); - } - else - { - return new ValueTask<Message>(AwaitReadAsync(awaiter, cancellationToken)); - } - } - - private void CancelRead() - { - // We need to fake cancellation support until we get a newer build of pipelines that has CancelPendingRead() - - // HACK: from hell, we attempt to cast the input to a pipeline writer and write 0 bytes so it so that we can - // force yielding the awaiter, this is buggy because overlapping writes can be a problem. - (_connection.Input as IPipelineWriter)?.WriteAsync(Span<byte>.Empty); - } - - bool IReadableChannel<Message>.TryRead(out Message item) - { - // We need to think about how we do this. There's no way to check if there is data available in a Pipeline... though maybe there should be - // We could ReadAsync and check IsCompleted, but then we'd also need to stash that Awaitable for later since we can't call ReadAsync a second time... - // CancelPendingReads could help here. - item = default(Message); - return false; - } - - Task<bool> IReadableChannel<Message>.WaitToReadAsync(CancellationToken cancellationToken) - { - // See above for TryRead. Same problems here. - throw new NotSupportedException(); - } - - Task IWritableChannel<Message>.WriteAsync(Message item, CancellationToken cancellationToken) - { - // Just dump the message on to the pipeline - var buffer = _connection.Output.Alloc(); - buffer.Append(item.Payload.Buffer); - return buffer.FlushAsync(); - } - - Task<bool> IWritableChannel<Message>.WaitToWriteAsync(CancellationToken cancellationToken) - { - // We need to think about how we do this. We don't have a wait to synchronously check for back-pressure in the Pipeline. - throw new NotSupportedException(); - } - - bool IWritableChannel<Message>.TryWrite(Message item) - { - // We need to think about how we do this. We don't have a wait to synchronously check for back-pressure in the Pipeline. - return false; - } - - bool IWritableChannel<Message>.TryComplete(Exception error) - { - _connection.Output.Complete(error); - _connection.Input.Complete(error); - return true; - } - - private async Task<Message> AwaitReadAsync(ReadableBufferAwaitable awaiter, CancellationToken cancellationToken) - { - using (cancellationToken.Register(state => ((FramingChannel)state).CancelRead(), this)) - { - // Just await and then call ReadSync - var result = await awaiter; - return ReadSync(result, cancellationToken); - } - } - - private Message ReadSync(ReadResult result, CancellationToken cancellationToken) - { - var buffer = result.Buffer; - - // Preserve the buffer and advance the pipeline past it - var preserved = buffer.Preserve(); - _connection.Input.Advance(buffer.End); - - var msg = new Message(preserved, _format, endOfMessage: true); - - if (result.IsCompleted) - { - // Complete the task - _tcs.TrySetResult(null); - } - - if (cancellationToken.IsCancellationRequested) - { - _tcs.TrySetCanceled(); - - msg.Dispose(); - - // In order to keep the behavior consistent between the transports, we throw if the token was cancelled - throw new OperationCanceledException(); - } - - return msg; - } - - public void Dispose() - { - _tcs.TrySetResult(null); - _connection.Dispose(); - } - } -} diff --git a/src/Microsoft.AspNetCore.Sockets/Internal/MessagingConnectionState.cs b/src/Microsoft.AspNetCore.Sockets/Internal/MessagingConnectionState.cs deleted file mode 100644 index 5deffe7a85becf3cfa3052bed8cb2f08bf632f62..0000000000000000000000000000000000000000 --- a/src/Microsoft.AspNetCore.Sockets/Internal/MessagingConnectionState.cs +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; - -namespace Microsoft.AspNetCore.Sockets.Internal -{ - public class MessagingConnectionState : ConnectionState - { - public new MessagingConnection Connection => (MessagingConnection)base.Connection; - public IChannelConnection<Message> Application { get; } - - public MessagingConnectionState(MessagingConnection connection, IChannelConnection<Message> application) : base(connection) - { - Application = application; - } - - public override void Dispose() - { - Connection.Dispose(); - Application.Dispose(); - } - - public override void TerminateTransport(Exception innerException) - { - Connection.Transport.Output.TryComplete(innerException); - } - } -} diff --git a/src/Microsoft.AspNetCore.Sockets/Internal/PipelineConnection.cs b/src/Microsoft.AspNetCore.Sockets/Internal/PipelineConnection.cs deleted file mode 100644 index 85fc313c3ad29d8af5dde49fffb88df0deffb9b3..0000000000000000000000000000000000000000 --- a/src/Microsoft.AspNetCore.Sockets/Internal/PipelineConnection.cs +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.IO.Pipelines; - -namespace Microsoft.AspNetCore.Sockets.Internal -{ - public class PipelineConnection : IPipelineConnection - { - public PipelineReaderWriter Input { get; } - public PipelineReaderWriter Output { get; } - - IPipelineReader IPipelineConnection.Input => Input; - IPipelineWriter IPipelineConnection.Output => Output; - - public PipelineConnection(PipelineReaderWriter input, PipelineReaderWriter output) - { - Input = input; - Output = output; - } - - public void Dispose() - { - Input.CompleteReader(); - Input.CompleteWriter(); - Output.CompleteReader(); - Output.CompleteWriter(); - } - } -} diff --git a/src/Microsoft.AspNetCore.Sockets/Internal/StreamingConnectionState.cs b/src/Microsoft.AspNetCore.Sockets/Internal/StreamingConnectionState.cs deleted file mode 100644 index dc69b3f18edc233925bac297702b0f937a9d2b16..0000000000000000000000000000000000000000 --- a/src/Microsoft.AspNetCore.Sockets/Internal/StreamingConnectionState.cs +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.IO.Pipelines; - -namespace Microsoft.AspNetCore.Sockets.Internal -{ - public class StreamingConnectionState : ConnectionState - { - public new StreamingConnection Connection => (StreamingConnection)base.Connection; - public IPipelineConnection Application { get; } - - public StreamingConnectionState(StreamingConnection connection, IPipelineConnection application) : base(connection) - { - Application = application; - } - - public override void Dispose() - { - Connection.Dispose(); - Application.Dispose(); - } - - public override void TerminateTransport(Exception innerException) - { - Connection.Transport.Output.Complete(innerException); - Connection.Transport.Input.Complete(innerException); - } - } -} diff --git a/src/Microsoft.AspNetCore.Sockets/MessagingConnection.cs b/src/Microsoft.AspNetCore.Sockets/MessagingConnection.cs deleted file mode 100644 index 6dd086b6f592d64dc64401bf3a019d847adf21de..0000000000000000000000000000000000000000 --- a/src/Microsoft.AspNetCore.Sockets/MessagingConnection.cs +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; - -namespace Microsoft.AspNetCore.Sockets -{ - public class MessagingConnection : Connection - { - public override ConnectionMode Mode => ConnectionMode.Messaging; - public IChannelConnection<Message> Transport { get; } - - public MessagingConnection(string id, IChannelConnection<Message> transport) : base(id) - { - Transport = transport; - } - - public override void Dispose() - { - Transport.Dispose(); - } - } -} diff --git a/src/Microsoft.AspNetCore.Sockets/MessagingEndPoint.cs b/src/Microsoft.AspNetCore.Sockets/MessagingEndPoint.cs deleted file mode 100644 index 4223d94bcfa5c44c04fbd075beb44d3b7e96698f..0000000000000000000000000000000000000000 --- a/src/Microsoft.AspNetCore.Sockets/MessagingEndPoint.cs +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Threading.Tasks; - -namespace Microsoft.AspNetCore.Sockets -{ - public abstract class MessagingEndPoint : EndPoint - { - public override ConnectionMode Mode => ConnectionMode.Messaging; - - public override Task OnConnectedAsync(Connection connection) - { - if (connection.Mode != Mode) - { - throw new InvalidOperationException($"Connection mode does not match endpoint mode. Connection mode is '{connection.Mode}', endpoint mode is '{Mode}'"); - } - return OnConnectedAsync((MessagingConnection)connection); - } - - /// <summary> - /// Called when a new connection is accepted to the endpoint - /// </summary> - /// <param name="connection">The new <see cref="MessagingConnection"/></param> - /// <returns>A <see cref="Task"/> that represents the connection lifetime. When the task completes, the connection is complete.</returns> - public abstract Task OnConnectedAsync(MessagingConnection connection); - } -} diff --git a/src/Microsoft.AspNetCore.Sockets/StreamingConnection.cs b/src/Microsoft.AspNetCore.Sockets/StreamingConnection.cs deleted file mode 100644 index 58dd16555ad1440141625ea608a486541ef88086..0000000000000000000000000000000000000000 --- a/src/Microsoft.AspNetCore.Sockets/StreamingConnection.cs +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System.IO.Pipelines; - -namespace Microsoft.AspNetCore.Sockets -{ - public class StreamingConnection : Connection - { - public override ConnectionMode Mode => ConnectionMode.Streaming; - - public IPipelineConnection Transport { get; set; } - - public StreamingConnection(string id, IPipelineConnection transport) : base(id) - { - Transport = transport; - } - - public override void Dispose() - { - Transport.Dispose(); - } - } -} diff --git a/src/Microsoft.AspNetCore.Sockets/StreamingEndPoint.cs b/src/Microsoft.AspNetCore.Sockets/StreamingEndPoint.cs deleted file mode 100644 index 231f6f55d6f0aa3fa8f9964d61854a563c50e34b..0000000000000000000000000000000000000000 --- a/src/Microsoft.AspNetCore.Sockets/StreamingEndPoint.cs +++ /dev/null @@ -1,29 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Threading.Tasks; - -namespace Microsoft.AspNetCore.Sockets -{ - public abstract class StreamingEndPoint : EndPoint - { - public override ConnectionMode Mode => ConnectionMode.Streaming; - - public override Task OnConnectedAsync(Connection connection) - { - if(connection.Mode != Mode) - { - throw new InvalidOperationException($"Connection mode does not match endpoint mode. Connection mode is '{connection.Mode}', endpoint mode is '{Mode}'"); - } - return OnConnectedAsync((StreamingConnection)connection); - } - - /// <summary> - /// Called when a new connection is accepted to the endpoint - /// </summary> - /// <param name="connection">The new <see cref="StreamingConnection"/></param> - /// <returns>A <see cref="Task"/> that represents the connection lifetime. When the task completes, the connection is complete.</returns> - public abstract Task OnConnectedAsync(StreamingConnection connection); - } -} diff --git a/src/Microsoft.AspNetCore.Sockets/Transports/LongPollingTransport.cs b/src/Microsoft.AspNetCore.Sockets/Transports/LongPollingTransport.cs index db2c7dc7c8145445964951b21e5a0614acaf6234..17ce98e18e09e2b38ac9ef12b0652e8740ac0ce6 100644 --- a/src/Microsoft.AspNetCore.Sockets/Transports/LongPollingTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets/Transports/LongPollingTransport.cs @@ -13,18 +13,18 @@ namespace Microsoft.AspNetCore.Sockets.Transports { public class LongPollingTransport : IHttpTransport { - private readonly IReadableChannel<Message> _connection; + private readonly IReadableChannel<Message> _application; private readonly ILogger _logger; - public LongPollingTransport(IReadableChannel<Message> connection, ILoggerFactory loggerFactory) + public LongPollingTransport(IReadableChannel<Message> application, ILoggerFactory loggerFactory) { - _connection = connection; + _application = application; _logger = loggerFactory.CreateLogger<LongPollingTransport>(); } public async Task ProcessRequestAsync(HttpContext context) { - if (_connection.Completion.IsCompleted) + if (_application.Completion.IsCompleted) { // Client should stop if it receives a 204 _logger.LogInformation("Terminating Long Polling connection by sending 204 response."); @@ -37,18 +37,19 @@ namespace Microsoft.AspNetCore.Sockets.Transports // TODO: We need the ability to yield the connection without completing the channel. // This is to force ReadAsync to yield without data to end to poll but not the entire connection. // This is for cases when the client reconnects see issue #27 - using (var message = await _connection.ReadAsync(context.RequestAborted)) + await _application.WaitToReadAsync(context.RequestAborted); + + Message message; + if (_application.TryRead(out message)) { - _logger.LogDebug("Writing {0} byte message to response", message.Payload.Buffer.Length); - context.Response.ContentLength = message.Payload.Buffer.Length; - await message.Payload.Buffer.CopyToAsync(context.Response.Body); + using (message) + { + _logger.LogDebug("Writing {0} byte message to response", message.Payload.Buffer.Length); + context.Response.ContentLength = message.Payload.Buffer.Length; + await message.Payload.Buffer.CopyToAsync(context.Response.Body); + } } } - catch (Exception ex) when (ex.GetType().IsNested && ex.GetType().DeclaringType == typeof(Channel)) - { - // The Channel was closed, while we were waiting to read. That's fine, just means we're done. - // Gross that we have to catch this this way. See https://github.com/dotnet/corefxlab/issues/1068 - } catch (OperationCanceledException) { // Suppress the exception diff --git a/src/Microsoft.AspNetCore.Sockets/Transports/ServerSentEventsTransport.cs b/src/Microsoft.AspNetCore.Sockets/Transports/ServerSentEventsTransport.cs index e69744d0d1e79b144d49d28d739273bc4d3e0298..88979b5cf08c336bd3e3d2b2e7cefc50e6f49d0a 100644 --- a/src/Microsoft.AspNetCore.Sockets/Transports/ServerSentEventsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets/Transports/ServerSentEventsTransport.cs @@ -30,18 +30,18 @@ namespace Microsoft.AspNetCore.Sockets.Transports try { - while (true) + while (await _application.WaitToReadAsync(context.RequestAborted)) { - using (var message = await _application.ReadAsync(context.RequestAborted)) + Message message; + if (_application.TryRead(out message)) { - await Send(context, message); + using (message) + { + await Send(context, message); + } } } } - catch (Exception ex) when (ex.GetType().IsNested && ex.GetType().DeclaringType == typeof(Channel)) - { - // Gross that we have to catch this this way. See https://github.com/dotnet/corefxlab/issues/1068 - } catch (OperationCanceledException) { // Closed connection diff --git a/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs b/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs index 74fd027f941ab330872eced356c835fd576629f6..07f2b9fd5e040d072e55a8063307b22bf02e51e4 100644 --- a/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs +++ b/src/Microsoft.AspNetCore.Sockets/Transports/WebSocketsTransport.cs @@ -22,20 +22,20 @@ namespace Microsoft.AspNetCore.Sockets.Transports private bool _lastFrameIncomplete = false; private readonly ILogger _logger; - private readonly IChannelConnection<Message> _connection; + private readonly IChannelConnection<Message> _application; - public WebSocketsTransport(IChannelConnection<Message> connection, ILoggerFactory loggerFactory) + public WebSocketsTransport(IChannelConnection<Message> application, ILoggerFactory loggerFactory) { - if (connection == null) + if (application == null) { - throw new ArgumentNullException(nameof(connection)); + throw new ArgumentNullException(nameof(application)); } if (loggerFactory == null) { throw new ArgumentNullException(nameof(loggerFactory)); } - _connection = connection; + _application = application; _logger = loggerFactory.CreateLogger<WebSocketsTransport>(); } @@ -84,7 +84,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports // Shutting down because we received a close frame from the client. // Complete the input writer so that the application knows there won't be any more input. _logger.LogDebug("Client closed connection with status code '{0}' ({1}). Signaling end-of-input to application", receiving.Result.Status, receiving.Result.Description); - _connection.Output.TryComplete(); + _application.Output.TryComplete(); // Wait for the application to finish sending. _logger.LogDebug("Waiting for the application to finish sending data"); @@ -95,7 +95,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports } else { - var failed = sending.IsFaulted || sending.IsCompleted; + var failed = sending.IsFaulted || _application.Input.Completion.IsFaulted; // The application finished sending. Close our end of the connection _logger.LogDebug(!failed ? "Application finished sending. Sending close frame." : "Application failed during sending. Sending InternalServerError close frame"); @@ -109,7 +109,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports // Wait for the client to close. // TODO: Consider timing out here and cancelling the receive loop. await receiving; - _connection.Output.TryComplete(); + _application.Output.TryComplete(); } } @@ -138,7 +138,7 @@ namespace Microsoft.AspNetCore.Sockets.Transports var message = new Message(frame.Payload.Preserve(), effectiveOpcode == WebSocketOpcode.Binary ? Format.Binary : Format.Text, frame.EndOfMessage); // Write the message to the channel - return _connection.Output.WriteAsync(message); + return _application.Output.WriteAsync(message); } private void LogFrame(string action, WebSocketFrame frame) @@ -152,12 +152,13 @@ namespace Microsoft.AspNetCore.Sockets.Transports private async Task StartSending(IWebSocketConnection ws) { - while (true) + while (await _application.Input.WaitToReadAsync()) { // Get a frame from the application - try + Message message; + if (_application.Input.TryRead(out message)) { - using (var message = await _connection.Input.ReadAsync()) + using (message) { if (message.Payload.Buffer.Length > 0) { @@ -185,11 +186,6 @@ namespace Microsoft.AspNetCore.Sockets.Transports } } } - catch (Exception ex) when (ex.GetType().IsNested && ex.GetType().DeclaringType == typeof(Channel)) - { - // Gross that we have to catch this this way. See https://github.com/dotnet/corefxlab/issues/1068 - break; - } } } } diff --git a/src/Microsoft.AspNetCore.Sockets/project.json b/src/Microsoft.AspNetCore.Sockets/project.json index 7ec861d96def8df038c273c5a5a76e68be618c02..e975f6f78ef937bbcd59d6fb1233e48d02f9703c 100644 --- a/src/Microsoft.AspNetCore.Sockets/project.json +++ b/src/Microsoft.AspNetCore.Sockets/project.json @@ -20,7 +20,6 @@ "xmlDoc": true }, "dependencies": { - "System.IO.Pipelines": "0.1.0-*", "System.Threading.Tasks.Channels": "0.1.0-*", "System.Security.Claims": "4.4.0-*", diff --git a/src/Microsoft.Extensions.WebSockets.Internal/WebSocketCloseResult.cs b/src/Microsoft.Extensions.WebSockets.Internal/WebSocketCloseResult.cs index 18d134ab8cb68f638d67b1b4599475b56d376629..f01b2425a9276e5eaf098c4c500248bb7072efe8 100644 --- a/src/Microsoft.Extensions.WebSockets.Internal/WebSocketCloseResult.cs +++ b/src/Microsoft.Extensions.WebSockets.Internal/WebSocketCloseResult.cs @@ -71,7 +71,7 @@ namespace Microsoft.Extensions.WebSockets.Internal buffer.WriteBigEndian((ushort)Status); if (!string.IsNullOrEmpty(Description)) { - buffer.Append(Description, EncodingData.InvariantUtf8.TextEncoding); + buffer.Append(Description, EncodingData.InvariantUtf8); } } } diff --git a/src/Microsoft.Extensions.WebSockets.Internal/WebSocketConnection.cs b/src/Microsoft.Extensions.WebSockets.Internal/WebSocketConnection.cs index cc3154a2de0332a34e1dfd95ec6738a5682dc419..5530309af68784bdd2d11c52ded88880416e5929 100644 --- a/src/Microsoft.Extensions.WebSockets.Internal/WebSocketConnection.cs +++ b/src/Microsoft.Extensions.WebSockets.Internal/WebSocketConnection.cs @@ -246,7 +246,7 @@ namespace Microsoft.Extensions.WebSockets.Internal _options.MaskingKeyGenerator.GetBytes(_maskingKeyBuffer); } - buffer.Set(_maskingKeyBuffer); + _maskingKeyBuffer.CopyTo(buffer); } private async Task<WebSocketCloseResult> ReceiveLoop(Func<WebSocketFrame, object, Task> messageHandler, object state, CancellationToken cancellationToken) @@ -550,7 +550,7 @@ namespace Microsoft.Extensions.WebSockets.Internal { // TODO: Could use TryGetPointer, GetBytes does take a byte*, but it seems like just waiting until we have a version that uses Span is best. // Slow path - Allocate a heap buffer for the encoded bytes before writing them out. - payload.Span.Set(Encoding.UTF8.GetBytes(str)); + Encoding.UTF8.GetBytes(str).CopyTo(payload.Span); } if (maskingKey.Length > 0) diff --git a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs index 0e3d16bdde98c1b266a5e64eb52d608317db772c..0631e7dbd7ded281269cb3ef7910c632a6249fb3 100644 --- a/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Client.FunctionalTests/HubConnectionTests.cs @@ -134,10 +134,10 @@ namespace Microsoft.AspNetCore.SignalR.Client.FunctionalTests EnsureConnectionEstablished(connection); - var ex = await Assert.ThrowsAnyAsync<InvalidOperationException>( - async () => await connection.Invoke<Task>("!@#$%")); + var ex = await Assert.ThrowsAnyAsync<Exception>( + async () => await connection.Invoke<object>("!@#$%")); - Assert.Equal(ex.Message, "The hub method '!@#$%' could not be resolved."); + Assert.Equal(ex.Message, "Unknown hub method '!@#$%'"); } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Test.Server/EchoEndPoint.cs b/test/Microsoft.AspNetCore.SignalR.Test.Server/EchoEndPoint.cs index 92f149a87c42abe786b208affcd51fdb35a680ee..7b930e543799c0a7f31998739a553160f0e93897 100644 --- a/test/Microsoft.AspNetCore.SignalR.Test.Server/EchoEndPoint.cs +++ b/test/Microsoft.AspNetCore.SignalR.Test.Server/EchoEndPoint.cs @@ -8,11 +8,11 @@ using Microsoft.AspNetCore.Sockets; namespace Microsoft.AspNetCore.SignalR.Test.Server { - public class EchoEndPoint : StreamingEndPoint + public class EchoEndPoint : EndPoint { - public async override Task OnConnectedAsync(StreamingConnection connection) + public async override Task OnConnectedAsync(Connection connection) { - await connection.Transport.Input.CopyToAsync(connection.Transport.Output); + await connection.Transport.Output.WriteAsync(await connection.Transport.Input.ReadAsync()); } } } diff --git a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs index e7eca16802fa2741ea280925e6db796f2871c3b5..084fa15a8c775fc0080cde3807e8b2c9c66884d0 100644 --- a/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs +++ b/test/Microsoft.AspNetCore.SignalR.Tests/HubEndpointTests.cs @@ -4,9 +4,11 @@ using System; using System.IO; using System.IO.Pipelines; +using System.Runtime.CompilerServices; using System.Security.Claims; using System.Threading; using System.Threading.Tasks; +using System.Threading.Tasks.Channels; using Microsoft.AspNetCore.Sockets; using Microsoft.AspNetCore.Sockets.Internal; using Microsoft.Extensions.DependencyInjection; @@ -33,7 +35,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests await connectionWrapper.ApplicationStartedReading; // kill the connection - connectionWrapper.ConnectionState.Dispose(); + connectionWrapper.Dispose(); await endPointTask; @@ -41,44 +43,12 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } - [Fact] - public async Task OnDisconnectedCalledWithExceptionIfHubMethodNotFound() - { - var hub = Mock.Of<Hub>(); - - var endPointType = GetEndPointType(hub.GetType()); - var serviceProvider = CreateServiceProvider(s => - { - s.AddSingleton(endPointType); - s.AddTransient(hub.GetType(), sp => hub); - }); - - dynamic endPoint = serviceProvider.GetService(endPointType); - - using (var connectionWrapper = new ConnectionWrapper()) - { - var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection); - - await connectionWrapper.ApplicationStartedReading; - - var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>(); - var adapter = invocationAdapter.GetInvocationAdapter("json"); - await SendRequest(connectionWrapper.Connection.Transport, adapter, "0xdeadbeef"); - - connectionWrapper.Dispose(); - - await Assert.ThrowsAsync<InvalidOperationException>(async () => await endPointTask); - - Mock.Get(hub).Verify(h => h.OnDisconnectedAsync(It.IsNotNull<InvalidOperationException>()), Times.Once()); - } - } - [Fact] public async Task LifetimeManagerOnDisconnectedAsyncCalledIfLifetimeManagerOnConnectedAsyncThrows() { var mockLifetimeManager = new Mock<HubLifetimeManager<Hub>>(); mockLifetimeManager - .Setup(m => m.OnConnectedAsync(It.IsAny<StreamingConnection>())) + .Setup(m => m.OnConnectedAsync(It.IsAny<Connection>())) .Throws(new InvalidOperationException("Lifetime manager OnConnectedAsync failed.")); var mockHubActivator = new Mock<IHubActivator<Hub, IClientProxy>>(); @@ -97,10 +67,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests async () => await endPoint.OnConnectedAsync(connectionWrapper.Connection)); Assert.Equal("Lifetime manager OnConnectedAsync failed.", exception.Message); - connectionWrapper.ConnectionState.Dispose(); + connectionWrapper.Dispose(); - mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<StreamingConnection>()), Times.Once); - mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<StreamingConnection>()), Times.Once); + mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<Connection>()), Times.Once); + mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<Connection>()), Times.Once); // No hubs should be created since the connection is terminated mockHubActivator.Verify(m => m.Create(), Times.Never); mockHubActivator.Verify(m => m.Release(It.IsAny<Hub>()), Times.Never); @@ -121,13 +91,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests using (var connectionWrapper = new ConnectionWrapper()) { var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection); - connectionWrapper.ConnectionState.Dispose(); + connectionWrapper.Dispose(); var exception = await Assert.ThrowsAsync<InvalidOperationException>(async () => await endPointTask); Assert.Equal("Hub OnConnected failed.", exception.Message); - mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<StreamingConnection>()), Times.Once); - mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<StreamingConnection>()), Times.Once); + mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<Connection>()), Times.Once); + mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<Connection>()), Times.Once); } } @@ -145,44 +115,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests using (var connectionWrapper = new ConnectionWrapper()) { var endPointTask = endPoint.OnConnectedAsync(connectionWrapper.Connection); - connectionWrapper.ConnectionState.Dispose(); + connectionWrapper.Dispose(); var exception = await Assert.ThrowsAsync<InvalidOperationException>(async () => await endPointTask); Assert.Equal("Hub OnDisconnected failed.", exception.Message); - mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<StreamingConnection>()), Times.Once); - mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<StreamingConnection>()), Times.Once); - } - } - - private static Type GetEndPointType(Type hubType) - { - var endPointType = typeof(HubEndPoint<>); - return endPointType.MakeGenericType(hubType); - } - - private static Type GetGenericType(Type genericType, Type hubType) - { - return genericType.MakeGenericType(hubType); - } - - public class OnConnectedThrowsHub : Hub - { - public override Task OnConnectedAsync() - { - var tcs = new TaskCompletionSource<object>(); - tcs.SetException(new InvalidOperationException("Hub OnConnected failed.")); - return tcs.Task; - } - } - - public class OnDisconnectedThrowsHub : Hub - { - public override Task OnDisconnectedAsync(Exception exception) - { - var tcs = new TaskCompletionSource<object>(); - tcs.SetException(new InvalidOperationException("Hub OnDisconnected failed.")); - return tcs.Task; + mockLifetimeManager.Verify(m => m.OnConnectedAsync(It.IsAny<Connection>()), Times.Once); + mockLifetimeManager.Verify(m => m.OnDisconnectedAsync(It.IsAny<Connection>()), Times.Once); } } @@ -202,10 +141,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>(); var adapter = invocationAdapter.GetInvocationAdapter("json"); - await SendRequest(connectionWrapper.Connection.Transport, adapter, "TaskValueMethod"); - var res = await ReadConnectionOutputAsync<InvocationResultDescriptor>(connectionWrapper.Connection.Transport); + await SendRequest(connectionWrapper, adapter, nameof(MethodHub.TaskValueMethod)); + var result = await ReadConnectionOutputAsync<InvocationResultDescriptor>(connectionWrapper); + // json serializer makes this a long - Assert.Equal(42L, res.Result); + Assert.Equal(42L, result.Result); // kill the connection connectionWrapper.Connection.Dispose(); @@ -230,10 +170,11 @@ namespace Microsoft.AspNetCore.SignalR.Tests var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>(); var adapter = invocationAdapter.GetInvocationAdapter("json"); - await SendRequest(connectionWrapper.Connection.Transport, adapter, "ValueMethod"); - var res = await ReadConnectionOutputAsync<InvocationResultDescriptor>(connectionWrapper.Connection.Transport); + await SendRequest(connectionWrapper, adapter, "ValueMethod"); + var result = await ReadConnectionOutputAsync<InvocationResultDescriptor>(connectionWrapper); + // json serializer makes this a long - Assert.Equal(43L, res.Result); + Assert.Equal(43L, result.Result); // kill the connection connectionWrapper.Connection.Dispose(); @@ -258,9 +199,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>(); var adapter = invocationAdapter.GetInvocationAdapter("json"); - await SendRequest(connectionWrapper.Connection.Transport, adapter, "StaticMethod"); - var res = await ReadConnectionOutputAsync<InvocationResultDescriptor>(connectionWrapper.Connection.Transport); - Assert.Equal("fromStatic", res.Result); + await SendRequest(connectionWrapper, adapter, "StaticMethod"); + var result = await ReadConnectionOutputAsync<InvocationResultDescriptor>(connectionWrapper); + + Assert.Equal("fromStatic", result.Result); // kill the connection connectionWrapper.Connection.Dispose(); @@ -285,9 +227,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>(); var adapter = invocationAdapter.GetInvocationAdapter("json"); - await SendRequest(connectionWrapper.Connection.Transport, adapter, "VoidMethod"); - var res = await ReadConnectionOutputAsync<InvocationResultDescriptor>(connectionWrapper.Connection.Transport); - Assert.Equal(null, res.Result); + await SendRequest(connectionWrapper, adapter, "VoidMethod"); + var result = await ReadConnectionOutputAsync<InvocationResultDescriptor>(connectionWrapper); + + Assert.Null(result.Result); // kill the connection connectionWrapper.Connection.Dispose(); @@ -312,9 +255,9 @@ namespace Microsoft.AspNetCore.SignalR.Tests var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>(); var adapter = invocationAdapter.GetInvocationAdapter("json"); - await SendRequest(connectionWrapper.Connection.Transport, adapter, "ConcatString", (byte)32, 42, 'm', "string"); - var res = await ReadConnectionOutputAsync<InvocationResultDescriptor>(connectionWrapper.Connection.Transport); - Assert.Equal("32, 42, m, string", res.Result); + await SendRequest(connectionWrapper, adapter, "ConcatString", (byte)32, 42, 'm', "string"); + var result = await ReadConnectionOutputAsync<InvocationResultDescriptor>(connectionWrapper); + Assert.Equal("32, 42, m, string", result.Result); // kill the connection connectionWrapper.Connection.Dispose(); @@ -339,17 +282,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>(); var adapter = invocationAdapter.GetInvocationAdapter("json"); - await SendRequest(connectionWrapper.Connection.Transport, adapter, "OnDisconnectedAsync"); + await SendRequest(connectionWrapper, adapter, "OnDisconnectedAsync"); + var result = await ReadConnectionOutputAsync<InvocationResultDescriptor>(connectionWrapper); - try - { - await endPointTask; - Assert.True(false); - } - catch (InvalidOperationException ex) - { - Assert.Equal("The hub method 'OnDisconnectedAsync' could not be resolved.", ex.Message); - } + Assert.Equal("Unknown hub method 'OnDisconnectedAsync'", result.Error); } } @@ -371,21 +307,21 @@ namespace Microsoft.AspNetCore.SignalR.Tests var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>(); var adapter = invocationAdapter.GetInvocationAdapter("json"); - await SendRequest(firstConnection.Connection.Transport, adapter, "BroadcastMethod", "test"); + await SendRequest(firstConnection, adapter, "BroadcastMethod", "test"); - foreach (var res in await Task.WhenAll( - ReadConnectionOutputAsync<InvocationDescriptor>(firstConnection.Connection.Transport), - ReadConnectionOutputAsync<InvocationDescriptor>(secondConnection.Connection.Transport))) + foreach (var result in await Task.WhenAll( + ReadConnectionOutputAsync<InvocationDescriptor>(firstConnection), + ReadConnectionOutputAsync<InvocationDescriptor>(secondConnection))) { - Assert.Equal("Broadcast", res.Method); - Assert.Equal(1, res.Arguments.Length); - Assert.Equal("test", res.Arguments[0]); + Assert.Equal("Broadcast", result.Method); + Assert.Equal(1, result.Arguments.Length); + Assert.Equal("test", result.Arguments[0]); } // kill the connections firstConnection.Connection.Dispose(); secondConnection.Connection.Dispose(); - + await Task.WhenAll(firstEndPointTask, secondEndPointTask); } } @@ -408,18 +344,20 @@ namespace Microsoft.AspNetCore.SignalR.Tests var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>(); var adapter = invocationAdapter.GetInvocationAdapter("json"); - await SendRequest_IgnoreReceive(firstConnection.Connection.Transport, adapter, "GroupSendMethod", "testGroup", "test"); + await SendRequest_IgnoreReceive(firstConnection, adapter, "GroupSendMethod", "testGroup", "test"); // check that 'secondConnection' hasn't received the group send - Assert.False(((PipelineReaderWriter)secondConnection.Connection.Transport.Output).ReadAsync().IsCompleted); + Message message; + Assert.False(secondConnection.Transport.Output.TryRead(out message)); + + await SendRequest_IgnoreReceive(secondConnection, adapter, "GroupAddMethod", "testGroup"); - await SendRequest_IgnoreReceive(secondConnection.Connection.Transport, adapter, "GroupAddMethod", "testGroup"); + await SendRequest(firstConnection, adapter, "GroupSendMethod", "testGroup", "test"); - await SendRequest(firstConnection.Connection.Transport, adapter, "GroupSendMethod", "testGroup", "test"); // check that 'firstConnection' hasn't received the group send - Assert.False(((PipelineReaderWriter)firstConnection.Connection.Transport.Output).ReadAsync().IsCompleted); + Assert.False(firstConnection.Transport.Output.TryRead(out message)); // check that 'secondConnection' has received the group send - var res = await ReadConnectionOutputAsync<InvocationDescriptor>(secondConnection.Connection.Transport); + var res = await ReadConnectionOutputAsync<InvocationDescriptor>(secondConnection); Assert.Equal("Send", res.Method); Assert.Equal(1, res.Arguments.Length); Assert.Equal("test", res.Arguments[0]); @@ -448,7 +386,7 @@ namespace Microsoft.AspNetCore.SignalR.Tests var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>(); var writer = invocationAdapter.GetInvocationAdapter("json"); - await SendRequest_IgnoreReceive(connection.Connection.Transport, writer, "GroupRemoveMethod", "testGroup"); + await SendRequest_IgnoreReceive(connection, writer, "GroupRemoveMethod", "testGroup"); // kill the connection connection.Connection.Dispose(); @@ -475,10 +413,10 @@ namespace Microsoft.AspNetCore.SignalR.Tests var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>(); var adapter = invocationAdapter.GetInvocationAdapter("json"); - await SendRequest_IgnoreReceive(firstConnection.Connection.Transport, adapter, "ClientSendMethod", secondConnection.Connection.User.Identity.Name, "test"); + await SendRequest_IgnoreReceive(firstConnection, adapter, "ClientSendMethod", secondConnection.Connection.User.Identity.Name, "test"); // check that 'secondConnection' has received the group send - var res = await ReadConnectionOutputAsync<InvocationDescriptor>(secondConnection.Connection.Transport); + var res = await ReadConnectionOutputAsync<InvocationDescriptor>(secondConnection); Assert.Equal("Send", res.Method); Assert.Equal(1, res.Arguments.Length); Assert.Equal("test", res.Arguments[0]); @@ -509,13 +447,13 @@ namespace Microsoft.AspNetCore.SignalR.Tests var invocationAdapter = serviceProvider.GetService<InvocationAdapterRegistry>(); var adapter = invocationAdapter.GetInvocationAdapter("json"); - await SendRequest_IgnoreReceive(firstConnection.Connection.Transport, adapter, "ConnectionSendMethod", secondConnection.Connection.ConnectionId, "test"); + await SendRequest_IgnoreReceive(firstConnection, adapter, "ConnectionSendMethod", secondConnection.Connection.ConnectionId, "test"); // check that 'secondConnection' has received the group send - var res = await ReadConnectionOutputAsync<InvocationDescriptor>(secondConnection.Connection.Transport); - Assert.Equal("Send", res.Method); - Assert.Equal(1, res.Arguments.Length); - Assert.Equal("test", res.Arguments[0]); + var result = await ReadConnectionOutputAsync<InvocationDescriptor>(secondConnection); + Assert.Equal("Send", result.Method); + Assert.Equal(1, result.Arguments.Length); + Assert.Equal("test", result.Arguments[0]); // kill the connections firstConnection.Connection.Dispose(); @@ -525,6 +463,84 @@ namespace Microsoft.AspNetCore.SignalR.Tests } } + private static Type GetEndPointType(Type hubType) + { + var endPointType = typeof(HubEndPoint<>); + return endPointType.MakeGenericType(hubType); + } + + private static Type GetGenericType(Type genericType, Type hubType) + { + return genericType.MakeGenericType(hubType); + } + + public async Task SendRequest(ConnectionWrapper connection, IInvocationAdapter writer, string method, params object[] args) + { + if (connection == null) + { + throw new ArgumentNullException(); + } + + var stream = new MemoryStream(); + await writer.WriteMessageAsync(new InvocationDescriptor + { + Arguments = args, + Method = method + }, + stream); + + var buffer = ReadableBuffer.Create(stream.ToArray()).Preserve(); + await connection.Transport.Input.WriteAsync(new Message(buffer, Format.Binary, endOfMessage: true)); + } + + public async Task SendRequest_IgnoreReceive(ConnectionWrapper connection, IInvocationAdapter writer, string method, params object[] args) + { + await SendRequest(connection, writer, method, args); + + // Consume the result + await connection.Transport.Output.ReadAsync(); + } + + private async Task<T> ReadConnectionOutputAsync<T>(ConnectionWrapper connection) + { + // TODO: other formats? + var message = await connection.Transport.Output.ReadAsync(); + var serializer = new JsonSerializer(); + return serializer.Deserialize<T>(new JsonTextReader(new StreamReader(new MemoryStream(message.Payload.Buffer.ToArray())))); + } + + private IServiceProvider CreateServiceProvider(Action<ServiceCollection> addServices = null) + { + var services = new ServiceCollection(); + services.AddOptions() + .AddLogging() + .AddSignalR(); + + addServices?.Invoke(services); + + return services.BuildServiceProvider(); + } + + public class OnConnectedThrowsHub : Hub + { + public override Task OnConnectedAsync() + { + var tcs = new TaskCompletionSource<object>(); + tcs.SetException(new InvalidOperationException("Hub OnConnected failed.")); + return tcs.Task; + } + } + + public class OnDisconnectedThrowsHub : Hub + { + public override Task OnDisconnectedAsync(Exception exception) + { + var tcs = new TaskCompletionSource<object>(); + tcs.SetException(new InvalidOperationException("Hub OnDisconnected failed.")); + return tcs.Task; + } + } + private class MethodHub : Hub { public Task GroupRemoveMethod(string groupName) @@ -610,83 +626,91 @@ namespace Microsoft.AspNetCore.SignalR.Tests public int DisposeCount = 0; } - public async Task SendRequest(IPipelineConnection connection, IInvocationAdapter writer, string method, params object[] args) + public class ConnectionWrapper : IDisposable { - if (connection == null) - { - throw new ArgumentNullException(); - } + private static int _id; + private readonly TestChannel<Message> _input; - var stream = new MemoryStream(); - await writer.WriteMessageAsync(new InvocationDescriptor + public Connection Connection { get; } + + public ChannelConnection<Message> Transport { get; } + + public Task ApplicationStartedReading => _input.ReadingStarted; + + public ConnectionWrapper(string format = "json") { - Arguments = args, - Method = method - }, stream); + var transportToApplication = Channel.CreateUnbounded<Message>(); + var applicationToTransport = Channel.CreateUnbounded<Message>(); - var buffer = ((PipelineReaderWriter)connection.Input).Alloc(); - buffer.Write(stream.ToArray()); - await buffer.FlushAsync(); - } + _input = new TestChannel<Message>(transportToApplication); - public async Task SendRequest_IgnoreReceive(IPipelineConnection connection, IInvocationAdapter writer, string method, params object[] args) - { - await SendRequest(connection, writer, method, args); + Transport = new ChannelConnection<Message>(_input, applicationToTransport); - var methodResult = await ((PipelineReaderWriter)connection.Output).ReadAsync(); - ((PipelineReaderWriter)connection.Output).AdvanceReader(methodResult.Buffer.End, methodResult.Buffer.End); - } + Connection = new Connection(Guid.NewGuid().ToString(), Transport); + Connection.Metadata["formatType"] = format; + Connection.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, Interlocked.Increment(ref _id).ToString()) })); + } - private async Task<T> ReadConnectionOutputAsync<T>(IPipelineConnection connection) - { - // TODO: other formats? - var methodResult = await ((PipelineReaderWriter)connection.Output).ReadAsync(); - var serializer = new JsonSerializer(); - var res = serializer.Deserialize<T>(new JsonTextReader(new StreamReader(new MemoryStream(methodResult.Buffer.ToArray())))); - ((PipelineReaderWriter)connection.Output).AdvanceReader(methodResult.Buffer.End, methodResult.Buffer.End); + public void Dispose() + { + Connection.Dispose(); + } - return res; - } + private class TestChannel<T> : IChannel<T> + { + private IChannel<T> _channel; + private TaskCompletionSource<object> _tcs = new TaskCompletionSource<object>(); - private IServiceProvider CreateServiceProvider(Action<ServiceCollection> addServices = null) - { - var services = new ServiceCollection(); - services.AddOptions() - .AddLogging() - .AddSignalR(); + public TestChannel(IChannel<T> channel) + { + _channel = channel; + } - addServices?.Invoke(services); + public Task Completion => _channel.Completion; - return services.BuildServiceProvider(); - } + public Task ReadingStarted => _tcs.Task; - private class ConnectionWrapper : IDisposable - { - private static int ID; - private PipelineFactory _factory; + public ValueAwaiter<T> GetAwaiter() + { + return _channel.GetAwaiter(); + } - public StreamingConnectionState ConnectionState; + public ValueTask<T> ReadAsync(CancellationToken cancellationToken = default(CancellationToken)) + { + _tcs.TrySetResult(null); + return _channel.ReadAsync(cancellationToken); + } - public StreamingConnection Connection => ConnectionState.Connection; + public bool TryComplete(Exception error = null) + { + return _channel.TryComplete(error); + } - // Still kinda gross... - public Task ApplicationStartedReading => ((PipelineReaderWriter)Connection.Transport.Input).ReadingStarted; + public bool TryRead(out T item) + { + return _channel.TryRead(out item); + } - public ConnectionWrapper(string format = "json") - { - _factory = new PipelineFactory(); + public bool TryWrite(T item) + { + return _channel.TryWrite(item); + } - var connectionManager = new ConnectionManager(_factory); + public Task<bool> WaitToReadAsync(CancellationToken cancellationToken = default(CancellationToken)) + { + _tcs.TrySetResult(null); + return _channel.WaitToReadAsync(cancellationToken); + } - ConnectionState = (StreamingConnectionState)connectionManager.CreateConnection(ConnectionMode.Streaming); - ConnectionState.Connection.Metadata["formatType"] = format; - ConnectionState.Connection.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim(ClaimTypes.Name, Interlocked.Increment(ref ID).ToString()) })); - } + public Task<bool> WaitToWriteAsync(CancellationToken cancellationToken = default(CancellationToken)) + { + return _channel.WaitToWriteAsync(cancellationToken); + } - public void Dispose() - { - ConnectionState.Dispose(); - _factory.Dispose(); + public Task WriteAsync(T item, CancellationToken cancellationToken = default(CancellationToken)) + { + return _channel.WriteAsync(item, cancellationToken); + } } } } diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs index d3609196bb535ed304cb91110fb7d67a45618c6c..7ac86dcc19d8aae1e0af80cee2f8f733d9232483 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/ConnectionManagerTests.cs @@ -1,7 +1,7 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. -using System.IO.Pipelines; +using System; using System.Threading.Tasks; using Microsoft.AspNetCore.Sockets.Internal; using Xunit; @@ -13,100 +13,87 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public void NewConnectionsHaveConnectionId() { - using (var factory = new PipelineFactory()) - { - var connectionManager = new ConnectionManager(factory); - var state = connectionManager.CreateConnection(ConnectionMode.Streaming); - - Assert.NotNull(state.Connection); - Assert.NotNull(state.Connection.ConnectionId); - Assert.True(state.Active); - Assert.Null(state.Close); - Assert.NotNull(((StreamingConnectionState)state).Connection.Transport); - } + + var connectionManager = new ConnectionManager(); + var state = connectionManager.CreateConnection(); + + Assert.NotNull(state.Connection); + Assert.NotNull(state.Connection.ConnectionId); + Assert.True(state.Active); + Assert.Null(state.Close); + Assert.NotNull(state.Connection.Transport); } [Fact] public void NewConnectionsCanBeRetrieved() { - using (var factory = new PipelineFactory()) - { - var connectionManager = new ConnectionManager(factory); - var state = connectionManager.CreateConnection(ConnectionMode.Streaming); + var connectionManager = new ConnectionManager(); + var state = connectionManager.CreateConnection(); - Assert.NotNull(state.Connection); - Assert.NotNull(state.Connection.ConnectionId); + Assert.NotNull(state.Connection); + Assert.NotNull(state.Connection.ConnectionId); - ConnectionState newState; - Assert.True(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState)); - Assert.Same(newState, state); - } + ConnectionState newState; + Assert.True(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState)); + Assert.Same(newState, state); } [Fact] public void AddNewConnection() { - using (var factory = new PipelineFactory()) - { - var connectionManager = new ConnectionManager(factory); - var state = connectionManager.CreateConnection(ConnectionMode.Streaming); + var connectionManager = new ConnectionManager(); + var state = connectionManager.CreateConnection(); - var transport = ((StreamingConnectionState)state).Connection.Transport; + var transport = state.Connection.Transport; - Assert.NotNull(state.Connection); - Assert.NotNull(state.Connection.ConnectionId); - Assert.NotNull(transport); + Assert.NotNull(state.Connection); + Assert.NotNull(state.Connection.ConnectionId); + Assert.NotNull(transport); - ConnectionState newState; - Assert.True(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState)); - Assert.Same(newState, state); - Assert.Same(transport, ((StreamingConnectionState)newState).Connection.Transport); - } + ConnectionState newState; + Assert.True(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState)); + Assert.Same(newState, state); + Assert.Same(transport, newState.Connection.Transport); } [Fact] public void RemoveConnection() { - using (var factory = new PipelineFactory()) - { - var connectionManager = new ConnectionManager(factory); - var state = connectionManager.CreateConnection(ConnectionMode.Streaming); + var connectionManager = new ConnectionManager(); + var state = connectionManager.CreateConnection(); - var transport = ((StreamingConnectionState)state).Connection.Transport; + var transport = state.Connection.Transport; - Assert.NotNull(state.Connection); - Assert.NotNull(state.Connection.ConnectionId); - Assert.NotNull(transport); + Assert.NotNull(state.Connection); + Assert.NotNull(state.Connection.ConnectionId); + Assert.NotNull(transport); - ConnectionState newState; - Assert.True(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState)); - Assert.Same(newState, state); - Assert.Same(transport, ((StreamingConnectionState)newState).Connection.Transport); + ConnectionState newState; + Assert.True(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState)); + Assert.Same(newState, state); + Assert.Same(transport, newState.Connection.Transport); - connectionManager.RemoveConnection(state.Connection.ConnectionId); - Assert.False(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState)); - } + connectionManager.RemoveConnection(state.Connection.ConnectionId); + Assert.False(connectionManager.TryGetConnection(state.Connection.ConnectionId, out newState)); } [Fact] public async Task CloseConnectionsEndsAllPendingConnections() { - using (var factory = new PipelineFactory()) - { - var connectionManager = new ConnectionManager(factory); - var state = (StreamingConnectionState)connectionManager.CreateConnection(ConnectionMode.Streaming); + var connectionManager = new ConnectionManager(); + var state = connectionManager.CreateConnection(); - var task = Task.Run(async () => - { - var result = await state.Connection.Transport.Input.ReadAsync(); + var task = Task.Run(async () => + { + var connection = state.Connection; - Assert.True(result.IsCompleted); - }); + Assert.False(await connection.Transport.Input.WaitToReadAsync()); + Assert.True(connection.Transport.Input.Completion.IsCompleted); + }); - connectionManager.CloseConnections(); + connectionManager.CloseConnections(); - await task; - } + await task; } } } diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs index a65a6a412ec9075ea3e91005f7d96034f12fd43d..7429b8efb45bdf710e9df150562af7089fe5827d 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/HttpConnectionDispatcherTests.cs @@ -22,96 +22,65 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public async Task GetIdReservesConnectionIdAndReturnsIt() { - using (var factory = new PipelineFactory()) - { - var manager = new ConnectionManager(factory); - var dispatcher = new HttpConnectionDispatcher(manager, factory, new LoggerFactory()); - var context = new DefaultHttpContext(); - var services = new ServiceCollection(); - services.AddSingleton<TestEndPoint>(); - context.RequestServices = services.BuildServiceProvider(); - var ms = new MemoryStream(); - context.Request.Path = "/getid"; - context.Response.Body = ms; - await dispatcher.ExecuteAsync<TestEndPoint>("", context); + var manager = new ConnectionManager(); + var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); + var context = new DefaultHttpContext(); + var services = new ServiceCollection(); + services.AddSingleton<TestEndPoint>(); + context.RequestServices = services.BuildServiceProvider(); + var ms = new MemoryStream(); + context.Request.Path = "/getid"; + context.Response.Body = ms; + await dispatcher.ExecuteAsync<TestEndPoint>("", context); - var id = Encoding.UTF8.GetString(ms.ToArray()); + var id = Encoding.UTF8.GetString(ms.ToArray()); - ConnectionState state; - Assert.True(manager.TryGetConnection(id, out state)); - Assert.Equal(id, state.Connection.ConnectionId); - } + ConnectionState state; + Assert.True(manager.TryGetConnection(id, out state)); + Assert.Equal(id, state.Connection.ConnectionId); } - // REVIEW: No longer relevant since we establish the connection right away. - //[Fact] - //public async Task SendingToReservedConnectionsThatHaveNotConnectedThrows() - //{ - // using (var factory = new PipelineFactory()) - // { - // var manager = new ConnectionManager(factory); - // var state = manager.ReserveConnection(); - - // var dispatcher = new HttpConnectionDispatcher(manager, factory, loggerFactory: null); - // var context = new DefaultHttpContext(); - // context.Request.Path = "/send"; - // var values = new Dictionary<string, StringValues>(); - // values["id"] = state.Connection.ConnectionId; - // var qs = new QueryCollection(values); - // context.Request.Query = qs; - // await Assert.ThrowsAsync<InvalidOperationException>(async () => - // { - // await dispatcher.ExecuteAsync<TestEndPoint>("", context); - // }); - // } - //} - [Fact] public async Task SendingToUnknownConnectionIdThrows() { - using (var factory = new PipelineFactory()) + var manager = new ConnectionManager(); + var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); + var context = new DefaultHttpContext(); + var services = new ServiceCollection(); + services.AddSingleton<TestEndPoint>(); + context.RequestServices = services.BuildServiceProvider(); + context.Request.Path = "/send"; + var values = new Dictionary<string, StringValues>(); + values["id"] = "unknown"; + var qs = new QueryCollection(values); + context.Request.Query = qs; + await Assert.ThrowsAsync<InvalidOperationException>(async () => { - var manager = new ConnectionManager(factory); - var dispatcher = new HttpConnectionDispatcher(manager, factory, new LoggerFactory()); - var context = new DefaultHttpContext(); - var services = new ServiceCollection(); - services.AddSingleton<TestEndPoint>(); - context.RequestServices = services.BuildServiceProvider(); - context.Request.Path = "/send"; - var values = new Dictionary<string, StringValues>(); - values["id"] = "unknown"; - var qs = new QueryCollection(values); - context.Request.Query = qs; - await Assert.ThrowsAsync<InvalidOperationException>(async () => - { - await dispatcher.ExecuteAsync<TestEndPoint>("", context); - }); - } + await dispatcher.ExecuteAsync<TestEndPoint>("", context); + }); } [Fact] public async Task SendingWithoutConnectionIdThrows() { - using (var factory = new PipelineFactory()) + + var manager = new ConnectionManager(); + var dispatcher = new HttpConnectionDispatcher(manager, new LoggerFactory()); + var context = new DefaultHttpContext(); + var services = new ServiceCollection(); + services.AddSingleton<TestEndPoint>(); + context.RequestServices = services.BuildServiceProvider(); + context.Request.Path = "/send"; + await Assert.ThrowsAsync<InvalidOperationException>(async () => { - var manager = new ConnectionManager(factory); - var dispatcher = new HttpConnectionDispatcher(manager, factory, new LoggerFactory()); - var context = new DefaultHttpContext(); - var services = new ServiceCollection(); - services.AddSingleton<TestEndPoint>(); - context.RequestServices = services.BuildServiceProvider(); - context.Request.Path = "/send"; - await Assert.ThrowsAsync<InvalidOperationException>(async () => - { - await dispatcher.ExecuteAsync<TestEndPoint>("", context); - }); - } + await dispatcher.ExecuteAsync<TestEndPoint>("", context); + }); } } - public class TestEndPoint : StreamingEndPoint + public class TestEndPoint : EndPoint { - public override Task OnConnectedAsync(StreamingConnection connection) + public override Task OnConnectedAsync(Connection connection) { throw new NotImplementedException(); } diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs index 3f4209f9da059e2bea208500859a1b22b361e687..145744a937261d6225e8a79383c240a4fb64d1bf 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/LongPollingTests.cs @@ -21,7 +21,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public async Task Set204StatusCodeWhenChannelComplete() { - var channel = Channel.Create<Message>(); + var channel = Channel.CreateUnbounded<Message>(); var context = new DefaultHttpContext(); var poll = new LongPollingTransport(channel, new LoggerFactory()); @@ -35,7 +35,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public async Task FrameSentAsSingleResponse() { - var channel = Channel.Create<Message>(); + var channel = Channel.CreateUnbounded<Message>(); var context = new DefaultHttpContext(); var poll = new LongPollingTransport(channel, new LoggerFactory()); var ms = new MemoryStream(); diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs index bbfdf1e0543ff5a87e3543475dde1b4e5a64843a..3c587cd11575a94ffbe22de43b8fd14201b593e4 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/ServerSentEventsTests.cs @@ -21,7 +21,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public async Task SSESetsContentType() { - var channel = Channel.Create<Message>(); + var channel = Channel.CreateUnbounded<Message>(); var context = new DefaultHttpContext(); var sse = new ServerSentEventsTransport(channel, new LoggerFactory()); @@ -36,7 +36,7 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public async Task SSEAddsAppropriateFraming() { - var channel = Channel.Create<Message>(); + var channel = Channel.CreateUnbounded<Message>(); var context = new DefaultHttpContext(); var sse = new ServerSentEventsTransport(channel, new LoggerFactory()); var ms = new MemoryStream(); diff --git a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs index 0eedefa3a4ce60065d3999c074af92c808db7a31..eca7c3c9d31427ccf10d94f9758e8a3d6db6fdab 100644 --- a/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs +++ b/test/Microsoft.AspNetCore.Sockets.Tests/WebSocketsTests.cs @@ -22,8 +22,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests [InlineData(Format.Binary, WebSocketOpcode.Binary)] public async Task ReceivedFramesAreWrittenToChannel(Format format, WebSocketOpcode opcode) { - var transportToApplication = Channel.Create<Message>(); - var applicationToTransport = Channel.Create<Message>(); + var transportToApplication = Channel.CreateUnbounded<Message>(); + var applicationToTransport = Channel.CreateUnbounded<Message>(); var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication); var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport); @@ -70,8 +70,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests [InlineData(Format.Binary, WebSocketOpcode.Binary)] public async Task MultiFrameMessagesArePropagatedToTheChannel(Format format, WebSocketOpcode opcode) { - var transportToApplication = Channel.Create<Message>(); - var applicationToTransport = Channel.Create<Message>(); + var transportToApplication = Channel.CreateUnbounded<Message>(); + var applicationToTransport = Channel.CreateUnbounded<Message>(); var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication); var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport); @@ -129,8 +129,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests [InlineData(Format.Binary, WebSocketOpcode.Binary)] public async Task IncompleteMessagesAreWrittenAsMultiFrameWebSocketMessages(Format format, WebSocketOpcode opcode) { - var transportToApplication = Channel.Create<Message>(); - var applicationToTransport = Channel.Create<Message>(); + var transportToApplication = Channel.CreateUnbounded<Message>(); + var applicationToTransport = Channel.CreateUnbounded<Message>(); var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication); var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport); @@ -177,8 +177,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests [InlineData(Format.Binary, WebSocketOpcode.Binary)] public async Task DataWrittenToOutputPipelineAreSentAsFrames(Format format, WebSocketOpcode opcode) { - var transportToApplication = Channel.Create<Message>(); - var applicationToTransport = Channel.Create<Message>(); + var transportToApplication = Channel.CreateUnbounded<Message>(); + var applicationToTransport = Channel.CreateUnbounded<Message>(); var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication); var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport); @@ -218,8 +218,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests [InlineData(Format.Binary, WebSocketOpcode.Binary)] public async Task FrameReceivedAfterServerCloseSent(Format format, WebSocketOpcode opcode) { - var transportToApplication = Channel.Create<Message>(); - var applicationToTransport = Channel.Create<Message>(); + var transportToApplication = Channel.CreateUnbounded<Message>(); + var applicationToTransport = Channel.CreateUnbounded<Message>(); var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication); var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport); @@ -261,8 +261,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public async Task TransportFailsWhenClientDisconnectsAbnormally() { - var transportToApplication = Channel.Create<Message>(); - var applicationToTransport = Channel.Create<Message>(); + var transportToApplication = Channel.CreateUnbounded<Message>(); + var applicationToTransport = Channel.CreateUnbounded<Message>(); var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication); var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport); @@ -289,8 +289,8 @@ namespace Microsoft.AspNetCore.Sockets.Tests [Fact] public async Task ClientReceivesInternalServerErrorWhenTheApplicationFails() { - var transportToApplication = Channel.Create<Message>(); - var applicationToTransport = Channel.Create<Message>(); + var transportToApplication = Channel.CreateUnbounded<Message>(); + var applicationToTransport = Channel.CreateUnbounded<Message>(); var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication); var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport); diff --git a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.SendAsync.cs b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.SendAsync.cs index a52652026732ce50432c2b7dfa1f56e9cf8f2ec3..4a74c8010d865cce7f4e92149a471c3a8ded82ea 100644 --- a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.SendAsync.cs +++ b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketConnectionTests.SendAsync.cs @@ -197,7 +197,7 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests } } - private static void CompleteChannels(params PipelineReaderWriter[] readerWriters) + private static void CompleteChannels(params Pipe[] readerWriters) { foreach (var readerWriter in readerWriters) { diff --git a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketPair.cs b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketPair.cs index 4d09cccad7a6bb85e904bac04886dc7e95772367..dd137a12a270ad3e1230a3d277ef7c74185ee7ee 100644 --- a/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketPair.cs +++ b/test/Microsoft.Extensions.WebSockets.Internal.Tests/WebSocketPair.cs @@ -14,13 +14,13 @@ namespace Microsoft.Extensions.WebSockets.Internal.Tests private PipelineFactory _factory; private readonly bool _ownFactory; - public PipelineReaderWriter ServerToClient { get; } - public PipelineReaderWriter ClientToServer { get; } + public Pipe ServerToClient { get; } + public Pipe ClientToServer { get; } public IWebSocketConnection ClientSocket { get; } public IWebSocketConnection ServerSocket { get; } - public WebSocketPair(bool ownFactory, PipelineFactory factory, PipelineReaderWriter serverToClient, PipelineReaderWriter clientToServer, IWebSocketConnection clientSocket, IWebSocketConnection serverSocket) + public WebSocketPair(bool ownFactory, PipelineFactory factory, Pipe serverToClient, Pipe clientToServer, IWebSocketConnection clientSocket, IWebSocketConnection serverSocket) { _ownFactory = ownFactory; _factory = factory;