diff --git a/eng/Versions.props b/eng/Versions.props index 987723b8bfaa7e4af7bf8d28bc9151a0fff61bc4..071e9119e69b16f0d146f546233efe48ed9f224f 100644 --- a/eng/Versions.props +++ b/eng/Versions.props @@ -199,7 +199,7 @@ <SystemCommandlineExperimentalPackageVersion>0.3.0-alpha.19317.1</SystemCommandlineExperimentalPackageVersion> <SystemComponentModelPackageVersion>4.3.0</SystemComponentModelPackageVersion> <SystemNetHttpPackageVersion>4.3.2</SystemNetHttpPackageVersion> - <SystemThreadingTasksExtensionsPackageVersion>4.5.2</SystemThreadingTasksExtensionsPackageVersion> + <SystemThreadingTasksExtensionsPackageVersion>4.5.3</SystemThreadingTasksExtensionsPackageVersion> <!-- Packages developed by @aspnet, but manually updated as necessary. --> <LibuvPackageVersion>1.10.0</LibuvPackageVersion> <MicrosoftAspNetWebApiClientPackageVersion>5.2.6</MicrosoftAspNetWebApiClientPackageVersion> @@ -242,7 +242,7 @@ <IdentityServer4PackageVersion>3.0.0</IdentityServer4PackageVersion> <IdentityServer4StoragePackageVersion>3.0.0</IdentityServer4StoragePackageVersion> <IdentityServer4EntityFrameworkStoragePackageVersion>3.0.0</IdentityServer4EntityFrameworkStoragePackageVersion> - <MessagePackPackageVersion>1.7.3.7</MessagePackPackageVersion> + <MessagePackPackageVersion>2.0.335</MessagePackPackageVersion> <MoqPackageVersion>4.10.0</MoqPackageVersion> <MonoCecilPackageVersion>0.10.1</MonoCecilPackageVersion> <NewtonsoftJsonBsonPackageVersion>1.0.2</NewtonsoftJsonBsonPackageVersion> diff --git a/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocol.cs b/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocol.cs index 78631d3c7f40e35e84121f22e87e375e062578f6..b42dcf7cfd3458bf8906b17a1a0a4c80e5d950b5 100644 --- a/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocol.cs +++ b/src/SignalR/common/Protocols.MessagePack/src/Protocol/MessagePackHubProtocol.cs @@ -6,10 +6,11 @@ using System.Buffers; using System.Collections.Generic; using System.Diagnostics; using System.IO; +using System.Linq; using System.Runtime.ExceptionServices; -using System.Runtime.InteropServices; using MessagePack; using MessagePack.Formatters; +using MessagePack.Resolvers; using Microsoft.AspNetCore.Connections; using Microsoft.AspNetCore.Internal; using Microsoft.Extensions.Options; @@ -25,8 +26,7 @@ namespace Microsoft.AspNetCore.SignalR.Protocol private const int VoidResult = 2; private const int NonVoidResult = 3; - private IFormatterResolver _resolver; - + private MessagePackSerializerOptions _msgPackSerializerOptions; private static readonly string ProtocolName = "messagepack"; private static readonly int ProtocolVersion = 1; @@ -62,7 +62,9 @@ namespace Microsoft.AspNetCore.SignalR.Protocol // with the provided resolvers if (options.FormatterResolvers.Count != SignalRResolver.Resolvers.Count) { - _resolver = new CombinedResolvers(options.FormatterResolvers); + var resolver = CompositeResolver.Create(Array.Empty<IMessagePackFormatter>(), (IReadOnlyList<IFormatterResolver>)options.FormatterResolvers); + _msgPackSerializerOptions = MessagePackSerializerOptions.Standard.WithResolver(resolver); + return; } @@ -71,13 +73,14 @@ namespace Microsoft.AspNetCore.SignalR.Protocol // check if the user customized the resolvers if (options.FormatterResolvers[i] != SignalRResolver.Resolvers[i]) { - _resolver = new CombinedResolvers(options.FormatterResolvers); + var resolver = CompositeResolver.Create(Array.Empty<IMessagePackFormatter>(), (IReadOnlyList<IFormatterResolver>)options.FormatterResolvers); + _msgPackSerializerOptions = MessagePackSerializerOptions.Standard.WithResolver(resolver); return; } } // Use optimized cached resolver if the default is chosen - _resolver = SignalRResolver.Instance; + _msgPackSerializerOptions = MessagePackSerializerOptions.Standard.WithResolver(SignalRResolver.Instance); } /// <inheritdoc /> @@ -95,59 +98,43 @@ namespace Microsoft.AspNetCore.SignalR.Protocol return false; } - var arraySegment = GetArraySegment(payload); - - message = ParseMessage(arraySegment.Array, arraySegment.Offset, binder, _resolver); + var reader = new MessagePackReader(payload); + message = ParseMessage(ref reader, binder, _msgPackSerializerOptions); return true; } - private static ArraySegment<byte> GetArraySegment(in ReadOnlySequence<byte> input) - { - if (input.IsSingleSegment) - { - var isArray = MemoryMarshal.TryGetArray(input.First, out var arraySegment); - // This will never be false unless we started using un-managed buffers - Debug.Assert(isArray); - return arraySegment; - } - - // Should be rare - return new ArraySegment<byte>(input.ToArray()); - } - - private static HubMessage ParseMessage(byte[] input, int startOffset, IInvocationBinder binder, IFormatterResolver resolver) + private static HubMessage ParseMessage(ref MessagePackReader reader, IInvocationBinder binder, MessagePackSerializerOptions msgPackSerializerOptions) { - var itemCount = MessagePackBinary.ReadArrayHeader(input, startOffset, out var readSize); - startOffset += readSize; + var itemCount = reader.ReadArrayHeader(); - var messageType = ReadInt32(input, ref startOffset, "messageType"); + var messageType = ReadInt32(ref reader, "messageType"); switch (messageType) { case HubProtocolConstants.InvocationMessageType: - return CreateInvocationMessage(input, ref startOffset, binder, resolver, itemCount); + return CreateInvocationMessage(ref reader, binder, msgPackSerializerOptions, itemCount); case HubProtocolConstants.StreamInvocationMessageType: - return CreateStreamInvocationMessage(input, ref startOffset, binder, resolver, itemCount); + return CreateStreamInvocationMessage(ref reader, binder, msgPackSerializerOptions, itemCount); case HubProtocolConstants.StreamItemMessageType: - return CreateStreamItemMessage(input, ref startOffset, binder, resolver); + return CreateStreamItemMessage(ref reader, binder, msgPackSerializerOptions); case HubProtocolConstants.CompletionMessageType: - return CreateCompletionMessage(input, ref startOffset, binder, resolver); + return CreateCompletionMessage(ref reader, binder, msgPackSerializerOptions); case HubProtocolConstants.CancelInvocationMessageType: - return CreateCancelInvocationMessage(input, ref startOffset); + return CreateCancelInvocationMessage(ref reader); case HubProtocolConstants.PingMessageType: return PingMessage.Instance; case HubProtocolConstants.CloseMessageType: - return CreateCloseMessage(input, ref startOffset, itemCount); + return CreateCloseMessage(ref reader, itemCount); default: // Future protocol changes can add message types, old clients can ignore them return null; } } - private static HubMessage CreateInvocationMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver, int itemCount) + private static HubMessage CreateInvocationMessage(ref MessagePackReader reader, IInvocationBinder binder, MessagePackSerializerOptions msgPackSerializerOptions, int itemCount) { - var headers = ReadHeaders(input, ref offset); - var invocationId = ReadInvocationId(input, ref offset); + var headers = ReadHeaders(ref reader); + var invocationId = ReadInvocationId(ref reader); // For MsgPack, we represent an empty invocation ID as an empty string, // so we need to normalize that to "null", which is what indicates a non-blocking invocation. @@ -156,13 +143,13 @@ namespace Microsoft.AspNetCore.SignalR.Protocol invocationId = null; } - var target = ReadString(input, ref offset, "target"); + var target = ReadString(ref reader, "target"); object[] arguments = null; try { var parameterTypes = binder.GetParameterTypes(target); - arguments = BindArguments(input, ref offset, parameterTypes, resolver); + arguments = BindArguments(ref reader, parameterTypes, msgPackSerializerOptions); } catch (Exception ex) { @@ -173,23 +160,23 @@ namespace Microsoft.AspNetCore.SignalR.Protocol // Previous clients will send 5 items, so we check if they sent a stream array or not if (itemCount > 5) { - streams = ReadStreamIds(input, ref offset); + streams = ReadStreamIds(ref reader); } return ApplyHeaders(headers, new InvocationMessage(invocationId, target, arguments, streams)); } - private static HubMessage CreateStreamInvocationMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver, int itemCount) + private static HubMessage CreateStreamInvocationMessage(ref MessagePackReader reader, IInvocationBinder binder, MessagePackSerializerOptions msgPackSerializerOptions, int itemCount) { - var headers = ReadHeaders(input, ref offset); - var invocationId = ReadInvocationId(input, ref offset); - var target = ReadString(input, ref offset, "target"); + var headers = ReadHeaders(ref reader); + var invocationId = ReadInvocationId(ref reader); + var target = ReadString(ref reader, "target"); object[] arguments = null; try { var parameterTypes = binder.GetParameterTypes(target); - arguments = BindArguments(input, ref offset, parameterTypes, resolver); + arguments = BindArguments(ref reader, parameterTypes, msgPackSerializerOptions); } catch (Exception ex) { @@ -200,21 +187,21 @@ namespace Microsoft.AspNetCore.SignalR.Protocol // Previous clients will send 5 items, so we check if they sent a stream array or not if (itemCount > 5) { - streams = ReadStreamIds(input, ref offset); + streams = ReadStreamIds(ref reader); } return ApplyHeaders(headers, new StreamInvocationMessage(invocationId, target, arguments, streams)); } - private static HubMessage CreateStreamItemMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver) + private static HubMessage CreateStreamItemMessage(ref MessagePackReader reader, IInvocationBinder binder, MessagePackSerializerOptions msgPackSerializerOptions) { - var headers = ReadHeaders(input, ref offset); - var invocationId = ReadInvocationId(input, ref offset); + var headers = ReadHeaders(ref reader); + var invocationId = ReadInvocationId(ref reader); object value; try { var itemType = binder.GetStreamItemType(invocationId); - value = DeserializeObject(input, ref offset, itemType, "item", resolver); + value = DeserializeObject(ref reader, itemType, "item", msgPackSerializerOptions); } catch (Exception ex) { @@ -224,11 +211,11 @@ namespace Microsoft.AspNetCore.SignalR.Protocol return ApplyHeaders(headers, new StreamItemMessage(invocationId, value)); } - private static CompletionMessage CreateCompletionMessage(byte[] input, ref int offset, IInvocationBinder binder, IFormatterResolver resolver) + private static CompletionMessage CreateCompletionMessage(ref MessagePackReader reader, IInvocationBinder binder, MessagePackSerializerOptions msgPackSerializerOptions) { - var headers = ReadHeaders(input, ref offset); - var invocationId = ReadInvocationId(input, ref offset); - var resultKind = ReadInt32(input, ref offset, "resultKind"); + var headers = ReadHeaders(ref reader); + var invocationId = ReadInvocationId(ref reader); + var resultKind = ReadInt32(ref reader, "resultKind"); string error = null; object result = null; @@ -237,11 +224,11 @@ namespace Microsoft.AspNetCore.SignalR.Protocol switch (resultKind) { case ErrorResult: - error = ReadString(input, ref offset, "error"); + error = ReadString(ref reader, "error"); break; case NonVoidResult: var itemType = binder.GetReturnType(invocationId); - result = DeserializeObject(input, ref offset, itemType, "argument", resolver); + result = DeserializeObject(ref reader, itemType, "argument", msgPackSerializerOptions); hasResult = true; break; case VoidResult: @@ -254,21 +241,21 @@ namespace Microsoft.AspNetCore.SignalR.Protocol return ApplyHeaders(headers, new CompletionMessage(invocationId, error, result, hasResult)); } - private static CancelInvocationMessage CreateCancelInvocationMessage(byte[] input, ref int offset) + private static CancelInvocationMessage CreateCancelInvocationMessage(ref MessagePackReader reader) { - var headers = ReadHeaders(input, ref offset); - var invocationId = ReadInvocationId(input, ref offset); + var headers = ReadHeaders(ref reader); + var invocationId = ReadInvocationId(ref reader); return ApplyHeaders(headers, new CancelInvocationMessage(invocationId)); } - private static CloseMessage CreateCloseMessage(byte[] input, ref int offset, int itemCount) + private static CloseMessage CreateCloseMessage(ref MessagePackReader reader, int itemCount) { - var error = ReadString(input, ref offset, "error"); + var error = ReadString(ref reader, "error"); var allowReconnect = false; if (itemCount > 2) { - allowReconnect = ReadBoolean(input, ref offset, "allowReconnect"); + allowReconnect = ReadBoolean(ref reader, "allowReconnect"); } // An empty string is still an error @@ -280,17 +267,17 @@ namespace Microsoft.AspNetCore.SignalR.Protocol return new CloseMessage(error, allowReconnect); } - private static Dictionary<string, string> ReadHeaders(byte[] input, ref int offset) + private static Dictionary<string, string> ReadHeaders(ref MessagePackReader reader) { - var headerCount = ReadMapLength(input, ref offset, "headers"); + var headerCount = ReadMapLength(ref reader, "headers"); if (headerCount > 0) { var headers = new Dictionary<string, string>(StringComparer.Ordinal); for (var i = 0; i < headerCount; i++) { - var key = ReadString(input, ref offset, $"headers[{i}].Key"); - var value = ReadString(input, ref offset, $"headers[{i}].Value"); + var key = ReadString(ref reader, $"headers[{i}].Key"); + var value = ReadString(ref reader, $"headers[{i}].Value"); headers.Add(key, value); } return headers; @@ -301,9 +288,9 @@ namespace Microsoft.AspNetCore.SignalR.Protocol } } - private static string[] ReadStreamIds(byte[] input, ref int offset) + private static string[] ReadStreamIds(ref MessagePackReader reader) { - var streamIdCount = ReadArrayLength(input, ref offset, "streamIds"); + var streamIdCount = ReadArrayLength(ref reader, "streamIds"); List<string> streams = null; if (streamIdCount > 0) @@ -311,17 +298,16 @@ namespace Microsoft.AspNetCore.SignalR.Protocol streams = new List<string>(); for (var i = 0; i < streamIdCount; i++) { - streams.Add(MessagePackBinary.ReadString(input, offset, out var read)); - offset += read; + streams.Add(reader.ReadString()); } } return streams?.ToArray(); } - private static object[] BindArguments(byte[] input, ref int offset, IReadOnlyList<Type> parameterTypes, IFormatterResolver resolver) + private static object[] BindArguments(ref MessagePackReader reader, IReadOnlyList<Type> parameterTypes, MessagePackSerializerOptions msgPackSerializerOptions) { - var argumentCount = ReadArrayLength(input, ref offset, "arguments"); + var argumentCount = ReadArrayLength(ref reader, "arguments"); if (parameterTypes.Count != argumentCount) { @@ -334,7 +320,7 @@ namespace Microsoft.AspNetCore.SignalR.Protocol var arguments = new object[argumentCount]; for (var i = 0; i < argumentCount; i++) { - arguments[i] = DeserializeObject(input, ref offset, parameterTypes[i], "argument", resolver); + arguments[i] = DeserializeObject(ref reader, parameterTypes[i], "argument", msgPackSerializerOptions); } return arguments; @@ -358,339 +344,314 @@ namespace Microsoft.AspNetCore.SignalR.Protocol /// <inheritdoc /> public void WriteMessage(HubMessage message, IBufferWriter<byte> output) { - var writer = MemoryBufferWriter.Get(); + var memoryBufferWriter = MemoryBufferWriter.Get(); try { + var writer = new MessagePackWriter(memoryBufferWriter); + // Write message to a buffer so we can get its length - WriteMessageCore(message, writer); + WriteMessageCore(message, ref writer); // Write length then message to output - BinaryMessageFormatter.WriteLengthPrefix(writer.Length, output); - writer.CopyTo(output); + BinaryMessageFormatter.WriteLengthPrefix(memoryBufferWriter.Length, output); + memoryBufferWriter.CopyTo(output); } finally { - MemoryBufferWriter.Return(writer); + MemoryBufferWriter.Return(memoryBufferWriter); } } /// <inheritdoc /> public ReadOnlyMemory<byte> GetMessageBytes(HubMessage message) { - var writer = MemoryBufferWriter.Get(); + var memoryBufferWriter = MemoryBufferWriter.Get(); try { + var writer = new MessagePackWriter(memoryBufferWriter); + // Write message to a buffer so we can get its length - WriteMessageCore(message, writer); + WriteMessageCore(message, ref writer); - var dataLength = writer.Length; - var prefixLength = BinaryMessageFormatter.LengthPrefixLength(writer.Length); + var dataLength = memoryBufferWriter.Length; + var prefixLength = BinaryMessageFormatter.LengthPrefixLength(memoryBufferWriter.Length); var array = new byte[dataLength + prefixLength]; var span = array.AsSpan(); // Write length then message to output - var written = BinaryMessageFormatter.WriteLengthPrefix(writer.Length, span); + var written = BinaryMessageFormatter.WriteLengthPrefix(memoryBufferWriter.Length, span); Debug.Assert(written == prefixLength); - writer.CopyTo(span.Slice(prefixLength)); + memoryBufferWriter.CopyTo(span.Slice(prefixLength)); return array; } finally { - MemoryBufferWriter.Return(writer); + MemoryBufferWriter.Return(memoryBufferWriter); } } - private void WriteMessageCore(HubMessage message, Stream packer) + private void WriteMessageCore(HubMessage message, ref MessagePackWriter writer) { switch (message) { case InvocationMessage invocationMessage: - WriteInvocationMessage(invocationMessage, packer); + WriteInvocationMessage(invocationMessage, ref writer); break; case StreamInvocationMessage streamInvocationMessage: - WriteStreamInvocationMessage(streamInvocationMessage, packer); + WriteStreamInvocationMessage(streamInvocationMessage, ref writer); break; case StreamItemMessage streamItemMessage: - WriteStreamingItemMessage(streamItemMessage, packer); + WriteStreamingItemMessage(streamItemMessage, ref writer); break; case CompletionMessage completionMessage: - WriteCompletionMessage(completionMessage, packer); + WriteCompletionMessage(completionMessage, ref writer); break; case CancelInvocationMessage cancelInvocationMessage: - WriteCancelInvocationMessage(cancelInvocationMessage, packer); + WriteCancelInvocationMessage(cancelInvocationMessage, ref writer); break; case PingMessage pingMessage: - WritePingMessage(pingMessage, packer); + WritePingMessage(pingMessage, ref writer); break; case CloseMessage closeMessage: - WriteCloseMessage(closeMessage, packer); + WriteCloseMessage(closeMessage, ref writer); break; default: throw new InvalidDataException($"Unexpected message type: {message.GetType().Name}"); } + + writer.Flush(); } - private void WriteInvocationMessage(InvocationMessage message, Stream packer) + private void WriteInvocationMessage(InvocationMessage message, ref MessagePackWriter writer) { - MessagePackBinary.WriteArrayHeader(packer, 6); + writer.WriteArrayHeader(6); - MessagePackBinary.WriteInt32(packer, HubProtocolConstants.InvocationMessageType); - PackHeaders(packer, message.Headers); + writer.Write(HubProtocolConstants.InvocationMessageType); + PackHeaders(message.Headers, ref writer); if (string.IsNullOrEmpty(message.InvocationId)) { - MessagePackBinary.WriteNil(packer); + writer.WriteNil(); } else { - MessagePackBinary.WriteString(packer, message.InvocationId); + writer.Write(message.InvocationId); } - MessagePackBinary.WriteString(packer, message.Target); - MessagePackBinary.WriteArrayHeader(packer, message.Arguments.Length); + writer.Write(message.Target); + writer.WriteArrayHeader(message.Arguments.Length); foreach (var arg in message.Arguments) { - WriteArgument(arg, packer); + WriteArgument(arg, ref writer); } - WriteStreamIds(message.StreamIds, packer); + WriteStreamIds(message.StreamIds, ref writer); } - private void WriteStreamInvocationMessage(StreamInvocationMessage message, Stream packer) + private void WriteStreamInvocationMessage(StreamInvocationMessage message, ref MessagePackWriter writer) { - MessagePackBinary.WriteArrayHeader(packer, 6); + writer.WriteArrayHeader(6); - MessagePackBinary.WriteInt16(packer, HubProtocolConstants.StreamInvocationMessageType); - PackHeaders(packer, message.Headers); - MessagePackBinary.WriteString(packer, message.InvocationId); - MessagePackBinary.WriteString(packer, message.Target); + writer.Write(HubProtocolConstants.StreamInvocationMessageType); + PackHeaders(message.Headers, ref writer); + writer.Write(message.InvocationId); + writer.Write(message.Target); - MessagePackBinary.WriteArrayHeader(packer, message.Arguments.Length); + writer.WriteArrayHeader(message.Arguments.Length); foreach (var arg in message.Arguments) { - WriteArgument(arg, packer); + WriteArgument(arg, ref writer); } - WriteStreamIds(message.StreamIds, packer); + WriteStreamIds(message.StreamIds, ref writer); } - private void WriteStreamingItemMessage(StreamItemMessage message, Stream packer) + private void WriteStreamingItemMessage(StreamItemMessage message, ref MessagePackWriter writer) { - MessagePackBinary.WriteArrayHeader(packer, 4); - MessagePackBinary.WriteInt16(packer, HubProtocolConstants.StreamItemMessageType); - PackHeaders(packer, message.Headers); - MessagePackBinary.WriteString(packer, message.InvocationId); - WriteArgument(message.Item, packer); + writer.WriteArrayHeader(4); + writer.Write(HubProtocolConstants.StreamItemMessageType); + PackHeaders(message.Headers, ref writer); + writer.Write(message.InvocationId); + WriteArgument(message.Item, ref writer); } - private void WriteArgument(object argument, Stream stream) + private void WriteArgument(object argument, ref MessagePackWriter writer) { if (argument == null) { - MessagePackBinary.WriteNil(stream); + writer.WriteNil(); } else { - MessagePackSerializer.NonGeneric.Serialize(argument.GetType(), stream, argument, _resolver); + MessagePackSerializer.Serialize(argument.GetType(), ref writer, argument, _msgPackSerializerOptions); } } - private void WriteStreamIds(string[] streamIds, Stream packer) + private void WriteStreamIds(string[] streamIds, ref MessagePackWriter writer) { if (streamIds != null) { - MessagePackBinary.WriteArrayHeader(packer, streamIds.Length); + writer.WriteArrayHeader(streamIds.Length); foreach (var streamId in streamIds) { - MessagePackBinary.WriteString(packer, streamId); + writer.Write(streamId); } } else { - MessagePackBinary.WriteArrayHeader(packer, 0); + writer.WriteArrayHeader(0); } } - private void WriteCompletionMessage(CompletionMessage message, Stream packer) + private void WriteCompletionMessage(CompletionMessage message, ref MessagePackWriter writer) { var resultKind = message.Error != null ? ErrorResult : message.HasResult ? NonVoidResult : VoidResult; - MessagePackBinary.WriteArrayHeader(packer, 4 + (resultKind != VoidResult ? 1 : 0)); - MessagePackBinary.WriteInt32(packer, HubProtocolConstants.CompletionMessageType); - PackHeaders(packer, message.Headers); - MessagePackBinary.WriteString(packer, message.InvocationId); - MessagePackBinary.WriteInt32(packer, resultKind); + writer.WriteArrayHeader(4 + (resultKind != VoidResult ? 1 : 0)); + writer.Write(HubProtocolConstants.CompletionMessageType); + PackHeaders(message.Headers, ref writer); + writer.Write(message.InvocationId); + writer.Write(resultKind); switch (resultKind) { case ErrorResult: - MessagePackBinary.WriteString(packer, message.Error); + writer.Write(message.Error); break; case NonVoidResult: - WriteArgument(message.Result, packer); + WriteArgument(message.Result, ref writer); break; } } - private void WriteCancelInvocationMessage(CancelInvocationMessage message, Stream packer) + private void WriteCancelInvocationMessage(CancelInvocationMessage message, ref MessagePackWriter writer) { - MessagePackBinary.WriteArrayHeader(packer, 3); - MessagePackBinary.WriteInt16(packer, HubProtocolConstants.CancelInvocationMessageType); - PackHeaders(packer, message.Headers); - MessagePackBinary.WriteString(packer, message.InvocationId); + writer.WriteArrayHeader(3); + writer.Write(HubProtocolConstants.CancelInvocationMessageType); + PackHeaders(message.Headers, ref writer); + writer.Write(message.InvocationId); } - private void WriteCloseMessage(CloseMessage message, Stream packer) + private void WriteCloseMessage(CloseMessage message, ref MessagePackWriter writer) { - MessagePackBinary.WriteArrayHeader(packer, 3); - MessagePackBinary.WriteInt16(packer, HubProtocolConstants.CloseMessageType); + writer.WriteArrayHeader(3); + writer.Write(HubProtocolConstants.CloseMessageType); if (string.IsNullOrEmpty(message.Error)) { - MessagePackBinary.WriteNil(packer); + writer.WriteNil(); } else { - MessagePackBinary.WriteString(packer, message.Error); + writer.Write(message.Error); } - MessagePackBinary.WriteBoolean(packer, message.AllowReconnect); + writer.Write(message.AllowReconnect); } - private void WritePingMessage(PingMessage pingMessage, Stream packer) + private void WritePingMessage(PingMessage pingMessage, ref MessagePackWriter writer) { - MessagePackBinary.WriteArrayHeader(packer, 1); - MessagePackBinary.WriteInt32(packer, HubProtocolConstants.PingMessageType); + writer.WriteArrayHeader(1); + writer.Write(HubProtocolConstants.PingMessageType); } - private void PackHeaders(Stream packer, IDictionary<string, string> headers) + private void PackHeaders(IDictionary<string, string> headers, ref MessagePackWriter writer) { if (headers != null) { - MessagePackBinary.WriteMapHeader(packer, headers.Count); + writer.WriteMapHeader(headers.Count); if (headers.Count > 0) { foreach (var header in headers) { - MessagePackBinary.WriteString(packer, header.Key); - MessagePackBinary.WriteString(packer, header.Value); + writer.Write(header.Key); + writer.Write(header.Value); } } } else { - MessagePackBinary.WriteMapHeader(packer, 0); + writer.WriteMapHeader(0); } } - private static string ReadInvocationId(byte[] input, ref int offset) - { - return ReadString(input, ref offset, "invocationId"); - } + private static string ReadInvocationId(ref MessagePackReader reader) => + ReadString(ref reader, "invocationId"); - private static bool ReadBoolean(byte[] input, ref int offset, string field) + private static bool ReadBoolean(ref MessagePackReader reader, string field) { - Exception msgPackException = null; try { - var readBool = MessagePackBinary.ReadBoolean(input, offset, out var readSize); - offset += readSize; - return readBool; + return reader.ReadBoolean(); } - catch (Exception e) + catch (Exception ex) { - msgPackException = e; + throw new InvalidDataException($"Reading '{field}' as Boolean failed.", ex); } - - throw new InvalidDataException($"Reading '{field}' as Boolean failed.", msgPackException); } - private static int ReadInt32(byte[] input, ref int offset, string field) + private static int ReadInt32(ref MessagePackReader reader, string field) { - Exception msgPackException = null; try { - var readInt = MessagePackBinary.ReadInt32(input, offset, out var readSize); - offset += readSize; - return readInt; + return reader.ReadInt32(); } - catch (Exception e) + catch (Exception ex) { - msgPackException = e; + throw new InvalidDataException($"Reading '{field}' as Int32 failed.", ex); } - - throw new InvalidDataException($"Reading '{field}' as Int32 failed.", msgPackException); } - private static string ReadString(byte[] input, ref int offset, string field) + private static string ReadString(ref MessagePackReader reader, string field) { - Exception msgPackException = null; try { - var readString = MessagePackBinary.ReadString(input, offset, out var readSize); - offset += readSize; - return readString; + return reader.ReadString(); } - catch (Exception e) + catch (Exception ex) { - msgPackException = e; + throw new InvalidDataException($"Reading '{field}' as String failed.", ex); } - - throw new InvalidDataException($"Reading '{field}' as String failed.", msgPackException); } - private static long ReadMapLength(byte[] input, ref int offset, string field) + private static long ReadMapLength(ref MessagePackReader reader, string field) { - Exception msgPackException = null; try { - var readMap = MessagePackBinary.ReadMapHeader(input, offset, out var readSize); - offset += readSize; - return readMap; + return reader.ReadMapHeader(); } - catch (Exception e) + catch (Exception ex) { - msgPackException = e; + throw new InvalidDataException($"Reading map length for '{field}' failed.", ex); } - throw new InvalidDataException($"Reading map length for '{field}' failed.", msgPackException); } - private static long ReadArrayLength(byte[] input, ref int offset, string field) + private static long ReadArrayLength(ref MessagePackReader reader, string field) { - Exception msgPackException = null; try { - var readArray = MessagePackBinary.ReadArrayHeader(input, offset, out var readSize); - offset += readSize; - return readArray; + return reader.ReadArrayHeader(); } - catch (Exception e) + catch (Exception ex) { - msgPackException = e; + throw new InvalidDataException($"Reading array length for '{field}' failed.", ex); } - - throw new InvalidDataException($"Reading array length for '{field}' failed.", msgPackException); } - private static object DeserializeObject(byte[] input, ref int offset, Type type, string field, IFormatterResolver resolver) + private static object DeserializeObject(ref MessagePackReader reader, Type type, string field, MessagePackSerializerOptions msgPackSerializerOptions) { - Exception msgPackException = null; try { - var obj = MessagePackSerializer.NonGeneric.Deserialize(type, new ArraySegment<byte>(input, offset, input.Length - offset), resolver); - offset += MessagePackBinary.ReadNextBlock(input, offset); - return obj; + return MessagePackSerializer.Deserialize(type, ref reader, msgPackSerializerOptions); } catch (Exception ex) { - msgPackException = ex; + throw new InvalidDataException($"Deserializing object of the `{type.Name}` type for '{field}' failed.", ex); } - - throw new InvalidDataException($"Deserializing object of the `{type.Name}` type for '{field}' failed.", msgPackException); } internal static List<IFormatterResolver> CreateDefaultFormatterResolvers() @@ -703,10 +664,10 @@ namespace Microsoft.AspNetCore.SignalR.Protocol { public static readonly IFormatterResolver Instance = new SignalRResolver(); - public static readonly IList<IFormatterResolver> Resolvers = new[] + public static readonly IList<IFormatterResolver> Resolvers = new IFormatterResolver[] { - MessagePack.Resolvers.DynamicEnumAsStringResolver.Instance, - MessagePack.Resolvers.ContractlessStandardResolver.Instance, + DynamicEnumAsStringResolver.Instance, + ContractlessStandardResolver.Instance, }; public IMessagePackFormatter<T> GetFormatter<T>() @@ -731,30 +692,5 @@ namespace Microsoft.AspNetCore.SignalR.Protocol } } } - - // Support for users making their own Formatter lists - internal class CombinedResolvers : IFormatterResolver - { - private readonly IList<IFormatterResolver> _resolvers; - - public CombinedResolvers(IList<IFormatterResolver> resolvers) - { - _resolvers = resolvers; - } - - public IMessagePackFormatter<T> GetFormatter<T>() - { - foreach (var resolver in _resolvers) - { - var formatter = resolver.GetFormatter<T>(); - if (formatter != null) - { - return formatter; - } - } - - return null; - } - } } } diff --git a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs index 8a87a67bd695d543f71d01997a0f806830060e9f..2355228bd9c0202783527f37e5463c56e13f3178 100644 --- a/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs +++ b/src/SignalR/common/SignalR.Common/test/Internal/Protocol/MessagePackHubProtocolTests.cs @@ -26,15 +26,20 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol AssertMessages(new byte[] { ArrayBytes(5), 3, 0x80, StringBytes(1), (byte)'0', 0x03, ArrayBytes(1), 42 }, result); } - [Fact] - public void WriteAndParseDateTimeConvertsToUTC() + [Theory] + [InlineData(DateTimeKind.Utc)] + [InlineData(DateTimeKind.Local)] + [InlineData(DateTimeKind.Unspecified)] + public void WriteAndParseDateTimeConvertsToUTC(DateTimeKind dateTimeKind) { - var dateTime = new DateTime(2018, 4, 9); + // The messagepack Timestamp format always converts input DateTime to Utc if they are passed as "DateTimeKind.Local" : + // https://github.com/neuecc/MessagePack-CSharp/pull/520/files#diff-ed970b3daebc708ce49f55d418075979 + var originalDateTime = new DateTime(2018, 4, 9, 0, 0, 0, dateTimeKind); var writer = MemoryBufferWriter.Get(); try { - HubProtocol.WriteMessage(CompletionMessage.WithResult("xyz", dateTime), writer); + HubProtocol.WriteMessage(CompletionMessage.WithResult("xyz", originalDateTime), writer); var bytes = new ReadOnlySequence<byte>(writer.ToArray()); HubProtocol.TryParseMessage(ref bytes, new TestBinder(typeof(DateTime)), out var hubMessage); @@ -44,7 +49,10 @@ namespace Microsoft.AspNetCore.SignalR.Common.Tests.Internal.Protocol // The messagepack Timestamp format specifies that time is stored as seconds since 1970-01-01 UTC // so the library has no choice but to store the time as UTC // https://github.com/msgpack/msgpack/blob/master/spec.md#timestamp-extension-type - Assert.Equal(dateTime.ToUniversalTime(), resultDateTime); + // So If the original DateTiem was a "Local" one, we create a new DateTime equivalent to the original one but converted to Utc + var expectedUtcDateTime = (originalDateTime.Kind == DateTimeKind.Local) ? originalDateTime.ToUniversalTime() : originalDateTime; + + Assert.Equal(expectedUtcDateTime, resultDateTime); } finally { diff --git a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs index c1d98cbddc39e14cd4e6270493976f2ae2587b25..e86fa54f51786250ed9c07de1c83a08b6dd74d81 100644 --- a/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs +++ b/src/SignalR/server/SignalR/test/HubConnectionHandlerTests.cs @@ -2512,33 +2512,15 @@ namespace Microsoft.AspNetCore.SignalR.Tests private class StringFormatter<T> : IMessagePackFormatter<T> { - public T Deserialize(byte[] bytes, int offset, IFormatterResolver formatterResolver, out int readSize) + public T Deserialize(ref MessagePackReader reader, MessagePackSerializerOptions options) { // this method isn't used in our tests - readSize = 0; return default; } - public int Serialize(ref byte[] bytes, int offset, T value, IFormatterResolver formatterResolver) - { - // string of size 15 - bytes[offset] = 0xAF; - bytes[offset + 1] = (byte)'f'; - bytes[offset + 2] = (byte)'o'; - bytes[offset + 3] = (byte)'r'; - bytes[offset + 4] = (byte)'m'; - bytes[offset + 5] = (byte)'a'; - bytes[offset + 6] = (byte)'t'; - bytes[offset + 7] = (byte)'t'; - bytes[offset + 8] = (byte)'e'; - bytes[offset + 9] = (byte)'d'; - bytes[offset + 10] = (byte)'S'; - bytes[offset + 11] = (byte)'t'; - bytes[offset + 12] = (byte)'r'; - bytes[offset + 13] = (byte)'i'; - bytes[offset + 14] = (byte)'n'; - bytes[offset + 15] = (byte)'g'; - return 16; + public void Serialize(ref MessagePackWriter writer, T value, MessagePackSerializerOptions options) + { + writer.Write("formattedString"); } } } diff --git a/src/SignalR/server/StackExchangeRedis/src/Internal/MessagePackUtil.cs b/src/SignalR/server/StackExchangeRedis/src/Internal/MessagePackUtil.cs deleted file mode 100644 index 7780bca98850bd21d0d8b2c9ee99fc6fd41b81f5..0000000000000000000000000000000000000000 --- a/src/SignalR/server/StackExchangeRedis/src/Internal/MessagePackUtil.cs +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright (c) .NET Foundation. All rights reserved. -// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. - -using System; -using System.Diagnostics; -using System.Runtime.InteropServices; -using MessagePack; - -namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal -{ - internal static class MessagePackUtil - { - public static int ReadArrayHeader(ref ReadOnlyMemory<byte> data) - { - var arr = GetArray(data); - var val = MessagePackBinary.ReadArrayHeader(arr.Array, arr.Offset, out var readSize); - data = data.Slice(readSize); - return val; - } - - public static int ReadMapHeader(ref ReadOnlyMemory<byte> data) - { - var arr = GetArray(data); - var val = MessagePackBinary.ReadMapHeader(arr.Array, arr.Offset, out var readSize); - data = data.Slice(readSize); - return val; - } - - public static string ReadString(ref ReadOnlyMemory<byte> data) - { - var arr = GetArray(data); - var val = MessagePackBinary.ReadString(arr.Array, arr.Offset, out var readSize); - data = data.Slice(readSize); - return val; - } - - public static byte[] ReadBytes(ref ReadOnlyMemory<byte> data) - { - var arr = GetArray(data); - var val = MessagePackBinary.ReadBytes(arr.Array, arr.Offset, out var readSize); - data = data.Slice(readSize); - return val; - } - - public static int ReadInt32(ref ReadOnlyMemory<byte> data) - { - var arr = GetArray(data); - var val = MessagePackBinary.ReadInt32(arr.Array, arr.Offset, out var readSize); - data = data.Slice(readSize); - return val; - } - - public static byte ReadByte(ref ReadOnlyMemory<byte> data) - { - var arr = GetArray(data); - var val = MessagePackBinary.ReadByte(arr.Array, arr.Offset, out var readSize); - data = data.Slice(readSize); - return val; - } - - private static ArraySegment<byte> GetArray(ReadOnlyMemory<byte> data) - { - var isArray = MemoryMarshal.TryGetArray(data, out var array); - Debug.Assert(isArray); - return array; - } - } -} diff --git a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs index 24426184895c88e2a3c939c030e63ab393304ff0..3126b52f5d49558e31b9afa60d93b4e8c0861b81 100644 --- a/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs +++ b/src/SignalR/server/StackExchangeRedis/src/Internal/RedisProtocol.cs @@ -2,6 +2,7 @@ // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. using System; +using System.Buffers; using System.Collections.Generic; using System.Diagnostics; using System.IO; @@ -43,30 +44,33 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal // * [The output of WriteSerializedHubMessage, which is an 'arr'] // Any additional items are discarded. - var writer = MemoryBufferWriter.Get(); - + var memoryBufferWriter = MemoryBufferWriter.Get(); try { - MessagePackBinary.WriteArrayHeader(writer, 2); + var writer = new MessagePackWriter(memoryBufferWriter); + + writer.WriteArrayHeader(2); if (excludedConnectionIds != null && excludedConnectionIds.Count > 0) { - MessagePackBinary.WriteArrayHeader(writer, excludedConnectionIds.Count); + writer.WriteArrayHeader(excludedConnectionIds.Count); foreach (var id in excludedConnectionIds) { - MessagePackBinary.WriteString(writer, id); + writer.Write(id); } } else { - MessagePackBinary.WriteArrayHeader(writer, 0); + writer.WriteArrayHeader(0); } - WriteHubMessage(writer, new InvocationMessage(methodName, args)); - return writer.ToArray(); + WriteHubMessage(ref writer, new InvocationMessage(methodName, args)); + writer.Flush(); + + return memoryBufferWriter.ToArray(); } finally { - MemoryBufferWriter.Return(writer); + MemoryBufferWriter.Return(memoryBufferWriter); } } @@ -80,21 +84,24 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal // * A 'str': The connection Id // Any additional items are discarded. - var writer = MemoryBufferWriter.Get(); + var memoryBufferWriter = MemoryBufferWriter.Get(); try { - MessagePackBinary.WriteArrayHeader(writer, 5); - MessagePackBinary.WriteInt32(writer, command.Id); - MessagePackBinary.WriteString(writer, command.ServerName); - MessagePackBinary.WriteByte(writer, (byte)command.Action); - MessagePackBinary.WriteString(writer, command.GroupName); - MessagePackBinary.WriteString(writer, command.ConnectionId); - - return writer.ToArray(); + var writer = new MessagePackWriter(memoryBufferWriter); + + writer.WriteArrayHeader(5); + writer.Write(command.Id); + writer.Write(command.ServerName); + writer.Write((byte)command.Action); + writer.Write(command.GroupName); + writer.Write(command.ConnectionId); + writer.Flush(); + + return memoryBufferWriter.ToArray(); } finally { - MemoryBufferWriter.Return(writer); + MemoryBufferWriter.Return(memoryBufferWriter); } } @@ -104,101 +111,110 @@ namespace Microsoft.AspNetCore.SignalR.StackExchangeRedis.Internal // * An 'int': The Id of the command being acknowledged. // Any additional items are discarded. - var writer = MemoryBufferWriter.Get(); + var memoryBufferWriter = MemoryBufferWriter.Get(); try { - MessagePackBinary.WriteArrayHeader(writer, 1); - MessagePackBinary.WriteInt32(writer, messageId); + var writer = new MessagePackWriter(memoryBufferWriter); - return writer.ToArray(); + writer.WriteArrayHeader(1); + writer.Write(messageId); + writer.Flush(); + + return memoryBufferWriter.ToArray(); } finally { - MemoryBufferWriter.Return(writer); + MemoryBufferWriter.Return(memoryBufferWriter); } } public RedisInvocation ReadInvocation(ReadOnlyMemory<byte> data) { // See WriteInvocation for the format - ValidateArraySize(ref data, 2, "Invocation"); + var reader = new MessagePackReader(data); + ValidateArraySize(ref reader, 2, "Invocation"); // Read excluded Ids IReadOnlyList<string> excludedConnectionIds = null; - var idCount = MessagePackUtil.ReadArrayHeader(ref data); + var idCount = reader.ReadArrayHeader(); if (idCount > 0) { var ids = new string[idCount]; for (var i = 0; i < idCount; i++) { - ids[i] = MessagePackUtil.ReadString(ref data); + ids[i] = reader.ReadString(); } excludedConnectionIds = ids; } // Read payload - var message = ReadSerializedHubMessage(ref data); + var message = ReadSerializedHubMessage(ref reader); return new RedisInvocation(message, excludedConnectionIds); } public RedisGroupCommand ReadGroupCommand(ReadOnlyMemory<byte> data) { + var reader = new MessagePackReader(data); + // See WriteGroupCommand for format. - ValidateArraySize(ref data, 5, "GroupCommand"); + ValidateArraySize(ref reader, 5, "GroupCommand"); - var id = MessagePackUtil.ReadInt32(ref data); - var serverName = MessagePackUtil.ReadString(ref data); - var action = (GroupAction)MessagePackUtil.ReadByte(ref data); - var groupName = MessagePackUtil.ReadString(ref data); - var connectionId = MessagePackUtil.ReadString(ref data); + var id = reader.ReadInt32(); + var serverName = reader.ReadString(); + var action = (GroupAction)reader.ReadByte(); + var groupName = reader.ReadString(); + var connectionId = reader.ReadString(); return new RedisGroupCommand(id, serverName, action, groupName, connectionId); } public int ReadAck(ReadOnlyMemory<byte> data) { + var reader = new MessagePackReader(data); + // See WriteAck for format - ValidateArraySize(ref data, 1, "Ack"); - return MessagePackUtil.ReadInt32(ref data); + ValidateArraySize(ref reader, 1, "Ack"); + return reader.ReadInt32(); } - private void WriteHubMessage(Stream stream, HubMessage message) + private void WriteHubMessage(ref MessagePackWriter writer, HubMessage message) { // Written as a MessagePack 'map' where the keys are the name of the protocol (as a MessagePack 'str') // and the values are the serialized blob (as a MessagePack 'bin'). var serializedHubMessages = _messageSerializer.SerializeMessage(message); - MessagePackBinary.WriteMapHeader(stream, serializedHubMessages.Count); + writer.WriteMapHeader(serializedHubMessages.Count); foreach (var serializedMessage in serializedHubMessages) { - MessagePackBinary.WriteString(stream, serializedMessage.ProtocolName); + writer.Write(serializedMessage.ProtocolName); var isArray = MemoryMarshal.TryGetArray(serializedMessage.Serialized, out var array); Debug.Assert(isArray); - MessagePackBinary.WriteBytes(stream, array.Array, array.Offset, array.Count); + writer.Write(array); } } - public static SerializedHubMessage ReadSerializedHubMessage(ref ReadOnlyMemory<byte> data) + public static SerializedHubMessage ReadSerializedHubMessage(ref MessagePackReader reader) { - var count = MessagePackUtil.ReadMapHeader(ref data); + var count = reader.ReadMapHeader(); var serializations = new SerializedMessage[count]; for (var i = 0; i < count; i++) { - var protocol = MessagePackUtil.ReadString(ref data); - var serialized = MessagePackUtil.ReadBytes(ref data); + var protocol = reader.ReadString(); + var serialized = reader.ReadBytes()?.ToArray() ?? Array.Empty<byte>(); + serializations[i] = new SerializedMessage(protocol, serialized); } return new SerializedHubMessage(serializations); } - private static void ValidateArraySize(ref ReadOnlyMemory<byte> data, int expectedLength, string messageType) + private static void ValidateArraySize(ref MessagePackReader reader, int expectedLength, string messageType) { - var length = MessagePackUtil.ReadArrayHeader(ref data); + var length = reader.ReadArrayHeader(); if (length < expectedLength) {