diff --git a/jupyter_server/gateway/gateway_client.py b/jupyter_server/gateway/gateway_client.py index 6faee7135c..8c49913ced 100644 --- a/jupyter_server/gateway/gateway_client.py +++ b/jupyter_server/gateway/gateway_client.py @@ -5,6 +5,7 @@ import logging import os import typing as ty +from abc import ABC, ABCMeta, abstractmethod from datetime import datetime from email.utils import parsedate_to_datetime from http.cookies import Morsel, SimpleCookie @@ -12,8 +13,63 @@ from tornado import web from tornado.httpclient import AsyncHTTPClient, HTTPClientError, HTTPResponse -from traitlets import Bool, Float, Int, TraitError, Unicode, default, observe, validate -from traitlets.config import SingletonConfigurable +from traitlets import ( + Bool, + Float, + Int, + TraitError, + Type, + Unicode, + default, + observe, + validate, +) +from traitlets.config import LoggingConfigurable, SingletonConfigurable + + +class GatewayTokenRenewerMeta(ABCMeta, type(LoggingConfigurable)): # type: ignore + """The metaclass necessary for proper ABC behavior in a Configurable.""" + + pass + + +class GatewayTokenRenewerBase(ABC, LoggingConfigurable, metaclass=GatewayTokenRenewerMeta): + """ + Abstract base class for refreshing tokens used between this server and a Gateway + server. Implementations requiring additional configuration can extend their class + with appropriate configuration values or convey those values via appropriate + environment variables relative to the implementation. + """ + + @abstractmethod + def get_token( + self, + auth_header_key: str, + auth_scheme: ty.Union[str, None], + auth_token: str, + **kwargs: ty.Any, + ) -> str: + """ + Given the current authorization header key, scheme, and token, this method returns + a (potentially renewed) token for use against the Gateway server. + """ + pass + + +class NoOpTokenRenewer(GatewayTokenRenewerBase): + """NoOpTokenRenewer is the default value to the GatewayClient trait + `gateway_token_renewer` and merely returns the provided token. + """ + + def get_token( + self, + auth_header_key: str, + auth_scheme: ty.Union[str, None], + auth_token: str, + **kwargs: ty.Any, + ) -> str: + """This implementation simply returns the current authorization token.""" + return auth_token class GatewayClient(SingletonConfigurable): @@ -28,9 +84,9 @@ class GatewayClient(SingletonConfigurable): allow_none=True, config=True, help="""The url of the Kernel or Enterprise Gateway server where - kernel specifications are defined and kernel management takes place. - If defined, this Notebook server acts as a proxy for all kernel - management and kernel specification retrieval. (JUPYTER_GATEWAY_URL env var) +kernel specifications are defined and kernel management takes place. +If defined, this Notebook server acts as a proxy for all kernel +management and kernel specification retrieval. (JUPYTER_GATEWAY_URL env var) """, ) @@ -54,7 +110,7 @@ def _url_validate(self, proposal): allow_none=True, config=True, help="""The websocket url of the Kernel or Enterprise Gateway server. If not provided, this value - will correspond to the value of the Gateway url with 'ws' in place of 'http'. (JUPYTER_GATEWAY_WS_URL env var) +will correspond to the value of the Gateway url with 'ws' in place of 'http'. (JUPYTER_GATEWAY_WS_URL env var) """, ) @@ -109,7 +165,7 @@ def _kernelspecs_endpoint_default(self): default_value=kernelspecs_resource_endpoint_default_value, config=True, help="""The gateway endpoint for accessing kernelspecs resources - (JUPYTER_GATEWAY_KERNELSPECS_RESOURCE_ENDPOINT env var)""", +(JUPYTER_GATEWAY_KERNELSPECS_RESOURCE_ENDPOINT env var)""", ) @default("kernelspecs_resource_endpoint") @@ -125,14 +181,12 @@ def _kernelspecs_resource_endpoint_default(self): default_value=connect_timeout_default_value, config=True, help="""The time allowed for HTTP connection establishment with the Gateway server. - (JUPYTER_GATEWAY_CONNECT_TIMEOUT env var)""", +(JUPYTER_GATEWAY_CONNECT_TIMEOUT env var)""", ) @default("connect_timeout") def connect_timeout_default(self): - return float( - os.environ.get("JUPYTER_GATEWAY_CONNECT_TIMEOUT", self.connect_timeout_default_value) - ) + return float(os.environ.get(self.connect_timeout_env, self.connect_timeout_default_value)) request_timeout_default_value = 42.0 request_timeout_env = "JUPYTER_GATEWAY_REQUEST_TIMEOUT" @@ -144,9 +198,7 @@ def connect_timeout_default(self): @default("request_timeout") def request_timeout_default(self): - return float( - os.environ.get("JUPYTER_GATEWAY_REQUEST_TIMEOUT", self.request_timeout_default_value) - ) + return float(os.environ.get(self.request_timeout_env, self.request_timeout_default_value)) client_key = Unicode( default_value=None, @@ -228,36 +280,54 @@ def _http_pwd_default(self): def _headers_default(self): return os.environ.get(self.headers_env, self.headers_default_value) + auth_header_key_default_value = "Authorization" + auth_header_key = Unicode( + config=True, + help="""The authorization header's key name (typically 'Authorization') used in the HTTP headers. The +header will be formatted as:: + +{'{auth_header_key}': '{auth_scheme} {auth_token}'} + +If the authorization header key takes a single value, `auth_scheme` should be set to None and +'auth_token' should be configured to use the appropriate value. + +(JUPYTER_GATEWAY_AUTH_HEADER_KEY env var)""", + ) + auth_header_key_env = "JUPYTER_GATEWAY_AUTH_HEADER_KEY" + + @default("auth_header_key") + def _auth_header_key_default(self): + return os.environ.get(self.auth_header_key_env, self.auth_header_key_default_value) + + auth_token_default_value = "" auth_token = Unicode( default_value=None, allow_none=True, config=True, help="""The authorization token used in the HTTP headers. The header will be formatted as:: - { - 'Authorization': '{auth_scheme} {auth_token}' - } +{'{auth_header_key}': '{auth_scheme} {auth_token}'} - (JUPYTER_GATEWAY_AUTH_TOKEN env var)""", +(JUPYTER_GATEWAY_AUTH_TOKEN env var)""", ) auth_token_env = "JUPYTER_GATEWAY_AUTH_TOKEN" @default("auth_token") def _auth_token_default(self): - return os.environ.get(self.auth_token_env, "") + return os.environ.get(self.auth_token_env, self.auth_token_default_value) + auth_scheme_default_value = "token" # This value is purely for backwards compatibility auth_scheme = Unicode( - default_value=None, allow_none=True, config=True, help="""The auth scheme, added as a prefix to the authorization token used in the HTTP headers. - (JUPYTER_GATEWAY_AUTH_SCHEME env var)""", +(JUPYTER_GATEWAY_AUTH_SCHEME env var)""", ) auth_scheme_env = "JUPYTER_GATEWAY_AUTH_SCHEME" @default("auth_scheme") def _auth_scheme_default(self): - return os.environ.get(self.auth_scheme_env, "token") + return os.environ.get(self.auth_scheme_env, self.auth_scheme_default_value) validate_cert_default_value = True validate_cert_env = "JUPYTER_GATEWAY_VALIDATE_CERT" @@ -265,7 +335,7 @@ def _auth_scheme_default(self): default_value=validate_cert_default_value, config=True, help="""For HTTPS requests, determines if server's certificate should be validated or not. - (JUPYTER_GATEWAY_VALIDATE_CERT env var)""", +(JUPYTER_GATEWAY_VALIDATE_CERT env var)""", ) @default("validate_cert") @@ -275,29 +345,22 @@ def validate_cert_default(self): not in ["no", "false"] ) - def __init__(self, **kwargs): - super().__init__(**kwargs) - self._static_args = {} # initialized on first use - - # store of cookies with store time - self._cookies = {} # type: ty.Dict[str, ty.Tuple[Morsel, datetime]] - allowed_envs_default_value = "" allowed_envs_env = "JUPYTER_GATEWAY_ALLOWED_ENVS" allowed_envs = Unicode( default_value=allowed_envs_default_value, config=True, help="""A comma-separated list of environment variable names that will be included, along with - their values, in the kernel startup request. The corresponding `allowed_envs` configuration - value must also be set on the Gateway server - since that configuration value indicates which - environmental values to make available to the kernel. (JUPYTER_GATEWAY_ALLOWED_ENVS env var)""", +their values, in the kernel startup request. The corresponding `client_envs` configuration +value must also be set on the Gateway server - since that configuration value indicates which +environmental values to make available to the kernel. (JUPYTER_GATEWAY_ALLOWED_ENVS env var)""", ) @default("allowed_envs") def _allowed_envs_default(self): return os.environ.get( - "JUPYTER_GATEWAY_ENV_WHITELIST", - os.environ.get(self.allowed_envs_env, self.allowed_envs_default_value), + self.allowed_envs_env, + os.environ.get("JUPYTER_GATEWAY_ENV_WHITELIST", self.allowed_envs_default_value), ) env_whitelist = Unicode( @@ -312,16 +375,16 @@ def _allowed_envs_default(self): default_value=gateway_retry_interval_default_value, config=True, help="""The time allowed for HTTP reconnection with the Gateway server for the first time. - Next will be JUPYTER_GATEWAY_RETRY_INTERVAL multiplied by two in factor of numbers of retries - but less than JUPYTER_GATEWAY_RETRY_INTERVAL_MAX. - (JUPYTER_GATEWAY_RETRY_INTERVAL env var)""", +Next will be JUPYTER_GATEWAY_RETRY_INTERVAL multiplied by two in factor of numbers of retries +but less than JUPYTER_GATEWAY_RETRY_INTERVAL_MAX. +(JUPYTER_GATEWAY_RETRY_INTERVAL env var)""", ) @default("gateway_retry_interval") def gateway_retry_interval_default(self): return float( os.environ.get( - "JUPYTER_GATEWAY_RETRY_INTERVAL", + self.gateway_retry_interval_env, self.gateway_retry_interval_default_value, ) ) @@ -332,14 +395,14 @@ def gateway_retry_interval_default(self): default_value=gateway_retry_interval_max_default_value, config=True, help="""The maximum time allowed for HTTP reconnection retry with the Gateway server. - (JUPYTER_GATEWAY_RETRY_INTERVAL_MAX env var)""", +(JUPYTER_GATEWAY_RETRY_INTERVAL_MAX env var)""", ) @default("gateway_retry_interval_max") def gateway_retry_interval_max_default(self): return float( os.environ.get( - "JUPYTER_GATEWAY_RETRY_INTERVAL_MAX", + self.gateway_retry_interval_max_env, self.gateway_retry_interval_max_default_value, ) ) @@ -350,13 +413,27 @@ def gateway_retry_interval_max_default(self): default_value=gateway_retry_max_default_value, config=True, help="""The maximum retries allowed for HTTP reconnection with the Gateway server. - (JUPYTER_GATEWAY_RETRY_MAX env var)""", +(JUPYTER_GATEWAY_RETRY_MAX env var)""", ) @default("gateway_retry_max") def gateway_retry_max_default(self): - return int( - os.environ.get("JUPYTER_GATEWAY_RETRY_MAX", self.gateway_retry_max_default_value) + return int(os.environ.get(self.gateway_retry_max_env, self.gateway_retry_max_default_value)) + + gateway_token_renewer_class_default_value = ( + "jupyter_server.gateway.gateway_client.NoOpTokenRenewer" + ) + gateway_token_renewer_class_env = "JUPYTER_GATEWAY_TOKEN_RENEWER_CLASS" + gateway_token_renewer_class = Type( + klass=GatewayTokenRenewerBase, + config=True, + help="""The class to use for Gateway token renewal. (JUPYTER_GATEWAY_TOKEN_RENEWER_CLASS env var)""", + ) + + @default("gateway_token_renewer_class") + def gateway_token_renewer_class_default(self): + return os.environ.get( + self.gateway_token_renewer_class_env, self.gateway_token_renewer_class_default_value ) launch_timeout_pad_default_value = 2.0 @@ -365,15 +442,15 @@ def gateway_retry_max_default(self): default_value=launch_timeout_pad_default_value, config=True, help="""Timeout pad to be ensured between KERNEL_LAUNCH_TIMEOUT and request_timeout - such that request_timeout >= KERNEL_LAUNCH_TIMEOUT + launch_timeout_pad. - (JUPYTER_GATEWAY_LAUNCH_TIMEOUT_PAD env var)""", +such that request_timeout >= KERNEL_LAUNCH_TIMEOUT + launch_timeout_pad. +(JUPYTER_GATEWAY_LAUNCH_TIMEOUT_PAD env var)""", ) @default("launch_timeout_pad") def launch_timeout_pad_default(self): return float( os.environ.get( - "JUPYTER_GATEWAY_LAUNCH_TIMEOUT_PAD", + self.launch_timeout_pad_env, self.launch_timeout_pad_default_value, ) ) @@ -431,10 +508,21 @@ def gateway_enabled(self): # Ensure KERNEL_LAUNCH_TIMEOUT has a default value. KERNEL_LAUNCH_TIMEOUT = int(os.environ.get("KERNEL_LAUNCH_TIMEOUT", 40)) - def init_static_args(self): - """Initialize arguments used on every request. Since these are static values, we'll - perform this operation once. + _connection_args: dict # initialized on first use + + gateway_token_renewer: GatewayTokenRenewerBase + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._connection_args = {} # initialized on first use + self.gateway_token_renewer = self.gateway_token_renewer_class(parent=self, log=self.log) + + # store of cookies with store time + self._cookies = {} # type: ty.Dict[str, ty.Tuple[Morsel, datetime]] + def init_connection_args(self): + """Initialize arguments used on every request. Since these are primarily static values, + we'll perform this operation once. """ # Ensure that request timeout and KERNEL_LAUNCH_TIMEOUT are in sync, taking the # greater value of the two and taking into account the following relation: @@ -451,39 +539,58 @@ def init_static_args(self): # Ensure any adjustments are reflected in env. os.environ["KERNEL_LAUNCH_TIMEOUT"] = str(GatewayClient.KERNEL_LAUNCH_TIMEOUT) - self._static_args["headers"] = json.loads(self.headers) - if "Authorization" not in self._static_args["headers"].keys(): - self._static_args["headers"].update( - {"Authorization": f"{self.auth_scheme} {self.auth_token}"} + self._connection_args["headers"] = json.loads(self.headers) + if self.auth_header_key not in self._connection_args["headers"].keys(): + self._connection_args["headers"].update( + {f"{self.auth_header_key}": f"{self.auth_scheme} {self.auth_token}"} ) - self._static_args["connect_timeout"] = self.connect_timeout - self._static_args["request_timeout"] = self.request_timeout - self._static_args["validate_cert"] = self.validate_cert + self._connection_args["connect_timeout"] = self.connect_timeout + self._connection_args["request_timeout"] = self.request_timeout + self._connection_args["validate_cert"] = self.validate_cert if self.client_cert: - self._static_args["client_cert"] = self.client_cert - self._static_args["client_key"] = self.client_key + self._connection_args["client_cert"] = self.client_cert + self._connection_args["client_key"] = self.client_key if self.ca_certs: - self._static_args["ca_certs"] = self.ca_certs + self._connection_args["ca_certs"] = self.ca_certs if self.http_user: - self._static_args["auth_username"] = self.http_user + self._connection_args["auth_username"] = self.http_user if self.http_pwd: - self._static_args["auth_password"] = self.http_pwd + self._connection_args["auth_password"] = self.http_pwd def load_connection_args(self, **kwargs): """Merges the static args relative to the connection, with the given keyword arguments. If statics have yet to be initialized, we'll do that here. """ - if len(self._static_args) == 0: - self.init_static_args() + if len(self._connection_args) == 0: + self.init_connection_args() + + # Give token renewal a shot at renewing the token + prev_auth_token = self.auth_token + try: + self.auth_token = self.gateway_token_renewer.get_token( + self.auth_header_key, self.auth_scheme, self.auth_token + ) + except Exception as ex: + self.log.error( + f"An exception occurred attempting to renew the " + f"Gateway authorization token using an instance of class " + f"'{self.gateway_token_renewer_class}'. The request will " + f"proceed using the current token value. Exception was: {ex}" + ) + self.auth_token = prev_auth_token - for arg, static_value in self._static_args.items(): + for arg, value in self._connection_args.items(): if arg == "headers": given_value = kwargs.setdefault(arg, {}) if isinstance(given_value, dict): - given_value.update(static_value) + given_value.update(value) + # Ensure the auth header is current + given_value.update( + {f"{self.auth_header_key}": f"{self.auth_scheme} {self.auth_token}"} + ) else: - kwargs[arg] = static_value + kwargs[arg] = value if self.accept_cookies: self._update_cookie_header(kwargs) diff --git a/tests/test_gateway.py b/tests/test_gateway.py index 1a29df70e8..6643b53a16 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -9,13 +9,20 @@ from http.cookies import SimpleCookie from io import BytesIO from queue import Empty +from typing import Any, Union from unittest.mock import MagicMock, patch import pytest import tornado from tornado.httpclient import HTTPRequest, HTTPResponse from tornado.web import HTTPError +from traitlets import Int, Unicode +from traitlets.config import Config +from jupyter_server.gateway.gateway_client import ( + GatewayTokenRenewerBase, + NoOpTokenRenewer, +) from jupyter_server.gateway.managers import ( ChannelQueue, GatewayClient, @@ -183,6 +190,27 @@ def helper(*args, **kwargs): return helper +class CustomTestTokenRenewer(GatewayTokenRenewerBase): + + TEST_EXPECTED_TOKEN_VALUE = "Use this token value: 42" + + # The following are configured by the config test to ensure they flow + config_var_1: int = Int(config=True) # configured to: 42 + config_var_2: str = Unicode(config=True) # configured to: "Use this token value: " + + def get_token( + self, auth_header_key: str, auth_scheme: Union[str, None], auth_token: str, **kwargs: Any + ) -> str: + return f"{self.config_var_2}{self.config_var_1}" + + +@pytest.fixture() +def jp_server_config(): + return Config( + {"CustomTestTokenRenewer": {"config_var_1": 42, "config_var_2": "Use this token value: "}} + ) + + @pytest.fixture def init_gateway(monkeypatch): """Initializes the server for use as a gateway client.""" @@ -214,7 +242,7 @@ async def test_gateway_env_options(init_gateway, jp_serverapp): assert jp_serverapp.gateway_config.accept_cookies is False assert jp_serverapp.gateway_config.allowed_envs == "FOO,BAR" - GatewayClient.instance().init_static_args() + GatewayClient.instance().init_connection_args() assert GatewayClient.instance().KERNEL_LAUNCH_TIMEOUT == 43 @@ -237,19 +265,51 @@ async def test_gateway_cli_options(jp_configurable_serverapp, capsys): assert app.gateway_config.connect_timeout == 44.4 assert app.gateway_config.request_timeout == 96.0 assert app.gateway_config.launch_timeout_pad == 5.1 + assert app.gateway_config.gateway_token_renewer_class == NoOpTokenRenewer assert app.gateway_config.allowed_envs == "FOO,BAR" captured = capsys.readouterr() assert ( "env_whitelist is deprecated in jupyter_server 2.0, use GatewayClient.allowed_envs" in captured.err ) - GatewayClient.instance().init_static_args() + gw_client = GatewayClient.instance() + gw_client.init_connection_args() assert ( - GatewayClient.instance().KERNEL_LAUNCH_TIMEOUT == 90 + gw_client.KERNEL_LAUNCH_TIMEOUT == 90 ) # Ensure KLT gets set from request-timeout - launch_timeout_pad GatewayClient.clear_instance() +@pytest.mark.parametrize("renewer_type", ["default", "custom"]) +async def test_token_renewer_config(jp_server_config, jp_configurable_serverapp, renewer_type): + argv = ["--gateway-url=" + mock_gateway_url] + if renewer_type == "custom": + argv.append( + "--GatewayClient.gateway_token_renewer_class=tests.test_gateway.CustomTestTokenRenewer" + ) + + GatewayClient.clear_instance() + app = jp_configurable_serverapp(argv=argv) + + assert app.gateway_config.gateway_enabled is True + assert app.gateway_config.url == mock_gateway_url + gw_client = GatewayClient.instance() + gw_client.init_connection_args() + assert isinstance(gw_client.gateway_token_renewer, GatewayTokenRenewerBase) + if renewer_type == "default": + assert isinstance(gw_client.gateway_token_renewer, NoOpTokenRenewer) + token = gw_client.gateway_token_renewer.get_token( + gw_client.auth_header_key, gw_client.auth_scheme, gw_client.auth_token + ) + assert token == gw_client.auth_token + else: + assert isinstance(gw_client.gateway_token_renewer, CustomTestTokenRenewer) + token = gw_client.gateway_token_renewer.get_token( + gw_client.auth_header_key, gw_client.auth_scheme, gw_client.auth_token + ) + assert token == CustomTestTokenRenewer.TEST_EXPECTED_TOKEN_VALUE + + @pytest.mark.parametrize( "request_timeout,kernel_launch_timeout,expected_request_timeout,expected_kernel_launch_timeout", [(50, 10, 50, 45), (10, 50, 55, 50)], @@ -271,7 +331,7 @@ async def test_gateway_request_timeout_pad_option( app = jp_configurable_serverapp(argv=argv) monkeypatch.setattr(GatewayClient, "KERNEL_LAUNCH_TIMEOUT", kernel_launch_timeout) - GatewayClient.instance().init_static_args() + GatewayClient.instance().init_connection_args() assert app.gateway_config.request_timeout == expected_request_timeout assert GatewayClient.instance().KERNEL_LAUNCH_TIMEOUT == expected_kernel_launch_timeout