diff --git a/py/selenium/webdriver/common/bidi/network.py b/py/selenium/webdriver/common/bidi/network.py new file mode 100644 index 0000000000000..09535805544a3 --- /dev/null +++ b/py/selenium/webdriver/common/bidi/network.py @@ -0,0 +1,262 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from .session import session_subscribe +from .session import session_unsubscribe + + +class Network: + EVENTS = { + 'before_request': 'network.beforeRequestSent', + 'response_started': 'network.responseStarted', + 'response_completed': 'network.responseCompleted', + 'auth_required': 'network.authRequired', + 'fetch_error': 'network.fetchError' + } + + PHASES = { + 'before_request': 'beforeRequestSent', + 'response_started': 'responseStarted', + 'auth_required': 'authRequired' + } + + def __init__(self, conn): + self.conn = conn + self.callbacks = {} + self.subscriptions = { + 'network.responseStarted': [], + 'network.beforeRequestSent': [], + 'network.authRequired': [] + } + + + def command_iterator(self, command): + """Generator to yield command.""" + yield command + return + + def has_callbacks(self): + """Checks if there are any callbacks set.""" + return len(self.callbacks) > 0 + + def __add_intercept(self, phases=None, contexts=None, url_patterns=None): + """Add an intercept to the network.""" + if phases is None: + phases = [] + if contexts is None and url_patterns is None: + params = { + 'phases': phases, + } + elif contexts is None: + params = { + 'phases': phases, + 'urlPatterns': url_patterns + } + elif url_patterns is None: + params = { + 'phases': phases, + 'contexts': contexts + } + else: + params = { + 'phases': phases, + 'contexts': contexts, + 'urlPatterns': url_patterns + } + command = {'method': 'network.addIntercept', 'params': params} + self.conn.execute(self.command_iterator(command)) + + def __remove_intercept(self, intercept=None, request_id=None): + """Remove an intercept from the network.""" + if request_id is not None: + command = {'method': 'network.removeIntercept', 'requestId': request_id} + self.conn.execute(self.command_iterator(command)) + elif intercept is not None: + command = {'method': 'network.removeIntercept', 'intercept': intercept} + self.conn.execute(self.command_iterator(command)) + else: + raise ValueError('Either requestId or intercept must be specified') + + def __continue_with_auth(self, request_id, username, password): + """Continue with authentication.""" + command = {'method': 'network.continueWithAuth', 'params': + { + 'request': request_id, + 'action': 'provideCredentials', + 'credentials': { + 'type': 'password', + 'username': username, + 'password': password + } + } + } + self.conn.execute(self.command_iterator(command)) + + def __on(self, event, callback): + """Set a callback function to subscribe to a network event.""" + event = self.EVENTS.get(event, event) + self.callbacks[event] = callback + if len(self.subscriptions[event]) == 0: + session_subscribe(self.conn, event, self.__handle_event) + + def __handle_event(self, event, data): + """Perform callback function on event.""" + if event in self.callbacks: + self.callbacks[event](data) + + def add_authentication_handler(self, username, password): + """Adds an authentication handler.""" + self.__add_intercept(phases=[self.PHASES['auth_required']]) + self.__on('auth_required', lambda data: self.__continue_with_auth(data['request']['request'], username, password)) + self.subscriptions['auth_required'] = [username, password] + + def remove_authentication_handler(self): + """Removes an authentication handler.""" + self.__remove_intercept(intercept='auth_required') + del self.subscriptions['auth_required'] + session_unsubscribe(self.conn, self.EVENTS['auth_required']) + + def add_request_handler(self, callback, url_pattern=''): + """Adds a request handler that executes a callback function when a + request matches the given URL pattern. + + Parameters: + callback (function): A function to be executed when url is matched by a URL pattern + The callback function receives a `Response` object as its argument. + url_pattern (str, optional): A substring to match against the response URL. + Default is an empty string, which matches all URLs. + + Returns: + str: The request ID of the intercepted response. + """ + self.__add_intercept(phases=[self.PHASES['before_request']]) + request_id = None + def callback_on_url_match(data): + nonlocal request_id + if url_pattern in data['request']['url']: + # create request object to pass to callback + request_id = data['request'].get('requestId', None) + url = data['request'].get('url') + method = data['request'].get('method') + headers = data['request'].get('headers', {}) + body = data['request'].get('body', None) + request = Request(request_id, url, method, headers, body, self) + callback(request) + self.__on('before_request', callback_on_url_match) + self.callbacks[request_id] = callback + if 'before_request' not in self.subscriptions or not self.subscriptions.get('before_request'): + self.subscriptions['before_request'] = [request_id] + else: + self.subscriptions['before_request'].append(request_id) + return request_id + + def remove_request_handler(self, request_id): + """Removes a request handler.""" + self.__remove_intercept(request_id=request_id) + self.subscriptions['before_request'].remove(request_id) + del self.callbacks[request_id] + if len(self.subscriptions['before_request']) == 0: + session_unsubscribe(self.conn, self.EVENTS['before_request']) + + def add_response_handler(self, callback, url_pattern=''): + """Adds a response handler that executes a callback function when a + response matches the given URL pattern. + + Parameters: + callback (function): A function to be executed when url is matched by a url_pattern + The callback function receives a `Response` object as its argument. + url_pattern (str, optional): A substring to match against the response URL. + Default is an empty string, which matches all URLs. + + Returns: + str: The request ID of the intercepted response. + """ + self.__add_intercept(phases=[self.PHASES['response_started']]) + request_id = None + def callback_on_url_match(data): + # create response object to pass to callback + nonlocal request_id + if url_pattern in data['response']['url']: + request_id = data['request'].get('requestId', None) + url = data['response'].get('url') + status_code = data['response'].get('status') + body = data['response'].get('body', None) + headers = data['response'].get('headers', {}) + response = Response(request_id, url, status_code, headers, body, self) + callback(response) + self.__on('response_started', callback_on_url_match) + self.callbacks[request_id] = callback + if 'response_started' not in self.subscriptions or not self.subscriptions.get('response_started'): + self.subscriptions['response_started'] = [request_id] + else: + self.subscriptions['response_started'].append(request_id) + return request_id + + def remove_response_handler(self, response_id): + """Removes a response handler.""" + self.__remove_intercept(request_id=response_id) + self.subscriptions['response_started'].remove(response_id) + del self.callbacks[response_id] + if len(self.subscriptions['response_started']) == 0: + session_unsubscribe(self.conn, self.EVENTS['response_started']) + +class Request: + def __init__(self, request_id, url, method, headers, body, network: Network): + self.request_id = request_id + self.url = url + self.method = method + self.headers = headers + self.body = body + self.network = network + + def continue_request(self): + """Continue after sending a request.""" + params = { + 'requestId': self.request_id + } + if self.url is not None: + params['url'] = self.url + if self.method is not None: + params['method'] = self.method + if self.headers is not None: + params['headers'] = self.headers + if self.body is not None: + params['body'] = self.body + command = {'method': 'network.continueRequest', 'params': params} + self.network.conn.execute(self.command_iterator(command)) + +class Response: + def __init__(self, request_id, url, status_code, headers, body, network: Network): + self.request_id = request_id + self.url = url + self.status_code = status_code + self.headers = headers + self.body = body + self.network = network + + def continue_response(self): + """Continue after receiving a response.""" + params = { + 'requestId': self.request_id, + 'status': self.status_code + } + if self.headers is not None: + params['headers'] = self.headers + if self.body is not None: + params['body'] = self.body + command = {'method': 'network.continueResponse', 'params': params} + self.network.conn.execute(self.command_iterator(command)) diff --git a/py/selenium/webdriver/remote/command.py b/py/selenium/webdriver/remote/command.py index b079ed5406f53..8252949a7a68c 100644 --- a/py/selenium/webdriver/remote/command.py +++ b/py/selenium/webdriver/remote/command.py @@ -26,6 +26,11 @@ class Command: https://w3c.github.io/webdriver/ """ + ADD_INTERCEPT: str = "network.addIntercept" + REMOVE_INTERCEPT: str = "network.removeIntercept" + CONTINUE_RESPONSE: str = "network.continueResponse" + CONTINUE_REQUEST: str = "network.continueRequest" + CONTINUE_WITH_AUTH: str = "network.continueWithAuth" NEW_SESSION: str = "newSession" DELETE_SESSION: str = "deleteSession" NEW_WINDOW: str = "newWindow" diff --git a/py/selenium/webdriver/remote/remote_connection.py b/py/selenium/webdriver/remote/remote_connection.py index 1bc432084b6e2..0228ad21cc2b8 100644 --- a/py/selenium/webdriver/remote/remote_connection.py +++ b/py/selenium/webdriver/remote/remote_connection.py @@ -36,6 +36,11 @@ LOGGER = logging.getLogger(__name__) remote_commands = { + Command.ADD_INTERCEPT: ("POST", "/session/$sessionId/network/intercept"), + Command.REMOVE_INTERCEPT: ("DELETE", "/session/$sessionId/network/intercept/$intercept"), + Command.CONTINUE_RESPONSE: ("POST", "/session/$sessionId/network/response/$requestId"), + Command.CONTINUE_REQUEST: ("POST", "/session/$sessionId/network/request/$requestId"), + Command.CONTINUE_WITH_AUTH: ("POST", "/session/$sessionId/network/auth"), Command.NEW_SESSION: ("POST", "/session"), Command.QUIT: ("DELETE", "/session/$sessionId"), Command.W3C_GET_CURRENT_WINDOW_HANDLE: ("GET", "/session/$sessionId/window"), diff --git a/py/selenium/webdriver/remote/webdriver.py b/py/selenium/webdriver/remote/webdriver.py index 564fb465369df..afc0e88e1ab50 100644 --- a/py/selenium/webdriver/remote/webdriver.py +++ b/py/selenium/webdriver/remote/webdriver.py @@ -41,6 +41,7 @@ from selenium.common.exceptions import NoSuchCookieException from selenium.common.exceptions import NoSuchElementException from selenium.common.exceptions import WebDriverException +from selenium.webdriver.common.bidi.network import Network from selenium.webdriver.common.bidi.script import Script from selenium.webdriver.common.by import By from selenium.webdriver.common.options import ArgOptions @@ -243,6 +244,7 @@ def __init__( self._websocket_connection = None self._script = None + self._network = None def __repr__(self): return f'<{type(self).__module__}.{type(self).__name__} (session="{self.session_id}")>' @@ -1128,6 +1130,16 @@ def _start_bidi(self): self._websocket_connection = WebSocketConnection(ws_url) + @property + def network(self): + if not self._websocket_connection: + self._start_bidi() + + if not hasattr(self, '_network') or self._network is None: + self._network = Network(self._websocket_connection) + + return self._network + def _get_cdp_details(self): import json diff --git a/py/test/selenium/webdriver/common/bidi_network_tests.py b/py/test/selenium/webdriver/common/bidi_network_tests.py new file mode 100644 index 0000000000000..327d2de8ddce2 --- /dev/null +++ b/py/test/selenium/webdriver/common/bidi_network_tests.py @@ -0,0 +1,88 @@ +# Licensed to the Software Freedom Conservancy (SFC) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The SFC licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest + +from selenium.webdriver.common.by import By + + +@pytest.mark.xfail_safari +def test_network_initialized(driver): + assert driver.network is not None + +@pytest.mark.xfail_safari +def test_add_response_handler(driver, pages): + passed = [False] + + def callback(response): + passed[0] = True + response.continue_response() + + driver.network.add_response_handler(callback) + pages.load("basicAuth") + assert passed[0], "Callback was NOT successful" + +@pytest.mark.xfail_safari +def test_remove_response_handler(driver, pages): + passed = [False] + + def callback(response): + passed[0] = True + response.continue_response() + + test_response_id = driver.network.add_response_handler(callback) + driver.network.remove_response_handler(response_id=test_response_id) + pages.load("basicAuth") + assert not passed[0], "Callback should NOT be successful" + +@pytest.mark.xfail_safari +def test_add_request_handler(driver, pages): + passed = [False] + + def callback(request): + passed[0] = True + request.continue_request() + + driver.network.add_request_handler(callback) + pages.load("basicAuth") + assert passed[0], "Callback was NOT successful" + +@pytest.mark.xfail_safari +def test_remove_request_handler(driver, pages): + passed = [False] + + def callback(request): + passed[0] = True + request.continue_request() + + test_request_id = driver.network.add_request_handler(callback) + driver.network.remove_request_handler(request_id=test_request_id) + pages.load("basicAuth") + assert not passed[0], "Callback should NOT be successful" + +@pytest.mark.xfail_safari +def test_add_authentication_handler(driver, pages): + driver.network.add_authentication_handler('test','test') + pages.load("basicAuth") + assert driver.find_element(By.TAG_NAME, 'h1').text == 'authorized', "Authentication was NOT successful" + +@pytest.mark.xfail_safari +def test_remove_authentication_handler(driver, pages): + driver.network.add_authentication_handler('test', 'test') + driver.network.remove_authentication_handler() + pages.load("basicAuth") + assert driver.find_element(By.TAG_NAME, 'h1').text != 'authorized', "Authentication was successful"