diff --git a/src/test/java/org/java_websocket/issues/Issue997Test.java b/src/test/java/org/java_websocket/issues/Issue997Test.java index 2bba7511..cd8486b1 100644 --- a/src/test/java/org/java_websocket/issues/Issue997Test.java +++ b/src/test/java/org/java_websocket/issues/Issue997Test.java @@ -41,6 +41,7 @@ import javax.net.ssl.SSLContext; import javax.net.ssl.SSLHandshakeException; import javax.net.ssl.SSLParameters; + import org.java_websocket.WebSocket; import org.java_websocket.client.WebSocketClient; import org.java_websocket.handshake.ClientHandshake; @@ -57,163 +58,163 @@ public class Issue997Test { - @Test() - @Timeout(2000) - public void test_localServer_ServerLocalhost_Client127_CheckActive() - throws CertificateException, UnrecoverableKeyException, NoSuchAlgorithmException, KeyManagementException, KeyStoreException, IOException, URISyntaxException, InterruptedException { - SSLWebSocketClient client = testIssueWithLocalServer("127.0.0.1", SocketUtil.getAvailablePort(), - SSLContextUtil.getLocalhostOnlyContext(), SSLContextUtil.getLocalhostOnlyContext(), - "HTTPS"); - assertFalse(client.onOpen); - assertTrue(client.onSSLError); - } - - @Test() - @Timeout(2000) - public void test_localServer_ServerLocalhost_Client127_CheckInactive() - throws CertificateException, UnrecoverableKeyException, NoSuchAlgorithmException, KeyManagementException, KeyStoreException, IOException, URISyntaxException, InterruptedException { - SSLWebSocketClient client = testIssueWithLocalServer("127.0.0.1", SocketUtil.getAvailablePort(), - SSLContextUtil.getLocalhostOnlyContext(), SSLContextUtil.getLocalhostOnlyContext(), ""); - assertTrue(client.onOpen); - assertFalse(client.onSSLError); - } - - @Test() - @Timeout(2000) - public void test_localServer_ServerLocalhost_Client127_CheckDefault() - throws CertificateException, UnrecoverableKeyException, NoSuchAlgorithmException, KeyManagementException, KeyStoreException, IOException, URISyntaxException, InterruptedException { - SSLWebSocketClient client = testIssueWithLocalServer("127.0.0.1", SocketUtil.getAvailablePort(), - SSLContextUtil.getLocalhostOnlyContext(), SSLContextUtil.getLocalhostOnlyContext(), null); - assertFalse(client.onOpen); - assertTrue(client.onSSLError); - } - - @Test() - @Timeout(2000) - public void test_localServer_ServerLocalhost_ClientLocalhost_CheckActive() - throws CertificateException, UnrecoverableKeyException, NoSuchAlgorithmException, KeyManagementException, KeyStoreException, IOException, URISyntaxException, InterruptedException { - SSLWebSocketClient client = testIssueWithLocalServer("localhost", SocketUtil.getAvailablePort(), - SSLContextUtil.getLocalhostOnlyContext(), SSLContextUtil.getLocalhostOnlyContext(), - "HTTPS"); - assertTrue(client.onOpen); - assertFalse(client.onSSLError); - } - - @Test() - @Timeout(2000) - public void test_localServer_ServerLocalhost_ClientLocalhost_CheckInactive() - throws CertificateException, UnrecoverableKeyException, NoSuchAlgorithmException, KeyManagementException, KeyStoreException, IOException, URISyntaxException, InterruptedException { - SSLWebSocketClient client = testIssueWithLocalServer("localhost", SocketUtil.getAvailablePort(), - SSLContextUtil.getLocalhostOnlyContext(), SSLContextUtil.getLocalhostOnlyContext(), ""); - assertTrue(client.onOpen); - assertFalse(client.onSSLError); - } - - @Test() - @Timeout(2000) - public void test_localServer_ServerLocalhost_ClientLocalhost_CheckDefault() - throws CertificateException, UnrecoverableKeyException, NoSuchAlgorithmException, KeyManagementException, KeyStoreException, IOException, URISyntaxException, InterruptedException { - SSLWebSocketClient client = testIssueWithLocalServer("localhost", SocketUtil.getAvailablePort(), - SSLContextUtil.getLocalhostOnlyContext(), SSLContextUtil.getLocalhostOnlyContext(), null); - assertTrue(client.onOpen); - assertFalse(client.onSSLError); - } - - - public SSLWebSocketClient testIssueWithLocalServer(String address, int port, - SSLContext serverContext, SSLContext clientContext, String endpointIdentificationAlgorithm) - throws IOException, URISyntaxException, InterruptedException { - CountDownLatch countServerDownLatch = new CountDownLatch(1); - SSLWebSocketClient client = new SSLWebSocketClient(address, port, - endpointIdentificationAlgorithm); - WebSocketServer server = new SSLWebSocketServer(port, countServerDownLatch); - - server.setWebSocketFactory(new DefaultSSLWebSocketServerFactory(serverContext)); - if (clientContext != null) { - client.setSocketFactory(clientContext.getSocketFactory()); + @Test() + @Timeout(2000) + public void test_localServer_ServerLocalhost_Client127_CheckActive() + throws CertificateException, UnrecoverableKeyException, NoSuchAlgorithmException, KeyManagementException, KeyStoreException, IOException, URISyntaxException, InterruptedException { + SSLWebSocketClient client = testIssueWithLocalServer("127.0.0.1", SocketUtil.getAvailablePort(), + SSLContextUtil.getLocalhostOnlyContext(), SSLContextUtil.getLocalhostOnlyContext(), + "HTTPS"); + assertFalse(client.onOpen, "client is not open"); + assertTrue(client.onSSLError, "client has caught a SSLHandshakeException"); } - server.start(); - countServerDownLatch.await(); - client.connectBlocking(1, TimeUnit.SECONDS); - return client; - } - - - private static class SSLWebSocketClient extends WebSocketClient { - private final String endpointIdentificationAlgorithm; - public boolean onSSLError = false; - public boolean onOpen = false; - - public SSLWebSocketClient(String address, int port, String endpointIdentificationAlgorithm) - throws URISyntaxException { - super(new URI("wss://" + address + ':' + port)); - this.endpointIdentificationAlgorithm = endpointIdentificationAlgorithm; + @Test() + @Timeout(2000) + public void test_localServer_ServerLocalhost_Client127_CheckInactive() + throws CertificateException, UnrecoverableKeyException, NoSuchAlgorithmException, KeyManagementException, KeyStoreException, IOException, URISyntaxException, InterruptedException { + SSLWebSocketClient client = testIssueWithLocalServer("127.0.0.1", SocketUtil.getAvailablePort(), + SSLContextUtil.getLocalhostOnlyContext(), SSLContextUtil.getLocalhostOnlyContext(), ""); + assertTrue(client.onOpen, "client is open"); + assertFalse(client.onSSLError, "client has not caught a SSLHandshakeException"); } - @Override - public void onOpen(ServerHandshake handshakedata) { - this.onOpen = true; + @Test() + @Timeout(2000) + public void test_localServer_ServerLocalhost_Client127_CheckDefault() + throws CertificateException, UnrecoverableKeyException, NoSuchAlgorithmException, KeyManagementException, KeyStoreException, IOException, URISyntaxException, InterruptedException { + SSLWebSocketClient client = testIssueWithLocalServer("127.0.0.1", SocketUtil.getAvailablePort(), + SSLContextUtil.getLocalhostOnlyContext(), SSLContextUtil.getLocalhostOnlyContext(), null); + assertFalse(client.onOpen, "client is not open"); + assertTrue(client.onSSLError, "client has caught a SSLHandshakeException"); } - @Override - public void onMessage(String message) { + @Test() + @Timeout(2000) + public void test_localServer_ServerLocalhost_ClientLocalhost_CheckActive() + throws CertificateException, UnrecoverableKeyException, NoSuchAlgorithmException, KeyManagementException, KeyStoreException, IOException, URISyntaxException, InterruptedException { + SSLWebSocketClient client = testIssueWithLocalServer("localhost", SocketUtil.getAvailablePort(), + SSLContextUtil.getLocalhostOnlyContext(), SSLContextUtil.getLocalhostOnlyContext(), + "HTTPS"); + assertTrue(client.onOpen, "client is open"); + assertFalse(client.onSSLError, "client has not caught a SSLHandshakeException"); } - @Override - public void onClose(int code, String reason, boolean remote) { + @Test() + @Timeout(2000) + public void test_localServer_ServerLocalhost_ClientLocalhost_CheckInactive() + throws CertificateException, UnrecoverableKeyException, NoSuchAlgorithmException, KeyManagementException, KeyStoreException, IOException, URISyntaxException, InterruptedException { + SSLWebSocketClient client = testIssueWithLocalServer("localhost", SocketUtil.getAvailablePort(), + SSLContextUtil.getLocalhostOnlyContext(), SSLContextUtil.getLocalhostOnlyContext(), ""); + assertTrue(client.onOpen, "client is open"); + assertFalse(client.onSSLError, "client has not caught a SSLHandshakeException"); } - @Override - public void onError(Exception ex) { - if (ex instanceof SSLHandshakeException) { - this.onSSLError = true; - } + @Test() + @Timeout(2000) + public void test_localServer_ServerLocalhost_ClientLocalhost_CheckDefault() + throws CertificateException, UnrecoverableKeyException, NoSuchAlgorithmException, KeyManagementException, KeyStoreException, IOException, URISyntaxException, InterruptedException { + SSLWebSocketClient client = testIssueWithLocalServer("localhost", SocketUtil.getAvailablePort(), + SSLContextUtil.getLocalhostOnlyContext(), SSLContextUtil.getLocalhostOnlyContext(), null); + assertTrue(client.onOpen, "client is open"); + assertFalse(client.onSSLError, "client has not caught a SSLHandshakeException"); } - @Override - protected void onSetSSLParameters(SSLParameters sslParameters) { - // Always call super to ensure hostname validation is active by default - super.onSetSSLParameters(sslParameters); - if (endpointIdentificationAlgorithm != null) { - sslParameters.setEndpointIdentificationAlgorithm(endpointIdentificationAlgorithm); - } + + public SSLWebSocketClient testIssueWithLocalServer(String address, int port, + SSLContext serverContext, SSLContext clientContext, String endpointIdentificationAlgorithm) + throws IOException, URISyntaxException, InterruptedException { + CountDownLatch countServerDownLatch = new CountDownLatch(1); + SSLWebSocketClient client = new SSLWebSocketClient(address, port, + endpointIdentificationAlgorithm); + WebSocketServer server = new SSLWebSocketServer(port, countServerDownLatch); + + server.setWebSocketFactory(new DefaultSSLWebSocketServerFactory(serverContext)); + if (clientContext != null) { + client.setSocketFactory(clientContext.getSocketFactory()); + } + server.start(); + countServerDownLatch.await(); + client.connectBlocking(1, TimeUnit.SECONDS); + return client; } - } + private static class SSLWebSocketClient extends WebSocketClient { - private static class SSLWebSocketServer extends WebSocketServer { + private final String endpointIdentificationAlgorithm; + public boolean onSSLError = false; + public boolean onOpen = false; - private final CountDownLatch countServerDownLatch; + public SSLWebSocketClient(String address, int port, String endpointIdentificationAlgorithm) + throws URISyntaxException { + super(new URI("wss://" + address + ':' + port)); + this.endpointIdentificationAlgorithm = endpointIdentificationAlgorithm; + } + @Override + public void onOpen(ServerHandshake handshakedata) { + this.onOpen = true; + } - public SSLWebSocketServer(int port, CountDownLatch countServerDownLatch) { - super(new InetSocketAddress(port)); - this.countServerDownLatch = countServerDownLatch; - } + @Override + public void onMessage(String message) { + } - @Override - public void onOpen(WebSocket conn, ClientHandshake handshake) { - } + @Override + public void onClose(int code, String reason, boolean remote) { + } - @Override - public void onClose(WebSocket conn, int code, String reason, boolean remote) { - } + @Override + public void onError(Exception ex) { + if (ex instanceof SSLHandshakeException) { + this.onSSLError = true; + } + } - @Override - public void onMessage(WebSocket conn, String message) { + @Override + protected void onSetSSLParameters(SSLParameters sslParameters) { + // Always call super to ensure hostname validation is active by default + super.onSetSSLParameters(sslParameters); + if (endpointIdentificationAlgorithm != null) { + sslParameters.setEndpointIdentificationAlgorithm(endpointIdentificationAlgorithm); + } + } } - @Override - public void onError(WebSocket conn, Exception ex) { - ex.printStackTrace(); - } - @Override - public void onStart() { - countServerDownLatch.countDown(); + private static class SSLWebSocketServer extends WebSocketServer { + + private final CountDownLatch countServerDownLatch; + + + public SSLWebSocketServer(int port, CountDownLatch countServerDownLatch) { + super(new InetSocketAddress(port)); + this.countServerDownLatch = countServerDownLatch; + } + + @Override + public void onOpen(WebSocket conn, ClientHandshake handshake) { + } + + @Override + public void onClose(WebSocket conn, int code, String reason, boolean remote) { + } + + @Override + public void onMessage(WebSocket conn, String message) { + + } + + @Override + public void onError(WebSocket conn, Exception ex) { + ex.printStackTrace(); + } + + @Override + public void onStart() { + countServerDownLatch.countDown(); + } } - } }