Skip to content
代码片段 群组 项目
HttpConnectionDispatcherTests.cs 91.1 KB
更新 更旧
  • 了解如何忽略特定修订
  • // Copyright (c) .NET Foundation. All rights reserved.
    
    // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
    
    using System;
    
    using System.Buffers;
    
    David Fowler's avatar
    David Fowler 已提交
    using System.Collections.Generic;
    using System.IO;
    
    using System.IO.Pipelines;
    
    using System.Linq;
    using System.Net;
    
    using System.Net.WebSockets;
    
    using System.Security.Claims;
    
    using System.Security.Principal;
    
    David Fowler's avatar
    David Fowler 已提交
    using System.Text;
    
    David Fowler's avatar
    David Fowler 已提交
    using System.Threading.Tasks;
    
    using Microsoft.AspNetCore.Connections;
    using Microsoft.AspNetCore.Connections.Features;
    
    using Microsoft.AspNetCore.Http.Connections.Internal;
    
    using Microsoft.AspNetCore.Http.Connections.Internal.Transports;
    
    using Microsoft.AspNetCore.Http.Features;
    using Microsoft.AspNetCore.Http.Internal;
    
    using Microsoft.AspNetCore.Internal;
    
    using Microsoft.AspNetCore.SignalR.Tests;
    
    using Microsoft.AspNetCore.Testing;
    
    using Microsoft.AspNetCore.Testing.xunit;
    
    using Microsoft.Extensions.DependencyInjection;
    
    using Microsoft.Extensions.Logging;
    
    using Microsoft.Extensions.Logging.Testing;
    
    using Microsoft.Extensions.Options;
    
    David Fowler's avatar
    David Fowler 已提交
    using Microsoft.Extensions.Primitives;
    
    using Newtonsoft.Json;
    using Newtonsoft.Json.Linq;
    
    David Fowler's avatar
    David Fowler 已提交
    using Xunit;
    
    
    namespace Microsoft.AspNetCore.Http.Connections.Tests
    
    David Fowler's avatar
    David Fowler 已提交
    {
    
        public class HttpConnectionDispatcherTests : VerifiableLoggedTest
    
    David Fowler's avatar
    David Fowler 已提交
        {
            [Fact]
    
            public async Task NegotiateReservesConnectionIdAndReturnsIt()
    
    David Fowler's avatar
    David Fowler 已提交
            {
    
                using (StartVerifiableLog())
    
                    var manager = CreateConnectionManager(LoggerFactory);
                    var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
    
                    var context = new DefaultHttpContext();
                    var services = new ServiceCollection();
    
                    services.AddSingleton<TestConnectionHandler>();
    
                    services.AddOptions();
                    var ms = new MemoryStream();
                    context.Request.Path = "/foo";
                    context.Request.Method = "POST";
                    context.Response.Body = ms;
    
                    await dispatcher.ExecuteNegotiateAsync(context, new HttpConnectionDispatcherOptions());
    
                    var negotiateResponse = JsonConvert.DeserializeObject<JObject>(Encoding.UTF8.GetString(ms.ToArray()));
                    var connectionId = negotiateResponse.Value<string>("connectionId");
                    Assert.True(manager.TryGetConnection(connectionId, out var connectionContext));
                    Assert.Equal(connectionId, connectionContext.ConnectionId);
                }
    
            [Fact]
            public async Task CheckThatThresholdValuesAreEnforced()
            {
    
                using (StartVerifiableLog())
    
                    var manager = CreateConnectionManager(LoggerFactory);
                    var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
    
                    var context = new DefaultHttpContext();
                    var services = new ServiceCollection();
    
                    services.AddSingleton<TestConnectionHandler>();
    
                    services.AddOptions();
                    var ms = new MemoryStream();
                    context.Request.Path = "/foo";
                    context.Request.Method = "POST";
                    context.Response.Body = ms;
    
                    var options = new HttpConnectionDispatcherOptions { TransportMaxBufferSize = 4, ApplicationMaxBufferSize = 4 };
    
                    await dispatcher.ExecuteNegotiateAsync(context, options);
    
                    var negotiateResponse = JsonConvert.DeserializeObject<JObject>(Encoding.UTF8.GetString(ms.ToArray()));
                    var connectionId = negotiateResponse.Value<string>("connectionId");
    
                    context.Request.QueryString = context.Request.QueryString.Add("id", connectionId);
    
                    Assert.True(manager.TryGetConnection(connectionId, out var connection));
    
                    // Fake actual connection after negotiate to populate the pipes on the connection
    
                    await dispatcher.ExecuteAsync(context, options, c => Task.CompletedTask);
    
    
                    // This write should complete immediately but it exceeds the writer threshold
    
                    var writeTask = connection.Application.Output.WriteAsync(new[] { (byte)'b', (byte)'y', (byte)'t', (byte)'e', (byte)'s' });
    
    
                    Assert.False(writeTask.IsCompleted);
    
    
                    // Reading here puts us below the threshold
    
                    await connection.Transport.Input.ConsumeAsync(5);
    
    
                    await writeTask.AsTask().OrTimeout();
    
            [InlineData(HttpTransportType.LongPolling)]
            [InlineData(HttpTransportType.ServerSentEvents)]
            public async Task CheckThatThresholdValuesAreEnforcedWithSends(HttpTransportType transportType)
    
                using (StartVerifiableLog())
    
                    var manager = CreateConnectionManager(LoggerFactory);
                    var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
    
                    var pipeOptions = new PipeOptions(pauseWriterThreshold: 8, resumeWriterThreshold: 4);
                    var connection = manager.CreateConnection(pipeOptions, pipeOptions);
    
                    connection.TransportType = transportType;
    
    
                    using (var requestBody = new MemoryStream())
                    using (var responseBody = new MemoryStream())
                    {
                        var bytes = Encoding.UTF8.GetBytes("EXTRADATA Hi");
                        requestBody.Write(bytes, 0, bytes.Length);
                        requestBody.Seek(0, SeekOrigin.Begin);
    
                        var context = new DefaultHttpContext();
                        context.Request.Body = requestBody;
                        context.Response.Body = responseBody;
    
                        var services = new ServiceCollection();
    
                        services.AddSingleton<TestConnectionHandler>();
    
                        services.AddOptions();
                        context.Request.Path = "/foo";
                        context.Request.Method = "POST";
                        var values = new Dictionary<string, StringValues>();
                        values["id"] = connection.ConnectionId;
                        var qs = new QueryCollection(values);
                        context.Request.Query = qs;
    
                        var builder = new ConnectionBuilder(services.BuildServiceProvider());
    
                        builder.UseConnectionHandler<TestConnectionHandler>();
    
                        var app = builder.Build();
    
                        // This task should complete immediately but it exceeds the writer threshold
    
                        var executeTask = dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
    
                        Assert.False(executeTask.IsCompleted);
                        await connection.Transport.Input.ConsumeAsync(10);
                        await executeTask.OrTimeout();
    
                        Assert.True(connection.Transport.Input.TryRead(out var result));
                        Assert.Equal("Hi", Encoding.UTF8.GetString(result.Buffer.ToArray()));
                        connection.Transport.Input.AdvanceTo(result.Buffer.End);
                    }
                }
            }
    
    
            [Theory]
    
            [InlineData(HttpTransportType.LongPolling | HttpTransportType.WebSockets | HttpTransportType.ServerSentEvents)]
            [InlineData(HttpTransportType.None)]
    
            [InlineData(HttpTransportType.LongPolling | HttpTransportType.WebSockets)]
            public async Task NegotiateReturnsAvailableTransportsAfterFilteringByOptions(HttpTransportType transports)
    
                using (StartVerifiableLog())
    
                    var manager = CreateConnectionManager(LoggerFactory);
                    var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
    
                    var context = new DefaultHttpContext();
                    context.Features.Set<IHttpResponseFeature>(new ResponseFeature());
    
                    context.Features.Set<IHttpWebSocketFeature>(new TestWebSocketConnectionFeature());
    
                    var services = new ServiceCollection();
    
                    services.AddSingleton<TestConnectionHandler>();
    
                    services.AddOptions();
                    var ms = new MemoryStream();
                    context.Request.Path = "/foo";
                    context.Request.Method = "POST";
                    context.Response.Body = ms;
    
                    await dispatcher.ExecuteNegotiateAsync(context, new HttpConnectionDispatcherOptions { Transports = transports });
    
    
                    var negotiateResponse = JsonConvert.DeserializeObject<JObject>(Encoding.UTF8.GetString(ms.ToArray()));
    
                    var availableTransports = HttpTransportType.None;
    
                    foreach (var transport in negotiateResponse["availableTransports"])
                    {
    
                        var transportType = (HttpTransportType)Enum.Parse(typeof(HttpTransportType), transport.Value<string>("transport"));
    
    David Fowler's avatar
    David Fowler 已提交
    
    
                    Assert.Equal(transports, availableTransports);
                }
    
    David Fowler's avatar
    David Fowler 已提交
            }
    
    
            [InlineData(HttpTransportType.WebSockets)]
            [InlineData(HttpTransportType.ServerSentEvents)]
            [InlineData(HttpTransportType.LongPolling)]
            public async Task EndpointsThatAcceptConnectionId404WhenUnknownConnectionIdProvided(HttpTransportType transportType)
    
    David Fowler's avatar
    David Fowler 已提交
            {
    
                using (StartVerifiableLog())
    
    David Fowler's avatar
    David Fowler 已提交
                {
    
                    var manager = CreateConnectionManager(LoggerFactory);
                    var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
    
                    using (var strm = new MemoryStream())
                    {
                        var context = new DefaultHttpContext();
                        context.Features.Set<IHttpResponseFeature>(new ResponseFeature());
                        context.Response.Body = strm;
    
                        var services = new ServiceCollection();
    
                        services.AddSingleton<TestConnectionHandler>();
    
                        services.AddOptions();
                        context.Request.Path = "/foo";
                        context.Request.Method = "GET";
                        var values = new Dictionary<string, StringValues>();
                        values["id"] = "unknown";
                        var qs = new QueryCollection(values);
                        context.Request.Query = qs;
                        SetTransport(context, transportType);
    
    
                        var builder = new ConnectionBuilder(services.BuildServiceProvider());
    
                        builder.UseConnectionHandler<TestConnectionHandler>();
    
                        var app = builder.Build();
    
                        await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
    
    
                        Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode);
                        await strm.FlushAsync();
                        Assert.Equal("No Connection with that ID", Encoding.UTF8.GetString(strm.ToArray()));
                    }
    
                }
            }
    
            [Fact]
            public async Task EndpointsThatAcceptConnectionId404WhenUnknownConnectionIdProvidedForPost()
            {
    
                using (StartVerifiableLog())
    
                    var manager = CreateConnectionManager(LoggerFactory);
                    var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
    
                    using (var strm = new MemoryStream())
                    {
                        var context = new DefaultHttpContext();
                        context.Response.Body = strm;
    
                        var services = new ServiceCollection();
    
                        services.AddSingleton<TestConnectionHandler>();
    
                        services.AddOptions();
                        context.Request.Path = "/foo";
                        context.Request.Method = "POST";
                        var values = new Dictionary<string, StringValues>();
                        values["id"] = "unknown";
                        var qs = new QueryCollection(values);
                        context.Request.Query = qs;
    
    
                        var builder = new ConnectionBuilder(services.BuildServiceProvider());
    
                        builder.UseConnectionHandler<TestConnectionHandler>();
    
                        var app = builder.Build();
    
                        await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
    
    
                        Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode);
                        await strm.FlushAsync();
                        Assert.Equal("No Connection with that ID", Encoding.UTF8.GetString(strm.ToArray()));
                    }
    
    David Fowler's avatar
    David Fowler 已提交
            }
    
    
            [Fact]
            public async Task PostNotAllowedForWebSocketConnections()
            {
    
                using (StartVerifiableLog())
    
                    var manager = CreateConnectionManager(LoggerFactory);
                    var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
    
                    var connection = manager.CreateConnection();
    
                    connection.TransportType = HttpTransportType.WebSockets;
    
                    using (var strm = new MemoryStream())
                    {
                        var context = new DefaultHttpContext();
                        context.Response.Body = strm;
    
                        var services = new ServiceCollection();
    
                        services.AddSingleton<TestConnectionHandler>();
    
                        services.AddOptions();
                        context.Request.Path = "/foo";
                        context.Request.Method = "POST";
                        var values = new Dictionary<string, StringValues>();
                        values["id"] = connection.ConnectionId;
                        var qs = new QueryCollection(values);
                        context.Request.Query = qs;
    
    
                        var builder = new ConnectionBuilder(services.BuildServiceProvider());
    
                        builder.UseConnectionHandler<TestConnectionHandler>();
    
                        var app = builder.Build();
    
                        await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
    
    
                        Assert.Equal(StatusCodes.Status405MethodNotAllowed, context.Response.StatusCode);
                        await strm.FlushAsync();
                        Assert.Equal("POST requests are not allowed for WebSocket connections.", Encoding.UTF8.GetString(strm.ToArray()));
                    }
    
            [Fact]
            public async Task PostReturns404IfConnectionDisposed()
            {
    
                using (StartVerifiableLog())
    
                    var manager = CreateConnectionManager(LoggerFactory);
                    var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
    
                    var connection = manager.CreateConnection();
                    connection.TransportType = HttpTransportType.LongPolling;
                    await connection.DisposeAsync(closeGracefully: false);
    
                    using (var strm = new MemoryStream())
                    {
                        var context = new DefaultHttpContext();
                        context.Response.Body = strm;
    
                        var services = new ServiceCollection();
                        services.AddSingleton<TestConnectionHandler>();
                        services.AddOptions();
                        context.Request.Path = "/foo";
                        context.Request.Method = "POST";
                        var values = new Dictionary<string, StringValues>();
                        values["id"] = connection.ConnectionId;
                        var qs = new QueryCollection(values);
                        context.Request.Query = qs;
    
                        var builder = new ConnectionBuilder(services.BuildServiceProvider());
                        builder.UseConnectionHandler<TestConnectionHandler>();
                        var app = builder.Build();
    
                        await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
    
    
                        Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode);
                    }
                }
            }
    
            [Theory]
            [InlineData(HttpTransportType.ServerSentEvents)]
            [InlineData(HttpTransportType.WebSockets)]
            public async Task TransportEndingGracefullyWaitsOnApplication(HttpTransportType transportType)
            {
    
                using (StartVerifiableLog())
    
                    var manager = CreateConnectionManager(LoggerFactory);
                    var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
    
                    var connection = manager.CreateConnection();
                    connection.TransportType = transportType;
    
                    using (var strm = new MemoryStream())
                    {
                        var context = new DefaultHttpContext();
                        SetTransport(context, transportType);
                        var cts = new CancellationTokenSource();
                        context.Response.Body = strm;
                        context.RequestAborted = cts.Token;
    
                        var services = new ServiceCollection();
                        services.AddSingleton<TestConnectionHandler>();
                        services.AddOptions();
                        context.Request.Path = "/foo";
                        context.Request.Method = "GET";
                        var values = new Dictionary<string, StringValues>();
                        values["id"] = connection.ConnectionId;
                        var qs = new QueryCollection(values);
                        context.Request.Query = qs;
    
                        var builder = new ConnectionBuilder(services.BuildServiceProvider());
                        builder.Use(next =>
                        {
                            return async connectionContext =>
                            {
                                // Ensure both sides of the pipe are ok
                                var result = await connectionContext.Transport.Input.ReadAsync();
                                Assert.True(result.IsCompleted);
                                await connectionContext.Transport.Output.WriteAsync(result.Buffer.First);
                            };
                        });
    
                        var app = builder.Build();
    
                        var task = dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
    
    
                        // Pretend the transport closed because the client disconnected
                        if (context.WebSockets.IsWebSocketRequest)
                        {
                            var ws = (TestWebSocketConnectionFeature)context.Features.Get<IHttpWebSocketFeature>();
                            await ws.Client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", default);
                        }
                        else
                        {
                            cts.Cancel();
                        }
    
                        await task.OrTimeout();
    
                        await connection.ApplicationTask.OrTimeout();
                    }
                }
            }
    
            [Fact]
            public async Task TransportEndingGracefullyWaitsOnApplicationLongPolling()
            {
    
                using (StartVerifiableLog())
    
                    var manager = CreateConnectionManager(LoggerFactory, TimeSpan.FromSeconds(5));
    
                    var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
    
                    var connection = manager.CreateConnection();
                    connection.TransportType = HttpTransportType.LongPolling;
    
                    using (var strm = new MemoryStream())
                    {
                        var context = new DefaultHttpContext();
                        SetTransport(context, HttpTransportType.LongPolling);
                        var cts = new CancellationTokenSource();
                        context.Response.Body = strm;
                        context.RequestAborted = cts.Token;
    
                        var services = new ServiceCollection();
                        services.AddSingleton<TestConnectionHandler>();
                        services.AddOptions();
                        context.Request.Path = "/foo";
                        context.Request.Method = "GET";
                        var values = new Dictionary<string, StringValues>();
                        values["id"] = connection.ConnectionId;
                        var qs = new QueryCollection(values);
                        context.Request.Query = qs;
    
                        var builder = new ConnectionBuilder(services.BuildServiceProvider());
                        builder.Use(next =>
                        {
                            return async connectionContext =>
                            {
                                // Ensure both sides of the pipe are ok
                                var result = await connectionContext.Transport.Input.ReadAsync();
                                Assert.True(result.IsCompleted);
                                await connectionContext.Transport.Output.WriteAsync(result.Buffer.First);
                            };
                        });
    
                        var app = builder.Build();
    
                        var task = dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
    
    
                        // Pretend the transport closed because the client disconnected
                        cts.Cancel();
    
                        await task.OrTimeout();
    
                        // We've been gone longer than the expiration time
                        connection.LastSeenUtc = DateTime.UtcNow.Subtract(TimeSpan.FromSeconds(10));
    
                        // The application is still running here because the poll is only killed
                        // by the heartbeat so we pretend to do a scan and this should force the application task to complete
    
    
                        // The application task should complete gracefully
                        await connection.ApplicationTask.OrTimeout();
                    }
                }
            }
    
    
            [InlineData(HttpTransportType.LongPolling)]
            [InlineData(HttpTransportType.ServerSentEvents)]
            public async Task PostSendsToConnection(HttpTransportType transportType)
    
                using (StartVerifiableLog())
    
                    var manager = CreateConnectionManager(LoggerFactory);
                    var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
    
                    var connection = manager.CreateConnection();
    
                    connection.TransportType = transportType;
    
    
                    using (var requestBody = new MemoryStream())
                    using (var responseBody = new MemoryStream())
                    {
                        var bytes = Encoding.UTF8.GetBytes("Hello World");
                        requestBody.Write(bytes, 0, bytes.Length);
                        requestBody.Seek(0, SeekOrigin.Begin);
    
                        var context = new DefaultHttpContext();
                        context.Request.Body = requestBody;
                        context.Response.Body = responseBody;
    
                        var services = new ServiceCollection();
    
                        services.AddSingleton<TestConnectionHandler>();
    
                        services.AddOptions();
                        context.Request.Path = "/foo";
                        context.Request.Method = "POST";
                        var values = new Dictionary<string, StringValues>();
                        values["id"] = connection.ConnectionId;
                        var qs = new QueryCollection(values);
                        context.Request.Query = qs;
    
    
                        var builder = new ConnectionBuilder(services.BuildServiceProvider());
    
                        builder.UseConnectionHandler<TestConnectionHandler>();
    
                        var app = builder.Build();
    
    
                        Assert.Equal(0, connection.ApplicationStream.Length);
    
    
                        await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
    
    
                        Assert.True(connection.Transport.Input.TryRead(out var result));
                        Assert.Equal("Hello World", Encoding.UTF8.GetString(result.Buffer.ToArray()));
    
                        Assert.Equal(0, connection.ApplicationStream.Length);
    
                        connection.Transport.Input.AdvanceTo(result.Buffer.End);
                    }
                }
            }
    
    
            [Theory]
            [InlineData(HttpTransportType.LongPolling)]
            [InlineData(HttpTransportType.ServerSentEvents)]
            public async Task PostSendsToConnectionInParallel(HttpTransportType transportType)
            {
    
                using (StartVerifiableLog())
    
                    var manager = CreateConnectionManager(LoggerFactory);
                    var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
    
                    var connection = manager.CreateConnection();
    
                    connection.TransportType = transportType;
    
    
                    // Allow a maximum of one caller to use code at one time
                    var callerTracker = new SemaphoreSlim(1, 1);
                    var waitTcs = new TaskCompletionSource<bool>();
    
                    // This tests thread safety of sending multiple pieces of data to a connection at once
                    var executeTask1 = DispatcherExecuteAsync(dispatcher, connection, callerTracker, waitTcs.Task);
                    var executeTask2 = DispatcherExecuteAsync(dispatcher, connection, callerTracker, waitTcs.Task);
    
                    waitTcs.SetResult(true);
    
                    await Task.WhenAll(executeTask1, executeTask2);
                }
    
                async Task DispatcherExecuteAsync(HttpConnectionDispatcher dispatcher, HttpConnectionContext connection, SemaphoreSlim callerTracker, Task waitTask)
                {
                    using (var requestBody = new TrackingMemoryStream(callerTracker, waitTask))
                    {
                        var bytes = Encoding.UTF8.GetBytes("Hello World");
                        requestBody.Write(bytes, 0, bytes.Length);
                        requestBody.Seek(0, SeekOrigin.Begin);
    
                        var context = new DefaultHttpContext();
                        context.Request.Body = requestBody;
    
                        var services = new ServiceCollection();
                        services.AddSingleton<TestConnectionHandler>();
                        services.AddOptions();
                        context.Request.Path = "/foo";
                        context.Request.Method = "POST";
                        var values = new Dictionary<string, StringValues>();
                        values["id"] = connection.ConnectionId;
                        var qs = new QueryCollection(values);
                        context.Request.Query = qs;
    
                        var builder = new ConnectionBuilder(services.BuildServiceProvider());
                        builder.UseConnectionHandler<TestConnectionHandler>();
                        var app = builder.Build();
    
    
                        await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
    
                    }
                }
            }
    
            private class TrackingMemoryStream : MemoryStream
            {
                private readonly SemaphoreSlim _callerTracker;
                private readonly Task _waitTask;
    
                public TrackingMemoryStream(SemaphoreSlim callerTracker, Task waitTask)
                {
                    _callerTracker = callerTracker;
                    _waitTask = waitTask;
                }
    
                public override async Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
                {
                    // Will return false if all available locks from semaphore are taken
                    if (!_callerTracker.Wait(0))
                    {
                        throw new Exception("Too many callers.");
                    }
    
                    try
                    {
                        await _waitTask;
    
                        await base.CopyToAsync(destination, bufferSize, cancellationToken);
                    }
                    finally
                    {
                        _callerTracker.Release();
                    }
                }
            }
    
    
            [Fact]
            public async Task HttpContextFeatureForLongpollingWorksBetweenPolls()
            {
    
                using (StartVerifiableLog())
    
                    var manager = CreateConnectionManager(LoggerFactory);
                    var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
    
                    var connection = manager.CreateConnection();
    
                    connection.TransportType = HttpTransportType.LongPolling;
    
    
                    using (var requestBody = new MemoryStream())
                    using (var responseBody = new MemoryStream())
                    {
                        var context = new DefaultHttpContext();
                        context.Request.Body = requestBody;
                        context.Response.Body = responseBody;
    
                        var services = new ServiceCollection();
    
                        services.AddSingleton<HttpContextConnectionHandler>();
    
                        services.AddOptions();
    
                        // Setup state on the HttpContext
                        context.Request.Path = "/foo";
                        context.Request.Method = "GET";
                        var values = new Dictionary<string, StringValues>();
                        values["id"] = connection.ConnectionId;
                        values["another"] = "value";
                        var qs = new QueryCollection(values);
                        context.Request.Query = qs;
                        context.Request.Headers["header1"] = "h1";
                        context.Request.Headers["header2"] = "h2";
                        context.Request.Headers["header3"] = "h3";
                        context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim("claim1", "claimValue") }));
                        context.TraceIdentifier = "requestid";
                        context.Connection.Id = "connectionid";
                        context.Connection.LocalIpAddress = IPAddress.Loopback;
                        context.Connection.LocalPort = 4563;
                        context.Connection.RemoteIpAddress = IPAddress.IPv6Any;
                        context.Connection.RemotePort = 43456;
    
                        context.SetEndpoint(new Endpoint(null, null, "TestName"));
    
    
                        var builder = new ConnectionBuilder(services.BuildServiceProvider());
    
                        builder.UseConnectionHandler<HttpContextConnectionHandler>();
    
                        var app = builder.Build();
    
                        // Start a poll
    
                        var task = dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
    
                        Assert.True(task.IsCompleted);
                        Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
    
                        task = dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
    
    
                        // Send to the application
                        var buffer = Encoding.UTF8.GetBytes("Hello World");
                        await connection.Application.Output.WriteAsync(buffer);
    
                        // The poll request should end
                        await task;
    
                        // Make sure the actual response isn't affected
                        Assert.Equal("application/octet-stream", context.Response.ContentType);
    
                        // Now do a new send again without the poll (that request should have ended)
                        await connection.Application.Output.WriteAsync(buffer);
    
                        connection.Application.Output.Complete();
    
                        // Wait for the endpoint to end
                        await connection.ApplicationTask;
    
                        var connectionHttpContext = connection.GetHttpContext();
                        Assert.NotNull(connectionHttpContext);
    
                        Assert.Equal(2, connectionHttpContext.Request.Query.Count);
                        Assert.Equal(connection.ConnectionId, connectionHttpContext.Request.Query["id"]);
                        Assert.Equal("value", connectionHttpContext.Request.Query["another"]);
    
                        Assert.Equal(3, connectionHttpContext.Request.Headers.Count);
                        Assert.Equal("h1", connectionHttpContext.Request.Headers["header1"]);
                        Assert.Equal("h2", connectionHttpContext.Request.Headers["header2"]);
                        Assert.Equal("h3", connectionHttpContext.Request.Headers["header3"]);
                        Assert.Equal("requestid", connectionHttpContext.TraceIdentifier);
                        Assert.Equal("claimValue", connectionHttpContext.User.Claims.FirstOrDefault().Value);
                        Assert.Equal("connectionid", connectionHttpContext.Connection.Id);
                        Assert.Equal(IPAddress.Loopback, connectionHttpContext.Connection.LocalIpAddress);
                        Assert.Equal(4563, connectionHttpContext.Connection.LocalPort);
                        Assert.Equal(IPAddress.IPv6Any, connectionHttpContext.Connection.RemoteIpAddress);
                        Assert.Equal(43456, connectionHttpContext.Connection.RemotePort);
                        Assert.NotNull(connectionHttpContext.RequestServices);
                        Assert.Equal(Stream.Null, connectionHttpContext.Response.Body);
                        Assert.NotNull(connectionHttpContext.Response.Headers);
                        Assert.Equal("application/xml", connectionHttpContext.Response.ContentType);
    
                        var endpointFeature = connectionHttpContext.Features.Get<IEndpointFeature>();
                        Assert.NotNull(endpointFeature);
                        Assert.Equal("TestName", endpointFeature.Endpoint.DisplayName);
    
            [InlineData(HttpTransportType.ServerSentEvents)]
            [InlineData(HttpTransportType.LongPolling)]
            public async Task EndpointsThatRequireConnectionId400WhenNoConnectionIdProvided(HttpTransportType transportType)
    
                using (StartVerifiableLog())
    
                    var manager = CreateConnectionManager(LoggerFactory);
                    var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
    
                    using (var strm = new MemoryStream())
                    {
                        var context = new DefaultHttpContext();
                        context.Features.Set<IHttpResponseFeature>(new ResponseFeature());
                        context.Response.Body = strm;
                        var services = new ServiceCollection();
                        services.AddOptions();
    
                        services.AddSingleton<TestConnectionHandler>();
    
                        context.Request.Path = "/foo";
                        context.Request.Method = "GET";
    
                        SetTransport(context, transportType);
    
    
                        var builder = new ConnectionBuilder(services.BuildServiceProvider());
    
                        builder.UseConnectionHandler<TestConnectionHandler>();
    
                        var app = builder.Build();
    
                        await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
    
    
                        Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
                        await strm.FlushAsync();
                        Assert.Equal("Connection ID required", Encoding.UTF8.GetString(strm.ToArray()));
                    }
    
            [Theory]
            [InlineData(HttpTransportType.LongPolling)]
            [InlineData(HttpTransportType.ServerSentEvents)]
            public async Task IOExceptionWhenReadingRequestReturns400Response(HttpTransportType transportType)
            {
    
                using (StartVerifiableLog())
    
                    var manager = CreateConnectionManager(LoggerFactory);
                    var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
    
                    var connection = manager.CreateConnection();
                    connection.TransportType = transportType;
    
                    var mockStream = new Mock<Stream>();
                    mockStream.Setup(m => m.CopyToAsync(It.IsAny<Stream>(), It.IsAny<int>(), It.IsAny<CancellationToken>())).Throws(new IOException());
    
                    using (var responseBody = new MemoryStream())
                    {
                        var context = new DefaultHttpContext();
                        context.Request.Body = mockStream.Object;
                        context.Response.Body = responseBody;
    
                        var services = new ServiceCollection();
                        services.AddSingleton<TestConnectionHandler>();
                        services.AddOptions();
                        context.Request.Path = "/foo";
                        context.Request.Method = "POST";
                        var values = new Dictionary<string, StringValues>();
                        values["id"] = connection.ConnectionId;
                        var qs = new QueryCollection(values);
                        context.Request.Query = qs;
    
                        await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), c => Task.CompletedTask);
    
                        Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
                    }
                }
            }
    
    
            [Fact]
            public async Task EndpointsThatRequireConnectionId400WhenNoConnectionIdProvidedForPost()
    
    David Fowler's avatar
    David Fowler 已提交
            {
    
                using (StartVerifiableLog())
    
    David Fowler's avatar
    David Fowler 已提交
                {
    
                    var manager = CreateConnectionManager(LoggerFactory);
                    var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
    
                    using (var strm = new MemoryStream())
                    {
                        var context = new DefaultHttpContext();
                        context.Response.Body = strm;
                        var services = new ServiceCollection();
                        services.AddOptions();
    
                        services.AddSingleton<TestConnectionHandler>();
    
                        context.Request.Path = "/foo";
                        context.Request.Method = "POST";
    
    
                        var builder = new ConnectionBuilder(services.BuildServiceProvider());
    
                        builder.UseConnectionHandler<TestConnectionHandler>();
    
                        var app = builder.Build();
    
                        await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
    
    
                        Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
                        await strm.FlushAsync();
                        Assert.Equal("Connection ID required", Encoding.UTF8.GetString(strm.ToArray()));
                    }
    
    David Fowler's avatar
    David Fowler 已提交
            }
    
            [Theory]
    
            [InlineData(HttpTransportType.LongPolling, 200)]
    
            [InlineData(HttpTransportType.WebSockets, 404)]
            [InlineData(HttpTransportType.ServerSentEvents, 404)]
            public async Task EndPointThatOnlySupportsLongPollingRejectsOtherTransports(HttpTransportType transportType, int status)
    
                using (StartVerifiableLog())
    
                    await CheckTransportSupported(HttpTransportType.LongPolling, transportType, status, LoggerFactory);
    
            [InlineData(HttpTransportType.ServerSentEvents, 200)]
            [InlineData(HttpTransportType.WebSockets, 404)]
            [InlineData(HttpTransportType.LongPolling, 404)]
            public async Task EndPointThatOnlySupportsSSERejectsOtherTransports(HttpTransportType transportType, int status)
    
                using (StartVerifiableLog())
    
                    await CheckTransportSupported(HttpTransportType.ServerSentEvents, transportType, status, LoggerFactory);
    
            [InlineData(HttpTransportType.WebSockets, 200)]
            [InlineData(HttpTransportType.ServerSentEvents, 404)]
            [InlineData(HttpTransportType.LongPolling, 404)]
            public async Task EndPointThatOnlySupportsWebSockesRejectsOtherTransports(HttpTransportType transportType, int status)
    
                using (StartVerifiableLog())
    
                    await CheckTransportSupported(HttpTransportType.WebSockets, transportType, status, LoggerFactory);
    
            [InlineData(HttpTransportType.LongPolling, 404)]
            public async Task EndPointThatOnlySupportsWebSocketsAndSSERejectsLongPolling(HttpTransportType transportType, int status)
    
                using (StartVerifiableLog())
    
                    await CheckTransportSupported(HttpTransportType.WebSockets | HttpTransportType.ServerSentEvents, transportType, status, LoggerFactory);
    
            [Fact]
            public async Task CompletedEndPointEndsConnection()
            {
    
                using (StartVerifiableLog())
    
                    var manager = CreateConnectionManager(LoggerFactory);
    
                    var connection = manager.CreateConnection();
    
                    connection.TransportType = HttpTransportType.ServerSentEvents;
    
                    var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
    
                    var context = MakeRequest("/foo", connection);
    
                    SetTransport(context, HttpTransportType.ServerSentEvents);
    
                    var services = new ServiceCollection();
    
                    services.AddSingleton<ImmediatelyCompleteConnectionHandler>();
    
                    var builder = new ConnectionBuilder(services.BuildServiceProvider());
    
                    builder.UseConnectionHandler<ImmediatelyCompleteConnectionHandler>();
    
                    var app = builder.Build();
    
                    await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
    
                    Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
    
                    var exists = manager.TryGetConnection(connection.ConnectionId, out _);
    
            [Fact]
            public async Task SynchronusExceptionEndsConnection()
            {
    
                bool ExpectedErrors(WriteContext writeContext)
                {
                    return writeContext.LoggerName == typeof(HttpConnectionManager).FullName &&
                           writeContext.EventId.Name == "FailedDispose";
                }
    
    
                using (StartVerifiableLog(expectedErrorsFilter: ExpectedErrors))
    
                    var manager = CreateConnectionManager(LoggerFactory);
    
                    var connection = manager.CreateConnection();
    
                    connection.TransportType = HttpTransportType.ServerSentEvents;
    
                    var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
    
                    var context = MakeRequest("/foo", connection);
    
                    SetTransport(context, HttpTransportType.ServerSentEvents);
    
                    var services = new ServiceCollection();
    
                    services.AddSingleton<SynchronusExceptionConnectionHandler>();
    
                    var builder = new ConnectionBuilder(services.BuildServiceProvider());
    
                    builder.UseConnectionHandler<SynchronusExceptionConnectionHandler>();
    
                    var app = builder.Build();
    
                    await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
    
                    Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
    
                    var exists = manager.TryGetConnection(connection.ConnectionId, out _);
    
            [Fact]
            public async Task CompletedEndPointEndsLongPollingConnection()
            {
    
                using (StartVerifiableLog())
    
                    var manager = CreateConnectionManager(LoggerFactory);
    
                    var connection = manager.CreateConnection();
    
                    connection.TransportType = HttpTransportType.LongPolling;
    
                    var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
    
                    var context = MakeRequest("/foo", connection);
    
                    var services = new ServiceCollection();
    
                    services.AddSingleton<ImmediatelyCompleteConnectionHandler>();
    
                    var builder = new ConnectionBuilder(services.BuildServiceProvider());
    
                    builder.UseConnectionHandler<ImmediatelyCompleteConnectionHandler>();
    
                    var app = builder.Build();
    
                    // First poll will 200
                    await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
                    Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
    
    
                    await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
    
                    Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
    
                    var exists = manager.TryGetConnection(connection.ConnectionId, out _);
    
            [Fact]
            public async Task LongPollingTimeoutSets200StatusCode()
            {
    
                using (StartVerifiableLog())
    
                    var manager = CreateConnectionManager(LoggerFactory);
    
                    var connection = manager.CreateConnection();
    
                    connection.TransportType = HttpTransportType.LongPolling;
    
                    var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
    
                    var context = MakeRequest("/foo", connection);
    
                    var services = new ServiceCollection();
    
                    services.AddSingleton<TestConnectionHandler>();
    
                    var builder = new ConnectionBuilder(services.BuildServiceProvider());
    
                    builder.UseConnectionHandler<TestConnectionHandler>();
    
                    var app = builder.Build();
    
                    var options = new HttpConnectionDispatcherOptions();
    
                    options.LongPolling.PollTimeout = TimeSpan.FromSeconds(2);
                    await dispatcher.ExecuteAsync(context, options, app).OrTimeout();
    
                    Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
                }
    
            [LogLevel(LogLevel.Trace)]
    
            public async Task WebSocketTransportTimesOutWhenCloseFrameNotReceived()
            {
    
                using (StartVerifiableLog())
    
                    var manager = CreateConnectionManager(LoggerFactory);
    
                    var connection = manager.CreateConnection();
    
                    connection.TransportType = HttpTransportType.WebSockets;
    
                    var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
    
                    var context = MakeRequest("/foo", connection);
    
                    SetTransport(context, HttpTransportType.WebSockets);
    
                    var services = new ServiceCollection();
    
                    services.AddSingleton<ImmediatelyCompleteConnectionHandler>();
    
                    var builder = new ConnectionBuilder(services.BuildServiceProvider());
    
                    builder.UseConnectionHandler<ImmediatelyCompleteConnectionHandler>();
    
                    var app = builder.Build();
    
                    var options = new HttpConnectionDispatcherOptions();
    
                    options.WebSockets.CloseTimeout = TimeSpan.FromSeconds(1);
    
                    var task = dispatcher.ExecuteAsync(context, options, app);
    
                    await task.OrTimeout();
                }
    
            [InlineData(HttpTransportType.WebSockets)]
            [InlineData(HttpTransportType.ServerSentEvents)]
            public async Task RequestToActiveConnectionId409ForStreamingTransports(HttpTransportType transportType)
    
                using (StartVerifiableLog())
    
                    var manager = CreateConnectionManager(LoggerFactory);
    
                    var connection = manager.CreateConnection();
    
                    connection.TransportType = transportType;
    
                    var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
    
                    var context1 = MakeRequest("/foo", connection);
                    var context2 = MakeRequest("/foo", connection);