diff --git a/py/selenium/webdriver/remote/client_config.py b/py/selenium/webdriver/remote/client_config.py index 8fae7571c026c0..8c62b051aecf78 100644 --- a/py/selenium/webdriver/remote/client_config.py +++ b/py/selenium/webdriver/remote/client_config.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import base64 import os from urllib import parse @@ -26,11 +27,19 @@ def __init__( self, remote_server_addr: str, keep_alive: bool = True, - proxy=None, + proxy: Proxy = Proxy(raw={"proxyType": ProxyType.SYSTEM}), + username: str = None, + password: str = None, + auth_type: str = "Basic", + token: str = None, ) -> None: self.remote_server_addr = remote_server_addr self.keep_alive = keep_alive self.proxy = proxy + self.username = username + self.password = password + self.auth_type = auth_type + self.token = token @property def remote_server_addr(self) -> str: @@ -57,8 +66,6 @@ def keep_alive(self, value: bool) -> None: @property def proxy(self) -> Proxy: """:Returns: The proxy used for communicating to the driver/server.""" - - self._proxy = self._proxy or Proxy(raw={"proxyType": ProxyType.SYSTEM}) return self._proxy @proxy.setter @@ -71,17 +78,49 @@ def proxy(self, proxy: Proxy) -> None: """ self._proxy = proxy - def get_proxy_url(self): + @property + def username(self) -> str: + return self._username + + @username.setter + def username(self, value: str) -> None: + self._username = value + + @property + def password(self) -> str: + return self._password + + @password.setter + def password(self, value: str) -> None: + self._password = value + + @property + def auth_type(self) -> str: + return self._auth_type + + @auth_type.setter + def auth_type(self, value: str) -> None: + self._auth_type = value + + @property + def token(self) -> str: + return self._token + + @token.setter + def token(self, value: str) -> None: + self._token = value + + def get_proxy_url(self) -> str: if self.proxy.proxy_type == ProxyType.DIRECT: return None elif self.proxy.proxy_type == ProxyType.SYSTEM: _no_proxy = os.environ.get("no_proxy", os.environ.get("NO_PROXY")) if _no_proxy: - for npu in _no_proxy.split(","): - npu = npu.strip() - if npu == "*": + for entry in _no_proxy.split(","): + entry = entry.strip() + if entry == "*": return None - n_url = parse.urlparse(npu) + n_url = parse.urlparse(entry) remote_add = parse.urlparse(self.remote_server_addr) if n_url.netloc: if remote_add.netloc == n_url.netloc: @@ -102,3 +141,15 @@ def get_proxy_url(self): return None else: return None + + def get_auth_header(self): + auth_type = self.auth_type.lower() + if auth_type == "basic" and self.username and self.password: + credentials = f"{self.username}:{self.password}" + encoded_credentials = base64.b64encode(credentials.encode()).decode() + return {"Authorization": f"Basic {encoded_credentials}"} + elif auth_type == "bearer" and self.token: + return {"Authorization": f"Bearer {self.token}"} + elif auth_type == "oauth" and self.token: + return {"Authorization": f"OAuth {self.token}"} + return None diff --git a/py/selenium/webdriver/remote/remote_connection.py b/py/selenium/webdriver/remote/remote_connection.py index cf34d5e8d31a0f..194cee4e9180b5 100644 --- a/py/selenium/webdriver/remote/remote_connection.py +++ b/py/selenium/webdriver/remote/remote_connection.py @@ -323,6 +323,11 @@ def _request(self, method, url, body=None): """ parsed_url = parse.urlparse(url) headers = self.get_remote_connection_headers(parsed_url, self._client_config.keep_alive) + auth_header = self._client_config.get_auth_header() + + if auth_header: + headers.update(auth_header) + if body and method not in ("POST", "PUT"): body = None diff --git a/py/test/unit/selenium/webdriver/remote/remote_connection_tests.py b/py/test/unit/selenium/webdriver/remote/remote_connection_tests.py index b167220a9bb199..cf2235d42ca790 100644 --- a/py/test/unit/selenium/webdriver/remote/remote_connection_tests.py +++ b/py/test/unit/selenium/webdriver/remote/remote_connection_tests.py @@ -22,6 +22,7 @@ from selenium import __version__ from selenium.webdriver.remote.remote_connection import RemoteConnection +from selenium.webdriver.remote.remote_connection import ClientConfig def test_get_remote_connection_headers_defaults(): @@ -54,6 +55,19 @@ def test_get_proxy_url_http(mock_proxy_settings): assert proxy_url == proxy +def test_get_auth_header_if_client_config_pass(): + custom_config = ClientConfig( + remote_server_addr="http://remote", + keep_alive=True, + username="user", + password="pass", + auth_type="Basic" + ) + remote_connection = RemoteConnection(custom_config.remote_server_addr, client_config=custom_config) + headers = remote_connection._client_config.get_auth_header() + assert headers.get("Authorization") == "Basic dXNlcjpwYXNz" + + def test_get_proxy_url_https(mock_proxy_settings): proxy = "http://https_proxy.com:8080" remote_connection = RemoteConnection("https://remote", keep_alive=False)