Skip to content
代码片段 群组 项目
提交 cd9ed922 编辑于 作者: David Fowler's avatar David Fowler 提交者: GitHub
浏览文件

Remove streaming transport as a top level API (#110)

- Remove Streaming* classes from Sockets. The main
API will be channels based and streaming transports
will use the PipelineChannel (formerly FramingChannel) to
access messages.
- Added WriteAsync and ReadAsync to Connection and hid
the IChannelConnection from public API.
- Also fixed the fact that unknown methods caused server side
exceptions.
- Changed the consumption pattern to WaitToReadAsync/TryRead to avoid
exceptions.
- React to API changes
上级 9dbb3742
No related branches found
No related tags found
无相关合并请求
显示
172 个添加272 个删除
......@@ -16,7 +16,7 @@ namespace ChatSample.Hubs
{
if (!Context.User.Identity.IsAuthenticated)
{
Context.Connection.Transport.Dispose();
Context.Connection.Dispose();
}
return Task.CompletedTask;
......
......@@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Linq;
using System.Threading.Tasks;
......@@ -13,19 +14,19 @@ namespace SocialWeather
public class PersistentConnectionLifeTimeManager
{
private readonly FormatterResolver _formatterResolver;
private readonly ConnectionList<StreamingConnection> _connectionList = new ConnectionList<StreamingConnection>();
private readonly ConnectionList _connectionList = new ConnectionList();
public PersistentConnectionLifeTimeManager(FormatterResolver formatterResolver)
{
_formatterResolver = formatterResolver;
}
public void OnConnectedAsync(StreamingConnection connection)
public void OnConnectedAsync(Connection connection)
{
_connectionList.Add(connection);
}
public void OnDisconnectedAsync(StreamingConnection connection)
public void OnDisconnectedAsync(Connection connection)
{
_connectionList.Remove(connection);
}
......@@ -35,7 +36,10 @@ namespace SocialWeather
foreach (var connection in _connectionList)
{
var formatter = _formatterResolver.GetFormatter<T>(connection.Metadata.Get<string>("formatType"));
await formatter.WriteAsync(data, connection.Transport.GetStream());
var ms = new MemoryStream();
await formatter.WriteAsync(data, ms);
var buffer = ReadableBuffer.Create(ms.ToArray()).Preserve();
await connection.Transport.Output.WriteAsync(new Message(buffer, Format.Binary, endOfMessage: true));
}
}
......@@ -54,7 +58,7 @@ namespace SocialWeather
throw new NotImplementedException();
}
public void AddGroupAsync(StreamingConnection connection, string groupName)
public void AddGroupAsync(Connection connection, string groupName)
{
var groups = connection.Metadata.GetOrAdd("groups", _ => new HashSet<string>());
lock (groups)
......@@ -63,7 +67,7 @@ namespace SocialWeather
}
}
public void RemoveGroupAsync(StreamingConnection connection, string groupName)
public void RemoveGroupAsync(Connection connection, string groupName)
{
var groups = connection.Metadata.Get<HashSet<string>>("groups");
if (groups != null)
......
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System.IO;
using System.IO.Pipelines;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;
......@@ -8,7 +9,7 @@ using Microsoft.Extensions.Logging;
namespace SocialWeather
{
public class SocialWeatherEndPoint : StreamingEndPoint
public class SocialWeatherEndPoint : EndPoint
{
private readonly PersistentConnectionLifeTimeManager _lifetimeManager;
private readonly FormatterResolver _formatterResolver;
......@@ -22,22 +23,24 @@ namespace SocialWeather
_logger = logger;
}
public async override Task OnConnectedAsync(StreamingConnection connection)
public async override Task OnConnectedAsync(Connection connection)
{
_lifetimeManager.OnConnectedAsync(connection);
await ProcessRequests(connection);
_lifetimeManager.OnDisconnectedAsync(connection);
}
public async Task ProcessRequests(StreamingConnection connection)
public async Task ProcessRequests(Connection connection)
{
var stream = connection.Transport.GetStream();
var formatter = _formatterResolver.GetFormatter<WeatherReport>(
connection.Metadata.Get<string>("formatType"));
WeatherReport weatherReport;
while ((weatherReport = await formatter.ReadAsync(stream)) != null)
while (true)
{
Message message = await connection.Transport.Input.ReadAsync();
var stream = new MemoryStream();
await message.Payload.Buffer.CopyToAsync(stream);
WeatherReport weatherReport = await formatter.ReadAsync(stream);
await _lifetimeManager.SendToAllAsync(weatherReport);
}
}
......
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.IO.Pipelines;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;
namespace SocketsSample
{
public class ChatEndPoint : StreamingEndPoint
{
public ConnectionList<StreamingConnection> Connections { get; } = new ConnectionList<StreamingConnection>();
public override async Task OnConnectedAsync(StreamingConnection connection)
{
Connections.Add(connection);
await Broadcast($"{connection.ConnectionId} connected ({connection.Metadata["transport"]})");
while (true)
{
var result = await connection.Transport.Input.ReadAsync();
var input = result.Buffer;
try
{
if (input.IsEmpty && result.IsCompleted)
{
break;
}
// We can avoid the copy here but we'll deal with that later
await Broadcast(input.ToArray());
}
finally
{
connection.Transport.Input.Advance(input.End);
}
}
Connections.Remove(connection);
await Broadcast($"{connection.ConnectionId} disconnected ({connection.Metadata["transport"]})");
}
private Task Broadcast(string text)
{
return Broadcast(Encoding.UTF8.GetBytes(text));
}
private Task Broadcast(byte[] payload)
{
var tasks = new List<Task>(Connections.Count);
foreach (var c in Connections)
{
tasks.Add(c.Transport.Output.WriteAsync(payload));
}
return Task.WhenAll(tasks);
}
}
}
......@@ -12,11 +12,11 @@ using Microsoft.AspNetCore.Sockets;
namespace SocketsSample.EndPoints
{
public class MessagesEndPoint : MessagingEndPoint
public class MessagesEndPoint : EndPoint
{
public ConnectionList<MessagingConnection> Connections { get; } = new ConnectionList<MessagingConnection>();
public ConnectionList Connections { get; } = new ConnectionList();
public override async Task OnConnectedAsync(MessagingConnection connection)
public override async Task OnConnectedAsync(Connection connection)
{
Connections.Add(connection);
......@@ -24,19 +24,19 @@ namespace SocketsSample.EndPoints
try
{
while (true)
while (await connection.Transport.Input.WaitToReadAsync())
{
using (var message = await connection.Transport.Input.ReadAsync())
Message message;
if (connection.Transport.Input.TryRead(out message))
{
// We can avoid the copy here but we'll deal with that later
await Broadcast(message.Payload.Buffer, message.MessageFormat, message.EndOfMessage);
using (message)
{
// We can avoid the copy here but we'll deal with that later
await Broadcast(message.Payload.Buffer, message.MessageFormat, message.EndOfMessage);
}
}
}
}
catch (Exception ex) when (ex.GetType().IsNested && ex.GetType().DeclaringType == typeof(Channel))
{
// Gross that we have to catch this this way. See https://github.com/dotnet/corefxlab/issues/1068
}
finally
{
Connections.Remove(connection);
......
......@@ -29,7 +29,6 @@ namespace SocketsSample
});
// .AddRedis();
services.AddSingleton<ChatEndPoint>();
services.AddSingleton<MessagesEndPoint>();
services.AddSingleton<ProtobufSerializer>();
}
......@@ -53,8 +52,7 @@ namespace SocketsSample
app.UseSockets(routes =>
{
routes.MapEndpoint<ChatEndPoint>("/chat");
routes.MapEndpoint<MessagesEndPoint>("/msgs");
routes.MapEndpoint<MessagesEndPoint>("/chat");
});
}
}
......
......@@ -6,18 +6,12 @@
</head>
<body>
<h1>ASP.NET Sockets</h1>
<h2>Streaming</h2>
<h2>Messaging</h2>
<ul>
<li><a href="sse.html#/chat">Server Sent Events</a></li>
<li><a href="polling.html#/chat">Long polling</a></li>
<li><a href="ws.html#/chat">Web Sockets</a></li>
</ul>
<h2>Messaging</h2>
<ul>
<li><a href="sse.html#/msgs">Server Sent Events</a></li>
<li><a href="polling.html#/msgs">Long polling</a></li>
<li><a href="ws.html#/msgs">Web Sockets</a></li>
</ul>
<h1>ASP.NET SignalR</h1>
<ul>
<li><a href="hubs.html">Hubs</a></li>
......
......@@ -20,7 +20,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
{
public class RedisHubLifetimeManager<THub> : HubLifetimeManager<THub>, IDisposable
{
private readonly ConnectionList<StreamingConnection> _connections = new ConnectionList<StreamingConnection>();
private readonly ConnectionList _connections = new ConnectionList();
// TODO: Investigate "memory leak" entries never get removed
private readonly ConcurrentDictionary<string, GroupData> _groups = new ConcurrentDictionary<string, GroupData>();
private readonly InvocationAdapterRegistry _registry;
......@@ -51,7 +51,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
foreach (var connection in _connections)
{
tasks.Add(connection.Transport.Output.WriteAsync((byte[])data));
tasks.Add(WriteAsync(connection, data));
}
previousBroadcastTask = Task.WhenAll(tasks);
......@@ -116,7 +116,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
}
}
public override Task OnConnectedAsync(StreamingConnection connection)
public override Task OnConnectedAsync(Connection connection)
{
var redisSubscriptions = connection.Metadata.GetOrAdd("redis_subscriptions", _ => new HashSet<string>());
var connectionTask = TaskCache.CompletedTask;
......@@ -133,7 +133,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
{
await previousConnectionTask;
previousConnectionTask = connection.Transport.Output.WriteAsync((byte[])data);
previousConnectionTask = WriteAsync(connection, data);
});
......@@ -149,14 +149,14 @@ namespace Microsoft.AspNetCore.SignalR.Redis
{
await previousUserTask;
previousUserTask = connection.Transport.Output.WriteAsync((byte[])data);
previousUserTask = WriteAsync(connection, data);
});
}
return Task.WhenAll(connectionTask, userTask);
}
public override Task OnDisconnectedAsync(StreamingConnection connection)
public override Task OnDisconnectedAsync(Connection connection)
{
_connections.Remove(connection);
......@@ -186,7 +186,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
return Task.WhenAll(tasks);
}
public override async Task AddGroupAsync(StreamingConnection connection, string groupName)
public override async Task AddGroupAsync(Connection connection, string groupName)
{
var groupChannel = typeof(THub).FullName + ".group." + groupName;
......@@ -220,9 +220,9 @@ namespace Microsoft.AspNetCore.SignalR.Redis
await previousTask;
var tasks = new List<Task>(group.Connections.Count);
foreach (var groupConnection in group.Connections.Cast<StreamingConnection>())
foreach (var groupConnection in group.Connections.Cast<Connection>())
{
tasks.Add(groupConnection.Transport.Output.WriteAsync((byte[])data));
tasks.Add(WriteAsync(groupConnection, data));
}
previousTask = Task.WhenAll(tasks);
......@@ -234,7 +234,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
}
}
public override async Task RemoveGroupAsync(StreamingConnection connection, string groupName)
public override async Task RemoveGroupAsync(Connection connection, string groupName)
{
var groupChannel = typeof(THub).FullName + ".group." + groupName;
......@@ -275,6 +275,12 @@ namespace Microsoft.AspNetCore.SignalR.Redis
_redisServerConnection.Dispose();
}
private Task WriteAsync(Connection connection, byte[] data)
{
var buffer = ReadableBuffer.Create(data).Preserve();
return connection.Transport.Output.WriteAsync(new Message(buffer, Format.Binary, endOfMessage: true));
}
private class LoggerTextWriter : TextWriter
{
private readonly ILogger _logger;
......@@ -300,7 +306,7 @@ namespace Microsoft.AspNetCore.SignalR.Redis
private class GroupData
{
public SemaphoreSlim Lock = new SemaphoreSlim(1, 1);
public ConnectionList<StreamingConnection> Connections = new ConnectionList<StreamingConnection>();
public ConnectionList Connections = new ConnectionList();
}
}
}
......@@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Threading.Tasks;
using Microsoft.AspNetCore.Sockets;
......@@ -12,7 +13,7 @@ namespace Microsoft.AspNetCore.SignalR
{
public class DefaultHubLifetimeManager<THub> : HubLifetimeManager<THub>
{
private readonly ConnectionList<StreamingConnection> _connections = new ConnectionList<StreamingConnection>();
private readonly ConnectionList _connections = new ConnectionList();
private readonly InvocationAdapterRegistry _registry;
public DefaultHubLifetimeManager(InvocationAdapterRegistry registry)
......@@ -20,7 +21,7 @@ namespace Microsoft.AspNetCore.SignalR
_registry = registry;
}
public override Task AddGroupAsync(StreamingConnection connection, string groupName)
public override Task AddGroupAsync(Connection connection, string groupName)
{
var groups = connection.Metadata.GetOrAdd("groups", _ => new HashSet<string>());
......@@ -32,7 +33,7 @@ namespace Microsoft.AspNetCore.SignalR
return TaskCache.CompletedTask;
}
public override Task RemoveGroupAsync(StreamingConnection connection, string groupName)
public override Task RemoveGroupAsync(Connection connection, string groupName)
{
var groups = connection.Metadata.Get<HashSet<string>>("groups");
......@@ -54,7 +55,7 @@ namespace Microsoft.AspNetCore.SignalR
return InvokeAllWhere(methodName, args, c => true);
}
private Task InvokeAllWhere(string methodName, object[] args, Func<StreamingConnection, bool> include)
private Task InvokeAllWhere(string methodName, object[] args, Func<Connection, bool> include)
{
var tasks = new List<Task>(_connections.Count);
var message = new InvocationDescriptor
......@@ -73,7 +74,7 @@ namespace Microsoft.AspNetCore.SignalR
var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get<string>("formatType"));
tasks.Add(invocationAdapter.WriteMessageAsync(message, connection.Transport.GetStream()));
tasks.Add(WriteAsync(connection, invocationAdapter, message));
}
return Task.WhenAll(tasks);
......@@ -91,7 +92,7 @@ namespace Microsoft.AspNetCore.SignalR
Arguments = args
};
return invocationAdapter.WriteMessageAsync(message, connection.Transport.GetStream());
return WriteAsync(connection, invocationAdapter, message);
}
public override Task InvokeGroupAsync(string groupName, string methodName, object[] args)
......@@ -111,17 +112,24 @@ namespace Microsoft.AspNetCore.SignalR
});
}
public override Task OnConnectedAsync(StreamingConnection connection)
public override Task OnConnectedAsync(Connection connection)
{
_connections.Add(connection);
return TaskCache.CompletedTask;
}
public override Task OnDisconnectedAsync(StreamingConnection connection)
public override Task OnDisconnectedAsync(Connection connection)
{
_connections.Remove(connection);
return TaskCache.CompletedTask;
}
}
private static Task WriteAsync(Connection connection, IInvocationAdapter invocationAdapter, InvocationDescriptor message)
{
var stream = new MemoryStream();
invocationAdapter.WriteMessageAsync(message, stream);
var buffer = ReadableBuffer.Create(stream.ToArray()).Preserve();
return connection.Transport.Output.WriteAsync(new Message(buffer, Format.Binary, endOfMessage: true));
}
}
}
......@@ -8,12 +8,12 @@ namespace Microsoft.AspNetCore.SignalR
{
public class HubCallerContext
{
public HubCallerContext(StreamingConnection connection)
public HubCallerContext(Connection connection)
{
Connection = connection;
}
public StreamingConnection Connection { get; }
public Connection Connection { get; }
public ClaimsPrincipal User => Connection.User;
......
......@@ -3,6 +3,7 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.IO.Pipelines;
using System.Linq;
using System.Reflection;
......@@ -25,10 +26,10 @@ namespace Microsoft.AspNetCore.SignalR
}
}
public class HubEndPoint<THub, TClient> : StreamingEndPoint, IInvocationBinder where THub : Hub<TClient>
public class HubEndPoint<THub, TClient> : EndPoint, IInvocationBinder where THub : Hub<TClient>
{
private readonly Dictionary<string, Func<StreamingConnection, InvocationDescriptor, Task<InvocationResultDescriptor>>> _callbacks
= new Dictionary<string, Func<StreamingConnection, InvocationDescriptor, Task<InvocationResultDescriptor>>>(StringComparer.OrdinalIgnoreCase);
private readonly Dictionary<string, Func<Connection, InvocationDescriptor, Task<InvocationResultDescriptor>>> _callbacks
= new Dictionary<string, Func<Connection, InvocationDescriptor, Task<InvocationResultDescriptor>>>(StringComparer.OrdinalIgnoreCase);
private readonly Dictionary<string, Type[]> _paramTypes = new Dictionary<string, Type[]>();
private readonly HubLifetimeManager<THub> _lifetimeManager;
......@@ -52,7 +53,7 @@ namespace Microsoft.AspNetCore.SignalR
DiscoverHubMethods();
}
public override async Task OnConnectedAsync(StreamingConnection connection)
public override async Task OnConnectedAsync(Connection connection)
{
// TODO: Dispatch from the caller
await Task.Yield();
......@@ -68,7 +69,7 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private async Task RunHubAsync(StreamingConnection connection)
private async Task RunHubAsync(Connection connection)
{
await HubOnConnectedAsync(connection);
......@@ -86,7 +87,7 @@ namespace Microsoft.AspNetCore.SignalR
await HubOnDisconnectedAsync(connection, null);
}
private async Task HubOnConnectedAsync(StreamingConnection connection)
private async Task HubOnConnectedAsync(Connection connection)
{
try
{
......@@ -112,7 +113,7 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private async Task HubOnDisconnectedAsync(StreamingConnection connection, Exception exception)
private async Task HubOnDisconnectedAsync(Connection connection, Exception exception)
{
try
{
......@@ -138,15 +139,26 @@ namespace Microsoft.AspNetCore.SignalR
}
}
private async Task DispatchMessagesAsync(StreamingConnection connection)
private async Task DispatchMessagesAsync(Connection connection)
{
var stream = connection.Transport.GetStream();
var invocationAdapter = _registry.GetInvocationAdapter(connection.Metadata.Get<string>("formatType"));
while (true)
while (await connection.Transport.Input.WaitToReadAsync())
{
// TODO: Handle receiving InvocationResultDescriptor
var invocationDescriptor = await invocationAdapter.ReadMessageAsync(stream, this) as InvocationDescriptor;
Message message;
if (!connection.Transport.Input.TryRead(out message))
{
continue;
}
InvocationDescriptor invocationDescriptor;
using (message)
{
var inputStream = new MemoryStream(message.Payload.Buffer.ToArray());
// TODO: Handle receiving InvocationResultDescriptor
invocationDescriptor = await invocationAdapter.ReadMessageAsync(inputStream, this) as InvocationDescriptor;
}
// Is there a better way of detecting that a connection was closed?
if (invocationDescriptor == null)
......@@ -160,7 +172,7 @@ namespace Microsoft.AspNetCore.SignalR
}
InvocationResultDescriptor result;
Func<StreamingConnection, InvocationDescriptor, Task<InvocationResultDescriptor>> callback;
Func<Connection, InvocationDescriptor, Task<InvocationResultDescriptor>> callback;
if (_callbacks.TryGetValue(invocationDescriptor.Method, out callback))
{
result = await callback(connection, invocationDescriptor);
......@@ -177,11 +189,19 @@ namespace Microsoft.AspNetCore.SignalR
_logger.LogError("Unknown hub method '{method}'", invocationDescriptor.Method);
}
await invocationAdapter.WriteMessageAsync(result, stream);
// TODO: Pool memory
var outStream = new MemoryStream();
await invocationAdapter.WriteMessageAsync(result, outStream);
var buffer = ReadableBuffer.Create(outStream.ToArray()).Preserve();
if (await connection.Transport.Output.WaitToWriteAsync())
{
connection.Transport.Output.TryWrite(new Message(buffer, Format.Binary, endOfMessage: true));
}
}
}
private void InitializeHub(THub hub, StreamingConnection connection)
private void InitializeHub(THub hub, Connection connection)
{
hub.Clients = _hubContext.Clients;
hub.Context = new HubCallerContext(connection);
......@@ -290,7 +310,7 @@ namespace Microsoft.AspNetCore.SignalR
Type[] types;
if (!_paramTypes.TryGetValue(methodName, out types))
{
throw new InvalidOperationException($"The hub method '{methodName}' could not be resolved.");
return Type.EmptyTypes;
}
return types;
}
......
......@@ -8,9 +8,9 @@ namespace Microsoft.AspNetCore.SignalR
{
public abstract class HubLifetimeManager<THub>
{
public abstract Task OnConnectedAsync(StreamingConnection connection);
public abstract Task OnConnectedAsync(Connection connection);
public abstract Task OnDisconnectedAsync(StreamingConnection connection);
public abstract Task OnDisconnectedAsync(Connection connection);
public abstract Task InvokeAllAsync(string methodName, object[] args);
......@@ -20,9 +20,9 @@ namespace Microsoft.AspNetCore.SignalR
public abstract Task InvokeUserAsync(string userId, string methodName, object[] args);
public abstract Task AddGroupAsync(StreamingConnection connection, string groupName);
public abstract Task AddGroupAsync(Connection connection, string groupName);
public abstract Task RemoveGroupAsync(StreamingConnection connection, string groupName);
public abstract Task RemoveGroupAsync(Connection connection, string groupName);
}
}
......@@ -75,10 +75,10 @@ namespace Microsoft.AspNetCore.SignalR
public class GroupManager<THub> : IGroupManager
{
private readonly StreamingConnection _connection;
private readonly Connection _connection;
private readonly HubLifetimeManager<THub> _lifetimeManager;
public GroupManager(StreamingConnection connection, HubLifetimeManager<THub> lifetimeManager)
public GroupManager(Connection connection, HubLifetimeManager<THub> lifetimeManager)
{
_connection = connection;
_lifetimeManager = lifetimeManager;
......
......@@ -3,24 +3,28 @@
using System;
using System.Security.Claims;
using System.Threading.Tasks;
namespace Microsoft.AspNetCore.Sockets
{
public abstract class Connection : IDisposable
public class Connection : IDisposable
{
public abstract ConnectionMode Mode { get; }
public string ConnectionId { get; }
public ClaimsPrincipal User { get; set; }
public ConnectionMetadata Metadata { get; } = new ConnectionMetadata();
protected Connection(string id)
public IChannelConnection<Message> Transport { get; }
public Connection(string id, IChannelConnection<Message> transport)
{
Transport = transport;
ConnectionId = id;
}
public virtual void Dispose()
public void Dispose()
{
Transport.Dispose();
}
}
}
......@@ -8,15 +8,15 @@ using System.Collections.Generic;
namespace Microsoft.AspNetCore.Sockets
{
public class ConnectionList<T> : IReadOnlyCollection<T> where T: Connection
public class ConnectionList : IReadOnlyCollection<Connection>
{
private readonly ConcurrentDictionary<string, T> _connections = new ConcurrentDictionary<string, T>();
private readonly ConcurrentDictionary<string, Connection> _connections = new ConcurrentDictionary<string, Connection>();
public T this[string connectionId]
public Connection this[string connectionId]
{
get
{
T connection;
Connection connection;
if (_connections.TryGetValue(connectionId, out connection))
{
return connection;
......@@ -27,18 +27,18 @@ namespace Microsoft.AspNetCore.Sockets
public int Count => _connections.Count;
public void Add(T connection)
public void Add(Connection connection)
{
_connections.TryAdd(connection.ConnectionId, connection);
}
public void Remove(T connection)
public void Remove(Connection connection)
{
T dummy;
Connection dummy;
_connections.TryRemove(connection.ConnectionId, out dummy);
}
public IEnumerator<T> GetEnumerator()
public IEnumerator<Connection> GetEnumerator()
{
foreach (var item in _connections)
{
......
......@@ -3,8 +3,6 @@
using System;
using System.Collections.Concurrent;
using System.Diagnostics;
using System.IO.Pipelines;
using System.Threading;
using System.Threading.Tasks.Channels;
using Microsoft.AspNetCore.Sockets.Internal;
......@@ -15,11 +13,9 @@ namespace Microsoft.AspNetCore.Sockets
{
private readonly ConcurrentDictionary<string, ConnectionState> _connections = new ConcurrentDictionary<string, ConnectionState>();
private readonly Timer _timer;
private readonly PipelineFactory _pipelineFactory;
public ConnectionManager(PipelineFactory pipelineFactory)
public ConnectionManager()
{
_pipelineFactory = pipelineFactory;
_timer = new Timer(Scan, this, 0, 1000);
}
......@@ -28,8 +24,23 @@ namespace Microsoft.AspNetCore.Sockets
return _connections.TryGetValue(id, out state);
}
public ConnectionState CreateConnection(ConnectionMode mode) =>
mode == ConnectionMode.Streaming ? CreateStreamingConnection() : CreateMessagingConnection();
public ConnectionState CreateConnection()
{
var id = MakeNewConnectionId();
var transportToApplication = Channel.CreateUnbounded<Message>();
var applicationToTransport = Channel.CreateUnbounded<Message>();
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
var state = new ConnectionState(
new Connection(id, applicationSide),
transportSide);
_connections.TryAdd(id, state);
return state;
}
public void RemoveConnection(string id)
{
......@@ -92,41 +103,5 @@ namespace Microsoft.AspNetCore.Sockets
}
}
}
private ConnectionState CreateMessagingConnection()
{
var id = MakeNewConnectionId();
var transportToApplication = Channel.Create<Message>();
var applicationToTransport = Channel.Create<Message>();
var transportSide = new ChannelConnection<Message>(applicationToTransport, transportToApplication);
var applicationSide = new ChannelConnection<Message>(transportToApplication, applicationToTransport);
var state = new MessagingConnectionState(
new MessagingConnection(id, applicationSide),
transportSide);
_connections.TryAdd(id, state);
return state;
}
private ConnectionState CreateStreamingConnection()
{
var id = MakeNewConnectionId();
var transportToApplication = _pipelineFactory.Create();
var applicationToTransport = _pipelineFactory.Create();
var transportSide = new PipelineConnection(applicationToTransport, transportToApplication);
var applicationSide = new PipelineConnection(transportToApplication, applicationToTransport);
var state = new StreamingConnectionState(
new StreamingConnection(id, applicationSide),
transportSide);
_connections.TryAdd(id, state);
return state;
}
}
}
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
namespace Microsoft.AspNetCore.Sockets
{
public enum ConnectionMode
{
Streaming,
Messaging
}
}
......@@ -11,14 +11,6 @@ namespace Microsoft.AspNetCore.Sockets
// REVIEW: This doesn't have any members any more... marker interface? Still even necessary?
public abstract class EndPoint
{
/// <summary>
/// Gets the connection mode supported by this endpoint.
/// </summary>
/// <remarks>
/// This maps directly to whichever of <see cref="MessagingEndPoint"/> or <see cref="StreamingEndPoint"/> the end point subclasses.
/// </remarks>
public abstract ConnectionMode Mode { get; }
/// <summary>
/// Called when a new connection is accepted to the endpoint
/// </summary>
......
......@@ -18,14 +18,12 @@ namespace Microsoft.AspNetCore.Sockets
public class HttpConnectionDispatcher
{
private readonly ConnectionManager _manager;
private readonly PipelineFactory _pipelineFactory;
private readonly ILoggerFactory _loggerFactory;
private readonly ILogger _logger;
public HttpConnectionDispatcher(ConnectionManager manager, PipelineFactory factory, ILoggerFactory loggerFactory)
public HttpConnectionDispatcher(ConnectionManager manager, ILoggerFactory loggerFactory)
{
_manager = manager;
_pipelineFactory = factory;
_loggerFactory = loggerFactory;
_logger = _loggerFactory.CreateLogger<HttpConnectionDispatcher>();
}
......@@ -37,7 +35,7 @@ namespace Microsoft.AspNetCore.Sockets
if (context.Request.Path.StartsWithSegments(path + "/getid"))
{
await ProcessGetId(context, endpoint.Mode);
await ProcessGetId(context);
}
else if (context.Request.Path.StartsWithSegments(path + "/send"))
{
......@@ -56,10 +54,10 @@ namespace Microsoft.AspNetCore.Sockets
? Format.Binary
: Format.Text;
var state = GetOrCreateConnection(context, endpoint.Mode);
var state = GetOrCreateConnection(context);
// Adapt the connection to a message-based transport if necessary, since all the HTTP transports are message-based.
var application = GetMessagingChannel(state, format);
var application = state.Application;
// Server sent events transport
if (context.Request.Path.StartsWithSegments(path + "/sse"))
......@@ -137,7 +135,7 @@ namespace Microsoft.AspNetCore.Sockets
// Notify the long polling transport to end
if (endpointTask.IsFaulted)
{
state.TerminateTransport(endpointTask.Exception.InnerException);
state.Connection.Transport.Output.TryComplete(endpointTask.Exception.InnerException);
}
state.Connection.Dispose();
......@@ -151,19 +149,6 @@ namespace Microsoft.AspNetCore.Sockets
}
}
private static IChannelConnection<Message> GetMessagingChannel(ConnectionState state, Format format)
{
if (state.Connection.Mode == ConnectionMode.Messaging)
{
return ((MessagingConnectionState)state).Application;
}
else
{
// We need to build an adapter
return new FramingChannel(((StreamingConnectionState)state).Application, format);
}
}
private ConnectionState InitializePersistentConnection(ConnectionState state, string transport, HttpContext context, EndPoint endpoint, Format format)
{
state.Connection.User = context.User;
......@@ -197,10 +182,10 @@ namespace Microsoft.AspNetCore.Sockets
await Task.WhenAll(endpointTask, transportTask);
}
private Task ProcessGetId(HttpContext context, ConnectionMode mode)
private Task ProcessGetId(HttpContext context)
{
// Establish the connection
var state = _manager.CreateConnection(mode);
var state = _manager.CreateConnection();
// Get the bytes for the connection id
var connectionIdBuffer = Encoding.UTF8.GetBytes(state.Connection.ConnectionId);
......@@ -221,34 +206,27 @@ namespace Microsoft.AspNetCore.Sockets
ConnectionState state;
if (_manager.TryGetConnection(connectionId, out state))
{
if (state.Connection.Mode == ConnectionMode.Streaming)
// Collect the message and write it to the channel
// TODO: Need to use some kind of pooled memory here.
byte[] buffer;
using (var stream = new MemoryStream())
{
var streamingState = (StreamingConnectionState)state;
await context.Request.Body.CopyToAsync(streamingState.Application.Output);
await context.Request.Body.CopyToAsync(stream);
buffer = stream.ToArray();
}
else
{
// Collect the message and write it to the channel
// TODO: Need to use some kind of pooled memory here.
byte[] buffer;
using (var strm = new MemoryStream())
{
await context.Request.Body.CopyToAsync(strm);
await strm.FlushAsync();
buffer = strm.ToArray();
}
var format =
string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase)
? Format.Binary
: Format.Text;
var message = new Message(
ReadableBuffer.Create(buffer).Preserve(),
format,
endOfMessage: true);
await ((MessagingConnectionState)state).Application.Output.WriteAsync(message);
}
var format =
string.Equals(context.Request.Query["format"], "binary", StringComparison.OrdinalIgnoreCase)
? Format.Binary
: Format.Text;
var message = new Message(
ReadableBuffer.Create(buffer).Preserve(),
format,
endOfMessage: true);
await state.Application.Output.WriteAsync(message);
}
else
{
......@@ -256,7 +234,7 @@ namespace Microsoft.AspNetCore.Sockets
}
}
private ConnectionState GetOrCreateConnection(HttpContext context, ConnectionMode mode)
private ConnectionState GetOrCreateConnection(HttpContext context)
{
var connectionId = context.Request.Query["id"];
ConnectionState connectionState;
......@@ -264,7 +242,7 @@ namespace Microsoft.AspNetCore.Sockets
// There's no connection id so this is a brand new connection
if (StringValues.IsNullOrEmpty(connectionId))
{
connectionState = _manager.CreateConnection(mode);
connectionState = _manager.CreateConnection();
}
else if (!_manager.TryGetConnection(connectionId, out connectionState))
{
......
......@@ -2,9 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using System.Threading.Tasks.Channels;
namespace Microsoft.AspNetCore.Sockets.Internal
......@@ -15,7 +12,6 @@ namespace Microsoft.AspNetCore.Sockets.Internal
public IChannel<T> Output { get; }
IReadableChannel<T> IChannelConnection<T>.Input => Input;
IWritableChannel<T> IChannelConnection<T>.Output => Output;
public ChannelConnection(IChannel<T> input, IChannel<T> output)
......
0% 加载中 .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册