diff --git a/py/selenium/webdriver/common/bidi/script.py b/py/selenium/webdriver/common/bidi/script.py index a5bc234eb1536..74b8a3568ac3a 100644 --- a/py/selenium/webdriver/common/bidi/script.py +++ b/py/selenium/webdriver/common/bidi/script.py @@ -15,9 +15,12 @@ # specific language governing permissions and limitations # under the License. +import datetime +import math from dataclasses import dataclass from typing import Any, Optional +from selenium.common.exceptions import WebDriverException from selenium.webdriver.common.bidi.common import command_builder from .log import LogEntryAdded @@ -238,12 +241,15 @@ class Script: "realm_destroyed": "script.realmDestroyed", } - def __init__(self, conn): + def __init__(self, conn, driver=None): self.conn = conn + self.driver = driver self.log_entry_subscribed = False self.subscriptions = {} self.callbacks = {} + # High-level APIs for SCRIPT module + def add_console_message_handler(self, handler): self._subscribe_to_log_entries() return self.conn.add_callback(LogEntryAdded, self._handle_log_entry("console", handler)) @@ -258,6 +264,122 @@ def remove_console_message_handler(self, id): remove_javascript_error_handler = remove_console_message_handler + def pin(self, script: str) -> str: + """Pins a script to the current browsing context. + + Parameters: + ----------- + script: The script to pin. + + Returns: + ------- + str: The ID of the pinned script. + """ + return self._add_preload_script(script) + + def unpin(self, script_id: str) -> None: + """Unpins a script from the current browsing context. + + Parameters: + ----------- + script_id: The ID of the pinned script to unpin. + """ + self._remove_preload_script(script_id) + + def execute(self, script: str, *args) -> dict: + """Executes a script in the current browsing context. + + Parameters: + ----------- + script: The script function to execute. + *args: Arguments to pass to the script function. + + Returns: + ------- + dict: The result value from the script execution. + + Raises: + ------ + WebDriverException: If the script execution fails. + """ + + if self.driver is None: + raise WebDriverException("Driver reference is required for script execution") + browsing_context_id = self.driver.current_window_handle + + # Convert arguments to the format expected by BiDi call_function (LocalValue Type) + arguments = [] + for arg in args: + arguments.append(self.__convert_to_local_value(arg)) + + target = {"context": browsing_context_id} + + result = self._call_function( + function_declaration=script, await_promise=True, target=target, arguments=arguments if arguments else None + ) + + if result.type == "success": + return result.result + else: + error_message = "Error while executing script" + if result.exception_details: + if "text" in result.exception_details: + error_message += f": {result.exception_details['text']}" + elif "message" in result.exception_details: + error_message += f": {result.exception_details['message']}" + + raise WebDriverException(error_message) + + def __convert_to_local_value(self, value) -> dict: + """ + Converts a Python value to BiDi LocalValue format. + """ + if value is None: + return {"type": "null"} + elif isinstance(value, bool): + return {"type": "boolean", "value": value} + elif isinstance(value, (int, float)): + if isinstance(value, float): + if math.isnan(value): + return {"type": "number", "value": "NaN"} + elif math.isinf(value): + if value > 0: + return {"type": "number", "value": "Infinity"} + else: + return {"type": "number", "value": "-Infinity"} + elif value == 0.0 and math.copysign(1.0, value) < 0: + return {"type": "number", "value": "-0"} + + JS_MAX_SAFE_INTEGER = 9007199254740991 + if isinstance(value, int) and (value > JS_MAX_SAFE_INTEGER or value < -JS_MAX_SAFE_INTEGER): + return {"type": "bigint", "value": str(value)} + + return {"type": "number", "value": value} + + elif isinstance(value, str): + return {"type": "string", "value": value} + elif isinstance(value, datetime.datetime): + # Convert Python datetime to JavaScript Date (ISO 8601 format) + return {"type": "date", "value": value.isoformat() + "Z" if value.tzinfo is None else value.isoformat()} + elif isinstance(value, datetime.date): + # Convert Python date to JavaScript Date + dt = datetime.datetime.combine(value, datetime.time.min).replace(tzinfo=datetime.timezone.utc) + return {"type": "date", "value": dt.isoformat()} + elif isinstance(value, set): + return {"type": "set", "value": [self.__convert_to_local_value(item) for item in value]} + elif isinstance(value, (list, tuple)): + return {"type": "array", "value": [self.__convert_to_local_value(item) for item in value]} + elif isinstance(value, dict): + return { + "type": "object", + "value": [ + [self.__convert_to_local_value(k), self.__convert_to_local_value(v)] for k, v in value.items() + ], + } + else: + # For other types, convert to string + return {"type": "string", "value": str(value)} + # low-level APIs for script module def _add_preload_script( self, diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index 7b19f053f3c99..0ab678103b844 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -1240,7 +1240,7 @@ def script(self): self._start_bidi() if not self._script: - self._script = Script(self._websocket_connection) + self._script = Script(self._websocket_connection, self) return self._script diff --git a/py/test/selenium/webdriver/common/bidi_script_tests.py b/py/test/selenium/webdriver/common/bidi_script_tests.py index 8677d2dbae396..35f8e455573be 100644 --- a/py/test/selenium/webdriver/common/bidi_script_tests.py +++ b/py/test/selenium/webdriver/common/bidi_script_tests.py @@ -60,7 +60,7 @@ def test_logs_console_errors(driver, pages): log_entries = [] def log_error(entry): - if entry.level == "error": + if entry.level == LogLevel.ERROR: log_entries.append(entry) driver.script.add_console_message_handler(log_error) @@ -561,3 +561,312 @@ def test_disown_handles(driver, pages): target={"context": driver.current_window_handle}, arguments=[{"handle": handle}], ) + + +# Tests for high-level SCRIPT API commands - pin, unpin, and execute + + +def test_pin_script(driver, pages): + """Test pinning a script.""" + function_declaration = "() => { window.pinnedScriptExecuted = 'yes'; }" + + script_id = driver.script.pin(function_declaration) + assert script_id is not None + assert isinstance(script_id, str) + + pages.load("blank.html") + + result = driver.script.execute("() => window.pinnedScriptExecuted") + assert result["value"] == "yes" + + +def test_unpin_script(driver, pages): + """Test unpinning a script.""" + function_declaration = "() => { window.unpinnableScript = 'executed'; }" + + script_id = driver.script.pin(function_declaration) + driver.script.unpin(script_id) + + pages.load("blank.html") + + result = driver.script.execute("() => typeof window.unpinnableScript") + assert result["value"] == "undefined" + + +def test_execute_script_with_null_argument(driver, pages): + """Test executing script with undefined argument.""" + pages.load("blank.html") + + result = driver.script.execute( + """(arg) => { + if(arg!==null) + throw Error("Argument should be null, but was "+arg); + return arg; + }""", + None, + ) + + assert result["type"] == "null" + + +def test_execute_script_with_number_argument(driver, pages): + """Test executing script with number argument.""" + pages.load("blank.html") + + result = driver.script.execute( + """(arg) => { + if(arg!==1.4) + throw Error("Argument should be 1.4, but was "+arg); + return arg; + }""", + 1.4, + ) + + assert result["type"] == "number" + assert result["value"] == 1.4 + + +def test_execute_script_with_nan(driver, pages): + """Test executing script with NaN argument.""" + pages.load("blank.html") + + result = driver.script.execute( + """(arg) => { + if(!Number.isNaN(arg)) + throw Error("Argument should be NaN, but was "+arg); + return arg; + }""", + float("nan"), + ) + + assert result["type"] == "number" + assert result["value"] == "NaN" + + +def test_execute_script_with_inf(driver, pages): + """Test executing script with number argument.""" + pages.load("blank.html") + + result = driver.script.execute( + """(arg) => { + if(arg!==Infinity) + throw Error("Argument should be Infinity, but was "+arg); + return arg; + }""", + float("inf"), + ) + + assert result["type"] == "number" + assert result["value"] == "Infinity" + + +def test_execute_script_with_minus_inf(driver, pages): + """Test executing script with number argument.""" + pages.load("blank.html") + + result = driver.script.execute( + """(arg) => { + if(arg!==-Infinity) + throw Error("Argument should be -Infinity, but was "+arg); + return arg; + }""", + float("-inf"), + ) + + assert result["type"] == "number" + assert result["value"] == "-Infinity" + + +def test_execute_script_with_bigint_argument(driver, pages): + """Test executing script with BigInt argument.""" + pages.load("blank.html") + + # Use a large integer that exceeds JavaScript safe integer limit + large_int = 9007199254740992 + result = driver.script.execute( + """(arg) => { + if(arg !== 9007199254740992n) + throw Error("Argument should be 9007199254740992n (BigInt), but was "+arg+" (type: "+typeof arg+")"); + return arg; + }""", + large_int, + ) + + assert result["type"] == "bigint" + assert result["value"] == str(large_int) + + +def test_execute_script_with_boolean_argument(driver, pages): + """Test executing script with boolean argument.""" + pages.load("blank.html") + + result = driver.script.execute( + """(arg) => { + if(arg!==true) + throw Error("Argument should be true, but was "+arg); + return arg; + }""", + True, + ) + + assert result["type"] == "boolean" + assert result["value"] is True + + +def test_execute_script_with_string_argument(driver, pages): + """Test executing script with string argument.""" + pages.load("blank.html") + + result = driver.script.execute( + """(arg) => { + if(arg!=="hello world") + throw Error("Argument should be 'hello world', but was "+arg); + return arg; + }""", + "hello world", + ) + + assert result["type"] == "string" + assert result["value"] == "hello world" + + +def test_execute_script_with_date_argument(driver, pages): + """Test executing script with date argument.""" + import datetime + + pages.load("blank.html") + + date = datetime.datetime(2023, 12, 25, 10, 30, 45) + result = driver.script.execute( + """(arg) => { + if(!(arg instanceof Date)) + throw Error("Argument type should be Date, but was "+ + Object.prototype.toString.call(arg)); + if(arg.getFullYear() !== 2023) + throw Error("Year should be 2023, but was "+arg.getFullYear()); + return arg; + }""", + date, + ) + + assert result["type"] == "date" + assert "2023-12-25T10:30:45" in result["value"] + + +def test_execute_script_with_array_argument(driver, pages): + """Test executing script with array argument.""" + pages.load("blank.html") + + test_list = [1, 2, 3] + + result = driver.script.execute( + """(arg) => { + if(!(arg instanceof Array)) + throw Error("Argument type should be Array, but was "+ + Object.prototype.toString.call(arg)); + if(arg.length !== 3) + throw Error("Array should have 3 elements, but had "+arg.length); + return arg; + }""", + test_list, + ) + + assert result["type"] == "array" + values = result["value"] + assert len(values) == 3 + + +def test_execute_script_with_multiple_arguments(driver, pages): + """Test executing script with multiple arguments.""" + pages.load("blank.html") + + result = driver.script.execute( + """(a, b, c) => { + if(a !== 1) throw Error("First arg should be 1"); + if(b !== "test") throw Error("Second arg should be 'test'"); + if(c !== true) throw Error("Third arg should be true"); + return a + b.length + (c ? 1 : 0); + }""", + 1, + "test", + True, + ) + + assert result["type"] == "number" + assert result["value"] == 6 # 1 + 4 + 1 + + +def test_execute_script_returns_promise(driver, pages): + """Test executing script that returns a promise.""" + pages.load("blank.html") + + result = driver.script.execute( + """() => { + return Promise.resolve("async result"); + }""", + ) + + assert result["type"] == "string" + assert result["value"] == "async result" + + +def test_execute_script_with_exception(driver, pages): + """Test executing script that throws an exception.""" + pages.load("blank.html") + + from selenium.common.exceptions import WebDriverException + + with pytest.raises(WebDriverException) as exc_info: + driver.script.execute( + """() => { + throw new Error("Test error message"); + }""", + ) + + assert "Test error message" in str(exc_info.value) + + +def test_execute_script_accessing_dom(driver, pages): + """Test executing script that accesses DOM elements.""" + pages.load("formPage.html") + + result = driver.script.execute( + """() => { + return document.title; + }""", + ) + + assert result["type"] == "string" + assert result["value"] == "We Leave From Here" + + +def test_execute_script_with_nested_objects(driver, pages): + """Test executing script with nested object arguments.""" + pages.load("blank.html") + + nested_data = { + "user": { + "name": "John", + "age": 30, + "hobbies": ["reading", "coding"], + }, + "settings": {"theme": "dark", "notifications": True}, + } + + result = driver.script.execute( + """(data) => { + return { + userName: data.user.name, + userAge: data.user.age, + hobbyCount: data.user.hobbies.length, + theme: data.settings.theme + }; + }""", + nested_data, + ) + + assert result["type"] == "object" + value_dict = {k: v["value"] for k, v in result["value"]} + assert value_dict["userName"] == "John" + assert value_dict["userAge"] == 30 + assert value_dict["hobbyCount"] == 2