更新
更旧
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Linq;
using System.Net;
using System.Security.Principal;
using System.Threading;
using Microsoft.AspNetCore.Connections;
using Microsoft.AspNetCore.Connections.Features;
using Microsoft.AspNetCore.Http.Connections.Internal;
using Microsoft.AspNetCore.Http.Connections.Internal.Transports;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Http.Internal;
using Microsoft.AspNetCore.Testing;
using Microsoft.Extensions.DependencyInjection;
using Microsoft.Extensions.Logging.Testing;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
namespace Microsoft.AspNetCore.Http.Connections.Tests
public class HttpConnectionDispatcherTests : VerifiableLoggedTest
public async Task NegotiateReservesConnectionIdAndReturnsIt()
var manager = CreateConnectionManager(LoggerFactory);
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
services.AddOptions();
var ms = new MemoryStream();
context.Request.Path = "/foo";
context.Request.Method = "POST";
context.Response.Body = ms;
await dispatcher.ExecuteNegotiateAsync(context, new HttpConnectionDispatcherOptions());
var negotiateResponse = JsonConvert.DeserializeObject<JObject>(Encoding.UTF8.GetString(ms.ToArray()));
var connectionId = negotiateResponse.Value<string>("connectionId");
Assert.True(manager.TryGetConnection(connectionId, out var connectionContext));
Assert.Equal(connectionId, connectionContext.ConnectionId);
}
[Fact]
public async Task CheckThatThresholdValuesAreEnforced()
{
var manager = CreateConnectionManager(LoggerFactory);
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = new DefaultHttpContext();
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
services.AddOptions();
var ms = new MemoryStream();
context.Request.Path = "/foo";
context.Request.Method = "POST";
context.Response.Body = ms;
var options = new HttpConnectionDispatcherOptions { TransportMaxBufferSize = 4, ApplicationMaxBufferSize = 4 };
await dispatcher.ExecuteNegotiateAsync(context, options);
var negotiateResponse = JsonConvert.DeserializeObject<JObject>(Encoding.UTF8.GetString(ms.ToArray()));
var connectionId = negotiateResponse.Value<string>("connectionId");
context.Request.QueryString = context.Request.QueryString.Add("id", connectionId);
Assert.True(manager.TryGetConnection(connectionId, out var connection));
// Fake actual connection after negotiate to populate the pipes on the connection
await dispatcher.ExecuteAsync(context, options, c => Task.CompletedTask);
// This write should complete immediately but it exceeds the writer threshold
var writeTask = connection.Application.Output.WriteAsync(new[] { (byte)'b', (byte)'y', (byte)'t', (byte)'e', (byte)'s' });
Assert.False(writeTask.IsCompleted);
await connection.Transport.Input.ConsumeAsync(5);
}
}
[Theory]
[InlineData(HttpTransportType.LongPolling)]
[InlineData(HttpTransportType.ServerSentEvents)]
public async Task CheckThatThresholdValuesAreEnforcedWithSends(HttpTransportType transportType)
var manager = CreateConnectionManager(LoggerFactory);
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var pipeOptions = new PipeOptions(pauseWriterThreshold: 8, resumeWriterThreshold: 4);
var connection = manager.CreateConnection(pipeOptions, pipeOptions);
connection.TransportType = transportType;
using (var requestBody = new MemoryStream())
using (var responseBody = new MemoryStream())
{
var bytes = Encoding.UTF8.GetBytes("EXTRADATA Hi");
requestBody.Write(bytes, 0, bytes.Length);
requestBody.Seek(0, SeekOrigin.Begin);
var context = new DefaultHttpContext();
context.Request.Body = requestBody;
context.Response.Body = responseBody;
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
services.AddOptions();
context.Request.Path = "/foo";
context.Request.Method = "POST";
var values = new Dictionary<string, StringValues>();
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
// This task should complete immediately but it exceeds the writer threshold
var executeTask = dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
Assert.False(executeTask.IsCompleted);
await connection.Transport.Input.ConsumeAsync(10);
await executeTask.OrTimeout();
Assert.True(connection.Transport.Input.TryRead(out var result));
Assert.Equal("Hi", Encoding.UTF8.GetString(result.Buffer.ToArray()));
connection.Transport.Input.AdvanceTo(result.Buffer.End);
}
}
}
[InlineData(HttpTransportType.LongPolling | HttpTransportType.WebSockets | HttpTransportType.ServerSentEvents)]
[InlineData(HttpTransportType.None)]
[InlineData(HttpTransportType.LongPolling | HttpTransportType.WebSockets)]
public async Task NegotiateReturnsAvailableTransportsAfterFilteringByOptions(HttpTransportType transports)
var manager = CreateConnectionManager(LoggerFactory);
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = new DefaultHttpContext();
context.Features.Set<IHttpResponseFeature>(new ResponseFeature());
context.Features.Set<IHttpWebSocketFeature>(new TestWebSocketConnectionFeature());
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
services.AddOptions();
var ms = new MemoryStream();
context.Request.Path = "/foo";
context.Request.Method = "POST";
context.Response.Body = ms;
await dispatcher.ExecuteNegotiateAsync(context, new HttpConnectionDispatcherOptions { Transports = transports });
var negotiateResponse = JsonConvert.DeserializeObject<JObject>(Encoding.UTF8.GetString(ms.ToArray()));
var availableTransports = HttpTransportType.None;
foreach (var transport in negotiateResponse["availableTransports"])
{
var transportType = (HttpTransportType)Enum.Parse(typeof(HttpTransportType), transport.Value<string>("transport"));
Assert.Equal(transports, availableTransports);
}
[InlineData(HttpTransportType.WebSockets)]
[InlineData(HttpTransportType.ServerSentEvents)]
[InlineData(HttpTransportType.LongPolling)]
public async Task EndpointsThatAcceptConnectionId404WhenUnknownConnectionIdProvided(HttpTransportType transportType)
var manager = CreateConnectionManager(LoggerFactory);
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
using (var strm = new MemoryStream())
{
var context = new DefaultHttpContext();
context.Features.Set<IHttpResponseFeature>(new ResponseFeature());
context.Response.Body = strm;
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
services.AddOptions();
context.Request.Path = "/foo";
context.Request.Method = "GET";
var values = new Dictionary<string, StringValues>();
values["id"] = "unknown";
var qs = new QueryCollection(values);
context.Request.Query = qs;
SetTransport(context, transportType);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<TestConnectionHandler>();
await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode);
await strm.FlushAsync();
Assert.Equal("No Connection with that ID", Encoding.UTF8.GetString(strm.ToArray()));
}
}
}
[Fact]
public async Task EndpointsThatAcceptConnectionId404WhenUnknownConnectionIdProvidedForPost()
{
var manager = CreateConnectionManager(LoggerFactory);
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
using (var strm = new MemoryStream())
{
var context = new DefaultHttpContext();
context.Response.Body = strm;
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
services.AddOptions();
context.Request.Path = "/foo";
context.Request.Method = "POST";
var values = new Dictionary<string, StringValues>();
values["id"] = "unknown";
var qs = new QueryCollection(values);
context.Request.Query = qs;
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<TestConnectionHandler>();
await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode);
await strm.FlushAsync();
Assert.Equal("No Connection with that ID", Encoding.UTF8.GetString(strm.ToArray()));
}
[Fact]
public async Task PostNotAllowedForWebSocketConnections()
{
var manager = CreateConnectionManager(LoggerFactory);
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.WebSockets;
using (var strm = new MemoryStream())
{
var context = new DefaultHttpContext();
context.Response.Body = strm;
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
services.AddOptions();
context.Request.Path = "/foo";
context.Request.Method = "POST";
var values = new Dictionary<string, StringValues>();
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<TestConnectionHandler>();
await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
Assert.Equal(StatusCodes.Status405MethodNotAllowed, context.Response.StatusCode);
await strm.FlushAsync();
Assert.Equal("POST requests are not allowed for WebSocket connections.", Encoding.UTF8.GetString(strm.ToArray()));
}
[Fact]
public async Task PostReturns404IfConnectionDisposed()
{
var manager = CreateConnectionManager(LoggerFactory);
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.LongPolling;
await connection.DisposeAsync(closeGracefully: false);
using (var strm = new MemoryStream())
{
var context = new DefaultHttpContext();
context.Response.Body = strm;
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
services.AddOptions();
context.Request.Path = "/foo";
context.Request.Method = "POST";
var values = new Dictionary<string, StringValues>();
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
Assert.Equal(StatusCodes.Status404NotFound, context.Response.StatusCode);
}
}
}
[Theory]
[InlineData(HttpTransportType.ServerSentEvents)]
[InlineData(HttpTransportType.WebSockets)]
public async Task TransportEndingGracefullyWaitsOnApplication(HttpTransportType transportType)
{
var manager = CreateConnectionManager(LoggerFactory);
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
var connection = manager.CreateConnection();
connection.TransportType = transportType;
using (var strm = new MemoryStream())
{
var context = new DefaultHttpContext();
SetTransport(context, transportType);
var cts = new CancellationTokenSource();
context.Response.Body = strm;
context.RequestAborted = cts.Token;
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
services.AddOptions();
context.Request.Path = "/foo";
context.Request.Method = "GET";
var values = new Dictionary<string, StringValues>();
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.Use(next =>
{
return async connectionContext =>
{
// Ensure both sides of the pipe are ok
var result = await connectionContext.Transport.Input.ReadAsync();
Assert.True(result.IsCompleted);
await connectionContext.Transport.Output.WriteAsync(result.Buffer.First);
};
});
var app = builder.Build();
var task = dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
// Pretend the transport closed because the client disconnected
if (context.WebSockets.IsWebSocketRequest)
{
var ws = (TestWebSocketConnectionFeature)context.Features.Get<IHttpWebSocketFeature>();
await ws.Client.CloseOutputAsync(WebSocketCloseStatus.NormalClosure, "", default);
}
else
{
cts.Cancel();
}
await task.OrTimeout();
await connection.ApplicationTask.OrTimeout();
}
}
}
[Fact]
public async Task TransportEndingGracefullyWaitsOnApplicationLongPolling()
{
var manager = CreateConnectionManager(LoggerFactory, TimeSpan.FromSeconds(5));
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.LongPolling;
using (var strm = new MemoryStream())
{
var context = new DefaultHttpContext();
SetTransport(context, HttpTransportType.LongPolling);
var cts = new CancellationTokenSource();
context.Response.Body = strm;
context.RequestAborted = cts.Token;
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
services.AddOptions();
context.Request.Path = "/foo";
context.Request.Method = "GET";
var values = new Dictionary<string, StringValues>();
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.Use(next =>
{
return async connectionContext =>
{
// Ensure both sides of the pipe are ok
var result = await connectionContext.Transport.Input.ReadAsync();
Assert.True(result.IsCompleted);
await connectionContext.Transport.Output.WriteAsync(result.Buffer.First);
};
});
var app = builder.Build();
var task = dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
// Pretend the transport closed because the client disconnected
cts.Cancel();
await task.OrTimeout();
// We've been gone longer than the expiration time
connection.LastSeenUtc = DateTime.UtcNow.Subtract(TimeSpan.FromSeconds(10));
// The application is still running here because the poll is only killed
// by the heartbeat so we pretend to do a scan and this should force the application task to complete
manager.Scan();
// The application task should complete gracefully
await connection.ApplicationTask.OrTimeout();
}
}
}
[InlineData(HttpTransportType.LongPolling)]
[InlineData(HttpTransportType.ServerSentEvents)]
public async Task PostSendsToConnection(HttpTransportType transportType)
var manager = CreateConnectionManager(LoggerFactory);
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = transportType;
using (var requestBody = new MemoryStream())
using (var responseBody = new MemoryStream())
{
var bytes = Encoding.UTF8.GetBytes("Hello World");
requestBody.Write(bytes, 0, bytes.Length);
requestBody.Seek(0, SeekOrigin.Begin);
var context = new DefaultHttpContext();
context.Request.Body = requestBody;
context.Response.Body = responseBody;
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
services.AddOptions();
context.Request.Path = "/foo";
context.Request.Method = "POST";
var values = new Dictionary<string, StringValues>();
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<TestConnectionHandler>();
Assert.Equal(0, connection.ApplicationStream.Length);
await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
Assert.True(connection.Transport.Input.TryRead(out var result));
Assert.Equal("Hello World", Encoding.UTF8.GetString(result.Buffer.ToArray()));
Assert.Equal(0, connection.ApplicationStream.Length);
connection.Transport.Input.AdvanceTo(result.Buffer.End);
}
}
}
[Theory]
[InlineData(HttpTransportType.LongPolling)]
[InlineData(HttpTransportType.ServerSentEvents)]
public async Task PostSendsToConnectionInParallel(HttpTransportType transportType)
{
var manager = CreateConnectionManager(LoggerFactory);
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = transportType;
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
// Allow a maximum of one caller to use code at one time
var callerTracker = new SemaphoreSlim(1, 1);
var waitTcs = new TaskCompletionSource<bool>();
// This tests thread safety of sending multiple pieces of data to a connection at once
var executeTask1 = DispatcherExecuteAsync(dispatcher, connection, callerTracker, waitTcs.Task);
var executeTask2 = DispatcherExecuteAsync(dispatcher, connection, callerTracker, waitTcs.Task);
waitTcs.SetResult(true);
await Task.WhenAll(executeTask1, executeTask2);
}
async Task DispatcherExecuteAsync(HttpConnectionDispatcher dispatcher, HttpConnectionContext connection, SemaphoreSlim callerTracker, Task waitTask)
{
using (var requestBody = new TrackingMemoryStream(callerTracker, waitTask))
{
var bytes = Encoding.UTF8.GetBytes("Hello World");
requestBody.Write(bytes, 0, bytes.Length);
requestBody.Seek(0, SeekOrigin.Begin);
var context = new DefaultHttpContext();
context.Request.Body = requestBody;
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
services.AddOptions();
context.Request.Path = "/foo";
context.Request.Method = "POST";
var values = new Dictionary<string, StringValues>();
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<TestConnectionHandler>();
var app = builder.Build();
await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
}
}
}
private class TrackingMemoryStream : MemoryStream
{
private readonly SemaphoreSlim _callerTracker;
private readonly Task _waitTask;
public TrackingMemoryStream(SemaphoreSlim callerTracker, Task waitTask)
{
_callerTracker = callerTracker;
_waitTask = waitTask;
}
public override async Task CopyToAsync(Stream destination, int bufferSize, CancellationToken cancellationToken)
{
// Will return false if all available locks from semaphore are taken
if (!_callerTracker.Wait(0))
{
throw new Exception("Too many callers.");
}
try
{
await _waitTask;
await base.CopyToAsync(destination, bufferSize, cancellationToken);
}
finally
{
_callerTracker.Release();
}
}
}
[Fact]
public async Task HttpContextFeatureForLongpollingWorksBetweenPolls()
{
var manager = CreateConnectionManager(LoggerFactory);
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.LongPolling;
using (var requestBody = new MemoryStream())
using (var responseBody = new MemoryStream())
{
var context = new DefaultHttpContext();
context.Request.Body = requestBody;
context.Response.Body = responseBody;
var services = new ServiceCollection();
services.AddSingleton<HttpContextConnectionHandler>();
services.AddOptions();
// Setup state on the HttpContext
context.Request.Path = "/foo";
context.Request.Method = "GET";
var values = new Dictionary<string, StringValues>();
values["id"] = connection.ConnectionId;
values["another"] = "value";
var qs = new QueryCollection(values);
context.Request.Query = qs;
context.Request.Headers["header1"] = "h1";
context.Request.Headers["header2"] = "h2";
context.Request.Headers["header3"] = "h3";
context.User = new ClaimsPrincipal(new ClaimsIdentity(new[] { new Claim("claim1", "claimValue") }));
context.TraceIdentifier = "requestid";
context.Connection.Id = "connectionid";
context.Connection.LocalIpAddress = IPAddress.Loopback;
context.Connection.LocalPort = 4563;
context.Connection.RemoteIpAddress = IPAddress.IPv6Any;
context.Connection.RemotePort = 43456;
context.SetEndpoint(new Endpoint(null, null, "TestName"));
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<HttpContextConnectionHandler>();
var app = builder.Build();
// Start a poll
var task = dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
Assert.True(task.IsCompleted);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
task = dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
// Send to the application
var buffer = Encoding.UTF8.GetBytes("Hello World");
await connection.Application.Output.WriteAsync(buffer);
// The poll request should end
await task;
// Make sure the actual response isn't affected
Assert.Equal("application/octet-stream", context.Response.ContentType);
// Now do a new send again without the poll (that request should have ended)
await connection.Application.Output.WriteAsync(buffer);
connection.Application.Output.Complete();
// Wait for the endpoint to end
await connection.ApplicationTask;
var connectionHttpContext = connection.GetHttpContext();
Assert.NotNull(connectionHttpContext);
Assert.Equal(2, connectionHttpContext.Request.Query.Count);
Assert.Equal(connection.ConnectionId, connectionHttpContext.Request.Query["id"]);
Assert.Equal("value", connectionHttpContext.Request.Query["another"]);
Assert.Equal(3, connectionHttpContext.Request.Headers.Count);
Assert.Equal("h1", connectionHttpContext.Request.Headers["header1"]);
Assert.Equal("h2", connectionHttpContext.Request.Headers["header2"]);
Assert.Equal("h3", connectionHttpContext.Request.Headers["header3"]);
Assert.Equal("requestid", connectionHttpContext.TraceIdentifier);
Assert.Equal("claimValue", connectionHttpContext.User.Claims.FirstOrDefault().Value);
Assert.Equal("connectionid", connectionHttpContext.Connection.Id);
Assert.Equal(IPAddress.Loopback, connectionHttpContext.Connection.LocalIpAddress);
Assert.Equal(4563, connectionHttpContext.Connection.LocalPort);
Assert.Equal(IPAddress.IPv6Any, connectionHttpContext.Connection.RemoteIpAddress);
Assert.Equal(43456, connectionHttpContext.Connection.RemotePort);
Assert.NotNull(connectionHttpContext.RequestServices);
Assert.Equal(Stream.Null, connectionHttpContext.Response.Body);
Assert.NotNull(connectionHttpContext.Response.Headers);
Assert.Equal("application/xml", connectionHttpContext.Response.ContentType);
var endpointFeature = connectionHttpContext.Features.Get<IEndpointFeature>();
Assert.NotNull(endpointFeature);
Assert.Equal("TestName", endpointFeature.Endpoint.DisplayName);
}
}
}
[InlineData(HttpTransportType.ServerSentEvents)]
[InlineData(HttpTransportType.LongPolling)]
public async Task EndpointsThatRequireConnectionId400WhenNoConnectionIdProvided(HttpTransportType transportType)
var manager = CreateConnectionManager(LoggerFactory);
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
using (var strm = new MemoryStream())
{
var context = new DefaultHttpContext();
context.Features.Set<IHttpResponseFeature>(new ResponseFeature());
context.Response.Body = strm;
var services = new ServiceCollection();
services.AddOptions();
services.AddSingleton<TestConnectionHandler>();
context.Request.Path = "/foo";
context.Request.Method = "GET";
SetTransport(context, transportType);
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<TestConnectionHandler>();
await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
await strm.FlushAsync();
Assert.Equal("Connection ID required", Encoding.UTF8.GetString(strm.ToArray()));
}
[Theory]
[InlineData(HttpTransportType.LongPolling)]
[InlineData(HttpTransportType.ServerSentEvents)]
public async Task IOExceptionWhenReadingRequestReturns400Response(HttpTransportType transportType)
{
var manager = CreateConnectionManager(LoggerFactory);
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
var connection = manager.CreateConnection();
connection.TransportType = transportType;
var mockStream = new Mock<Stream>();
mockStream.Setup(m => m.CopyToAsync(It.IsAny<Stream>(), It.IsAny<int>(), It.IsAny<CancellationToken>())).Throws(new IOException());
using (var responseBody = new MemoryStream())
{
var context = new DefaultHttpContext();
context.Request.Body = mockStream.Object;
context.Response.Body = responseBody;
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
services.AddOptions();
context.Request.Path = "/foo";
context.Request.Method = "POST";
var values = new Dictionary<string, StringValues>();
values["id"] = connection.ConnectionId;
var qs = new QueryCollection(values);
context.Request.Query = qs;
await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), c => Task.CompletedTask);
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
}
}
}
[Fact]
public async Task EndpointsThatRequireConnectionId400WhenNoConnectionIdProvidedForPost()
var manager = CreateConnectionManager(LoggerFactory);
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
using (var strm = new MemoryStream())
{
var context = new DefaultHttpContext();
context.Response.Body = strm;
var services = new ServiceCollection();
services.AddOptions();
services.AddSingleton<TestConnectionHandler>();
context.Request.Path = "/foo";
context.Request.Method = "POST";
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<TestConnectionHandler>();
await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
Assert.Equal(StatusCodes.Status400BadRequest, context.Response.StatusCode);
await strm.FlushAsync();
Assert.Equal("Connection ID required", Encoding.UTF8.GetString(strm.ToArray()));
}
[InlineData(HttpTransportType.LongPolling, 200)]
[InlineData(HttpTransportType.WebSockets, 404)]
[InlineData(HttpTransportType.ServerSentEvents, 404)]
public async Task EndPointThatOnlySupportsLongPollingRejectsOtherTransports(HttpTransportType transportType, int status)
await CheckTransportSupported(HttpTransportType.LongPolling, transportType, status, LoggerFactory);
[InlineData(HttpTransportType.ServerSentEvents, 200)]
[InlineData(HttpTransportType.WebSockets, 404)]
[InlineData(HttpTransportType.LongPolling, 404)]
public async Task EndPointThatOnlySupportsSSERejectsOtherTransports(HttpTransportType transportType, int status)
await CheckTransportSupported(HttpTransportType.ServerSentEvents, transportType, status, LoggerFactory);
[InlineData(HttpTransportType.WebSockets, 200)]
[InlineData(HttpTransportType.ServerSentEvents, 404)]
[InlineData(HttpTransportType.LongPolling, 404)]
public async Task EndPointThatOnlySupportsWebSockesRejectsOtherTransports(HttpTransportType transportType, int status)
await CheckTransportSupported(HttpTransportType.WebSockets, transportType, status, LoggerFactory);
[InlineData(HttpTransportType.LongPolling, 404)]
public async Task EndPointThatOnlySupportsWebSocketsAndSSERejectsLongPolling(HttpTransportType transportType, int status)
await CheckTransportSupported(HttpTransportType.WebSockets | HttpTransportType.ServerSentEvents, transportType, status, LoggerFactory);
[Fact]
public async Task CompletedEndPointEndsConnection()
{
var manager = CreateConnectionManager(LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.ServerSentEvents;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
SetTransport(context, HttpTransportType.ServerSentEvents);
var services = new ServiceCollection();
services.AddSingleton<ImmediatelyCompleteConnectionHandler>();
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<ImmediatelyCompleteConnectionHandler>();
await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
var exists = manager.TryGetConnection(connection.ConnectionId, out _);
Assert.False(exists);
}
[Fact]
public async Task SynchronusExceptionEndsConnection()
{
bool ExpectedErrors(WriteContext writeContext)
{
return writeContext.LoggerName == typeof(HttpConnectionManager).FullName &&
writeContext.EventId.Name == "FailedDispose";
}
using (StartVerifiableLog(expectedErrorsFilter: ExpectedErrors))
var manager = CreateConnectionManager(LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.ServerSentEvents;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
SetTransport(context, HttpTransportType.ServerSentEvents);
var services = new ServiceCollection();
services.AddSingleton<SynchronusExceptionConnectionHandler>();
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<SynchronusExceptionConnectionHandler>();
await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
var exists = manager.TryGetConnection(connection.ConnectionId, out _);
Assert.False(exists);
}
[Fact]
public async Task CompletedEndPointEndsLongPollingConnection()
{
var manager = CreateConnectionManager(LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.LongPolling;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
var services = new ServiceCollection();
services.AddSingleton<ImmediatelyCompleteConnectionHandler>();
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<ImmediatelyCompleteConnectionHandler>();
// First poll will 200
await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
await dispatcher.ExecuteAsync(context, new HttpConnectionDispatcherOptions(), app);
Assert.Equal(StatusCodes.Status204NoContent, context.Response.StatusCode);
var exists = manager.TryGetConnection(connection.ConnectionId, out _);
Assert.False(exists);
}
[Fact]
public async Task LongPollingTimeoutSets200StatusCode()
{
var manager = CreateConnectionManager(LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.LongPolling;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
var services = new ServiceCollection();
services.AddSingleton<TestConnectionHandler>();
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<TestConnectionHandler>();
var options = new HttpConnectionDispatcherOptions();
options.LongPolling.PollTimeout = TimeSpan.FromSeconds(2);
await dispatcher.ExecuteAsync(context, options, app).OrTimeout();
Assert.Equal(StatusCodes.Status200OK, context.Response.StatusCode);
}
public async Task WebSocketTransportTimesOutWhenCloseFrameNotReceived()
{
var manager = CreateConnectionManager(LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = HttpTransportType.WebSockets;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context = MakeRequest("/foo", connection);
SetTransport(context, HttpTransportType.WebSockets);
var services = new ServiceCollection();
services.AddSingleton<ImmediatelyCompleteConnectionHandler>();
var builder = new ConnectionBuilder(services.BuildServiceProvider());
builder.UseConnectionHandler<ImmediatelyCompleteConnectionHandler>();
var options = new HttpConnectionDispatcherOptions();
options.WebSockets.CloseTimeout = TimeSpan.FromSeconds(1);
var task = dispatcher.ExecuteAsync(context, options, app);
await task.OrTimeout();
}
[InlineData(HttpTransportType.WebSockets)]
[InlineData(HttpTransportType.ServerSentEvents)]
public async Task RequestToActiveConnectionId409ForStreamingTransports(HttpTransportType transportType)
var manager = CreateConnectionManager(LoggerFactory);
var connection = manager.CreateConnection();
connection.TransportType = transportType;
var dispatcher = new HttpConnectionDispatcher(manager, LoggerFactory);
var context1 = MakeRequest("/foo", connection);
var context2 = MakeRequest("/foo", connection);