diff --git a/clients/java/signalr/src/main/java/com/microsoft/signalr/HttpHubConnectionBuilder.java b/clients/java/signalr/src/main/java/com/microsoft/signalr/HttpHubConnectionBuilder.java index c2594826a819225b262402a17365c8050a33ba4a..d18c0a8d352a1cfe0d6041fb77dbdf7c0ec7e371 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/signalr/HttpHubConnectionBuilder.java +++ b/clients/java/signalr/src/main/java/com/microsoft/signalr/HttpHubConnectionBuilder.java @@ -3,8 +3,7 @@ package com.microsoft.signalr; -import java.util.concurrent.CompletableFuture; -import java.util.function.Supplier; +import java.time.Duration; import io.reactivex.Single; @@ -15,6 +14,7 @@ public class HttpHubConnectionBuilder { private HttpClient httpClient; private boolean skipNegotiate; private Single<String> accessTokenProvider; + private Duration handshakeResponseTimeout; HttpHubConnectionBuilder(String url) { this.url = url; @@ -56,7 +56,12 @@ public class HttpHubConnectionBuilder { return this; } + HttpHubConnectionBuilder withHandshakeResponseTimeout(Duration timeout) { + this.handshakeResponseTimeout = timeout; + return this; + } + public HubConnection build() { - return new HubConnection(url, transport, skipNegotiate, logger, httpClient, accessTokenProvider); + return new HubConnection(url, transport, skipNegotiate, logger, httpClient, accessTokenProvider, handshakeResponseTimeout); } } \ No newline at end of file diff --git a/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java b/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java index 9d926023cc35e3fadfc7e43e2913095711c5f4d0..a904f9dfbf5757cf233e591ce8f2a4a3bfa34f83 100644 --- a/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java +++ b/clients/java/signalr/src/main/java/com/microsoft/signalr/HubConnection.java @@ -4,11 +4,12 @@ package com.microsoft.signalr; import java.io.IOException; +import java.time.Duration; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.CompletableFuture; +import java.util.concurrent.*; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; @@ -38,8 +39,10 @@ public class HubConnection { private ConnectionState connectionState = null; private HttpClient httpClient; private String stopError; + private CompletableFuture<Void> handshakeResponseFuture; + private Duration handshakeResponseTimeout = Duration.ofSeconds(15); - HubConnection(String url, Transport transport, boolean skipNegotiate, Logger logger, HttpClient httpClient, Single<String> accessTokenProvider) { + HubConnection(String url, Transport transport, boolean skipNegotiate, Logger logger, HttpClient httpClient, Single<String> accessTokenProvider, Duration handshakeResponseTimeout) { if (url == null || url.isEmpty()) { throw new IllegalArgumentException("A valid url is required."); } @@ -69,19 +72,33 @@ public class HubConnection { this.transport = transport; } + if (handshakeResponseTimeout != null) { + this.handshakeResponseTimeout = handshakeResponseTimeout; + } + this.skipNegotiate = skipNegotiate; this.callback = (payload) -> { if (!handshakeReceived) { int handshakeLength = payload.indexOf(RECORD_SEPARATOR) + 1; String handshakeResponseString = payload.substring(0, handshakeLength - 1); - HandshakeResponseMessage handshakeResponse = HandshakeProtocol.parseHandshakeResponse(handshakeResponseString); + HandshakeResponseMessage handshakeResponse; + try { + handshakeResponse = HandshakeProtocol.parseHandshakeResponse(handshakeResponseString); + } catch (RuntimeException ex) { + RuntimeException exception = new RuntimeException("An invalid handshake response was received from the server.", ex); + handshakeResponseFuture.completeExceptionally(exception); + throw exception; + } if (handshakeResponse.getHandshakeError() != null) { String errorMessage = "Error in handshake " + handshakeResponse.getHandshakeError(); logger.log(LogLevel.Error, errorMessage); - throw new RuntimeException(errorMessage); + RuntimeException exception = new RuntimeException(errorMessage); + handshakeResponseFuture.completeExceptionally(exception); + throw exception; } handshakeReceived = true; + handshakeResponseFuture.complete(null); payload = payload.substring(handshakeLength); // The payload only contained the handshake response so we can return. @@ -134,6 +151,12 @@ public class HubConnection { }; } + private void timeoutHandshakeResponse(long timeout, TimeUnit unit) { + ScheduledExecutorService scheduledThreadPool = Executors.newSingleThreadScheduledExecutor(); + scheduledThreadPool.schedule(() -> handshakeResponseFuture.completeExceptionally( + new TimeoutException("Timed out waiting for the server to respond to the handshake message.")), timeout, unit); + } + private CompletableFuture<NegotiateResponse> handleNegotiate(String url) { HttpRequest request = new HttpRequest(); request.addHeaders(this.headers); @@ -184,8 +207,9 @@ public class HubConnection { return Completable.complete(); } + handshakeResponseFuture = new CompletableFuture<>(); handshakeReceived = false; - CompletableFuture<Void> tokenFuture = new CompletableFuture<>(); + CompletableFuture<Void> tokenFuture = new CompletableFuture<>(); accessTokenProvider.subscribe(token -> { if (token != null && !token.isEmpty()) { this.headers.put("Authorization", "Bearer " + token); @@ -213,15 +237,18 @@ public class HubConnection { return transport.start(url).thenCompose((future) -> { String handshake = HandshakeProtocol.createHandshakeRequestMessage( new HandshakeRequestMessage(protocol.getName(), protocol.getVersion())); - return transport.send(handshake).thenRun(() -> { - hubConnectionStateLock.lock(); - try { - hubConnectionState = HubConnectionState.CONNECTED; - connectionState = new ConnectionState(this); - logger.log(LogLevel.Information, "HubConnection started."); - } finally { - hubConnectionStateLock.unlock(); - } + return transport.send(handshake).thenCompose((innerFuture) -> { + timeoutHandshakeResponse(handshakeResponseTimeout.toMillis(), TimeUnit.MILLISECONDS); + return handshakeResponseFuture.thenRun(() -> { + hubConnectionStateLock.lock(); + try { + hubConnectionState = HubConnectionState.CONNECTED; + connectionState = new ConnectionState(this); + logger.log(LogLevel.Information, "HubConnection started."); + } finally { + hubConnectionStateLock.unlock(); + } + }); }); }); })); @@ -308,6 +335,7 @@ public class HubConnection { connectionState = null; logger.log(LogLevel.Information, "HubConnection stopped."); hubConnectionState = HubConnectionState.DISCONNECTED; + handshakeResponseFuture.complete(null); } finally { hubConnectionStateLock.unlock(); } diff --git a/clients/java/signalr/src/test/java/com/microsoft/signalr/HandshakeProtocolTest.java b/clients/java/signalr/src/test/java/com/microsoft/signalr/HandshakeProtocolTest.java index 98be91207cf0692231ae16a9be60207aef5ccce7..811d39c5580efc09ec6dde2c41d06919963dfbf7 100644 --- a/clients/java/signalr/src/test/java/com/microsoft/signalr/HandshakeProtocolTest.java +++ b/clients/java/signalr/src/test/java/com/microsoft/signalr/HandshakeProtocolTest.java @@ -29,4 +29,10 @@ class HandshakeProtocolTest { HandshakeResponseMessage hsr = HandshakeProtocol.parseHandshakeResponse(handshakeResponseWithError); assertEquals(hsr.getHandshakeError(), "Requested protocol 'messagepack' is not available."); } + + @Test + public void InvalidHandshakeResponse() { + String handshakeResponseWithError = "{\"error\": \"Requested proto"; + Throwable exception = assertThrows(RuntimeException.class, ()-> HandshakeProtocol.parseHandshakeResponse(handshakeResponseWithError)); + } } \ No newline at end of file diff --git a/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java b/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java index c76fc4adade2df3fcdffa8533405683e8fbb3dd4..1cc7550a3c3a591524cc4214b79051839421a2f3 100644 --- a/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java +++ b/clients/java/signalr/src/test/java/com/microsoft/signalr/HubConnectionTest.java @@ -5,6 +5,7 @@ package com.microsoft.signalr; import static org.junit.jupiter.api.Assertions.*; +import java.time.Duration; import java.util.List; import java.util.concurrent.CancellationException; import java.util.concurrent.CompletableFuture; @@ -34,7 +35,7 @@ class HubConnectionTest { @Test public void transportCloseTriggersStopInHubConnection() throws Exception { - MockTransport mockTransport = new MockTransport(); + MockTransport mockTransport = new MockTransport(true); HubConnection hubConnection = TestUtils.createHubConnection("http://example.com", mockTransport); hubConnection.start().blockingAwait(1000, TimeUnit.MILLISECONDS); assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); @@ -45,7 +46,7 @@ class HubConnectionTest { @Test public void transportCloseWithErrorTriggersStopInHubConnection() throws Exception { - MockTransport mockTransport = new MockTransport(); + MockTransport mockTransport = new MockTransport(true); AtomicReference<String> message = new AtomicReference<>(); HubConnection hubConnection = TestUtils.createHubConnection("http://example.com", mockTransport); String errorMessage = "Example transport error."; @@ -58,12 +59,27 @@ class HubConnectionTest { assertEquals(HubConnectionState.CONNECTED, hubConnection.getConnectionState()); mockTransport.stopWithError(errorMessage); assertEquals(errorMessage, message.get()); + } + + @Test + public void checkHubConnectionStateNoHandShakeResponse() { + MockTransport mockTransport = new MockTransport(); + HubConnection hubConnection = HubConnectionBuilder.create("http://example.com") + .withTransport(mockTransport) + .withHttpClient(new TestHttpClient()) + .shouldSkipNegotiate(true) + .withHandshakeResponseTimeout(Duration.ofMillis(100)) + .build(); + Throwable exception = assertThrows(RuntimeException.class, () -> hubConnection.start().blockingAwait(1000, TimeUnit.MILLISECONDS)); + assertEquals(ExecutionException.class, exception.getCause().getClass()); + assertEquals(TimeoutException.class, exception.getCause().getCause().getClass()); + assertEquals(exception.getCause().getCause().getMessage(), "Timed out waiting for the server to respond to the handshake message."); assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState()); } @Test - public void constructHubConnectionWithHttpConnectionOptions() throws Exception { - Transport mockTransport = new MockTransport(); + public void constructHubConnectionWithHttpConnectionOptions() { + Transport mockTransport = new MockTransport(true); HubConnection hubConnection = TestUtils.createHubConnection("http://example.com", mockTransport); hubConnection.start(); @@ -88,6 +104,18 @@ class HubConnectionTest { assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState()); } + @Test + public void invalidHandShakeResponse() throws Exception { + MockTransport mockTransport = new MockTransport(); + HubConnection hubConnection = TestUtils.createHubConnection("http://example.com", mockTransport); + + hubConnection.start(); + + Throwable exception = assertThrows(RuntimeException.class, () -> mockTransport.receiveMessage("{" + RECORD_SEPARATOR)); + assertEquals("An invalid handshake response was received from the server.", exception.getMessage()); + assertEquals(HubConnectionState.DISCONNECTED, hubConnection.getConnectionState()); + } + @Test public void hubConnectionReceiveHandshakeResponseWithError() { MockTransport mockTransport = new MockTransport(); @@ -958,7 +986,7 @@ class HubConnectionTest { "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" + "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"))); - MockTransport transport = new MockTransport(); + MockTransport transport = new MockTransport(true); HubConnection hubConnection = HubConnectionBuilder .create("http://example.com") .withTransport(transport) @@ -977,7 +1005,7 @@ class HubConnectionTest { TestHttpClient client = new TestHttpClient().on("POST", "http://example.com/negotiate", (req) -> CompletableFuture.completedFuture(new HttpResponse(200, "", "{\"error\":\"Test error.\"}"))); - MockTransport transport = new MockTransport(); + MockTransport transport = new MockTransport(true); HubConnection hubConnection = HubConnectionBuilder .create("http://example.com") .withHttpClient(client) @@ -997,7 +1025,7 @@ class HubConnectionTest { (req) -> CompletableFuture.completedFuture(new HttpResponse(200, "", "{\"connectionId\":\"bVOiRPG8-6YiJ6d7ZcTOVQ\",\"" + "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}"))); - MockTransport transport = new MockTransport(); + MockTransport transport = new MockTransport(true); HubConnection hubConnection = HubConnectionBuilder .create("http://example.com") .withTransport(transport) @@ -1022,7 +1050,7 @@ class HubConnectionTest { + "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}")); }); - MockTransport transport = new MockTransport(); + MockTransport transport = new MockTransport(true); HubConnection hubConnection = HubConnectionBuilder .create("http://example.com") .withTransport(transport) @@ -1048,7 +1076,7 @@ class HubConnectionTest { + "availableTransports\":[{\"transport\":\"WebSockets\",\"transferFormats\":[\"Text\",\"Binary\"]}]}")); }); - MockTransport transport = new MockTransport(); + MockTransport transport = new MockTransport(true); HubConnection hubConnection = HubConnectionBuilder .create("http://example.com") .withTransport(transport) @@ -1065,7 +1093,7 @@ class HubConnectionTest { @Test public void hubConnectionCanBeStartedAfterBeingStopped() throws Exception { - MockTransport transport = new MockTransport(); + MockTransport transport = new MockTransport(true); HubConnection hubConnection = HubConnectionBuilder .create("http://example.com") .withTransport(transport) @@ -1084,7 +1112,7 @@ class HubConnectionTest { @Test public void hubConnectionCanBeStartedAfterBeingStoppedAndRedirected() throws Exception { - MockTransport mockTransport = new MockTransport(); + MockTransport mockTransport = new MockTransport(true); TestHttpClient client = new TestHttpClient() .on("POST", "http://example.com/negotiate", (req) -> CompletableFuture .completedFuture(new HttpResponse(200, "", "{\"url\":\"http://testexample.com/\"}"))) diff --git a/clients/java/signalr/src/test/java/com/microsoft/signalr/MockTransport.java b/clients/java/signalr/src/test/java/com/microsoft/signalr/MockTransport.java index 8390a34918826c75cc2de3807e2848746c305e6f..10677d56f4340e3f76b8e42289ce516b597e97b6 100644 --- a/clients/java/signalr/src/test/java/com/microsoft/signalr/MockTransport.java +++ b/clients/java/signalr/src/test/java/com/microsoft/signalr/MockTransport.java @@ -12,10 +12,27 @@ class MockTransport implements Transport { private ArrayList<String> sentMessages = new ArrayList<>(); private String url; private Consumer<String> onClose; + private boolean autoHandshake; + + private static final String RECORD_SEPARATOR = "\u001e"; + + public MockTransport() { + } + + public MockTransport(boolean autoHandshake) { + this.autoHandshake = autoHandshake; + } @Override public CompletableFuture start(String url) { this.url = url; + if (autoHandshake) { + try { + onReceiveCallBack.invoke("{}" + RECORD_SEPARATOR); + } catch (Exception e) { + throw new RuntimeException(e); + } + } return CompletableFuture.completedFuture(null); } diff --git a/clients/java/signalr/src/test/java/com/microsoft/signalr/TestUtils.java b/clients/java/signalr/src/test/java/com/microsoft/signalr/TestUtils.java index eebf0a1968ad171a3a57fa3371bbe82d5062c863..af8366a6f476e8b593bf3991473ecb96c7d04d21 100644 --- a/clients/java/signalr/src/test/java/com/microsoft/signalr/TestUtils.java +++ b/clients/java/signalr/src/test/java/com/microsoft/signalr/TestUtils.java @@ -5,7 +5,7 @@ package com.microsoft.signalr; class TestUtils { static HubConnection createHubConnection(String url) { - return createHubConnection(url, new MockTransport(), new NullLogger(), true, new TestHttpClient()); + return createHubConnection(url, new MockTransport(true), new NullLogger(), true, new TestHttpClient()); } static HubConnection createHubConnection(String url, Transport transport) {