diff --git a/src/websockets/asyncio/client.py b/src/websockets/asyncio/client.py index 63cd2be2e..684cc8b49 100644 --- a/src/websockets/asyncio/client.py +++ b/src/websockets/asyncio/client.py @@ -18,6 +18,8 @@ InvalidProxyMessage, InvalidProxyStatus, InvalidStatus, + InvalidURI, + InvalidRedirectURI, ProxyError, SecurityError, ) @@ -493,7 +495,10 @@ def process_redirect(self, exc: Exception) -> Exception | str: old_ws_uri = parse_uri(self.uri) new_uri = urllib.parse.urljoin(self.uri, exc.response.headers["Location"]) - new_ws_uri = parse_uri(new_uri) + try: + new_ws_uri = parse_uri(new_uri) + except InvalidURI as uri_exception: + raise InvalidRedirectURI(uri_exception.uri, uri_exception.msg) # If connect() received a socket, it is closed and cannot be reused. if self.connection_kwargs.get("sock") is not None: diff --git a/src/websockets/exceptions.py b/src/websockets/exceptions.py index a88deaa66..4a1fc3522 100644 --- a/src/websockets/exceptions.py +++ b/src/websockets/exceptions.py @@ -177,6 +177,16 @@ def __str__(self) -> str: return f"{self.uri} isn't a valid URI: {self.msg}" +class InvalidRedirectURI(InvalidURI): + """ + Raised when redirected to a URI that isn't a valid WebSocket URI. + + """ + + def __str__(self) -> str: + return f"Redirection URI {self.uri} isn't a valid URI: {self.msg}" + + class InvalidProxy(WebSocketException): """ Raised when connecting via a proxy that isn't valid. diff --git a/tests/asyncio/test_client.py b/tests/asyncio/test_client.py index 465ea2bdb..a12bca189 100644 --- a/tests/asyncio/test_client.py +++ b/tests/asyncio/test_client.py @@ -354,6 +354,25 @@ def redirect(connection, request): "cannot follow redirect to ws://invalid/ with a preexisting socket", ) + async def test_not_a_websocket_redirect(self): + """Client raises an explicit error when redirected to an absolute URI that isn't using websocket protocole.""" + + def redirect(connection, request): + response = connection.respond(http.HTTPStatus.FOUND, "") + response.headers["Location"] = "https://not-a-websocket.com" + return response + + async with serve(*args, process_request=redirect) as server: + host, port = get_host_port(server) + with self.assertRaises(InvalidURI) as raised: + async with connect("ws://overridden/", host=host, port=port): + self.fail("did not raise") + + self.assertEqual( + str(raised.exception), + "Redirection URI https://not-a-websocket.com isn't a valid URI: scheme isn't ws or wss", + ) + async def test_invalid_uri(self): """Client receives an invalid URI.""" with self.assertRaises(InvalidURI):