diff --git a/py/selenium/webdriver/common/by.py b/py/selenium/webdriver/common/by.py index 56a3f96d6fbb8..65a74649f070e 100644 --- a/py/selenium/webdriver/common/by.py +++ b/py/selenium/webdriver/common/by.py @@ -16,7 +16,9 @@ # under the License. """The By implementation.""" +from typing import Dict from typing import Literal +from typing import Optional class By: @@ -31,5 +33,19 @@ class By: CLASS_NAME = "class name" CSS_SELECTOR = "css selector" + _custom_finders: Dict[str, str] = {} + + @classmethod + def register_custom_finder(cls, name: str, strategy: str) -> None: + cls._custom_finders[name] = strategy + + @classmethod + def get_finder(cls, name: str) -> Optional[str]: + return cls._custom_finders.get(name) or getattr(cls, name.upper(), None) + + @classmethod + def clear_custom_finders(cls) -> None: + cls._custom_finders.clear() + ByType = Literal["id", "xpath", "link text", "partial link text", "name", "tag name", "class name", "css selector"] diff --git a/py/selenium/webdriver/remote/locator_converter.py b/py/selenium/webdriver/remote/locator_converter.py new file mode 100644 index 0000000000000..b43da73ef47cd --- /dev/null +++ b/py/selenium/webdriver/remote/locator_converter.py @@ -0,0 +1,28 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +class LocatorConverter: + def convert(self, by, value): + # Default conversion logic + if by == "id": + return "css selector", f'[id="{value}"]' + elif by == "class name": + return "css selector", f".{value}" + elif by == "name": + return "css selector", f'[name="{value}"]' + return by, value diff --git a/py/selenium/webdriver/remote/remote_connection.py b/py/selenium/webdriver/remote/remote_connection.py index e409d2b9c104b..0a4bde22c40d2 100644 --- a/py/selenium/webdriver/remote/remote_connection.py +++ b/py/selenium/webdriver/remote/remote_connection.py @@ -143,6 +143,14 @@ class RemoteConnection: ) _ca_certs = os.getenv("REQUESTS_CA_BUNDLE") if "REQUESTS_CA_BUNDLE" in os.environ else certifi.where() + system = platform.system().lower() + if system == "darwin": + system = "mac" + + # Class variables for headers + extra_headers = None + user_agent = f"selenium/{__version__} (python {system})" + @classmethod def get_timeout(cls): """:Returns: @@ -196,14 +204,10 @@ def get_remote_connection_headers(cls, parsed_url, keep_alive=False): - keep_alive (Boolean) - Is this a keep-alive connection (default: False) """ - system = platform.system().lower() - if system == "darwin": - system = "mac" - headers = { "Accept": "application/json", "Content-Type": "application/json;charset=UTF-8", - "User-Agent": f"selenium/{__version__} (python {system})", + "User-Agent": cls.user_agent, } if parsed_url.username: @@ -213,6 +217,9 @@ def get_remote_connection_headers(cls, parsed_url, keep_alive=False): if keep_alive: headers.update({"Connection": "keep-alive"}) + if cls.extra_headers: + headers.update(cls.extra_headers) + return headers def _get_proxy_url(self): @@ -236,7 +243,12 @@ def _separate_http_proxy_auth(self): def _get_connection_manager(self): pool_manager_init_args = {"timeout": self.get_timeout()} - if self._ca_certs: + pool_manager_init_args.update(self._init_args_for_pool_manager.get("init_args_for_pool_manager", {})) + + if self._ignore_certificates: + pool_manager_init_args["cert_reqs"] = "CERT_NONE" + urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) + elif self._ca_certs: pool_manager_init_args["cert_reqs"] = "CERT_REQUIRED" pool_manager_init_args["ca_certs"] = self._ca_certs @@ -252,9 +264,18 @@ def _get_connection_manager(self): return urllib3.PoolManager(**pool_manager_init_args) - def __init__(self, remote_server_addr: str, keep_alive: bool = False, ignore_proxy: bool = False): + def __init__( + self, + remote_server_addr: str, + keep_alive: bool = False, + ignore_proxy: bool = False, + ignore_certificates: bool = False, + init_args_for_pool_manager: dict = None, + ): self.keep_alive = keep_alive self._url = remote_server_addr + self._ignore_certificates = ignore_certificates + self._init_args_for_pool_manager = init_args_for_pool_manager or {} # Env var NO_PROXY will override this part of the code _no_proxy = os.environ.get("no_proxy", os.environ.get("NO_PROXY")) @@ -280,6 +301,16 @@ def __init__(self, remote_server_addr: str, keep_alive: bool = False, ignore_pro self._conn = self._get_connection_manager() self._commands = remote_commands + extra_commands = {} + + def add_command(self, name, method, url): + """Register a new command.""" + self._commands[name] = (method, url) + + def get_command(self, name: str): + """Retrieve a command if it exists.""" + return self._commands.get(name) + def execute(self, command, params): """Send a command to the remote server. @@ -291,7 +322,7 @@ def execute(self, command, params): - params - A dictionary of named parameters to send with the command as its JSON payload. """ - command_info = self._commands[command] + command_info = self._commands.get(command) or self.extra_commands.get(command) assert command_info is not None, f"Unrecognised command {command}" path_string = command_info[1] path = string.Template(path_string).substitute(params) diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index 41c4645bdc686..8ef6292012089 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -58,6 +58,7 @@ from .errorhandler import ErrorHandler from .file_detector import FileDetector from .file_detector import LocalFileDetector +from .locator_converter import LocatorConverter from .mobile import Mobile from .remote_connection import RemoteConnection from .script_key import ScriptKey @@ -171,6 +172,8 @@ def __init__( keep_alive: bool = True, file_detector: Optional[FileDetector] = None, options: Optional[Union[BaseOptions, List[BaseOptions]]] = None, + locator_converter: Optional[LocatorConverter] = None, + web_element_cls: Optional[type] = None, ) -> None: """Create a new driver that will issue commands using the wire protocol. @@ -183,6 +186,8 @@ def __init__( - file_detector - Pass custom file detector object during instantiation. If None, then default LocalFileDetector() will be used. - options - instance of a driver options.Options class + - locator_converter - Custom locator converter to use. Defaults to None. + - web_element_cls - Custom class to use for web elements. Defaults to WebElement. """ if isinstance(options, list): @@ -207,6 +212,8 @@ def __init__( self._switch_to = SwitchTo(self) self._mobile = Mobile(self) self.file_detector = file_detector or LocalFileDetector() + self.locator_converter = locator_converter or LocatorConverter() + self._web_element_cls = web_element_cls or self._web_element_cls self._authenticator_id = None self.start_client() self.start_session(capabilities) @@ -729,22 +736,14 @@ def find_element(self, by=By.ID, value: Optional[str] = None) -> WebElement: :rtype: WebElement """ + by, value = self.locator_converter.convert(by, value) + if isinstance(by, RelativeBy): elements = self.find_elements(by=by, value=value) if not elements: raise NoSuchElementException(f"Cannot locate relative element with: {by.root}") return elements[0] - if by == By.ID: - by = By.CSS_SELECTOR - value = f'[id="{value}"]' - elif by == By.CLASS_NAME: - by = By.CSS_SELECTOR - value = f".{value}" - elif by == By.NAME: - by = By.CSS_SELECTOR - value = f'[name="{value}"]' - return self.execute(Command.FIND_ELEMENT, {"using": by, "value": value})["value"] def find_elements(self, by=By.ID, value: Optional[str] = None) -> List[WebElement]: @@ -757,22 +756,14 @@ def find_elements(self, by=By.ID, value: Optional[str] = None) -> List[WebElemen :rtype: list of WebElement """ + by, value = self.locator_converter.convert(by, value) + if isinstance(by, RelativeBy): _pkg = ".".join(__name__.split(".")[:-1]) raw_function = pkgutil.get_data(_pkg, "findElements.js").decode("utf8") find_element_js = f"/* findElements */return ({raw_function}).apply(null, arguments);" return self.execute_script(find_element_js, by.to_dict()) - if by == By.ID: - by = By.CSS_SELECTOR - value = f'[id="{value}"]' - elif by == By.CLASS_NAME: - by = By.CSS_SELECTOR - value = f".{value}" - elif by == By.NAME: - by = By.CSS_SELECTOR - value = f'[name="{value}"]' - # Return empty list if driver returns null # See https://github.com/SeleniumHQ/selenium/issues/4555 return self.execute(Command.FIND_ELEMENTS, {"using": by, "value": value})["value"] or [] diff --git a/py/selenium/webdriver/remote/webelement.py b/py/selenium/webdriver/remote/webelement.py index ef60757294caa..08c772eaad56e 100644 --- a/py/selenium/webdriver/remote/webelement.py +++ b/py/selenium/webdriver/remote/webelement.py @@ -404,16 +404,7 @@ def find_element(self, by=By.ID, value=None) -> WebElement: :rtype: WebElement """ - if by == By.ID: - by = By.CSS_SELECTOR - value = f'[id="{value}"]' - elif by == By.CLASS_NAME: - by = By.CSS_SELECTOR - value = f".{value}" - elif by == By.NAME: - by = By.CSS_SELECTOR - value = f'[name="{value}"]' - + by, value = self._parent.locator_converter.convert(by, value) return self._execute(Command.FIND_CHILD_ELEMENT, {"using": by, "value": value})["value"] def find_elements(self, by=By.ID, value=None) -> List[WebElement]: @@ -426,16 +417,7 @@ def find_elements(self, by=By.ID, value=None) -> List[WebElement]: :rtype: list of WebElement """ - if by == By.ID: - by = By.CSS_SELECTOR - value = f'[id="{value}"]' - elif by == By.CLASS_NAME: - by = By.CSS_SELECTOR - value = f".{value}" - elif by == By.NAME: - by = By.CSS_SELECTOR - value = f'[name="{value}"]' - + by, value = self._parent.locator_converter.convert(by, value) return self._execute(Command.FIND_CHILD_ELEMENTS, {"using": by, "value": value})["value"] def __hash__(self) -> int: diff --git a/py/test/selenium/webdriver/common/driver_element_finding_tests.py b/py/test/selenium/webdriver/common/driver_element_finding_tests.py index 2578ac2f57861..205edb92e1f88 100644 --- a/py/test/selenium/webdriver/common/driver_element_finding_tests.py +++ b/py/test/selenium/webdriver/common/driver_element_finding_tests.py @@ -715,3 +715,21 @@ def test_should_not_be_able_to_find_an_element_on_a_blank_page(driver, pages): driver.get("about:blank") with pytest.raises(NoSuchElementException): driver.find_element(By.TAG_NAME, "a") + + +# custom finders tests + + +def test_register_and_get_custom_finder(): + By.register_custom_finder("custom", "custom strategy") + assert By.get_finder("custom") == "custom strategy" + + +def test_get_nonexistent_finder(): + assert By.get_finder("nonexistent") is None + + +def test_clear_custom_finders(): + By.register_custom_finder("custom", "custom strategy") + By.clear_custom_finders() + assert By.get_finder("custom") is None diff --git a/py/test/selenium/webdriver/remote/custom_element_tests.py b/py/test/selenium/webdriver/remote/custom_element_tests.py new file mode 100644 index 0000000000000..3fccb52ad3119 --- /dev/null +++ b/py/test/selenium/webdriver/remote/custom_element_tests.py @@ -0,0 +1,50 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from selenium.webdriver.common.by import By +from selenium.webdriver.remote.webelement import WebElement + + +# Custom element class +class MyCustomElement(WebElement): + def custom_method(self): + return "Custom element method" + + +def test_find_element_with_custom_class(driver, pages): + """Test to ensure custom element class is used for a single element.""" + driver._web_element_cls = MyCustomElement + pages.load("simpleTest.html") + element = driver.find_element(By.TAG_NAME, "body") + assert isinstance(element, MyCustomElement) + assert element.custom_method() == "Custom element method" + + +def test_find_elements_with_custom_class(driver, pages): + """Test to ensure custom element class is used for multiple elements.""" + driver._web_element_cls = MyCustomElement + pages.load("simpleTest.html") + elements = driver.find_elements(By.TAG_NAME, "div") + assert all(isinstance(el, MyCustomElement) for el in elements) + assert all(el.custom_method() == "Custom element method" for el in elements) + + +def test_default_element_class(driver, pages): + """Test to ensure default WebElement class is used.""" + pages.load("simpleTest.html") + element = driver.find_element(By.TAG_NAME, "body") + assert isinstance(element, WebElement) + assert not hasattr(element, "custom_method") diff --git a/py/test/selenium/webdriver/remote/remote_custom_locator_tests.py b/py/test/selenium/webdriver/remote/remote_custom_locator_tests.py new file mode 100644 index 0000000000000..e235f2ee2e999 --- /dev/null +++ b/py/test/selenium/webdriver/remote/remote_custom_locator_tests.py @@ -0,0 +1,40 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from selenium.webdriver.remote.locator_converter import LocatorConverter + + +class CustomLocatorConverter(LocatorConverter): + def convert(self, by, value): + # Custom conversion logic + if by == "custom": + return "css selector", f'[custom-attr="{value}"]' + return super().convert(by, value) + + +def test_find_element_with_custom_locator(driver): + driver.get("data:text/html,
Test
") + element = driver.find_element("custom", "example") + assert element is not None + assert element.text == "Test" + + +def test_find_elements_with_custom_locator(driver): + driver.get("data:text/html,
Test1
Test2
") + elements = driver.find_elements("custom", "example") + assert len(elements) == 2 + assert elements[0].text == "Test1" + assert elements[1].text == "Test2" diff --git a/py/test/unit/selenium/webdriver/remote/remote_connection_tests.py b/py/test/unit/selenium/webdriver/remote/remote_connection_tests.py index 2c798365d85f5..260214e5c918d 100644 --- a/py/test/unit/selenium/webdriver/remote/remote_connection_tests.py +++ b/py/test/unit/selenium/webdriver/remote/remote_connection_tests.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +from unittest.mock import patch from urllib import parse import pytest @@ -24,6 +25,31 @@ from selenium.webdriver.remote.remote_connection import RemoteConnection +@pytest.fixture +def remote_connection(): + """Fixture to create a RemoteConnection instance.""" + return RemoteConnection("http://localhost:4444") + + +def test_add_command(remote_connection): + """Test adding a custom command to the connection.""" + remote_connection.add_command("CUSTOM_COMMAND", "PUT", "/session/$sessionId/custom") + assert remote_connection.get_command("CUSTOM_COMMAND") == ("PUT", "/session/$sessionId/custom") + + +@patch("selenium.webdriver.remote.remote_connection.RemoteConnection._request") +def test_execute_custom_command(mock_request, remote_connection): + """Test executing a custom command through the connection.""" + remote_connection.add_command("CUSTOM_COMMAND", "PUT", "/session/$sessionId/custom") + mock_request.return_value = {"status": 200, "value": "OK"} + + params = {"sessionId": "12345"} + response = remote_connection.execute("CUSTOM_COMMAND", params) + + mock_request.assert_called_once_with("PUT", "http://localhost:4444/session/12345/custom", body="{}") + assert response == {"status": 200, "value": "OK"} + + def test_get_remote_connection_headers_defaults(): url = "http://remote" headers = RemoteConnection.get_remote_connection_headers(parse.urlparse(url)) @@ -32,7 +58,7 @@ def test_get_remote_connection_headers_defaults(): assert headers.get("Accept") == "application/json" assert headers.get("Content-Type") == "application/json;charset=UTF-8" assert headers.get("User-Agent").startswith(f"selenium/{__version__} (python ") - assert headers.get("User-Agent").split(" ")[-1] in {"windows)", "mac)", "linux)"} + assert headers.get("User-Agent").split(" ")[-1] in {"windows)", "mac)", "linux)", "mac", "windows", "linux"} def test_get_remote_connection_headers_adds_auth_header_if_pass(): @@ -239,3 +265,52 @@ def mock_no_proxy_settings(monkeypatch): monkeypatch.setenv("http_proxy", http_proxy) monkeypatch.setenv("no_proxy", "65.253.214.253,localhost,127.0.0.1,*zyz.xx,::1") monkeypatch.setenv("NO_PROXY", "65.253.214.253,localhost,127.0.0.1,*zyz.xx,::1,127.0.0.0/8") + + +@patch("selenium.webdriver.remote.remote_connection.RemoteConnection.get_remote_connection_headers") +def test_override_user_agent_in_headers(mock_get_remote_connection_headers, remote_connection): + RemoteConnection.user_agent = "custom-agent/1.0 (python 3.8)" + + mock_get_remote_connection_headers.return_value = { + "Accept": "application/json", + "Content-Type": "application/json;charset=UTF-8", + "User-Agent": "custom-agent/1.0 (python 3.8)", + } + + headers = RemoteConnection.get_remote_connection_headers(parse.urlparse("http://remote")) + + assert headers.get("User-Agent") == "custom-agent/1.0 (python 3.8)" + assert headers.get("Accept") == "application/json" + assert headers.get("Content-Type") == "application/json;charset=UTF-8" + + +@patch("selenium.webdriver.remote.remote_connection.RemoteConnection._request") +def test_register_extra_headers(mock_request, remote_connection): + RemoteConnection.extra_headers = {"Foo": "bar"} + + mock_request.return_value = {"status": 200, "value": "OK"} + remote_connection.execute("newSession", {}) + + mock_request.assert_called_once_with("POST", "http://localhost:4444/session", body="{}") + headers = RemoteConnection.get_remote_connection_headers(parse.urlparse("http://localhost:4444"), False) + assert headers["Foo"] == "bar" + + +def test_get_connection_manager_ignores_certificates(monkeypatch): + monkeypatch.setattr(RemoteConnection, "get_timeout", lambda _: 10) + remote_connection = RemoteConnection("http://remote", ignore_certificates=True) + conn = remote_connection._get_connection_manager() + + assert conn.connection_pool_kw["timeout"] == 10 + assert conn.connection_pool_kw["cert_reqs"] == "CERT_NONE" + assert isinstance(conn, urllib3.PoolManager) + + +def test_get_connection_manager_with_custom_args(): + custom_args = {"init_args_for_pool_manager": {"retries": 3, "block": True}} + remote_connection = RemoteConnection("http://remote", keep_alive=False, init_args_for_pool_manager=custom_args) + conn = remote_connection._get_connection_manager() + + assert isinstance(conn, urllib3.PoolManager) + assert conn.connection_pool_kw["retries"] == 3 + assert conn.connection_pool_kw["block"] is True