From c0b917a87c6bec9a282ed8059d373c043377933e Mon Sep 17 00:00:00 2001 From: Dustin Ngo Date: Tue, 17 Sep 2024 17:26:48 -0400 Subject: [PATCH] feat: Wire up API keys via env var for Phoenix clients and experiments (#4617) * feat: Wire up API keys via env var for Phoenix clients and experiments * Unify header management * Capitalize `authorization` header key * Capitalization shouldn't matter * Recapitalize b/c REST --- src/phoenix/experiments/functions.py | 5 +---- src/phoenix/session/client.py | 7 +------ src/phoenix/utilities/client.py | 16 ++++++++++++++++ 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/src/phoenix/experiments/functions.py b/src/phoenix/experiments/functions.py index e682cb2a41..3b86c1bb6f 100644 --- a/src/phoenix/experiments/functions.py +++ b/src/phoenix/experiments/functions.py @@ -41,7 +41,7 @@ from opentelemetry.trace import Status, StatusCode, Tracer from typing_extensions import TypeAlias -from phoenix.config import get_base_url, get_env_client_headers +from phoenix.config import get_base_url from phoenix.evals.executors import get_executor_on_sync_context from phoenix.evals.models.rate_limiters import RateLimiter from phoenix.evals.utils import get_tqdm_progress_bar_formatter @@ -77,13 +77,10 @@ def _phoenix_clients() -> Tuple[httpx.Client, httpx.AsyncClient]: - headers = get_env_client_headers() return VersionedClient( base_url=get_base_url(), - headers=headers, ), VersionedAsyncClient( base_url=get_base_url(), - headers=headers, ) diff --git a/src/phoenix/session/client.py b/src/phoenix/session/client.py index 2cac6a5cd3..2f66b48e4f 100644 --- a/src/phoenix/session/client.py +++ b/src/phoenix/session/client.py @@ -35,10 +35,8 @@ from typing_extensions import TypeAlias, assert_never from phoenix.config import ( - get_env_client_headers, get_env_collector_endpoint, get_env_host, - get_env_phoenix_api_key, get_env_port, get_env_project_name, ) @@ -88,15 +86,12 @@ def __init__( ) if kwargs: raise TypeError(f"Unexpected keyword arguments: {', '.join(kwargs)}") - headers = dict((headers or get_env_client_headers() or {})) + headers = dict(headers or {}) if api_key: headers = { **{k: v for k, v in (headers or {}).items() if k.lower() != "authorization"}, "Authorization": f"Bearer {api_key}", } - elif api_key := get_env_phoenix_api_key(): - if not headers or ("authorization" not in [k.lower() for k in headers]): - headers = {**(headers or {}), "Authorization": f"Bearer {api_key}"} host = get_env_host() if host == "0.0.0.0": host = "127.0.0.1" diff --git a/src/phoenix/utilities/client.py b/src/phoenix/utilities/client.py index ceb7e3211a..8cb956fb64 100644 --- a/src/phoenix/utilities/client.py +++ b/src/phoenix/utilities/client.py @@ -3,6 +3,8 @@ import httpx +from phoenix.config import get_env_client_headers, get_env_phoenix_api_key + PHOENIX_SERVER_VERSION_HEADER = "x-phoenix-server-version" @@ -15,6 +17,13 @@ def __init__(self, *args: Any, **kwargs: Any): from phoenix import __version__ as phoenix_version super().__init__(*args, **kwargs) + + if env_headers := get_env_client_headers(): + self.headers.update(env_headers) + if "authorization" not in [k.lower() for k in self.headers]: + if api_key := get_env_phoenix_api_key(): + self.headers["Authorization"] = f"Bearer {api_key}" + self._client_phoenix_version = phoenix_version self._warned_on_minor_version_mismatch = False @@ -72,6 +81,13 @@ def __init__(self, *args: Any, **kwargs: Any): from phoenix import __version__ as phoenix_version super().__init__(*args, **kwargs) + + if env_headers := get_env_client_headers(): + self.headers.update(env_headers) + if "authorization" not in [k.lower() for k in self.headers]: + if api_key := get_env_phoenix_api_key(): + self.headers["Authorization"] = f"Bearer {api_key}" + self._client_phoenix_version = phoenix_version self._warned_on_minor_version_mismatch = False