Skip to content
19 changes: 12 additions & 7 deletions py/selenium/webdriver/remote/client_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,22 +50,19 @@ class ClientConfig:
keep_alive = _ClientConfigDescriptor("_keep_alive")
"""Gets and Sets Keep Alive value."""
proxy = _ClientConfigDescriptor("_proxy")
"""Gets and Sets the proxy used for communicating to the driver/server."""
"""Gets and Sets the proxy used for communicating with the driver/server."""
ignore_certificates = _ClientConfigDescriptor("_ignore_certificates")
"""Gets and Sets the ignore certificate check value."""
init_args_for_pool_manager = _ClientConfigDescriptor("_init_args_for_pool_manager")
"""Gets and Sets the ignore certificate check."""
timeout = _ClientConfigDescriptor("_timeout")
"""Gets and Sets the timeout (in seconds) used for communicating to the
driver/server."""
"""Gets and Sets the timeout (in seconds) used for communicating with the driver/server."""
ca_certs = _ClientConfigDescriptor("_ca_certs")
"""Gets and Sets the path to bundle of CA certificates."""
username = _ClientConfigDescriptor("_username")
"""Gets and Sets the username used for basic authentication to the
remote."""
"""Gets and Sets the username used for basic authentication to the remote."""
password = _ClientConfigDescriptor("_password")
"""Gets and Sets the password used for basic authentication to the
remote."""
"""Gets and Sets the password used for basic authentication to the remote."""
auth_type = _ClientConfigDescriptor("_auth_type")
"""Gets and Sets the type of authentication to the remote server."""
token = _ClientConfigDescriptor("_token")
Expand All @@ -74,6 +71,10 @@ class ClientConfig:
"""Gets and Sets user agent to be added to the request headers."""
extra_headers = _ClientConfigDescriptor("_extra_headers")
"""Gets and Sets extra headers to be added to the request."""
websocket_timeout = _ClientConfigDescriptor("_websocket_timeout")
"""Gets and Sets the WebSocket response wait timeout (in seconds) used for communicating with the browser."""
websocket_interval = _ClientConfigDescriptor("_websocket_interval")
"""Gets and Sets the WebSocket response wait interval (in seconds) used for communicating with the browser."""

def __init__(
self,
Expand All @@ -90,6 +91,8 @@ def __init__(
token: Optional[str] = None,
user_agent: Optional[str] = None,
extra_headers: Optional[dict] = None,
websocket_timeout: Optional[float] = 30.0,
websocket_interval: Optional[float] = 0.1,
) -> None:
self.remote_server_addr = remote_server_addr
self.keep_alive = keep_alive
Expand All @@ -103,6 +106,8 @@ def __init__(
self.token = token
self.user_agent = user_agent
self.extra_headers = extra_headers
self.websocket_timeout = websocket_timeout
self.websocket_interval = websocket_interval

self.ca_certs = (
(os.getenv("REQUESTS_CA_BUNDLE") if "REQUESTS_CA_BUNDLE" in os.environ else certifi.where())
Expand Down
9 changes: 7 additions & 2 deletions py/selenium/webdriver/remote/webdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

"""The WebDriver implementation."""

import base64
Expand Down Expand Up @@ -1211,7 +1212,9 @@ def start_devtools(self):
return self._devtools, self._websocket_connection
if self.caps["browserName"].lower() == "firefox":
raise RuntimeError("CDP support for Firefox has been removed. Please switch to WebDriver BiDi.")
self._websocket_connection = WebSocketConnection(ws_url)
self._websocket_connection = WebSocketConnection(
ws_url, self.client_config.websocket_timeout, self.client_config.websocket_interval
)
targets = self._websocket_connection.execute(self._devtools.target.get_targets())
for target in targets:
if target.target_id == self.current_window_handle:
Expand Down Expand Up @@ -1260,7 +1263,9 @@ def _start_bidi(self):
else:
raise WebDriverException("Unable to find url to connect to from capabilities")

self._websocket_connection = WebSocketConnection(ws_url)
self._websocket_connection = WebSocketConnection(
ws_url, self.client_config.websocket_timeout, self.client_config.websocket_interval
)

@property
def network(self):
Expand Down
23 changes: 14 additions & 9 deletions py/selenium/webdriver/remote/websocket_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import json
import logging
from ssl import CERT_NONE
Expand All @@ -28,16 +29,20 @@


class WebSocketConnection:
_response_wait_timeout = 30
_response_wait_interval = 0.1

_max_log_message_size = 9999

def __init__(self, url):
self.callbacks = {}
self.session_id = None
def __init__(self, url, timeout, interval):
if not isinstance(timeout, (int, float)) or timeout < 0:
raise WebDriverException("timeout must be a positive number")
if not isinstance(interval, (int, float)) or timeout < 0:
raise WebDriverException("interval must be a positive number")

self.url = url
self.response_wait_timeout = timeout
self.response_wait_interval = interval

self.callbacks = {}
self.session_id = None
self._id = 0
self._messages = {}
self._started = False
Expand All @@ -46,7 +51,7 @@ def __init__(self, url):
self._wait_until(lambda: self._started)

def close(self):
self._ws_thread.join(timeout=self._response_wait_timeout)
self._ws_thread.join(timeout=self.response_wait_timeout)
self._ws.close()
self._started = False
self._ws = None
Expand Down Expand Up @@ -142,8 +147,8 @@ def _process_message(self, message):
Thread(target=callback, args=(params,)).start()

def _wait_until(self, condition):
timeout = self._response_wait_timeout
interval = self._response_wait_interval
timeout = self.response_wait_timeout
interval = self.response_wait_interval

while timeout > 0:
result = condition()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ def test_execute_custom_command(mock_request, remote_connection):
assert response == {"status": 200, "value": "OK"}


def test_default_websocket_settings():
config = ClientConfig(remote_server_addr="http://localhost:4444")
assert config.websocket_timeout == 30.0
assert config.websocket_interval == 0.1


def test_get_remote_connection_headers_defaults():
url = "http://remote"
headers = RemoteConnection.get_remote_connection_headers(parse.urlparse(url))
Expand Down
Loading