From 980d8081e071940e10579a23f1c499d3ad932389 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sun, 15 Sep 2024 00:29:13 -0700 Subject: [PATCH 01/29] implement authorization code flow --- app/src/pages/auth/LoginPage.tsx | 35 ++++- app/src/pages/auth/oAuthCallbackLoader.ts | 41 ++++++ app/src/window.d.ts | 6 + pyproject.toml | 2 + src/phoenix/config.py | 77 +++++++++++ src/phoenix/server/api/routers/__init__.py | 2 + src/phoenix/server/api/routers/oauth.py | 39 ++++++ src/phoenix/server/app.py | 26 +++- src/phoenix/server/main.py | 2 + src/phoenix/server/oauth.py | 145 +++++++++++++++++++++ src/phoenix/server/templates/index.html | 1 + 11 files changed, 374 insertions(+), 2 deletions(-) create mode 100644 app/src/pages/auth/oAuthCallbackLoader.ts create mode 100644 src/phoenix/server/api/routers/oauth.py create mode 100644 src/phoenix/server/oauth.py diff --git a/app/src/pages/auth/LoginPage.tsx b/app/src/pages/auth/LoginPage.tsx index a66701ca67..1660d9d757 100644 --- a/app/src/pages/auth/LoginPage.tsx +++ b/app/src/pages/auth/LoginPage.tsx @@ -1,12 +1,14 @@ import React from "react"; +import { css } from "@emotion/react"; -import { Flex, View } from "@arizeai/components"; +import { Button, Flex, Form, View } from "@arizeai/components"; import { AuthLayout } from "./AuthLayout"; import { LoginForm } from "./LoginForm"; import { PhoenixLogo } from "./PhoenixLogo"; export function LoginPage() { + const oAuthIdps = window.Config.oAuthIdps; return ( @@ -15,6 +17,37 @@ export function LoginPage() { + {oAuthIdps.map((idp) => ( + + ))} ); } + +type OAuthLoginFormProps = { + idpId: string; + idpDisplayName: string; +}; +export function OAuthLoginForm({ idpId, idpDisplayName }: OAuthLoginFormProps) { + return ( +
+
+ +
+
+ ); +} diff --git a/app/src/pages/auth/oAuthCallbackLoader.ts b/app/src/pages/auth/oAuthCallbackLoader.ts new file mode 100644 index 0000000000..251ed76915 --- /dev/null +++ b/app/src/pages/auth/oAuthCallbackLoader.ts @@ -0,0 +1,41 @@ +// import { redirect } from "react-router"; +import { LoaderFunctionArgs } from "react-router-dom"; + +export async function oAuthCallbackLoader(args: LoaderFunctionArgs) { + const queryParameters = new URL(args.request.url).searchParams; + const authorizationCode = queryParameters.get("code"); + const state = queryParameters.get("state"); + const actualState = sessionStorage.getItem("oAuthState"); + sessionStorage.removeItem("oAuthState"); + if ( + authorizationCode == undefined || + state == undefined || + actualState == undefined || + state !== actualState + ) { + // todo: display error message + return null; + } + const origin = new URL(window.location.href).origin; + const redirectUri = `${origin}/oauth-callback`; + try { + const response = await fetch("/auth/oauth-tokens", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify({ + authorization_code: authorizationCode, + redirect_uri: redirectUri, + }), + }); + if (!response.ok) { + // todo: parse response body and display error message + return null; + } + } catch (error) { + // todo: display error + } + // redirect("/"); + return null; +} diff --git a/app/src/window.d.ts b/app/src/window.d.ts index f50eca98f9..08ab1cb825 100644 --- a/app/src/window.d.ts +++ b/app/src/window.d.ts @@ -1,5 +1,10 @@ export {}; +type OAuthIdp = { + id: string; + displayName: string; +}; + declare global { interface Window { Config: { @@ -15,6 +20,7 @@ declare global { nSamples: number; }; authenticationEnabled: boolean; + oAuthIdps: OAuthIdp[]; }; } } diff --git a/pyproject.toml b/pyproject.toml index 4c1daa327c..f88a32ed3d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,6 +62,7 @@ dependencies = [ "fastapi-mail", "pydantic>=1.0,!=2.0.*,<3", # exclude 2.0.* since it does not support the `json_encoders` configuration setting "pyjwt", + "authlib", ] dynamic = ["version"] @@ -407,6 +408,7 @@ module = [ "grpc.*", "py_grpc_prometheus.*", "orjson", # suppress fastapi internal type errors + "authlib.*", ] ignore_missing_imports = true diff --git a/src/phoenix/config.py b/src/phoenix/config.py index 74d72ec465..dfb11f9339 100644 --- a/src/phoenix/config.py +++ b/src/phoenix/config.py @@ -1,13 +1,20 @@ import os import re import tempfile +from dataclasses import dataclass from datetime import timedelta from logging import getLogger from pathlib import Path from typing import Dict, List, Optional, Tuple, overload +from typing_extensions import TypeAlias + from phoenix.utilities.re import parse_env_headers +IdpId: TypeAlias = str +EnvVarName: TypeAlias = str +EnvVarValue: TypeAlias = str + logger = getLogger(__name__) # Phoenix environment variables @@ -317,6 +324,72 @@ def get_env_smtp_validate_certs() -> bool: return _bool_val(ENV_PHOENIX_SMTP_VALIDATE_CERTS, True) +@dataclass(frozen=True) +class OAuthClientConfig: + idp_id: str + display_name: str + client_id: str + client_secret: str + server_metadata_url: Optional[str] = None + authorize_url: Optional[str] = None + access_token_url: Optional[str] = None + + @classmethod + def from_env(cls, idp_id: str) -> "OAuthClientConfig": + idp_id_upper = idp_id.upper() + if ( + client_id := os.getenv(client_id_env_var := f"PHOENIX_OAUTH_{idp_id_upper}_CLIENT_ID") + ) is None: + raise ValueError( + f"A client id must be set for the {idp_id} OAuth IDP " + f"via the {client_id_env_var} environment variable" + ) + if ( + client_secret := os.getenv( + client_secret_env_var := f"PHOENIX_OAUTH_{idp_id_upper}_CLIENT_SECRET" + ) + ) is None: + raise ValueError( + f"A client secret must be set for the {idp_id} OAuth IDP " + f"via the {client_secret_env_var} environment variable" + ) + return cls( + idp_id=idp_id, + display_name=os.getenv( + f"PHOENIX_OAUTH_{idp_id_upper}_DISPLAY_NAME", get_default_idp_display_name(idp_id) + ), + client_id=client_id, + client_secret=client_secret, + server_metadata_url=os.getenv(f"PHOENIX_OAUTH_{idp_id_upper}_SERVER_METADATA_URL"), + access_token_url=os.getenv(f"PHOENIX_OAUTH_{idp_id_upper}_ACCESS_TOKEN_URL"), + authorize_url=os.getenv(f"PHOENIX_OAUTH_{idp_id_upper}_AUTHORIZE_URL"), + ) + + def __post_init__(self) -> None: + assert self.idp_id + if not self.display_name: + raise ValueError(f"OAuth display name for {self.idp_id} cannot be empty") + if not self.client_id: + raise ValueError(f"OAuth client id for {self.idp_id} cannot be empty") + if not self.client_secret: + raise ValueError(f"OAuth client secret for {self.idp_id} cannot be empty") + + +def get_env_oauth_settings() -> List[OAuthClientConfig]: + """ + Get OAuth settings from environment variables. + """ + + idp_ids = set() + pattern = re.compile( + r"^PHOENIX_OAUTH_(\w+)_(DISPLAY_NAME|CLIENT_ID|CLIENT_SECRET|SERVER_METADATA_URL|ACCESS_TOKEN_URL|AUTHORIZE_URL)$" + ) + for env_var in os.environ: + if (match := pattern.match(env_var)) is not None and (idp_id := match.group(1).lower()): + idp_ids.add(idp_id) + return [OAuthClientConfig.from_env(idp_id) for idp_id in sorted(idp_ids)] + + PHOENIX_DIR = Path(__file__).resolve().parent # Server config SERVER_DIR = PHOENIX_DIR / "server" @@ -485,5 +558,9 @@ def get_web_base_url() -> str: return get_base_url() +def get_default_idp_display_name(ipd_id: IdpId) -> str: + return ipd_id.replace("_", " ").title() + + DEFAULT_PROJECT_NAME = "default" _KUBERNETES_PHOENIX_PORT_PATTERN = re.compile(r"^tcp://\d{1,3}[.]\d{1,3}[.]\d{1,3}[.]\d{1,3}:\d+$") diff --git a/src/phoenix/server/api/routers/__init__.py b/src/phoenix/server/api/routers/__init__.py index 354d5d106d..8c65c0c768 100644 --- a/src/phoenix/server/api/routers/__init__.py +++ b/src/phoenix/server/api/routers/__init__.py @@ -1,9 +1,11 @@ from .auth import router as auth_router from .embeddings import create_embeddings_router +from .oauth import router as oauth_router from .v1 import create_v1_router __all__ = [ "auth_router", "create_embeddings_router", "create_v1_router", + "oauth_router", ] diff --git a/src/phoenix/server/api/routers/oauth.py b/src/phoenix/server/api/routers/oauth.py new file mode 100644 index 0000000000..db4e82897e --- /dev/null +++ b/src/phoenix/server/api/routers/oauth.py @@ -0,0 +1,39 @@ +from authlib.integrations.starlette_client import OAuthError +from authlib.integrations.starlette_client import StarletteOAuth2App as OAuthClient +from fastapi import APIRouter, Depends, HTTPException, Request +from starlette.responses import RedirectResponse +from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_404_NOT_FOUND + +from phoenix.server.rate_limiters import ServerRateLimiter, fastapi_rate_limiter + +rate_limiter = ServerRateLimiter( + per_second_rate_limit=0.2, + enforcement_window_seconds=30, + partition_seconds=60, + active_partitions=2, +) +login_rate_limiter = fastapi_rate_limiter(rate_limiter, paths=["/login"]) +router = APIRouter( + prefix="/oauth", include_in_schema=False, dependencies=[Depends(login_rate_limiter)] +) + + +@router.post("/{idp}/login") +async def login(request: Request, idp: str) -> RedirectResponse: + if not isinstance(oauth_client := request.app.state.oauth_clients.get_client(idp), OAuthClient): + raise HTTPException(HTTP_404_NOT_FOUND, f"Unknown IDP: {idp}") + redirect_uri = request.url_for("create_tokens", idp=idp) + response: RedirectResponse = await oauth_client.authorize_redirect(request, redirect_uri) + return response + + +@router.get("/{idp}/tokens") +async def create_tokens(request: Request, idp: str) -> RedirectResponse: + if not isinstance(oauth_client := request.app.state.oauth_clients.get_client(idp), OAuthClient): + raise HTTPException(HTTP_404_NOT_FOUND, f"Unknown IDP: {idp}") + try: + token = await oauth_client.authorize_access_token(request) + except OAuthError as error: + raise HTTPException(HTTP_401_UNAUTHORIZED, detail=str(error)) + print(f"{token=}") + return RedirectResponse(url="/") diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index c58b2ea4ed..1750119f33 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -20,7 +20,9 @@ List, NamedTuple, Optional, + Sequence, Tuple, + TypedDict, Union, cast, ) @@ -50,6 +52,7 @@ from phoenix.config import ( DEFAULT_PROJECT_NAME, SERVER_DIR, + OAuthClientConfig, get_env_host, get_env_port, server_instrumentation_is_enabled, @@ -90,7 +93,12 @@ UserRolesDataLoader, UsersDataLoader, ) -from phoenix.server.api.routers import auth_router, create_embeddings_router, create_v1_router +from phoenix.server.api.routers import ( + auth_router, + create_embeddings_router, + create_v1_router, + oauth_router, +) from phoenix.server.api.routers.v1 import REST_API_VERSION from phoenix.server.api.schema import schema from phoenix.server.bearer_auth import BearerTokenAuthBackend, is_authenticated @@ -99,6 +107,7 @@ from phoenix.server.email.types import EmailSender from phoenix.server.grpc_server import GrpcServer from phoenix.server.jwt_store import JwtStore +from phoenix.server.oauth import OAuthClients from phoenix.server.telemetry import initialize_opentelemetry_tracer_provider from phoenix.server.types import ( CanGetLastUpdatedAt, @@ -144,6 +153,11 @@ _Callback: TypeAlias = Callable[[], Union[None, Awaitable[None]]] +class OAuthIdp(TypedDict): + id: str + displayName: str + + class AppConfig(NamedTuple): has_inferences: bool """ Whether the model has inferences (e.g. a primary dataset) """ @@ -155,6 +169,7 @@ class AppConfig(NamedTuple): web_manifest_path: Path authentication_enabled: bool """ Whether authentication is enabled """ + oauth_idps: Sequence[OAuthIdp] class Static(StaticFiles): @@ -203,6 +218,7 @@ async def get_response(self, path: str, scope: Scope) -> Response: "is_development": self._app_config.is_development, "manifest": self._web_manifest, "authentication_enabled": self._app_config.authentication_enabled, + "oauth_idps": self._app_config.oauth_idps, }, ) except Exception as e: @@ -611,6 +627,7 @@ def create_app( refresh_token_expiry: Optional[timedelta] = None, scaffolder_config: Optional[ScaffolderConfig] = None, email_sender: Optional[EmailSender] = None, + oauth_client_configs: Optional[List[OAuthClientConfig]] = None, ) -> FastAPI: startup_callbacks_list: List[_Callback] = list(startup_callbacks) shutdown_callbacks_list: List[_Callback] = list(shutdown_callbacks) @@ -723,9 +740,14 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: app.include_router(graphql_router) if authentication_enabled: app.include_router(auth_router) + app.include_router(oauth_router) app.add_middleware(GZipMiddleware) web_manifest_path = SERVER_DIR / "static" / ".vite" / "manifest.json" if serve_ui and web_manifest_path.is_file(): + oauth_idps = [ + OAuthIdp(id=config.idp_id, displayName=config.display_name) + for config in oauth_client_configs or [] + ] app.mount( "/", app=Static( @@ -739,6 +761,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: is_development=dev, authentication_enabled=authentication_enabled, web_manifest_path=web_manifest_path, + oauth_idps=oauth_idps, ), ), name="static", @@ -748,6 +771,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: app.state.password_reset_token_expiry = password_reset_token_expiry app.state.access_token_expiry = access_token_expiry app.state.refresh_token_expiry = refresh_token_expiry + app.state.oauth_clients = OAuthClients.from_configs(oauth_client_configs or []) app.state.db = db app.state.email_sender = email_sender app = _add_get_secret_method(app=app, secret=secret) diff --git a/src/phoenix/server/main.py b/src/phoenix/server/main.py index 8e788b2dfe..bbf63bcb1f 100644 --- a/src/phoenix/server/main.py +++ b/src/phoenix/server/main.py @@ -26,6 +26,7 @@ get_env_grpc_port, get_env_host, get_env_host_root_path, + get_env_oauth_settings, get_env_password_reset_token_expiry, get_env_port, get_env_refresh_token_expiry, @@ -417,6 +418,7 @@ def _get_pid_file() -> Path: refresh_token_expiry=get_env_refresh_token_expiry(), scaffolder_config=scaffolder_config, email_sender=email_sender, + oauth_client_configs=get_env_oauth_settings(), ) server = Server(config=Config(app, host=host, port=port, root_path=host_root_path)) # type: ignore Thread(target=_write_pid_file_when_ready, args=(server,), daemon=True).start() diff --git a/src/phoenix/server/oauth.py b/src/phoenix/server/oauth.py new file mode 100644 index 0000000000..5ce659d7da --- /dev/null +++ b/src/phoenix/server/oauth.py @@ -0,0 +1,145 @@ +from dataclasses import asdict, dataclass +from datetime import datetime, timedelta +from types import MappingProxyType +from typing import Any, Dict, Generic, List, Optional, Tuple + +from authlib.integrations.starlette_client import OAuth +from authlib.integrations.starlette_client import StarletteOAuth2App as OAuthClient +from typing_extensions import TypeAlias, TypeVar + +from phoenix.config import OAuthClientConfig + +IdpId: TypeAlias = str + + +class OAuthClients: + def __init__(self) -> None: + self._clients: Dict[IdpId, OAuthClient] = {} + self._oauth = OAuth(cache=_OAuthClientTTLCache[str, Any]()) + + def add_client(self, config: OAuthClientConfig) -> None: + if (idp_id := config.idp_id) in self._clients: + raise ValueError(f"oauth client already registered: {idp_id}") + config = _apply_oauth_config_defaults(config) + server_metadata_url = config.server_metadata_url + authorize_url = config.authorize_url + access_token_url = config.access_token_url + if not (server_metadata_url or (authorize_url and access_token_url)): + raise ValueError( + f"{idp_id} OAuth client must have either a server metadata URL," + " or authorize and access token URLs" + ) + client = self._oauth.register( + idp_id, + client_id=config.client_id, + client_secret=config.client_secret, + server_metadata_url=server_metadata_url, + authorize_url=authorize_url, + access_token_url=access_token_url, + client_kwargs={"scope": "openid email profile"}, + ) + assert isinstance(client, OAuthClient) + self._clients[config.idp_id] = client + + def get_client(self, idp_id: IdpId) -> OAuthClient: + if (client := self._clients.get(idp_id)) is None: + raise ValueError(f"unknown or unregistered oauth client: {idp_id}") + return client + + @classmethod + def from_configs(cls, configs: List[OAuthClientConfig]) -> "OAuthClients": + oauth_clients = cls() + for config in configs: + oauth_clients.add_client(config) + return oauth_clients + + +@dataclass +class OAuthClientDefaultConfig: + idp_id: IdpId + display_name: Optional[str] = None + server_metadata_url: Optional[str] = None + authorize_url: Optional[str] = None + access_token_url: Optional[str] = None + + +def _apply_oauth_config_defaults(config: OAuthClientConfig) -> OAuthClientConfig: + if (default_config := _OAUTH_CLIENT_DEFAULT_CONFIGS.get(config.idp_id)) is None: + return config + return OAuthClientConfig( + **{ + **{k: v for k, v in asdict(default_config).items() if v is not None}, + **{k: v for k, v in asdict(config).items() if v is not None}, + } + ) + + +_OAUTH_CLIENT_DEFAULT_CONFIGS = MappingProxyType( + { + config.idp_id: config + for config in ( + OAuthClientDefaultConfig( + idp_id="google", + server_metadata_url="https://accounts.google.com/.well-known/openid-configuration", + ), + ) + } +) + +_CacheKey = TypeVar("_CacheKey") +_CacheValue = TypeVar("_CacheValue") +_Expiry: TypeAlias = datetime +_MINUTE = timedelta(minutes=1) + + +class _OAuthClientTTLCache(Generic[_CacheKey, _CacheValue]): + """ + A TTL cache satisfying the interface required by the Authlib Starlette + integration. Provides an alternative to starlette session middleware. + """ + + def __init__(self, cleanup_interval: timedelta = 10 * _MINUTE) -> None: + self._data: Dict[_CacheKey, Tuple[_CacheValue, _Expiry]] = {} + self._last_cleanup_time = datetime.now() + self._cleanup_interval = cleanup_interval + + async def get(self, key: _CacheKey) -> Optional[_CacheValue]: + """ + Retrieves the value associated with the given key if it exists and has + not expired, otherwise, returns None. + """ + if (value_and_expiry := self._data.get(key)) is None: + return None + value, expiry = value_and_expiry + if datetime.now() < expiry: + return value + self._data.pop(key, None) + return None + + async def set(self, key: _CacheKey, value: _CacheValue, expires: int) -> None: + """ + Sets the value associated with the given key to the provided value with + the given expiry time in seconds. + """ + self._remove_expired_keys_if_cleanup_interval_exceeded() + expiry = datetime.now() + timedelta(seconds=expires) + self._data[key] = (value, expiry) + + async def delete(self, key: _CacheKey) -> None: + """ + Removes the value associated with the given key if it exists. + """ + self._remove_expired_keys_if_cleanup_interval_exceeded() + self._data.pop(key, None) + + def _remove_expired_keys_if_cleanup_interval_exceeded(self) -> None: + time_since_last_cleanup = datetime.now() - self._last_cleanup_time + if time_since_last_cleanup > self._cleanup_interval: + self._remove_expired_keys() + + def _remove_expired_keys(self) -> None: + current_time = datetime.now() + delete_keys = [key for key, (_, expiry) in self._data.items() if expiry <= current_time] + for key in delete_keys: + self._data.pop(key, None) + self._last_cleanup_time = current_time diff --git a/src/phoenix/server/templates/index.html b/src/phoenix/server/templates/index.html index 4ef66f7942..748288cbe8 100644 --- a/src/phoenix/server/templates/index.html +++ b/src/phoenix/server/templates/index.html @@ -87,6 +87,7 @@ nSamples: parseInt("{{n_samples}}"), }, authenticationEnabled: Boolean("{{authentication_enabled}}" == "True"), + oAuthIdps: {{ oauth_idps | tojson }}, }), writable: false }); From 4e5e8e78920b2df1fb536a1c100db9f6ccc5b693 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Sun, 15 Sep 2024 13:17:00 -0700 Subject: [PATCH 02/29] update database --- src/phoenix/db/enums.py | 4 ++ src/phoenix/db/facilitator.py | 38 ++++++++++++++++-- .../versions/cd164e83824f_users_and_tokens.py | 40 ++++++++++++++++--- src/phoenix/db/models.py | 37 +++++++++++++++-- .../server/api/mutations/user_mutations.py | 24 +++++++++-- 5 files changed, 127 insertions(+), 16 deletions(-) diff --git a/src/phoenix/db/enums.py b/src/phoenix/db/enums.py index 64450eccb6..ba002bdd5b 100644 --- a/src/phoenix/db/enums.py +++ b/src/phoenix/db/enums.py @@ -16,6 +16,10 @@ class AuthMethod(Enum): LOCAL = "LOCAL" +class IdentityProviderName(Enum): + LOCAL = "local" + + COLUMN_ENUMS: Mapping[InstrumentedAttribute[str], Type[Enum]] = { models.UserRole.name: UserRole, } diff --git a/src/phoenix/db/facilitator.py b/src/phoenix/db/facilitator.py index 0ef131c414..9e114fb714 100644 --- a/src/phoenix/db/facilitator.py +++ b/src/phoenix/db/facilitator.py @@ -5,6 +5,7 @@ from functools import partial from sqlalchemy import ( + and_, distinct, insert, select, @@ -19,7 +20,7 @@ compute_password_hash, ) from phoenix.db import models -from phoenix.db.enums import COLUMN_ENUMS, AuthMethod, UserRole +from phoenix.db.enums import COLUMN_ENUMS, AuthMethod, IdentityProviderName, UserRole from phoenix.server.types import DbSessionFactory @@ -38,6 +39,7 @@ async def __call__(self) -> None: async with self._db() as session: for fn in ( _ensure_enums, + _ensure_identity_providers, _ensure_user_roles, ): async with session.begin_nested(): @@ -61,6 +63,26 @@ async def _ensure_enums(session: AsyncSession) -> None: await session.execute(insert(table), [{column.key: v} for v in missing]) +async def _ensure_identity_providers(session: AsyncSession) -> None: + """ + Ensures that the local identity provider is present in the database. + """ + local_idp = await session.scalar( + select(models.IdentityProvider).where( + models.IdentityProvider.name == IdentityProviderName.LOCAL.value + ) + ) + if local_idp is None: + local_idp = models.IdentityProvider( + name=IdentityProviderName.LOCAL.value, auth_method=AuthMethod.LOCAL.value + ) + session.add( + models.IdentityProvider( + name=IdentityProviderName.LOCAL.value, auth_method=AuthMethod.LOCAL.value + ) + ) + + async def _ensure_user_roles(session: AsyncSession) -> None: """ Ensure that the system and admin roles are present in the database. If they are not, they will @@ -79,13 +101,23 @@ async def _ensure_user_roles(session: AsyncSession) -> None: select(distinct(models.UserRole.name)).join_from(models.User, models.UserRole) ) ] + local_idp_id = ( + select(models.IdentityProvider.id) + .where( + and_( + models.IdentityProvider.name == IdentityProviderName.LOCAL.value, + models.IdentityProvider.auth_method == AuthMethod.LOCAL.value, + ) + ) + .scalar_subquery() + ) if (system_role := UserRole.SYSTEM.value) not in existing_roles and ( system_role_id := role_ids.get(system_role) ) is not None: system_user = models.User( user_role_id=system_role_id, email="system@localhost", - auth_method=AuthMethod.LOCAL.value, + identity_provider_id=local_idp_id, reset_password=False, ) session.add(system_user) @@ -100,7 +132,7 @@ async def _ensure_user_roles(session: AsyncSession) -> None: user_role_id=admin_role_id, username=DEFAULT_ADMIN_USERNAME, email=DEFAULT_ADMIN_EMAIL, - auth_method=AuthMethod.LOCAL.value, + identity_provider_id=local_idp_id, password_salt=salt, password_hash=hash_, reset_password=True, diff --git a/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py b/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py index e351f4957f..0029c6808f 100644 --- a/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +++ b/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py @@ -19,6 +19,27 @@ def upgrade() -> None: + op.create_table( + "identity_providers", + sa.Column("id", sa.Integer, primary_key=True), + sa.Column( + "name", + sa.String, + index=True, + nullable=False, + ), + sa.Column( + "auth_method", + sa.String, + sa.CheckConstraint("auth_method IN ('LOCAL', 'OAUTH')", "valid_auth_method"), + index=True, + nullable=False, + ), + sa.UniqueConstraint( + "name", + "auth_method", + ), + ) op.create_table( "user_roles", sa.Column("id", sa.Integer, primary_key=True), @@ -27,6 +48,7 @@ def upgrade() -> None: sa.String, nullable=False, unique=True, + index=True, ), ) op.create_table( @@ -39,14 +61,17 @@ def upgrade() -> None: nullable=False, index=True, ), - sa.Column("username", sa.String, nullable=True, unique=True, index=True), - sa.Column("email", sa.String, nullable=False, unique=True, index=True), sa.Column( - "auth_method", - sa.String, - sa.CheckConstraint("auth_method IN ('LOCAL')", "valid_auth_method"), + "identity_provider_id", + sa.Integer, + sa.ForeignKey("identity_providers.id", ondelete="CASCADE"), + index=True, nullable=False, ), + sa.Column("identity_provider_user_id", sa.Integer, index=True, nullable=True), + sa.Column("username", sa.String, nullable=True, unique=True, index=True), + sa.Column("email", sa.String, nullable=False, unique=True, index=True), + sa.Column("profile_picture_url", sa.String, nullable=True), sa.Column("password_hash", sa.LargeBinary, nullable=True), sa.Column("password_salt", sa.LargeBinary, nullable=True), sa.Column("reset_password", sa.Boolean, nullable=False), @@ -69,6 +94,10 @@ def upgrade() -> None: nullable=True, ), sa.CheckConstraint("password_hash is null or password_salt is not null", name="salt"), + sa.UniqueConstraint( + "identity_provider_id", + "identity_provider_user_id", + ), sqlite_autoincrement=True, ) op.create_table( @@ -147,3 +176,4 @@ def downgrade() -> None: op.drop_table("password_reset_tokens") op.drop_table("users") op.drop_table("user_roles") + op.drop_table("identity_providers") diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index 7a2ffa4633..6b0e3362c5 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -620,10 +620,28 @@ class ExperimentRunAnnotation(Base): ) +class IdentityProvider(Base): + __tablename__ = "identity_providers" + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(index=True, nullable=False) + auth_method: Mapped[str] = mapped_column( + CheckConstraint("auth_method IN ('LOCAL', 'OAUTH')", name="valid_auth_method"), + index=True, + ) + users: Mapped[List["User"]] = relationship("User", back_populates="identity_provider") + + __table_args__ = ( + UniqueConstraint( + "name", + "auth_method", + ), + ) + + class UserRole(Base): __tablename__ = "user_roles" id: Mapped[int] = mapped_column(primary_key=True) - name: Mapped[str] = mapped_column(unique=True) + name: Mapped[str] = mapped_column(unique=True, index=True) users: Mapped[List["User"]] = relationship("User", back_populates="role") @@ -635,11 +653,18 @@ class User(Base): index=True, ) role: Mapped["UserRole"] = relationship("UserRole", back_populates="users") + identity_provider_id: Mapped[int] = mapped_column( + ForeignKey("identity_providers.id", ondelete="CASCADE"), + index=True, + nullable=False, + ) + identity_provider: Mapped["IdentityProvider"] = relationship( + "IdentityProvider", back_populates="users" + ) + identity_provider_user_id: Mapped[Optional[str]] = mapped_column(index=True, nullable=True) username: Mapped[Optional[str]] = mapped_column(nullable=True, unique=True, index=True) email: Mapped[str] = mapped_column(nullable=False, unique=True, index=True) - auth_method: Mapped[str] = mapped_column( - CheckConstraint("auth_method IN ('LOCAL')", name="valid_auth_method") - ) + profile_picture_url: Mapped[Optional[str]] password_hash: Mapped[Optional[bytes]] password_salt: Mapped[Optional[bytes]] reset_password: Mapped[bool] @@ -660,6 +685,10 @@ class User(Base): api_keys: Mapped[List["ApiKey"]] = relationship("ApiKey", back_populates="user") __table_args__ = ( CheckConstraint("password_hash is null or password_salt is not null", name="salt"), + UniqueConstraint( + "identity_provider_id", + "identity_provider_user_id", + ), dict(sqlite_autoincrement=True), ) diff --git a/src/phoenix/server/api/mutations/user_mutations.py b/src/phoenix/server/api/mutations/user_mutations.py index cfae3afba3..9610d959e7 100644 --- a/src/phoenix/server/api/mutations/user_mutations.py +++ b/src/phoenix/server/api/mutations/user_mutations.py @@ -91,11 +91,21 @@ async def create_user( validate_password_format(password := input.password) salt = secrets.token_bytes(DEFAULT_SECRET_LENGTH) password_hash = await info.context.hash_password(password, salt) + local_idp_id = ( + select(models.IdentityProvider.id) + .where( + and_( + models.IdentityProvider.name == enums.IdentityProviderName.LOCAL.value, + models.IdentityProvider.auth_method == enums.AuthMethod.LOCAL.value, + ) + ) + .scalar_subquery() + ) user = models.User( reset_password=True, username=input.username or None, email=email, - auth_method=enums.AuthMethod.LOCAL.value, + identity_provider_id=local_idp_id, password_hash=password_hash, password_salt=salt, ) @@ -138,7 +148,10 @@ async def patch_user( raise NotFound(f"Role {input.new_role.value} not found") user.user_role_id = user_role_id if password := input.new_password: - if user.auth_method != enums.AuthMethod.LOCAL.value: + if ( + (idp := user.identity_provider).name != enums.IdentityProviderName.LOCAL.value + or idp.auth_method != enums.AuthMethod.LOCAL.value + ): raise Conflict("Cannot modify password for non-local user") validate_password_format(password) user.password_salt = secrets.token_bytes(DEFAULT_SECRET_LENGTH) @@ -171,7 +184,10 @@ async def patch_viewer( raise NotFound("User not found") stack.enter_context(session.no_autoflush) if password := input.new_password: - if user.auth_method != enums.AuthMethod.LOCAL.value: + if ( + (idp := user.identity_provider).name != enums.IdentityProviderName.LOCAL.value + or idp.auth_method != enums.AuthMethod.LOCAL.value + ): raise Conflict("Cannot modify password for non-local user") if not ( current_password := input.current_password @@ -322,7 +338,7 @@ def _select_user_by_id(user_id: int) -> Select[Tuple[models.User]]: return ( select(models.User) .where(and_(models.User.id == user_id, models.User.deleted_at.is_(None))) - .options(joinedload(models.User.role)) + .options(joinedload(models.User.role), joinedload(models.User.identity_provider)) ) From 54f6cc753dd4baaf5668459944825e05625230ad Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 16 Sep 2024 01:20:57 -0700 Subject: [PATCH 03/29] working end-to-end --- src/phoenix/auth.py | 17 +- src/phoenix/db/enums.py | 1 + src/phoenix/server/api/routers/auth.py | 57 ++---- src/phoenix/server/api/routers/oauth.py | 252 ++++++++++++++++++++++-- src/phoenix/server/api/types/User.py | 2 +- src/phoenix/server/bearer_auth.py | 58 +++++- 6 files changed, 319 insertions(+), 68 deletions(-) diff --git a/src/phoenix/auth.py b/src/phoenix/auth.py index 83713aa606..72b65ecd83 100644 --- a/src/phoenix/auth.py +++ b/src/phoenix/auth.py @@ -7,10 +7,13 @@ from hashlib import pbkdf2_hmac from typing import Any, Literal, Optional, Protocol -from fastapi import Response +from starlette.responses import Response +from typing_extensions import TypeVar from phoenix.config import get_env_phoenix_use_secure_cookies +ResponseType = TypeVar("ResponseType", bound=Response) + def compute_password_hash(*, password: str, salt: bytes) -> bytes: """ @@ -66,8 +69,8 @@ def validate_password_format(password: str) -> None: def set_access_token_cookie( - *, response: Response, access_token: str, max_age: timedelta -) -> Response: + *, response: ResponseType, access_token: str, max_age: timedelta +) -> ResponseType: return _set_token_cookie( response=response, cookie_name=PHOENIX_ACCESS_TOKEN_COOKIE_NAME, @@ -77,8 +80,8 @@ def set_access_token_cookie( def set_refresh_token_cookie( - *, response: Response, refresh_token: str, max_age: timedelta -) -> Response: + *, response: ResponseType, refresh_token: str, max_age: timedelta +) -> ResponseType: return _set_token_cookie( response=response, cookie_name=PHOENIX_REFRESH_TOKEN_COOKIE_NAME, @@ -88,8 +91,8 @@ def set_refresh_token_cookie( def _set_token_cookie( - response: Response, cookie_name: str, cookie_max_age: timedelta, token: str -) -> Response: + response: ResponseType, cookie_name: str, cookie_max_age: timedelta, token: str +) -> ResponseType: response.set_cookie( key=cookie_name, value=token, diff --git a/src/phoenix/db/enums.py b/src/phoenix/db/enums.py index ba002bdd5b..71f6c2e528 100644 --- a/src/phoenix/db/enums.py +++ b/src/phoenix/db/enums.py @@ -14,6 +14,7 @@ class UserRole(Enum): class AuthMethod(Enum): LOCAL = "LOCAL" + OAUTH = "OAUTH" class IdentityProviderName(Enum): diff --git a/src/phoenix/server/api/routers/auth.py b/src/phoenix/server/api/routers/auth.py index c2d7966f0c..c98cd5f015 100644 --- a/src/phoenix/server/api/routers/auth.py +++ b/src/phoenix/server/api/routers/auth.py @@ -30,17 +30,15 @@ ) from phoenix.config import get_base_url from phoenix.db import enums, models -from phoenix.db.enums import UserRole -from phoenix.server.bearer_auth import PhoenixUser +from phoenix.server.bearer_auth import PhoenixUser, create_access_and_refresh_tokens from phoenix.server.email.templates.types import PasswordResetTemplateBody from phoenix.server.email.types import EmailSender +from phoenix.server.jwt_store import JwtStore from phoenix.server.rate_limiters import ServerRateLimiter, fastapi_rate_limiter from phoenix.server.types import ( - AccessTokenAttributes, AccessTokenClaims, PasswordResetTokenClaims, PasswordResetTokenId, - RefreshTokenAttributes, RefreshTokenClaims, TokenStore, UserId, @@ -71,6 +69,7 @@ async def login(request: Request) -> Response: assert isinstance(access_token_expiry := request.app.state.access_token_expiry, timedelta) assert isinstance(refresh_token_expiry := request.app.state.refresh_token_expiry, timedelta) + token_store: JwtStore = request.app.state.get_token_store() data = await request.json() email = data.get("email") password = data.get("password") @@ -96,27 +95,12 @@ async def login(request: Request) -> Response: if not await loop.run_in_executor(None, password_is_valid): raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail=LOGIN_FAILED_MESSAGE) - issued_at = datetime.now(timezone.utc) - refresh_token_claims = RefreshTokenClaims( - subject=UserId(user.id), - issued_at=issued_at, - expiration_time=issued_at + refresh_token_expiry, - attributes=RefreshTokenAttributes( - user_role=UserRole(user.role.name), - ), - ) - token_store: TokenStore = request.app.state.get_token_store() - refresh_token, refresh_token_id = await token_store.create_refresh_token(refresh_token_claims) - access_token_claims = AccessTokenClaims( - subject=UserId(user.id), - issued_at=issued_at, - expiration_time=issued_at + access_token_expiry, - attributes=AccessTokenAttributes( - user_role=UserRole(user.role.name), - refresh_token_id=refresh_token_id, - ), + access_token, refresh_token = await create_access_and_refresh_tokens( + token_store=token_store, + user=user, + access_token_expiry=access_token_expiry, + refresh_token_expiry=refresh_token_expiry, ) - access_token, _ = await token_store.create_access_token(access_token_claims) response = Response(status_code=HTTP_204_NO_CONTENT) response = set_access_token_cookie( response=response, access_token=access_token, max_age=access_token_expiry @@ -177,27 +161,12 @@ async def refresh_tokens(request: Request) -> Response: ) ) is None: raise HTTPException(status_code=HTTP_404_NOT_FOUND, detail="User not found") - user_role = UserRole(user.role.name) - issued_at = datetime.now(timezone.utc) - refresh_token_claims = RefreshTokenClaims( - subject=UserId(user.id), - issued_at=issued_at, - expiration_time=issued_at + refresh_token_expiry, - attributes=RefreshTokenAttributes( - user_role=user_role, - ), - ) - refresh_token, refresh_token_id = await token_store.create_refresh_token(refresh_token_claims) - access_token_claims = AccessTokenClaims( - subject=UserId(user.id), - issued_at=issued_at, - expiration_time=issued_at + access_token_expiry, - attributes=AccessTokenAttributes( - user_role=user_role, - refresh_token_id=refresh_token_id, - ), + access_token, refresh_token = await create_access_and_refresh_tokens( + token_store=token_store, + user=user, + access_token_expiry=access_token_expiry, + refresh_token_expiry=refresh_token_expiry, ) - access_token, _ = await token_store.create_access_token(access_token_claims) response = Response(status_code=HTTP_204_NO_CONTENT) response = set_access_token_cookie( response=response, access_token=access_token, max_age=access_token_expiry diff --git a/src/phoenix/server/api/routers/oauth.py b/src/phoenix/server/api/routers/oauth.py index db4e82897e..d0bb75b502 100644 --- a/src/phoenix/server/api/routers/oauth.py +++ b/src/phoenix/server/api/routers/oauth.py @@ -1,11 +1,29 @@ +from dataclasses import dataclass +from datetime import timedelta +from typing import Any, Dict, Optional + from authlib.integrations.starlette_client import OAuthError from authlib.integrations.starlette_client import StarletteOAuth2App as OAuthClient -from fastapi import APIRouter, Depends, HTTPException, Request +from fastapi import APIRouter, Depends, HTTPException, Path, Request +from sqlalchemy import and_, insert, or_, select, update +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload from starlette.responses import RedirectResponse from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_404_NOT_FOUND +from typing_extensions import Annotated +from phoenix.auth import ( + set_access_token_cookie, + set_refresh_token_cookie, +) +from phoenix.db import models +from phoenix.db.enums import AuthMethod, UserRole +from phoenix.server.bearer_auth import create_access_and_refresh_tokens +from phoenix.server.jwt_store import JwtStore from phoenix.server.rate_limiters import ServerRateLimiter, fastapi_rate_limiter +ALPHANUMS_AND_UNDERSCORES = r"[a-z0-9_]+" + rate_limiter = ServerRateLimiter( per_second_rate_limit=0.2, enforcement_window_seconds=30, @@ -18,22 +36,232 @@ ) -@router.post("/{idp}/login") -async def login(request: Request, idp: str) -> RedirectResponse: - if not isinstance(oauth_client := request.app.state.oauth_clients.get_client(idp), OAuthClient): - raise HTTPException(HTTP_404_NOT_FOUND, f"Unknown IDP: {idp}") - redirect_uri = request.url_for("create_tokens", idp=idp) +@router.post("/{idp_name}/login") +async def login( + request: Request, + idp_name: Annotated[str, Path(min_length=1, pattern=ALPHANUMS_AND_UNDERSCORES)], +) -> RedirectResponse: + if not isinstance( + oauth_client := request.app.state.oauth_clients.get_client(idp_name), OAuthClient + ): + raise HTTPException(HTTP_404_NOT_FOUND, f"Unknown IDP: {idp_name}") + redirect_uri = request.url_for("create_tokens", idp_name=idp_name) response: RedirectResponse = await oauth_client.authorize_redirect(request, redirect_uri) return response -@router.get("/{idp}/tokens") -async def create_tokens(request: Request, idp: str) -> RedirectResponse: - if not isinstance(oauth_client := request.app.state.oauth_clients.get_client(idp), OAuthClient): - raise HTTPException(HTTP_404_NOT_FOUND, f"Unknown IDP: {idp}") +@router.get("/{idp_name}/tokens") +async def create_tokens( + request: Request, + idp_name: Annotated[str, Path(min_length=1, pattern=ALPHANUMS_AND_UNDERSCORES)], +) -> RedirectResponse: + assert isinstance(access_token_expiry := request.app.state.access_token_expiry, timedelta) + assert isinstance(refresh_token_expiry := request.app.state.refresh_token_expiry, timedelta) + token_store: JwtStore = request.app.state.get_token_store() + if not isinstance( + oauth_client := request.app.state.oauth_clients.get_client(idp_name), OAuthClient + ): + raise HTTPException(HTTP_404_NOT_FOUND, f"Unknown IDP: {idp_name}") try: token = await oauth_client.authorize_access_token(request) except OAuthError as error: raise HTTPException(HTTP_401_UNAUTHORIZED, detail=str(error)) - print(f"{token=}") - return RedirectResponse(url="/") + user_info = _get_user_info(token) + async with request.app.state.db() as session: + user = await _ensure_user_exists_and_is_up_to_date( + session, idp_name=idp_name, user_info=user_info + ) + access_token, refresh_token = await create_access_and_refresh_tokens( + user=user, + token_store=token_store, + access_token_expiry=access_token_expiry, + refresh_token_expiry=refresh_token_expiry, + ) + response = RedirectResponse(url="/") + response = set_access_token_cookie( + response=response, access_token=access_token, max_age=access_token_expiry + ) + response = set_refresh_token_cookie( + response=response, refresh_token=refresh_token, max_age=refresh_token_expiry + ) + return response + + +@dataclass +class UserInfo: + idp_user_id: str + email: str + username: Optional[str] + profile_picture_url: Optional[str] + + +def _get_user_info(token: Dict[str, Any]) -> UserInfo: + assert isinstance(token.get("access_token"), str) + assert isinstance(token_type := token.get("token_type"), str) + assert token_type.lower() == "bearer" + assert isinstance(user_info := token.get("userinfo"), dict) + assert isinstance(subject := user_info.get("sub"), (str, int)) + idp_user_id = str(subject) + assert isinstance(email := user_info.get("email"), str) + assert isinstance(username := user_info.get("name"), str) or username is None + assert ( + isinstance(profile_picture_url := user_info.get("picture"), str) + or profile_picture_url is None + ) + return UserInfo( + idp_user_id=idp_user_id, + email=email, + username=username, + profile_picture_url=profile_picture_url, + ) + + +async def _ensure_identity_provider_exists( + session: AsyncSession, /, *, idp_name: str +) -> models.IdentityProvider: + idp = await session.scalar( + select(models.IdentityProvider).where( + and_( + models.IdentityProvider.name == idp_name, + models.IdentityProvider.auth_method == AuthMethod.OAUTH.value, + ) + ) + ) + if idp is not None: + return idp + idp = await session.scalar( + insert(models.IdentityProvider) + .returning(models.IdentityProvider) + .values(name=idp_name, auth_method=AuthMethod.OAUTH.value) + ) + assert isinstance(idp, models.IdentityProvider) + return idp + + +async def _ensure_user_exists_and_is_up_to_date( + session: AsyncSession, /, *, idp_name: str, user_info: UserInfo +) -> models.User: + idp = await _ensure_identity_provider_exists(session, idp_name=idp_name) + user = await _get_user(session, idp_id=idp.id, idp_user_id=user_info.idp_user_id) + if user is None: + user = await _create_user(session, user_info=user_info, idp=idp) + elif _db_user_is_outdated(user=user, user_info=user_info): + user = await _update_user(session, user_id=user.id, user_info=user_info) + return user + + +async def _get_user( + session: AsyncSession, /, *, idp_id: int, idp_user_id: str +) -> Optional[models.User]: + user = await session.scalar( + select(models.User) + .where( + and_( + models.User.identity_provider_id == idp_id, + models.User.identity_provider_user_id == idp_user_id, + ) + ) + .options(joinedload(models.User.role)) + ) + return user + + +async def _ensure_email_and_username_are_not_used_by_other_idps( + session: AsyncSession, /, *, email: str, username: Optional[str], idp_id: int, idp_name: str +) -> None: + # todo: simplify query + conflicting_users = ( + await session.scalars( + select(models.User).where( + and_( + or_(models.User.email == email, models.User.username == username), + models.User.identity_provider_id != idp_id, + ) + ) + ) + ).all() + for user in conflicting_users: + if user.email == email: + raise EmailAlreadyInUse( + f"An account for {email} is already in use. " + f"This email cannot be re-used with {idp_name} OAuth." + ) + if username and user.username == username: + raise UsernameAlreadyInUse( + f"An account already exists with username {username}. " + f"This username cannot be re-used with {idp_name} OAuth." + ) + return None + + +async def _create_user( + session: AsyncSession, + /, + *, + user_info: UserInfo, + idp: models.IdentityProvider, +) -> models.User: + await _ensure_email_and_username_are_not_used_by_other_idps( + session, + email=user_info.email, + username=user_info.username, + idp_id=idp.id, + idp_name=idp.name, + ) + member_role_id = ( + select(models.UserRole.id) + .where(models.UserRole.name == UserRole.MEMBER.value) + .scalar_subquery() + ) + user = await session.scalar( + insert(models.User) + .returning(models.User) + .values( + user_role_id=member_role_id, + identity_provider_id=idp.id, + identity_provider_user_id=user_info.idp_user_id, + username=user_info.username, + email=user_info.email, + profile_picture_url=user_info.profile_picture_url, + password_hash=None, + password_salt=None, + reset_password=False, + ) + .options(joinedload(models.User.role)) + ) + assert isinstance(user, models.User) + return user + + +async def _update_user( + session: AsyncSession, /, *, user_id: int, user_info: UserInfo +) -> models.User: + user = await session.scalar( + update(models.User) + .where(models.User.id == user_id) + .returning(models.User) + .values( + username=user_info.username, + email=user_info.email, + profile_picture_url=user_info.profile_picture_url, + ) + .options(joinedload(models.User.role)) + ) + assert isinstance(user, models.User) + return user + + +def _db_user_is_outdated(*, user: models.User, user_info: UserInfo) -> bool: + return ( + user.email != user_info.email + or user.username != user_info.username + or user.profile_picture_url != user_info.profile_picture_url + ) + + +class EmailAlreadyInUse(Exception): + pass + + +class UsernameAlreadyInUse(Exception): + pass diff --git a/src/phoenix/server/api/types/User.py b/src/phoenix/server/api/types/User.py index 3a87c03e37..eddf677e62 100644 --- a/src/phoenix/server/api/types/User.py +++ b/src/phoenix/server/api/types/User.py @@ -53,5 +53,5 @@ def to_gql_user(user: models.User, api_keys: Optional[List[models.ApiKey]] = Non email=user.email, created_at=user.created_at, user_role_id=user.user_role_id, - auth_method=AuthMethod(user.auth_method), + auth_method=AuthMethod("MEMBER"), ) diff --git a/src/phoenix/server/bearer_auth.py b/src/phoenix/server/bearer_auth.py index 4a986f12da..14ca32af51 100644 --- a/src/phoenix/server/bearer_auth.py +++ b/src/phoenix/server/bearer_auth.py @@ -1,13 +1,19 @@ from abc import ABC +from datetime import datetime, timedelta, timezone from functools import cached_property -from typing import Any, Awaitable, Callable, Optional, Tuple +from typing import ( + Any, + Awaitable, + Callable, + Optional, + Tuple, +) import grpc -from fastapi import Request +from fastapi import HTTPException, Request from grpc_interceptor import AsyncServerInterceptor from grpc_interceptor.exceptions import Unauthenticated from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser -from starlette.exceptions import HTTPException from starlette.requests import HTTPConnection from starlette.status import HTTP_401_UNAUTHORIZED @@ -18,7 +24,20 @@ Token, ) from phoenix.db import enums -from phoenix.server.types import AccessTokenClaims, ApiKeyClaims, UserClaimSet, UserId +from phoenix.db.enums import UserRole +from phoenix.db.models import User as OrmUser +from phoenix.server.types import ( + AccessToken, + AccessTokenAttributes, + AccessTokenClaims, + ApiKeyClaims, + RefreshToken, + RefreshTokenAttributes, + RefreshTokenClaims, + TokenStore, + UserClaimSet, + UserId, +) class HasTokenStore(ABC): @@ -108,3 +127,34 @@ async def is_authenticated(request: Request) -> None: raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Expired token") if claims.status is not ClaimSetStatus.VALID: raise HTTPException(status_code=HTTP_401_UNAUTHORIZED, detail="Invalid token") + + +async def create_access_and_refresh_tokens( + *, + token_store: TokenStore, + user: OrmUser, + access_token_expiry: timedelta, + refresh_token_expiry: timedelta, +) -> Tuple[AccessToken, RefreshToken]: + issued_at = datetime.now(timezone.utc) + user_role = UserRole(user.role.name) + refresh_token_claims = RefreshTokenClaims( + subject=UserId(user.id), + issued_at=issued_at, + expiration_time=issued_at + refresh_token_expiry, + attributes=RefreshTokenAttributes( + user_role=user_role, + ), + ) + refresh_token, refresh_token_id = await token_store.create_refresh_token(refresh_token_claims) + access_token_claims = AccessTokenClaims( + subject=UserId(user.id), + issued_at=issued_at, + expiration_time=issued_at + access_token_expiry, + attributes=AccessTokenAttributes( + user_role=user_role, + refresh_token_id=refresh_token_id, + ), + ) + access_token, _ = await token_store.create_access_token(access_token_claims) + return access_token, refresh_token From 30b6cd1305e9a836a2dbdb025d421759c755cd98 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 16 Sep 2024 19:28:30 -0700 Subject: [PATCH 04/29] running on google, azure and aws --- src/phoenix/config.py | 52 ++++++++++++++++++++----- src/phoenix/server/api/routers/oauth.py | 38 ++++++++++++------ src/phoenix/server/oauth.py | 47 +--------------------- 3 files changed, 69 insertions(+), 68 deletions(-) diff --git a/src/phoenix/config.py b/src/phoenix/config.py index dfb11f9339..78b2d93611 100644 --- a/src/phoenix/config.py +++ b/src/phoenix/config.py @@ -3,9 +3,11 @@ import tempfile from dataclasses import dataclass from datetime import timedelta +from enum import Enum from logging import getLogger from pathlib import Path from typing import Dict, List, Optional, Tuple, overload +from urllib.parse import urlparse from typing_extensions import TypeAlias @@ -330,9 +332,7 @@ class OAuthClientConfig: display_name: str client_id: str client_secret: str - server_metadata_url: Optional[str] = None - authorize_url: Optional[str] = None - access_token_url: Optional[str] = None + server_metadata_url: str @classmethod def from_env(cls, idp_id: str) -> "OAuthClientConfig": @@ -353,16 +353,32 @@ def from_env(cls, idp_id: str) -> "OAuthClientConfig": f"A client secret must be set for the {idp_id} OAuth IDP " f"via the {client_secret_env_var} environment variable" ) + if ( + server_metadata_url := ( + os.getenv( + server_metadata_url_env_var + := f"PHOENIX_OAUTH_{idp_id_upper}_SERVER_METADATA_URL", + ) + or _get_default_server_metadata_url(idp_id) + ) + ) is None: + raise ValueError( + f"A server metadata URL must be set for the {idp_id} OAuth IDP " + f"via the {server_metadata_url_env_var} environment variable" + ) + if urlparse(server_metadata_url).scheme != "https": + raise ValueError( + f"Server metadata URL for {idp_id} OAuth IDP " + "must be a valid URL using the https protocol" + ) return cls( idp_id=idp_id, display_name=os.getenv( - f"PHOENIX_OAUTH_{idp_id_upper}_DISPLAY_NAME", get_default_idp_display_name(idp_id) + f"PHOENIX_OAUTH_{idp_id_upper}_DISPLAY_NAME", _get_default_idp_display_name(idp_id) ), client_id=client_id, client_secret=client_secret, - server_metadata_url=os.getenv(f"PHOENIX_OAUTH_{idp_id_upper}_SERVER_METADATA_URL"), - access_token_url=os.getenv(f"PHOENIX_OAUTH_{idp_id_upper}_ACCESS_TOKEN_URL"), - authorize_url=os.getenv(f"PHOENIX_OAUTH_{idp_id_upper}_AUTHORIZE_URL"), + server_metadata_url=server_metadata_url, ) def __post_init__(self) -> None: @@ -382,7 +398,7 @@ def get_env_oauth_settings() -> List[OAuthClientConfig]: idp_ids = set() pattern = re.compile( - r"^PHOENIX_OAUTH_(\w+)_(DISPLAY_NAME|CLIENT_ID|CLIENT_SECRET|SERVER_METADATA_URL|ACCESS_TOKEN_URL|AUTHORIZE_URL)$" + r"^PHOENIX_OAUTH_(\w+)_(DISPLAY_NAME|CLIENT_ID|CLIENT_SECRET|SERVER_METADATA_URL)$" ) for env_var in os.environ: if (match := pattern.match(env_var)) is not None and (idp_id := match.group(1).lower()): @@ -558,8 +574,24 @@ def get_web_base_url() -> str: return get_base_url() -def get_default_idp_display_name(ipd_id: IdpId) -> str: - return ipd_id.replace("_", " ").title() +class OAuth2Idp(Enum): + AWS_COGNITO = "aws_cognito" + AZURE_AD = "azure_ad" + GOOGLE = "google" + + +def _get_default_idp_display_name(idp_id: IdpId) -> str: + if idp_id == OAuth2Idp.AWS_COGNITO.value: + return "AWS Cognito" + if idp_id == OAuth2Idp.AZURE_AD.value: + return "Azure AD" + return idp_id.replace("_", " ").title() + + +def _get_default_server_metadata_url(idp_id: IdpId) -> Optional[str]: + if idp_id == OAuth2Idp.GOOGLE.value: + return "https://accounts.google.com/.well-known/openid-configuration" + return None DEFAULT_PROJECT_NAME = "default" diff --git a/src/phoenix/server/api/routers/oauth.py b/src/phoenix/server/api/routers/oauth.py index d0bb75b502..315bcc8009 100644 --- a/src/phoenix/server/api/routers/oauth.py +++ b/src/phoenix/server/api/routers/oauth.py @@ -66,18 +66,25 @@ async def create_tokens( token = await oauth_client.authorize_access_token(request) except OAuthError as error: raise HTTPException(HTTP_401_UNAUTHORIZED, detail=str(error)) - user_info = _get_user_info(token) - async with request.app.state.db() as session: - user = await _ensure_user_exists_and_is_up_to_date( - session, idp_name=idp_name, user_info=user_info + if (user_info := _get_user_info(token)) is None: + raise HTTPException( + status_code=HTTP_401_UNAUTHORIZED, + detail=f"OAuth IDP {idp_name} does not support OpenID Connect.", ) + async with request.app.state.db() as session: + try: + user = await _ensure_user_exists_and_is_up_to_date( + session, idp_name=idp_name, user_info=user_info + ) + except (EmailAlreadyInUse, UsernameAlreadyInUse) as error: + raise HTTPException(HTTP_401_UNAUTHORIZED, detail=str(error)) access_token, refresh_token = await create_access_and_refresh_tokens( user=user, token_store=token_store, access_token_expiry=access_token_expiry, refresh_token_expiry=refresh_token_expiry, ) - response = RedirectResponse(url="/") + response = RedirectResponse(url="/") # todo: sanitize a return url response = set_access_token_cookie( response=response, access_token=access_token, max_age=access_token_expiry ) @@ -95,11 +102,12 @@ class UserInfo: profile_picture_url: Optional[str] -def _get_user_info(token: Dict[str, Any]) -> UserInfo: +def _get_user_info(token: Dict[str, Any]) -> Optional[UserInfo]: assert isinstance(token.get("access_token"), str) assert isinstance(token_type := token.get("token_type"), str) assert token_type.lower() == "bearer" - assert isinstance(user_info := token.get("userinfo"), dict) + if (user_info := token.get("userinfo")) is None: + return None assert isinstance(subject := user_info.get("sub"), (str, int)) idp_user_id = str(subject) assert isinstance(email := user_info.get("email"), str) @@ -213,9 +221,9 @@ async def _create_user( .where(models.UserRole.name == UserRole.MEMBER.value) .scalar_subquery() ) - user = await session.scalar( + user_id = await session.scalar( insert(models.User) - .returning(models.User) + .returning(models.User.id) .values( user_role_id=member_role_id, identity_provider_id=idp.id, @@ -227,8 +235,11 @@ async def _create_user( password_salt=None, reset_password=False, ) - .options(joinedload(models.User.role)) ) + assert isinstance(user_id, int) + user = await session.scalar( + select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role)) + ) # query user for joined load assert isinstance(user, models.User) return user @@ -236,10 +247,9 @@ async def _create_user( async def _update_user( session: AsyncSession, /, *, user_id: int, user_info: UserInfo ) -> models.User: - user = await session.scalar( + await session.execute( update(models.User) .where(models.User.id == user_id) - .returning(models.User) .values( username=user_info.username, email=user_info.email, @@ -247,6 +257,10 @@ async def _update_user( ) .options(joinedload(models.User.role)) ) + assert isinstance(user_id, int) + user = await session.scalar( + select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role)) + ) # query user for joined load assert isinstance(user, models.User) return user diff --git a/src/phoenix/server/oauth.py b/src/phoenix/server/oauth.py index 5ce659d7da..e6fa7b323e 100644 --- a/src/phoenix/server/oauth.py +++ b/src/phoenix/server/oauth.py @@ -1,6 +1,4 @@ -from dataclasses import asdict, dataclass from datetime import datetime, timedelta -from types import MappingProxyType from typing import Any, Dict, Generic, List, Optional, Tuple from authlib.integrations.starlette_client import OAuth @@ -20,22 +18,11 @@ def __init__(self) -> None: def add_client(self, config: OAuthClientConfig) -> None: if (idp_id := config.idp_id) in self._clients: raise ValueError(f"oauth client already registered: {idp_id}") - config = _apply_oauth_config_defaults(config) - server_metadata_url = config.server_metadata_url - authorize_url = config.authorize_url - access_token_url = config.access_token_url - if not (server_metadata_url or (authorize_url and access_token_url)): - raise ValueError( - f"{idp_id} OAuth client must have either a server metadata URL," - " or authorize and access token URLs" - ) client = self._oauth.register( idp_id, client_id=config.client_id, client_secret=config.client_secret, - server_metadata_url=server_metadata_url, - authorize_url=authorize_url, - access_token_url=access_token_url, + server_metadata_url=config.server_metadata_url, client_kwargs={"scope": "openid email profile"}, ) assert isinstance(client, OAuthClient) @@ -54,38 +41,6 @@ def from_configs(cls, configs: List[OAuthClientConfig]) -> "OAuthClients": return oauth_clients -@dataclass -class OAuthClientDefaultConfig: - idp_id: IdpId - display_name: Optional[str] = None - server_metadata_url: Optional[str] = None - authorize_url: Optional[str] = None - access_token_url: Optional[str] = None - - -def _apply_oauth_config_defaults(config: OAuthClientConfig) -> OAuthClientConfig: - if (default_config := _OAUTH_CLIENT_DEFAULT_CONFIGS.get(config.idp_id)) is None: - return config - return OAuthClientConfig( - **{ - **{k: v for k, v in asdict(default_config).items() if v is not None}, - **{k: v for k, v in asdict(config).items() if v is not None}, - } - ) - - -_OAUTH_CLIENT_DEFAULT_CONFIGS = MappingProxyType( - { - config.idp_id: config - for config in ( - OAuthClientDefaultConfig( - idp_id="google", - server_metadata_url="https://accounts.google.com/.well-known/openid-configuration", - ), - ) - } -) - _CacheKey = TypeVar("_CacheKey") _CacheValue = TypeVar("_CacheValue") _Expiry: TypeAlias = datetime From ce2c5282a58a7da2c89b9a908b9eaecf5f4933df Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 16 Sep 2024 19:44:56 -0700 Subject: [PATCH 05/29] refactor idp id to idp name --- app/src/window.d.ts | 2 +- src/phoenix/config.py | 56 ++++++++++++++++++------------------- src/phoenix/server/app.py | 4 +-- src/phoenix/server/oauth.py | 18 ++++++------ 4 files changed, 39 insertions(+), 41 deletions(-) diff --git a/app/src/window.d.ts b/app/src/window.d.ts index 08ab1cb825..a1e00d700e 100644 --- a/app/src/window.d.ts +++ b/app/src/window.d.ts @@ -1,7 +1,7 @@ export {}; type OAuthIdp = { - id: string; + name: string; displayName: string; }; diff --git a/src/phoenix/config.py b/src/phoenix/config.py index 78b2d93611..96cba0fa8d 100644 --- a/src/phoenix/config.py +++ b/src/phoenix/config.py @@ -13,7 +13,6 @@ from phoenix.utilities.re import parse_env_headers -IdpId: TypeAlias = str EnvVarName: TypeAlias = str EnvVarValue: TypeAlias = str @@ -328,53 +327,54 @@ def get_env_smtp_validate_certs() -> bool: @dataclass(frozen=True) class OAuthClientConfig: - idp_id: str + idp_name: str display_name: str client_id: str client_secret: str server_metadata_url: str @classmethod - def from_env(cls, idp_id: str) -> "OAuthClientConfig": - idp_id_upper = idp_id.upper() + def from_env(cls, idp_name: str) -> "OAuthClientConfig": + idp_name_upper = idp_name.upper() if ( - client_id := os.getenv(client_id_env_var := f"PHOENIX_OAUTH_{idp_id_upper}_CLIENT_ID") + client_id := os.getenv(client_id_env_var := f"PHOENIX_OAUTH_{idp_name_upper}_CLIENT_ID") ) is None: raise ValueError( - f"A client id must be set for the {idp_id} OAuth IDP " + f"A client id must be set for the {idp_name} OAuth IDP " f"via the {client_id_env_var} environment variable" ) if ( client_secret := os.getenv( - client_secret_env_var := f"PHOENIX_OAUTH_{idp_id_upper}_CLIENT_SECRET" + client_secret_env_var := f"PHOENIX_OAUTH_{idp_name_upper}_CLIENT_SECRET" ) ) is None: raise ValueError( - f"A client secret must be set for the {idp_id} OAuth IDP " + f"A client secret must be set for the {idp_name} OAuth IDP " f"via the {client_secret_env_var} environment variable" ) if ( server_metadata_url := ( os.getenv( server_metadata_url_env_var - := f"PHOENIX_OAUTH_{idp_id_upper}_SERVER_METADATA_URL", + := f"PHOENIX_OAUTH_{idp_name_upper}_SERVER_METADATA_URL", ) - or _get_default_server_metadata_url(idp_id) + or _get_default_server_metadata_url(idp_name) ) ) is None: raise ValueError( - f"A server metadata URL must be set for the {idp_id} OAuth IDP " + f"A server metadata URL must be set for the {idp_name} OAuth IDP " f"via the {server_metadata_url_env_var} environment variable" ) if urlparse(server_metadata_url).scheme != "https": raise ValueError( - f"Server metadata URL for {idp_id} OAuth IDP " + f"Server metadata URL for {idp_name} OAuth IDP " "must be a valid URL using the https protocol" ) return cls( - idp_id=idp_id, + idp_name=idp_name, display_name=os.getenv( - f"PHOENIX_OAUTH_{idp_id_upper}_DISPLAY_NAME", _get_default_idp_display_name(idp_id) + f"PHOENIX_OAUTH_{idp_name_upper}_DISPLAY_NAME", + _get_default_idp_display_name(idp_name), ), client_id=client_id, client_secret=client_secret, @@ -382,13 +382,13 @@ def from_env(cls, idp_id: str) -> "OAuthClientConfig": ) def __post_init__(self) -> None: - assert self.idp_id + assert self.idp_name if not self.display_name: - raise ValueError(f"OAuth display name for {self.idp_id} cannot be empty") + raise ValueError(f"OAuth display name for {self.idp_name} cannot be empty") if not self.client_id: - raise ValueError(f"OAuth client id for {self.idp_id} cannot be empty") + raise ValueError(f"OAuth client id for {self.idp_name} cannot be empty") if not self.client_secret: - raise ValueError(f"OAuth client secret for {self.idp_id} cannot be empty") + raise ValueError(f"OAuth client secret for {self.idp_name} cannot be empty") def get_env_oauth_settings() -> List[OAuthClientConfig]: @@ -396,14 +396,14 @@ def get_env_oauth_settings() -> List[OAuthClientConfig]: Get OAuth settings from environment variables. """ - idp_ids = set() + idp_names = set() pattern = re.compile( r"^PHOENIX_OAUTH_(\w+)_(DISPLAY_NAME|CLIENT_ID|CLIENT_SECRET|SERVER_METADATA_URL)$" ) for env_var in os.environ: - if (match := pattern.match(env_var)) is not None and (idp_id := match.group(1).lower()): - idp_ids.add(idp_id) - return [OAuthClientConfig.from_env(idp_id) for idp_id in sorted(idp_ids)] + if (match := pattern.match(env_var)) is not None and (idp_name := match.group(1).lower()): + idp_names.add(idp_name) + return [OAuthClientConfig.from_env(idp_name) for idp_name in sorted(idp_names)] PHOENIX_DIR = Path(__file__).resolve().parent @@ -580,16 +580,16 @@ class OAuth2Idp(Enum): GOOGLE = "google" -def _get_default_idp_display_name(idp_id: IdpId) -> str: - if idp_id == OAuth2Idp.AWS_COGNITO.value: +def _get_default_idp_display_name(idp_name: str) -> str: + if idp_name == OAuth2Idp.AWS_COGNITO.value: return "AWS Cognito" - if idp_id == OAuth2Idp.AZURE_AD.value: + if idp_name == OAuth2Idp.AZURE_AD.value: return "Azure AD" - return idp_id.replace("_", " ").title() + return idp_name.replace("_", " ").title() -def _get_default_server_metadata_url(idp_id: IdpId) -> Optional[str]: - if idp_id == OAuth2Idp.GOOGLE.value: +def _get_default_server_metadata_url(idp_name: str) -> Optional[str]: + if idp_name == OAuth2Idp.GOOGLE.value: return "https://accounts.google.com/.well-known/openid-configuration" return None diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index 1750119f33..df112be94f 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -154,7 +154,7 @@ class OAuthIdp(TypedDict): - id: str + name: str displayName: str @@ -745,7 +745,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: web_manifest_path = SERVER_DIR / "static" / ".vite" / "manifest.json" if serve_ui and web_manifest_path.is_file(): oauth_idps = [ - OAuthIdp(id=config.idp_id, displayName=config.display_name) + OAuthIdp(name=config.idp_name, displayName=config.display_name) for config in oauth_client_configs or [] ] app.mount( diff --git a/src/phoenix/server/oauth.py b/src/phoenix/server/oauth.py index e6fa7b323e..0a1f35a627 100644 --- a/src/phoenix/server/oauth.py +++ b/src/phoenix/server/oauth.py @@ -7,30 +7,28 @@ from phoenix.config import OAuthClientConfig -IdpId: TypeAlias = str - class OAuthClients: def __init__(self) -> None: - self._clients: Dict[IdpId, OAuthClient] = {} + self._clients: Dict[str, OAuthClient] = {} self._oauth = OAuth(cache=_OAuthClientTTLCache[str, Any]()) def add_client(self, config: OAuthClientConfig) -> None: - if (idp_id := config.idp_id) in self._clients: - raise ValueError(f"oauth client already registered: {idp_id}") + if (idp_name := config.idp_name) in self._clients: + raise ValueError(f"oauth client already registered: {idp_name}") client = self._oauth.register( - idp_id, + idp_name, client_id=config.client_id, client_secret=config.client_secret, server_metadata_url=config.server_metadata_url, client_kwargs={"scope": "openid email profile"}, ) assert isinstance(client, OAuthClient) - self._clients[config.idp_id] = client + self._clients[config.idp_name] = client - def get_client(self, idp_id: IdpId) -> OAuthClient: - if (client := self._clients.get(idp_id)) is None: - raise ValueError(f"unknown or unregistered oauth client: {idp_id}") + def get_client(self, idp_name: str) -> OAuthClient: + if (client := self._clients.get(idp_name)) is None: + raise ValueError(f"unknown or unregistered oauth client: {idp_name}") return client @classmethod From b816062bc372d9d55258657b798a4e3045bda4f8 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 16 Sep 2024 21:21:21 -0700 Subject: [PATCH 06/29] oauth to oauth2 --- app/src/pages/auth/LoginPage.tsx | 21 +++++---- app/src/window.d.ts | 4 +- src/phoenix/config.py | 51 ++++++++++------------ src/phoenix/server/api/routers/__init__.py | 4 +- src/phoenix/server/api/routers/oauth.py | 18 ++++---- src/phoenix/server/app.py | 26 +++++------ src/phoenix/server/main.py | 4 +- src/phoenix/server/oauth.py | 26 +++++------ src/phoenix/server/templates/index.html | 2 +- 9 files changed, 76 insertions(+), 80 deletions(-) diff --git a/app/src/pages/auth/LoginPage.tsx b/app/src/pages/auth/LoginPage.tsx index 1660d9d757..873174e18d 100644 --- a/app/src/pages/auth/LoginPage.tsx +++ b/app/src/pages/auth/LoginPage.tsx @@ -8,7 +8,7 @@ import { LoginForm } from "./LoginForm"; import { PhoenixLogo } from "./PhoenixLogo"; export function LoginPage() { - const oAuthIdps = window.Config.oAuthIdps; + const oAuth2Idps = window.Config.oAuth2Idps; return ( @@ -17,10 +17,10 @@ export function LoginPage() { - {oAuthIdps.map((idp) => ( - ( + ))} @@ -28,13 +28,16 @@ export function LoginPage() { ); } -type OAuthLoginFormProps = { - idpId: string; +type OAuth2LoginFormProps = { + idpName: string; idpDisplayName: string; }; -export function OAuthLoginForm({ idpId, idpDisplayName }: OAuthLoginFormProps) { +export function OAuth2LoginForm({ + idpName, + idpDisplayName, +}: OAuth2LoginFormProps) { return ( -
+
bool: @dataclass(frozen=True) -class OAuthClientConfig: +class OAuth2ClientConfig: idp_name: str display_name: str client_id: str @@ -334,46 +334,48 @@ class OAuthClientConfig: server_metadata_url: str @classmethod - def from_env(cls, idp_name: str) -> "OAuthClientConfig": + def from_env(cls, idp_name: str) -> "OAuth2ClientConfig": idp_name_upper = idp_name.upper() - if ( - client_id := os.getenv(client_id_env_var := f"PHOENIX_OAUTH_{idp_name_upper}_CLIENT_ID") - ) is None: + if not ( + client_id := os.getenv( + client_id_env_var := f"PHOENIX_OAUTH2_{idp_name_upper}_CLIENT_ID" + ) + ): raise ValueError( - f"A client id must be set for the {idp_name} OAuth IDP " + f"A client id must be set for the {idp_name} OAuth2 IDP " f"via the {client_id_env_var} environment variable" ) - if ( + if not ( client_secret := os.getenv( - client_secret_env_var := f"PHOENIX_OAUTH_{idp_name_upper}_CLIENT_SECRET" + client_secret_env_var := f"PHOENIX_OAUTH2_{idp_name_upper}_CLIENT_SECRET" ) - ) is None: + ): raise ValueError( - f"A client secret must be set for the {idp_name} OAuth IDP " + f"A client secret must be set for the {idp_name} OAuth2 IDP " f"via the {client_secret_env_var} environment variable" ) - if ( + if not ( server_metadata_url := ( os.getenv( server_metadata_url_env_var - := f"PHOENIX_OAUTH_{idp_name_upper}_SERVER_METADATA_URL", + := f"PHOENIX_OAUTH2_{idp_name_upper}_SERVER_METADATA_URL", ) or _get_default_server_metadata_url(idp_name) ) - ) is None: + ): raise ValueError( - f"A server metadata URL must be set for the {idp_name} OAuth IDP " + f"A server metadata URL must be set for the {idp_name} OAuth2 IDP " f"via the {server_metadata_url_env_var} environment variable" ) if urlparse(server_metadata_url).scheme != "https": raise ValueError( - f"Server metadata URL for {idp_name} OAuth IDP " + f"Server metadata URL for {idp_name} OAuth2 IDP " "must be a valid URL using the https protocol" ) return cls( idp_name=idp_name, display_name=os.getenv( - f"PHOENIX_OAUTH_{idp_name_upper}_DISPLAY_NAME", + f"PHOENIX_OAUTH2_{idp_name_upper}_DISPLAY_NAME", _get_default_idp_display_name(idp_name), ), client_id=client_id, @@ -381,29 +383,20 @@ def from_env(cls, idp_name: str) -> "OAuthClientConfig": server_metadata_url=server_metadata_url, ) - def __post_init__(self) -> None: - assert self.idp_name - if not self.display_name: - raise ValueError(f"OAuth display name for {self.idp_name} cannot be empty") - if not self.client_id: - raise ValueError(f"OAuth client id for {self.idp_name} cannot be empty") - if not self.client_secret: - raise ValueError(f"OAuth client secret for {self.idp_name} cannot be empty") - -def get_env_oauth_settings() -> List[OAuthClientConfig]: +def get_env_oauth2_settings() -> List[OAuth2ClientConfig]: """ - Get OAuth settings from environment variables. + Get OAuth2 settings from environment variables. """ idp_names = set() pattern = re.compile( - r"^PHOENIX_OAUTH_(\w+)_(DISPLAY_NAME|CLIENT_ID|CLIENT_SECRET|SERVER_METADATA_URL)$" + r"^PHOENIX_OAUTH2_(\w+)_(DISPLAY_NAME|CLIENT_ID|CLIENT_SECRET|SERVER_METADATA_URL)$" ) for env_var in os.environ: if (match := pattern.match(env_var)) is not None and (idp_name := match.group(1).lower()): idp_names.add(idp_name) - return [OAuthClientConfig.from_env(idp_name) for idp_name in sorted(idp_names)] + return [OAuth2ClientConfig.from_env(idp_name) for idp_name in sorted(idp_names)] PHOENIX_DIR = Path(__file__).resolve().parent diff --git a/src/phoenix/server/api/routers/__init__.py b/src/phoenix/server/api/routers/__init__.py index 8c65c0c768..fb4a98e3cd 100644 --- a/src/phoenix/server/api/routers/__init__.py +++ b/src/phoenix/server/api/routers/__init__.py @@ -1,11 +1,11 @@ from .auth import router as auth_router from .embeddings import create_embeddings_router -from .oauth import router as oauth_router +from .oauth import router as oauth2_router from .v1 import create_v1_router __all__ = [ "auth_router", "create_embeddings_router", "create_v1_router", - "oauth_router", + "oauth2_router", ] diff --git a/src/phoenix/server/api/routers/oauth.py b/src/phoenix/server/api/routers/oauth.py index 315bcc8009..52ac0ea155 100644 --- a/src/phoenix/server/api/routers/oauth.py +++ b/src/phoenix/server/api/routers/oauth.py @@ -3,7 +3,7 @@ from typing import Any, Dict, Optional from authlib.integrations.starlette_client import OAuthError -from authlib.integrations.starlette_client import StarletteOAuth2App as OAuthClient +from authlib.integrations.starlette_client import StarletteOAuth2App as OAuth2Client from fastapi import APIRouter, Depends, HTTPException, Path, Request from sqlalchemy import and_, insert, or_, select, update from sqlalchemy.ext.asyncio import AsyncSession @@ -32,7 +32,7 @@ ) login_rate_limiter = fastapi_rate_limiter(rate_limiter, paths=["/login"]) router = APIRouter( - prefix="/oauth", include_in_schema=False, dependencies=[Depends(login_rate_limiter)] + prefix="/oauth2", include_in_schema=False, dependencies=[Depends(login_rate_limiter)] ) @@ -42,11 +42,11 @@ async def login( idp_name: Annotated[str, Path(min_length=1, pattern=ALPHANUMS_AND_UNDERSCORES)], ) -> RedirectResponse: if not isinstance( - oauth_client := request.app.state.oauth_clients.get_client(idp_name), OAuthClient + oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client ): raise HTTPException(HTTP_404_NOT_FOUND, f"Unknown IDP: {idp_name}") redirect_uri = request.url_for("create_tokens", idp_name=idp_name) - response: RedirectResponse = await oauth_client.authorize_redirect(request, redirect_uri) + response: RedirectResponse = await oauth2_client.authorize_redirect(request, redirect_uri) return response @@ -59,17 +59,17 @@ async def create_tokens( assert isinstance(refresh_token_expiry := request.app.state.refresh_token_expiry, timedelta) token_store: JwtStore = request.app.state.get_token_store() if not isinstance( - oauth_client := request.app.state.oauth_clients.get_client(idp_name), OAuthClient + oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client ): raise HTTPException(HTTP_404_NOT_FOUND, f"Unknown IDP: {idp_name}") try: - token = await oauth_client.authorize_access_token(request) + token = await oauth2_client.authorize_access_token(request) except OAuthError as error: raise HTTPException(HTTP_401_UNAUTHORIZED, detail=str(error)) if (user_info := _get_user_info(token)) is None: raise HTTPException( status_code=HTTP_401_UNAUTHORIZED, - detail=f"OAuth IDP {idp_name} does not support OpenID Connect.", + detail=f"OAuth2 IDP {idp_name} does not support OpenID Connect.", ) async with request.app.state.db() as session: try: @@ -192,12 +192,12 @@ async def _ensure_email_and_username_are_not_used_by_other_idps( if user.email == email: raise EmailAlreadyInUse( f"An account for {email} is already in use. " - f"This email cannot be re-used with {idp_name} OAuth." + f"This email cannot be re-used with {idp_name} OAuth2." ) if username and user.username == username: raise UsernameAlreadyInUse( f"An account already exists with username {username}. " - f"This username cannot be re-used with {idp_name} OAuth." + f"This username cannot be re-used with {idp_name} OAuth2." ) return None diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index df112be94f..f28d3975c1 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -52,7 +52,7 @@ from phoenix.config import ( DEFAULT_PROJECT_NAME, SERVER_DIR, - OAuthClientConfig, + OAuth2ClientConfig, get_env_host, get_env_port, server_instrumentation_is_enabled, @@ -97,7 +97,7 @@ auth_router, create_embeddings_router, create_v1_router, - oauth_router, + oauth2_router, ) from phoenix.server.api.routers.v1 import REST_API_VERSION from phoenix.server.api.schema import schema @@ -107,7 +107,7 @@ from phoenix.server.email.types import EmailSender from phoenix.server.grpc_server import GrpcServer from phoenix.server.jwt_store import JwtStore -from phoenix.server.oauth import OAuthClients +from phoenix.server.oauth import OAuth2Clients from phoenix.server.telemetry import initialize_opentelemetry_tracer_provider from phoenix.server.types import ( CanGetLastUpdatedAt, @@ -153,7 +153,7 @@ _Callback: TypeAlias = Callable[[], Union[None, Awaitable[None]]] -class OAuthIdp(TypedDict): +class OAuth2Idp(TypedDict): name: str displayName: str @@ -169,7 +169,7 @@ class AppConfig(NamedTuple): web_manifest_path: Path authentication_enabled: bool """ Whether authentication is enabled """ - oauth_idps: Sequence[OAuthIdp] + oauth2_idps: Sequence[OAuth2Idp] class Static(StaticFiles): @@ -218,7 +218,7 @@ async def get_response(self, path: str, scope: Scope) -> Response: "is_development": self._app_config.is_development, "manifest": self._web_manifest, "authentication_enabled": self._app_config.authentication_enabled, - "oauth_idps": self._app_config.oauth_idps, + "oauth2_idps": self._app_config.oauth2_idps, }, ) except Exception as e: @@ -627,7 +627,7 @@ def create_app( refresh_token_expiry: Optional[timedelta] = None, scaffolder_config: Optional[ScaffolderConfig] = None, email_sender: Optional[EmailSender] = None, - oauth_client_configs: Optional[List[OAuthClientConfig]] = None, + oauth2_client_configs: Optional[List[OAuth2ClientConfig]] = None, ) -> FastAPI: startup_callbacks_list: List[_Callback] = list(startup_callbacks) shutdown_callbacks_list: List[_Callback] = list(shutdown_callbacks) @@ -740,13 +740,13 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: app.include_router(graphql_router) if authentication_enabled: app.include_router(auth_router) - app.include_router(oauth_router) + app.include_router(oauth2_router) app.add_middleware(GZipMiddleware) web_manifest_path = SERVER_DIR / "static" / ".vite" / "manifest.json" if serve_ui and web_manifest_path.is_file(): - oauth_idps = [ - OAuthIdp(name=config.idp_name, displayName=config.display_name) - for config in oauth_client_configs or [] + oauth2_idps = [ + OAuth2Idp(name=config.idp_name, displayName=config.display_name) + for config in oauth2_client_configs or [] ] app.mount( "/", @@ -761,7 +761,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: is_development=dev, authentication_enabled=authentication_enabled, web_manifest_path=web_manifest_path, - oauth_idps=oauth_idps, + oauth2_idps=oauth2_idps, ), ), name="static", @@ -771,7 +771,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: app.state.password_reset_token_expiry = password_reset_token_expiry app.state.access_token_expiry = access_token_expiry app.state.refresh_token_expiry = refresh_token_expiry - app.state.oauth_clients = OAuthClients.from_configs(oauth_client_configs or []) + app.state.oauth2_clients = OAuth2Clients.from_configs(oauth2_client_configs or []) app.state.db = db app.state.email_sender = email_sender app = _add_get_secret_method(app=app, secret=secret) diff --git a/src/phoenix/server/main.py b/src/phoenix/server/main.py index bbf63bcb1f..ae3c251eda 100644 --- a/src/phoenix/server/main.py +++ b/src/phoenix/server/main.py @@ -26,7 +26,7 @@ get_env_grpc_port, get_env_host, get_env_host_root_path, - get_env_oauth_settings, + get_env_oauth2_settings, get_env_password_reset_token_expiry, get_env_port, get_env_refresh_token_expiry, @@ -418,7 +418,7 @@ def _get_pid_file() -> Path: refresh_token_expiry=get_env_refresh_token_expiry(), scaffolder_config=scaffolder_config, email_sender=email_sender, - oauth_client_configs=get_env_oauth_settings(), + oauth2_client_configs=get_env_oauth2_settings(), ) server = Server(config=Config(app, host=host, port=port, root_path=host_root_path)) # type: ignore Thread(target=_write_pid_file_when_ready, args=(server,), daemon=True).start() diff --git a/src/phoenix/server/oauth.py b/src/phoenix/server/oauth.py index 0a1f35a627..cc3256e2f9 100644 --- a/src/phoenix/server/oauth.py +++ b/src/phoenix/server/oauth.py @@ -2,18 +2,18 @@ from typing import Any, Dict, Generic, List, Optional, Tuple from authlib.integrations.starlette_client import OAuth -from authlib.integrations.starlette_client import StarletteOAuth2App as OAuthClient +from authlib.integrations.starlette_client import StarletteOAuth2App as OAuth2Client from typing_extensions import TypeAlias, TypeVar -from phoenix.config import OAuthClientConfig +from phoenix.config import OAuth2ClientConfig -class OAuthClients: +class OAuth2Clients: def __init__(self) -> None: - self._clients: Dict[str, OAuthClient] = {} - self._oauth = OAuth(cache=_OAuthClientTTLCache[str, Any]()) + self._clients: Dict[str, OAuth2Client] = {} + self._oauth = OAuth(cache=_OAuth2ClientTTLCache[str, Any]()) - def add_client(self, config: OAuthClientConfig) -> None: + def add_client(self, config: OAuth2ClientConfig) -> None: if (idp_name := config.idp_name) in self._clients: raise ValueError(f"oauth client already registered: {idp_name}") client = self._oauth.register( @@ -23,20 +23,20 @@ def add_client(self, config: OAuthClientConfig) -> None: server_metadata_url=config.server_metadata_url, client_kwargs={"scope": "openid email profile"}, ) - assert isinstance(client, OAuthClient) + assert isinstance(client, OAuth2Client) self._clients[config.idp_name] = client - def get_client(self, idp_name: str) -> OAuthClient: + def get_client(self, idp_name: str) -> OAuth2Client: if (client := self._clients.get(idp_name)) is None: raise ValueError(f"unknown or unregistered oauth client: {idp_name}") return client @classmethod - def from_configs(cls, configs: List[OAuthClientConfig]) -> "OAuthClients": - oauth_clients = cls() + def from_configs(cls, configs: List[OAuth2ClientConfig]) -> "OAuth2Clients": + oauth2_clients = cls() for config in configs: - oauth_clients.add_client(config) - return oauth_clients + oauth2_clients.add_client(config) + return oauth2_clients _CacheKey = TypeVar("_CacheKey") @@ -45,7 +45,7 @@ def from_configs(cls, configs: List[OAuthClientConfig]) -> "OAuthClients": _MINUTE = timedelta(minutes=1) -class _OAuthClientTTLCache(Generic[_CacheKey, _CacheValue]): +class _OAuth2ClientTTLCache(Generic[_CacheKey, _CacheValue]): """ A TTL cache satisfying the interface required by the Authlib Starlette integration. Provides an alternative to starlette session middleware. diff --git a/src/phoenix/server/templates/index.html b/src/phoenix/server/templates/index.html index 748288cbe8..0a78f03a02 100644 --- a/src/phoenix/server/templates/index.html +++ b/src/phoenix/server/templates/index.html @@ -87,7 +87,7 @@ nSamples: parseInt("{{n_samples}}"), }, authenticationEnabled: Boolean("{{authentication_enabled}}" == "True"), - oAuthIdps: {{ oauth_idps | tojson }}, + oAuth2Idps: {{ oauth2_idps | tojson }}, }), writable: false }); From d6b9c8d66233573305c9132535903b0f7a0a5776 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 16 Sep 2024 21:27:14 -0700 Subject: [PATCH 07/29] rename files oauth to oath2 --- src/phoenix/server/api/routers/__init__.py | 2 +- src/phoenix/server/api/routers/{oauth.py => oauth2.py} | 0 src/phoenix/server/app.py | 2 +- src/phoenix/server/{oauth.py => oauth2.py} | 0 4 files changed, 2 insertions(+), 2 deletions(-) rename src/phoenix/server/api/routers/{oauth.py => oauth2.py} (100%) rename src/phoenix/server/{oauth.py => oauth2.py} (100%) diff --git a/src/phoenix/server/api/routers/__init__.py b/src/phoenix/server/api/routers/__init__.py index fb4a98e3cd..cf52fd0c47 100644 --- a/src/phoenix/server/api/routers/__init__.py +++ b/src/phoenix/server/api/routers/__init__.py @@ -1,6 +1,6 @@ from .auth import router as auth_router from .embeddings import create_embeddings_router -from .oauth import router as oauth2_router +from .oauth2 import router as oauth2_router from .v1 import create_v1_router __all__ = [ diff --git a/src/phoenix/server/api/routers/oauth.py b/src/phoenix/server/api/routers/oauth2.py similarity index 100% rename from src/phoenix/server/api/routers/oauth.py rename to src/phoenix/server/api/routers/oauth2.py diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index f28d3975c1..0b557eed42 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -107,7 +107,7 @@ from phoenix.server.email.types import EmailSender from phoenix.server.grpc_server import GrpcServer from phoenix.server.jwt_store import JwtStore -from phoenix.server.oauth import OAuth2Clients +from phoenix.server.oauth2 import OAuth2Clients from phoenix.server.telemetry import initialize_opentelemetry_tracer_provider from phoenix.server.types import ( CanGetLastUpdatedAt, diff --git a/src/phoenix/server/oauth.py b/src/phoenix/server/oauth2.py similarity index 100% rename from src/phoenix/server/oauth.py rename to src/phoenix/server/oauth2.py From 17e1afc8461c11a1a2f23187b0d09d21ec20f807 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 16 Sep 2024 23:21:27 -0700 Subject: [PATCH 08/29] redirect to login --- src/phoenix/server/api/routers/oauth2.py | 54 +++++++++++------------- 1 file changed, 24 insertions(+), 30 deletions(-) diff --git a/src/phoenix/server/api/routers/oauth2.py b/src/phoenix/server/api/routers/oauth2.py index 52ac0ea155..3ae911f693 100644 --- a/src/phoenix/server/api/routers/oauth2.py +++ b/src/phoenix/server/api/routers/oauth2.py @@ -4,12 +4,12 @@ from authlib.integrations.starlette_client import OAuthError from authlib.integrations.starlette_client import StarletteOAuth2App as OAuth2Client -from fastapi import APIRouter, Depends, HTTPException, Path, Request +from fastapi import APIRouter, Depends, Path, Request from sqlalchemy import and_, insert, or_, select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload +from starlette.datastructures import URL from starlette.responses import RedirectResponse -from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_404_NOT_FOUND from typing_extensions import Annotated from phoenix.auth import ( @@ -44,7 +44,7 @@ async def login( if not isinstance( oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client ): - raise HTTPException(HTTP_404_NOT_FOUND, f"Unknown IDP: {idp_name}") + return _redirect_to_login(error=f"Unknown IDP: {idp_name}.") redirect_uri = request.url_for("create_tokens", idp_name=idp_name) response: RedirectResponse = await oauth2_client.authorize_redirect(request, redirect_uri) return response @@ -61,23 +61,22 @@ async def create_tokens( if not isinstance( oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client ): - raise HTTPException(HTTP_404_NOT_FOUND, f"Unknown IDP: {idp_name}") + return _redirect_to_login(error=f"Unknown IDP: {idp_name}.") try: token = await oauth2_client.authorize_access_token(request) except OAuthError as error: - raise HTTPException(HTTP_401_UNAUTHORIZED, detail=str(error)) + return _redirect_to_login(error=str(error)) if (user_info := _get_user_info(token)) is None: - raise HTTPException( - status_code=HTTP_401_UNAUTHORIZED, - detail=f"OAuth2 IDP {idp_name} does not support OpenID Connect.", + return _redirect_to_login( + error=f"OAuth2 IDP {idp_name} does not appear to support OpenID Connect." ) - async with request.app.state.db() as session: - try: + try: + async with request.app.state.db() as session: user = await _ensure_user_exists_and_is_up_to_date( session, idp_name=idp_name, user_info=user_info ) - except (EmailAlreadyInUse, UsernameAlreadyInUse) as error: - raise HTTPException(HTTP_401_UNAUTHORIZED, detail=str(error)) + except (EmailAlreadyInUse, UsernameAlreadyInUse) as error: + return _redirect_to_login(error=str(error)) access_token, refresh_token = await create_access_and_refresh_tokens( user=user, token_store=token_store, @@ -174,31 +173,21 @@ async def _get_user( return user -async def _ensure_email_and_username_are_not_used_by_other_idps( - session: AsyncSession, /, *, email: str, username: Optional[str], idp_id: int, idp_name: str +async def _ensure_email_and_username_are_not_in_use( + session: AsyncSession, /, *, email: str, username: Optional[str] ) -> None: - # todo: simplify query conflicting_users = ( await session.scalars( select(models.User).where( - and_( - or_(models.User.email == email, models.User.username == username), - models.User.identity_provider_id != idp_id, - ) + or_(models.User.email == email, models.User.username == username) ) ) ).all() for user in conflicting_users: if user.email == email: - raise EmailAlreadyInUse( - f"An account for {email} is already in use. " - f"This email cannot be re-used with {idp_name} OAuth2." - ) + raise EmailAlreadyInUse(f"An account for {email} is already in use.") if username and user.username == username: - raise UsernameAlreadyInUse( - f"An account already exists with username {username}. " - f"This username cannot be re-used with {idp_name} OAuth2." - ) + raise UsernameAlreadyInUse(f'An account already exists with username "{username}".') return None @@ -209,12 +198,10 @@ async def _create_user( user_info: UserInfo, idp: models.IdentityProvider, ) -> models.User: - await _ensure_email_and_username_are_not_used_by_other_idps( + await _ensure_email_and_username_are_not_in_use( session, email=user_info.email, username=user_info.username, - idp_id=idp.id, - idp_name=idp.name, ) member_role_id = ( select(models.UserRole.id) @@ -279,3 +266,10 @@ class EmailAlreadyInUse(Exception): class UsernameAlreadyInUse(Exception): pass + + +def _redirect_to_login(*, error: str) -> RedirectResponse: + """ + Creates a RedirectResponse to the login page to display an error message. + """ + return RedirectResponse(url=URL("/login").include_query_params(error=error)) From 6f967c38f5a14f282d3a9db2a85e266402e1de57 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Mon, 16 Sep 2024 23:52:00 -0700 Subject: [PATCH 09/29] optimize query --- src/phoenix/server/api/routers/oauth2.py | 34 ++++++++++++++++-------- 1 file changed, 23 insertions(+), 11 deletions(-) diff --git a/src/phoenix/server/api/routers/oauth2.py b/src/phoenix/server/api/routers/oauth2.py index 3ae911f693..5100beddb7 100644 --- a/src/phoenix/server/api/routers/oauth2.py +++ b/src/phoenix/server/api/routers/oauth2.py @@ -5,7 +5,7 @@ from authlib.integrations.starlette_client import OAuthError from authlib.integrations.starlette_client import StarletteOAuth2App as OAuth2Client from fastapi import APIRouter, Depends, Path, Request -from sqlalchemy import and_, insert, or_, select, update +from sqlalchemy import Boolean, and_, case, cast, func, insert, or_, select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload from starlette.datastructures import URL @@ -176,18 +176,30 @@ async def _get_user( async def _ensure_email_and_username_are_not_in_use( session: AsyncSession, /, *, email: str, username: Optional[str] ) -> None: - conflicting_users = ( - await session.scalars( - select(models.User).where( - or_(models.User.email == email, models.User.username == username) - ) + [(email_exists, username_exists)] = ( + await session.execute( + select( + cast( + func.coalesce( + func.max(case((models.User.email == email, 1), else_=0)), + 0, + ), + Boolean, + ).label("email_exists"), + cast( + func.coalesce( + func.max(case((models.User.username == username, 1), else_=0)), + 0, + ), + Boolean, + ).label("username_exists"), + ).where(or_(models.User.email == email, models.User.username == username)) ) ).all() - for user in conflicting_users: - if user.email == email: - raise EmailAlreadyInUse(f"An account for {email} is already in use.") - if username and user.username == username: - raise UsernameAlreadyInUse(f'An account already exists with username "{username}".') + if email_exists: + raise EmailAlreadyInUse(f"An account for {email} is already in use.") + if username_exists: + raise UsernameAlreadyInUse(f'An account already exists with username "{username}".') return None From fee4b1327066bf19df91e5b97252a4c768432966 Mon Sep 17 00:00:00 2001 From: Mikyo King Date: Tue, 17 Sep 2024 12:47:02 -0600 Subject: [PATCH 10/29] feat: style sso buttons - mikeldking --- app/src/pages/auth/LoginPage.tsx | 69 +++++++++++------------ app/src/pages/auth/Oauth2Login.tsx | 90 ++++++++++++++++++++++++++++++ cspell.json | 1 + 3 files changed, 125 insertions(+), 35 deletions(-) create mode 100644 app/src/pages/auth/Oauth2Login.tsx diff --git a/app/src/pages/auth/LoginPage.tsx b/app/src/pages/auth/LoginPage.tsx index 873174e18d..fd79ce1076 100644 --- a/app/src/pages/auth/LoginPage.tsx +++ b/app/src/pages/auth/LoginPage.tsx @@ -1,14 +1,31 @@ import React from "react"; import { css } from "@emotion/react"; -import { Button, Flex, Form, View } from "@arizeai/components"; +import { Flex, View } from "@arizeai/components"; import { AuthLayout } from "./AuthLayout"; import { LoginForm } from "./LoginForm"; +import { OAuth2Login } from "./Oauth2Login"; import { PhoenixLogo } from "./PhoenixLogo"; +const separatorCSS = css` + text-align: center; + margin-top: var(--ac-global-dimension-size-200); + margin-bottom: var(--ac-global-dimension-size-200); + color: var(--ac-global-text-color-700); +`; + +const oAuthLoginButtonListCSS = css` + display: flex; + flex-direction: column; + gap: var(--ac-global-dimension-size-100); + flex-wrap: wrap; + justify-content: center; +`; + export function LoginPage() { const oAuth2Idps = window.Config.oAuth2Idps; + const hasOAuth2Idps = oAuth2Idps.length > 0; return ( @@ -17,40 +34,22 @@ export function LoginPage() { - {oAuth2Idps.map((idp) => ( - - ))} + {hasOAuth2Idps && ( + <> +
or
+
    + {oAuth2Idps.map((idp) => ( +
  • + +
  • + ))} +
+ + )}
); } - -type OAuth2LoginFormProps = { - idpName: string; - idpDisplayName: string; -}; -export function OAuth2LoginForm({ - idpName, - idpDisplayName, -}: OAuth2LoginFormProps) { - return ( - -
- -
- - ); -} diff --git a/app/src/pages/auth/Oauth2Login.tsx b/app/src/pages/auth/Oauth2Login.tsx new file mode 100644 index 0000000000..9fbb85ae31 --- /dev/null +++ b/app/src/pages/auth/Oauth2Login.tsx @@ -0,0 +1,90 @@ +import React, { ReactNode } from "react"; +import { css } from "@emotion/react"; + +import { Button } from "@arizeai/components"; + +type OAuth2LoginProps = { + idpName: string; + idpDisplayName: string; +}; + +const loginCSS = css` + button { + width: 100%; + } + i { + display: block; + width: 20px; + height: 20px; + padding-right: var(--ac-global-dimension-size-50); + } + &[data-provider^="aws"], + &[data-provider^="google"] { + button { + background-color: white; + color: black; + &:hover { + background-color: #ececec !important; + } + } + } +`; + +export function OAuth2Login({ idpName, idpDisplayName }: OAuth2LoginProps) { + return ( +
+ +
+ ); +} + +function IDPIcon({ idpName }: { idpName: string }): ReactNode { + const hasIcon = + idpName === "github" || + idpName === "google" || + idpName === "azure_ad" || + idpName.startsWith("aws"); + if (!hasIcon) { + return null; + } + return ( + +
+ + ); +} diff --git a/cspell.json b/cspell.json index 264b970ded..40488d461b 100644 --- a/cspell.json +++ b/cspell.json @@ -22,6 +22,7 @@ "graphiql", "HDBSCAN", "httpx", + "Idps", "Instrumentor", "instrumentors", "langchain", From 7a0e84ee7169e3ccd60de812f6bcc687665a49d7 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 17 Sep 2024 14:21:34 -0700 Subject: [PATCH 11/29] display error messages from query parameters --- app/src/pages/auth/LoginForm.tsx | 15 ++++++++++++--- app/src/pages/auth/LoginPage.tsx | 7 ++++++- 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/app/src/pages/auth/LoginForm.tsx b/app/src/pages/auth/LoginForm.tsx index 98fe6c7588..54b915cc7f 100644 --- a/app/src/pages/auth/LoginForm.tsx +++ b/app/src/pages/auth/LoginForm.tsx @@ -13,12 +13,21 @@ type LoginFormParams = { password: string; }; -export function LoginForm() { +type LoginFormProps = { + initialError: string | null; + /** + * Callback function called when the form is submitted + */ + onSubmit?: () => void; +}; +export function LoginForm(props: LoginFormProps) { const navigate = useNavigate(); - const [error, setError] = useState(null); + const { initialError, onSubmit: propsOnSubmit } = props; + const [error, setError] = useState(initialError); const [isLoading, setIsLoading] = useState(false); const onSubmit = useCallback( async (params: LoginFormParams) => { + propsOnSubmit?.(); setError(null); setIsLoading(true); try { @@ -42,7 +51,7 @@ export function LoginForm() { const returnUrl = getReturnUrl(); navigate(returnUrl); }, - [navigate, setError] + [navigate, propsOnSubmit, setError] ); const { control, handleSubmit } = useForm({ defaultValues: { email: "", password: "" }, diff --git a/app/src/pages/auth/LoginPage.tsx b/app/src/pages/auth/LoginPage.tsx index fd79ce1076..b86711d565 100644 --- a/app/src/pages/auth/LoginPage.tsx +++ b/app/src/pages/auth/LoginPage.tsx @@ -1,4 +1,5 @@ import React from "react"; +import { useSearchParams } from "react-router-dom"; import { css } from "@emotion/react"; import { Flex, View } from "@arizeai/components"; @@ -26,6 +27,7 @@ const oAuthLoginButtonListCSS = css` export function LoginPage() { const oAuth2Idps = window.Config.oAuth2Idps; const hasOAuth2Idps = oAuth2Idps.length > 0; + const [searchParams, setSearchParams] = useSearchParams(); return ( @@ -33,7 +35,10 @@ export function LoginPage() { - + setSearchParams({})} + /> {hasOAuth2Idps && ( <>
or
From e0f2f613e0a586b9f4fd235eaca0145d720da6bb Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 17 Sep 2024 15:19:02 -0700 Subject: [PATCH 12/29] remove idp table --- src/phoenix/db/enums.py | 5 --- src/phoenix/db/facilitator.py | 36 +--------------- .../versions/cd164e83824f_users_and_tokens.py | 36 ++-------------- src/phoenix/db/models.py | 35 +++------------- .../server/api/mutations/user_mutations.py | 33 ++++++--------- src/phoenix/server/api/routers/oauth2.py | 41 ++++--------------- 6 files changed, 33 insertions(+), 153 deletions(-) diff --git a/src/phoenix/db/enums.py b/src/phoenix/db/enums.py index 71f6c2e528..12ae245942 100644 --- a/src/phoenix/db/enums.py +++ b/src/phoenix/db/enums.py @@ -12,11 +12,6 @@ class UserRole(Enum): MEMBER = "MEMBER" -class AuthMethod(Enum): - LOCAL = "LOCAL" - OAUTH = "OAUTH" - - class IdentityProviderName(Enum): LOCAL = "local" diff --git a/src/phoenix/db/facilitator.py b/src/phoenix/db/facilitator.py index 9e114fb714..d01e4f1560 100644 --- a/src/phoenix/db/facilitator.py +++ b/src/phoenix/db/facilitator.py @@ -5,7 +5,6 @@ from functools import partial from sqlalchemy import ( - and_, distinct, insert, select, @@ -20,7 +19,7 @@ compute_password_hash, ) from phoenix.db import models -from phoenix.db.enums import COLUMN_ENUMS, AuthMethod, IdentityProviderName, UserRole +from phoenix.db.enums import COLUMN_ENUMS, UserRole from phoenix.server.types import DbSessionFactory @@ -39,7 +38,6 @@ async def __call__(self) -> None: async with self._db() as session: for fn in ( _ensure_enums, - _ensure_identity_providers, _ensure_user_roles, ): async with session.begin_nested(): @@ -63,26 +61,6 @@ async def _ensure_enums(session: AsyncSession) -> None: await session.execute(insert(table), [{column.key: v} for v in missing]) -async def _ensure_identity_providers(session: AsyncSession) -> None: - """ - Ensures that the local identity provider is present in the database. - """ - local_idp = await session.scalar( - select(models.IdentityProvider).where( - models.IdentityProvider.name == IdentityProviderName.LOCAL.value - ) - ) - if local_idp is None: - local_idp = models.IdentityProvider( - name=IdentityProviderName.LOCAL.value, auth_method=AuthMethod.LOCAL.value - ) - session.add( - models.IdentityProvider( - name=IdentityProviderName.LOCAL.value, auth_method=AuthMethod.LOCAL.value - ) - ) - - async def _ensure_user_roles(session: AsyncSession) -> None: """ Ensure that the system and admin roles are present in the database. If they are not, they will @@ -101,23 +79,12 @@ async def _ensure_user_roles(session: AsyncSession) -> None: select(distinct(models.UserRole.name)).join_from(models.User, models.UserRole) ) ] - local_idp_id = ( - select(models.IdentityProvider.id) - .where( - and_( - models.IdentityProvider.name == IdentityProviderName.LOCAL.value, - models.IdentityProvider.auth_method == AuthMethod.LOCAL.value, - ) - ) - .scalar_subquery() - ) if (system_role := UserRole.SYSTEM.value) not in existing_roles and ( system_role_id := role_ids.get(system_role) ) is not None: system_user = models.User( user_role_id=system_role_id, email="system@localhost", - identity_provider_id=local_idp_id, reset_password=False, ) session.add(system_user) @@ -132,7 +99,6 @@ async def _ensure_user_roles(session: AsyncSession) -> None: user_role_id=admin_role_id, username=DEFAULT_ADMIN_USERNAME, email=DEFAULT_ADMIN_EMAIL, - identity_provider_id=local_idp_id, password_salt=salt, password_hash=hash_, reset_password=True, diff --git a/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py b/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py index 0029c6808f..38f3a4fe4a 100644 --- a/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +++ b/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py @@ -19,27 +19,6 @@ def upgrade() -> None: - op.create_table( - "identity_providers", - sa.Column("id", sa.Integer, primary_key=True), - sa.Column( - "name", - sa.String, - index=True, - nullable=False, - ), - sa.Column( - "auth_method", - sa.String, - sa.CheckConstraint("auth_method IN ('LOCAL', 'OAUTH')", "valid_auth_method"), - index=True, - nullable=False, - ), - sa.UniqueConstraint( - "name", - "auth_method", - ), - ) op.create_table( "user_roles", sa.Column("id", sa.Integer, primary_key=True), @@ -61,20 +40,14 @@ def upgrade() -> None: nullable=False, index=True, ), - sa.Column( - "identity_provider_id", - sa.Integer, - sa.ForeignKey("identity_providers.id", ondelete="CASCADE"), - index=True, - nullable=False, - ), - sa.Column("identity_provider_user_id", sa.Integer, index=True, nullable=True), sa.Column("username", sa.String, nullable=True, unique=True, index=True), sa.Column("email", sa.String, nullable=False, unique=True, index=True), sa.Column("profile_picture_url", sa.String, nullable=True), sa.Column("password_hash", sa.LargeBinary, nullable=True), sa.Column("password_salt", sa.LargeBinary, nullable=True), sa.Column("reset_password", sa.Boolean, nullable=False), + sa.Column("oauth2_identity_provider_name", sa.String, nullable=True, index=True), + sa.Column("oauth2_identity_provider_user_id", sa.String, nullable=True, index=True), sa.Column( "created_at", sa.TIMESTAMP(timezone=True), @@ -95,8 +68,8 @@ def upgrade() -> None: ), sa.CheckConstraint("password_hash is null or password_salt is not null", name="salt"), sa.UniqueConstraint( - "identity_provider_id", - "identity_provider_user_id", + "oauth2_identity_provider_name", + "oauth2_identity_provider_user_id", ), sqlite_autoincrement=True, ) @@ -176,4 +149,3 @@ def downgrade() -> None: op.drop_table("password_reset_tokens") op.drop_table("users") op.drop_table("user_roles") - op.drop_table("identity_providers") diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index 6b0e3362c5..3394dfeb56 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -620,24 +620,6 @@ class ExperimentRunAnnotation(Base): ) -class IdentityProvider(Base): - __tablename__ = "identity_providers" - id: Mapped[int] = mapped_column(primary_key=True) - name: Mapped[str] = mapped_column(index=True, nullable=False) - auth_method: Mapped[str] = mapped_column( - CheckConstraint("auth_method IN ('LOCAL', 'OAUTH')", name="valid_auth_method"), - index=True, - ) - users: Mapped[List["User"]] = relationship("User", back_populates="identity_provider") - - __table_args__ = ( - UniqueConstraint( - "name", - "auth_method", - ), - ) - - class UserRole(Base): __tablename__ = "user_roles" id: Mapped[int] = mapped_column(primary_key=True) @@ -653,21 +635,16 @@ class User(Base): index=True, ) role: Mapped["UserRole"] = relationship("UserRole", back_populates="users") - identity_provider_id: Mapped[int] = mapped_column( - ForeignKey("identity_providers.id", ondelete="CASCADE"), - index=True, - nullable=False, - ) - identity_provider: Mapped["IdentityProvider"] = relationship( - "IdentityProvider", back_populates="users" - ) - identity_provider_user_id: Mapped[Optional[str]] = mapped_column(index=True, nullable=True) username: Mapped[Optional[str]] = mapped_column(nullable=True, unique=True, index=True) email: Mapped[str] = mapped_column(nullable=False, unique=True, index=True) profile_picture_url: Mapped[Optional[str]] password_hash: Mapped[Optional[bytes]] password_salt: Mapped[Optional[bytes]] reset_password: Mapped[bool] + oauth2_identity_provider_name: Mapped[Optional[str]] = mapped_column(index=True, nullable=True) + oauth2_identity_provider_user_id: Mapped[Optional[str]] = mapped_column( + index=True, nullable=True + ) created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now()) updated_at: Mapped[datetime] = mapped_column( UtcTimeStamp, server_default=func.now(), onupdate=func.now() @@ -686,8 +663,8 @@ class User(Base): __table_args__ = ( CheckConstraint("password_hash is null or password_salt is not null", name="salt"), UniqueConstraint( - "identity_provider_id", - "identity_provider_user_id", + "oauth2_identity_provider_name", + "oauth2_identity_provider_user_id", ), dict(sqlite_autoincrement=True), ) diff --git a/src/phoenix/server/api/mutations/user_mutations.py b/src/phoenix/server/api/mutations/user_mutations.py index 9610d959e7..487bd4fd91 100644 --- a/src/phoenix/server/api/mutations/user_mutations.py +++ b/src/phoenix/server/api/mutations/user_mutations.py @@ -91,21 +91,10 @@ async def create_user( validate_password_format(password := input.password) salt = secrets.token_bytes(DEFAULT_SECRET_LENGTH) password_hash = await info.context.hash_password(password, salt) - local_idp_id = ( - select(models.IdentityProvider.id) - .where( - and_( - models.IdentityProvider.name == enums.IdentityProviderName.LOCAL.value, - models.IdentityProvider.auth_method == enums.AuthMethod.LOCAL.value, - ) - ) - .scalar_subquery() - ) user = models.User( reset_password=True, username=input.username or None, email=email, - identity_provider_id=local_idp_id, password_hash=password_hash, password_salt=salt, ) @@ -148,10 +137,7 @@ async def patch_user( raise NotFound(f"Role {input.new_role.value} not found") user.user_role_id = user_role_id if password := input.new_password: - if ( - (idp := user.identity_provider).name != enums.IdentityProviderName.LOCAL.value - or idp.auth_method != enums.AuthMethod.LOCAL.value - ): + if not _is_locally_authenticated_user(user): raise Conflict("Cannot modify password for non-local user") validate_password_format(password) user.password_salt = secrets.token_bytes(DEFAULT_SECRET_LENGTH) @@ -184,10 +170,7 @@ async def patch_viewer( raise NotFound("User not found") stack.enter_context(session.no_autoflush) if password := input.new_password: - if ( - (idp := user.identity_provider).name != enums.IdentityProviderName.LOCAL.value - or idp.auth_method != enums.AuthMethod.LOCAL.value - ): + if not _is_locally_authenticated_user(user): raise Conflict("Cannot modify password for non-local user") if not ( current_password := input.current_password @@ -338,7 +321,17 @@ def _select_user_by_id(user_id: int) -> Select[Tuple[models.User]]: return ( select(models.User) .where(and_(models.User.id == user_id, models.User.deleted_at.is_(None))) - .options(joinedload(models.User.role), joinedload(models.User.identity_provider)) + .options(joinedload(models.User.role)) + ) + + +def _is_locally_authenticated_user(user: models.User) -> bool: + """ + Returns true if the user is authenticated locally, i.e., not through an + OAuth2 provider, and false otherwise. + """ + return ( + user.oauth2_identity_provider_name is None and user.oauth2_identity_provider_user_id is None ) diff --git a/src/phoenix/server/api/routers/oauth2.py b/src/phoenix/server/api/routers/oauth2.py index 5100beddb7..096d203aab 100644 --- a/src/phoenix/server/api/routers/oauth2.py +++ b/src/phoenix/server/api/routers/oauth2.py @@ -17,7 +17,7 @@ set_refresh_token_cookie, ) from phoenix.db import models -from phoenix.db.enums import AuthMethod, UserRole +from phoenix.db.enums import UserRole from phoenix.server.bearer_auth import create_access_and_refresh_tokens from phoenix.server.jwt_store import JwtStore from phoenix.server.rate_limiters import ServerRateLimiter, fastapi_rate_limiter @@ -123,49 +123,26 @@ def _get_user_info(token: Dict[str, Any]) -> Optional[UserInfo]: ) -async def _ensure_identity_provider_exists( - session: AsyncSession, /, *, idp_name: str -) -> models.IdentityProvider: - idp = await session.scalar( - select(models.IdentityProvider).where( - and_( - models.IdentityProvider.name == idp_name, - models.IdentityProvider.auth_method == AuthMethod.OAUTH.value, - ) - ) - ) - if idp is not None: - return idp - idp = await session.scalar( - insert(models.IdentityProvider) - .returning(models.IdentityProvider) - .values(name=idp_name, auth_method=AuthMethod.OAUTH.value) - ) - assert isinstance(idp, models.IdentityProvider) - return idp - - async def _ensure_user_exists_and_is_up_to_date( session: AsyncSession, /, *, idp_name: str, user_info: UserInfo ) -> models.User: - idp = await _ensure_identity_provider_exists(session, idp_name=idp_name) - user = await _get_user(session, idp_id=idp.id, idp_user_id=user_info.idp_user_id) + user = await _get_user(session, idp_name=idp_name, idp_user_id=user_info.idp_user_id) if user is None: - user = await _create_user(session, user_info=user_info, idp=idp) + user = await _create_user(session, user_info=user_info, idp_name=idp_name) elif _db_user_is_outdated(user=user, user_info=user_info): user = await _update_user(session, user_id=user.id, user_info=user_info) return user async def _get_user( - session: AsyncSession, /, *, idp_id: int, idp_user_id: str + session: AsyncSession, /, *, idp_name: str, idp_user_id: str ) -> Optional[models.User]: user = await session.scalar( select(models.User) .where( and_( - models.User.identity_provider_id == idp_id, - models.User.identity_provider_user_id == idp_user_id, + models.User.oauth2_identity_provider_name == idp_name, + models.User.oauth2_identity_provider_user_id == idp_user_id, ) ) .options(joinedload(models.User.role)) @@ -208,7 +185,7 @@ async def _create_user( /, *, user_info: UserInfo, - idp: models.IdentityProvider, + idp_name: str, ) -> models.User: await _ensure_email_and_username_are_not_in_use( session, @@ -225,8 +202,8 @@ async def _create_user( .returning(models.User.id) .values( user_role_id=member_role_id, - identity_provider_id=idp.id, - identity_provider_user_id=user_info.idp_user_id, + oauth2_identity_provider_name=idp_name, + oauth2_identity_provider_user_id=user_info.idp_user_id, username=user_info.username, email=user_info.email, profile_picture_url=user_info.profile_picture_url, From aaae0b72f2b4cbdfd54b775294d31330af2633ab Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 17 Sep 2024 16:05:19 -0700 Subject: [PATCH 13/29] add oauth2 client id --- .../versions/cd164e83824f_users_and_tokens.py | 8 +++--- src/phoenix/db/models.py | 10 +++---- .../server/api/mutations/user_mutations.py | 6 ++--- src/phoenix/server/api/routers/oauth2.py | 26 ++++++++++++------- 4 files changed, 26 insertions(+), 24 deletions(-) diff --git a/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py b/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py index 38f3a4fe4a..706746590b 100644 --- a/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py +++ b/src/phoenix/db/migrations/versions/cd164e83824f_users_and_tokens.py @@ -46,8 +46,8 @@ def upgrade() -> None: sa.Column("password_hash", sa.LargeBinary, nullable=True), sa.Column("password_salt", sa.LargeBinary, nullable=True), sa.Column("reset_password", sa.Boolean, nullable=False), - sa.Column("oauth2_identity_provider_name", sa.String, nullable=True, index=True), - sa.Column("oauth2_identity_provider_user_id", sa.String, nullable=True, index=True), + sa.Column("oauth2_client_id", sa.String, nullable=True, index=True), + sa.Column("oauth2_user_id", sa.String, nullable=True, index=True), sa.Column( "created_at", sa.TIMESTAMP(timezone=True), @@ -68,8 +68,8 @@ def upgrade() -> None: ), sa.CheckConstraint("password_hash is null or password_salt is not null", name="salt"), sa.UniqueConstraint( - "oauth2_identity_provider_name", - "oauth2_identity_provider_user_id", + "oauth2_client_id", + "oauth2_user_id", ), sqlite_autoincrement=True, ) diff --git a/src/phoenix/db/models.py b/src/phoenix/db/models.py index 3394dfeb56..9acc18c3b2 100644 --- a/src/phoenix/db/models.py +++ b/src/phoenix/db/models.py @@ -641,10 +641,8 @@ class User(Base): password_hash: Mapped[Optional[bytes]] password_salt: Mapped[Optional[bytes]] reset_password: Mapped[bool] - oauth2_identity_provider_name: Mapped[Optional[str]] = mapped_column(index=True, nullable=True) - oauth2_identity_provider_user_id: Mapped[Optional[str]] = mapped_column( - index=True, nullable=True - ) + oauth2_client_id: Mapped[Optional[str]] = mapped_column(index=True, nullable=True) + oauth2_user_id: Mapped[Optional[str]] = mapped_column(index=True, nullable=True) created_at: Mapped[datetime] = mapped_column(UtcTimeStamp, server_default=func.now()) updated_at: Mapped[datetime] = mapped_column( UtcTimeStamp, server_default=func.now(), onupdate=func.now() @@ -663,8 +661,8 @@ class User(Base): __table_args__ = ( CheckConstraint("password_hash is null or password_salt is not null", name="salt"), UniqueConstraint( - "oauth2_identity_provider_name", - "oauth2_identity_provider_user_id", + "oauth2_client_id", + "oauth2_user_id", ), dict(sqlite_autoincrement=True), ) diff --git a/src/phoenix/server/api/mutations/user_mutations.py b/src/phoenix/server/api/mutations/user_mutations.py index 487bd4fd91..76d48aab12 100644 --- a/src/phoenix/server/api/mutations/user_mutations.py +++ b/src/phoenix/server/api/mutations/user_mutations.py @@ -328,11 +328,9 @@ def _select_user_by_id(user_id: int) -> Select[Tuple[models.User]]: def _is_locally_authenticated_user(user: models.User) -> bool: """ Returns true if the user is authenticated locally, i.e., not through an - OAuth2 provider, and false otherwise. + OAuth2 identity provider, and false otherwise. """ - return ( - user.oauth2_identity_provider_name is None and user.oauth2_identity_provider_user_id is None - ) + return user.oauth2_client_id is None and user.oauth2_user_id is None def _user_operation_error_message( diff --git a/src/phoenix/server/api/routers/oauth2.py b/src/phoenix/server/api/routers/oauth2.py index 096d203aab..97f464cf67 100644 --- a/src/phoenix/server/api/routers/oauth2.py +++ b/src/phoenix/server/api/routers/oauth2.py @@ -73,7 +73,9 @@ async def create_tokens( try: async with request.app.state.db() as session: user = await _ensure_user_exists_and_is_up_to_date( - session, idp_name=idp_name, user_info=user_info + session, + oauth2_client_id=str(oauth2_client.client_id), + user_info=user_info, ) except (EmailAlreadyInUse, UsernameAlreadyInUse) as error: return _redirect_to_login(error=str(error)) @@ -124,25 +126,29 @@ def _get_user_info(token: Dict[str, Any]) -> Optional[UserInfo]: async def _ensure_user_exists_and_is_up_to_date( - session: AsyncSession, /, *, idp_name: str, user_info: UserInfo + session: AsyncSession, /, *, oauth2_client_id: str, user_info: UserInfo ) -> models.User: - user = await _get_user(session, idp_name=idp_name, idp_user_id=user_info.idp_user_id) + user = await _get_user( + session, + oauth2_client_id=oauth2_client_id, + idp_user_id=user_info.idp_user_id, + ) if user is None: - user = await _create_user(session, user_info=user_info, idp_name=idp_name) + user = await _create_user(session, oauth2_client_id=oauth2_client_id, user_info=user_info) elif _db_user_is_outdated(user=user, user_info=user_info): user = await _update_user(session, user_id=user.id, user_info=user_info) return user async def _get_user( - session: AsyncSession, /, *, idp_name: str, idp_user_id: str + session: AsyncSession, /, *, oauth2_client_id: str, idp_user_id: str ) -> Optional[models.User]: user = await session.scalar( select(models.User) .where( and_( - models.User.oauth2_identity_provider_name == idp_name, - models.User.oauth2_identity_provider_user_id == idp_user_id, + models.User.oauth2_client_id == oauth2_client_id, + models.User.oauth2_user_id == idp_user_id, ) ) .options(joinedload(models.User.role)) @@ -184,8 +190,8 @@ async def _create_user( session: AsyncSession, /, *, + oauth2_client_id: str, user_info: UserInfo, - idp_name: str, ) -> models.User: await _ensure_email_and_username_are_not_in_use( session, @@ -202,8 +208,8 @@ async def _create_user( .returning(models.User.id) .values( user_role_id=member_role_id, - oauth2_identity_provider_name=idp_name, - oauth2_identity_provider_user_id=user_info.idp_user_id, + oauth2_client_id=oauth2_client_id, + oauth2_user_id=user_info.idp_user_id, username=user_info.username, email=user_info.email, profile_picture_url=user_info.profile_picture_url, From d1314ef278bd884c3eb1eb36364cd48157dd75a6 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 17 Sep 2024 16:36:27 -0700 Subject: [PATCH 14/29] update azure ad to microsoft entra id --- app/src/pages/auth/Oauth2Login.tsx | 4 ++-- src/phoenix/config.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/app/src/pages/auth/Oauth2Login.tsx b/app/src/pages/auth/Oauth2Login.tsx index 9fbb85ae31..1d70f8297d 100644 --- a/app/src/pages/auth/Oauth2Login.tsx +++ b/app/src/pages/auth/Oauth2Login.tsx @@ -53,7 +53,7 @@ function IDPIcon({ idpName }: { idpName: string }): ReactNode { const hasIcon = idpName === "github" || idpName === "google" || - idpName === "azure_ad" || + idpName === "microsoft_entra_id" || idpName.startsWith("aws"); if (!hasIcon) { return null; @@ -76,7 +76,7 @@ function IDPIcon({ idpName }: { idpName: string }): ReactNode { &[data-provider^="google"] { background-image: url("data:image/svg+xml;charset=utf-8,%3Csvg xmlns='http://www.w3.org/2000/svg' xmlns:xlink='http://www.w3.org/1999/xlink' viewBox='0 0 48 48'%3E%3Cdefs%3E%3Cpath id='a' d='M44.5 20H24v8.5h11.8C34.7 33.9 30.1 37 24 37c-7.2 0-13-5.8-13-13s5.8-13 13-13c3.1 0 5.9 1.1 8.1 2.9l6.4-6.4C34.6 4.1 29.6 2 24 2 11.8 2 2 11.8 2 24s9.8 22 22 22c11 0 21-8 21-22 0-1.3-.2-2.7-.5-4z'/%3E%3C/defs%3E%3CclipPath id='b'%3E%3Cuse xlink:href='%23a' overflow='visible'/%3E%3C/clipPath%3E%3Cpath clip-path='url(%23b)' fill='%23FBBC05' d='M0 37V11l17 13z'/%3E%3Cpath clip-path='url(%23b)' fill='%23EA4335' d='M0 11l17 13 7-6.1L48 14V0H0z'/%3E%3Cpath clip-path='url(%23b)' fill='%2334A853' d='M0 37l30-23 7.9 1L48 0v48H0z'/%3E%3Cpath clip-path='url(%23b)' fill='%234285F4' d='M48 48L17 24l-4-3 35-10z'/%3E%3C/svg%3E"); } - &[data-provider^="azure"] { + &[data-provider^="microsoft"] { background-image: url("data:image/svg+xml;charset=utf-8,%3Csvg xmlns='http://www.w3.org/2000/svg' width='21' height='21'%3E%3Cpath fill='%23f25022' d='M1 1h9v9H1z'/%3E%3Cpath fill='%2300a4ef' d='M1 11h9v9H1z'/%3E%3Cpath fill='%237fba00' d='M11 1h9v9h-9z'/%3E%3Cpath fill='%23ffb900' d='M11 11h9v9h-9z'/%3E%3C/svg%3E"); } &[data-provider^="aws"] { diff --git a/src/phoenix/config.py b/src/phoenix/config.py index 083caa66d9..aaeae1145e 100644 --- a/src/phoenix/config.py +++ b/src/phoenix/config.py @@ -569,15 +569,15 @@ def get_web_base_url() -> str: class OAuth2Idp(Enum): AWS_COGNITO = "aws_cognito" - AZURE_AD = "azure_ad" GOOGLE = "google" + MICROSOFT_ENTRA_ID = "microsoft_entra_id" def _get_default_idp_display_name(idp_name: str) -> str: if idp_name == OAuth2Idp.AWS_COGNITO.value: return "AWS Cognito" - if idp_name == OAuth2Idp.AZURE_AD.value: - return "Azure AD" + if idp_name == OAuth2Idp.MICROSOFT_ENTRA_ID.value: + return "Microsoft Entra ID" return idp_name.replace("_", " ").title() From 230d52ef98e3f5deb8a4c3083dd9f8a51793c7e1 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Tue, 17 Sep 2024 18:18:12 -0700 Subject: [PATCH 15/29] clean up --- app/src/pages/auth/LoginPage.tsx | 2 +- .../auth/{Oauth2Login.tsx => OAuth2Login.tsx} | 0 app/src/pages/auth/oAuthCallbackLoader.ts | 41 ------ src/phoenix/config.py | 15 +- src/phoenix/db/enums.py | 4 - src/phoenix/server/api/routers/oauth2.py | 135 ++++++++++-------- src/phoenix/server/app.py | 2 +- src/phoenix/server/bearer_auth.py | 5 +- src/phoenix/server/oauth2.py | 9 +- 9 files changed, 91 insertions(+), 122 deletions(-) rename app/src/pages/auth/{Oauth2Login.tsx => OAuth2Login.tsx} (100%) delete mode 100644 app/src/pages/auth/oAuthCallbackLoader.ts diff --git a/app/src/pages/auth/LoginPage.tsx b/app/src/pages/auth/LoginPage.tsx index b86711d565..26928ac770 100644 --- a/app/src/pages/auth/LoginPage.tsx +++ b/app/src/pages/auth/LoginPage.tsx @@ -6,7 +6,7 @@ import { Flex, View } from "@arizeai/components"; import { AuthLayout } from "./AuthLayout"; import { LoginForm } from "./LoginForm"; -import { OAuth2Login } from "./Oauth2Login"; +import { OAuth2Login } from "./OAuth2Login"; import { PhoenixLogo } from "./PhoenixLogo"; const separatorCSS = css` diff --git a/app/src/pages/auth/Oauth2Login.tsx b/app/src/pages/auth/OAuth2Login.tsx similarity index 100% rename from app/src/pages/auth/Oauth2Login.tsx rename to app/src/pages/auth/OAuth2Login.tsx diff --git a/app/src/pages/auth/oAuthCallbackLoader.ts b/app/src/pages/auth/oAuthCallbackLoader.ts deleted file mode 100644 index 251ed76915..0000000000 --- a/app/src/pages/auth/oAuthCallbackLoader.ts +++ /dev/null @@ -1,41 +0,0 @@ -// import { redirect } from "react-router"; -import { LoaderFunctionArgs } from "react-router-dom"; - -export async function oAuthCallbackLoader(args: LoaderFunctionArgs) { - const queryParameters = new URL(args.request.url).searchParams; - const authorizationCode = queryParameters.get("code"); - const state = queryParameters.get("state"); - const actualState = sessionStorage.getItem("oAuthState"); - sessionStorage.removeItem("oAuthState"); - if ( - authorizationCode == undefined || - state == undefined || - actualState == undefined || - state !== actualState - ) { - // todo: display error message - return null; - } - const origin = new URL(window.location.href).origin; - const redirectUri = `${origin}/oauth-callback`; - try { - const response = await fetch("/auth/oauth-tokens", { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - authorization_code: authorizationCode, - redirect_uri: redirectUri, - }), - }); - if (!response.ok) { - // todo: parse response body and display error message - return null; - } - } catch (error) { - // todo: display error - } - // redirect("/"); - return null; -} diff --git a/src/phoenix/config.py b/src/phoenix/config.py index aaeae1145e..bec10aa91c 100644 --- a/src/phoenix/config.py +++ b/src/phoenix/config.py @@ -9,13 +9,8 @@ from typing import Dict, List, Optional, Tuple, overload from urllib.parse import urlparse -from typing_extensions import TypeAlias - from phoenix.utilities.re import parse_env_headers -EnvVarName: TypeAlias = str -EnvVarValue: TypeAlias = str - logger = getLogger(__name__) # Phoenix environment variables @@ -328,7 +323,7 @@ def get_env_smtp_validate_certs() -> bool: @dataclass(frozen=True) class OAuth2ClientConfig: idp_name: str - display_name: str + idp_display_name: str client_id: str client_secret: str server_metadata_url: str @@ -374,7 +369,7 @@ def from_env(cls, idp_name: str) -> "OAuth2ClientConfig": ) return cls( idp_name=idp_name, - display_name=os.getenv( + idp_display_name=os.getenv( f"PHOENIX_OAUTH2_{idp_name_upper}_DISPLAY_NAME", _get_default_idp_display_name(idp_name), ), @@ -574,6 +569,9 @@ class OAuth2Idp(Enum): def _get_default_idp_display_name(idp_name: str) -> str: + """ + Get the default display name for an OAuth2 IDP. + """ if idp_name == OAuth2Idp.AWS_COGNITO.value: return "AWS Cognito" if idp_name == OAuth2Idp.MICROSOFT_ENTRA_ID.value: @@ -582,6 +580,9 @@ def _get_default_idp_display_name(idp_name: str) -> str: def _get_default_server_metadata_url(idp_name: str) -> Optional[str]: + """ + Gets the default server metadata URL for an OAuth2 IDP. + """ if idp_name == OAuth2Idp.GOOGLE.value: return "https://accounts.google.com/.well-known/openid-configuration" return None diff --git a/src/phoenix/db/enums.py b/src/phoenix/db/enums.py index 12ae245942..8f99057750 100644 --- a/src/phoenix/db/enums.py +++ b/src/phoenix/db/enums.py @@ -12,10 +12,6 @@ class UserRole(Enum): MEMBER = "MEMBER" -class IdentityProviderName(Enum): - LOCAL = "local" - - COLUMN_ENUMS: Mapping[InstrumentedAttribute[str], Type[Enum]] = { models.UserRole.name: UserRole, } diff --git a/src/phoenix/server/api/routers/oauth2.py b/src/phoenix/server/api/routers/oauth2.py index 97f464cf67..33b2f3e1c6 100644 --- a/src/phoenix/server/api/routers/oauth2.py +++ b/src/phoenix/server/api/routers/oauth2.py @@ -4,7 +4,7 @@ from authlib.integrations.starlette_client import OAuthError from authlib.integrations.starlette_client import StarletteOAuth2App as OAuth2Client -from fastapi import APIRouter, Depends, Path, Request +from fastapi import APIRouter, Path, Request from sqlalchemy import Boolean, and_, case, cast, func, insert, or_, select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload @@ -20,26 +20,17 @@ from phoenix.db.enums import UserRole from phoenix.server.bearer_auth import create_access_and_refresh_tokens from phoenix.server.jwt_store import JwtStore -from phoenix.server.rate_limiters import ServerRateLimiter, fastapi_rate_limiter -ALPHANUMS_AND_UNDERSCORES = r"[a-z0-9_]+" +_LOWERCASE_ALPHANUMS_AND_UNDERSCORES = r"[a-z0-9_]+" -rate_limiter = ServerRateLimiter( - per_second_rate_limit=0.2, - enforcement_window_seconds=30, - partition_seconds=60, - active_partitions=2, -) -login_rate_limiter = fastapi_rate_limiter(rate_limiter, paths=["/login"]) -router = APIRouter( - prefix="/oauth2", include_in_schema=False, dependencies=[Depends(login_rate_limiter)] -) + +router = APIRouter(prefix="/oauth2", include_in_schema=False) @router.post("/{idp_name}/login") async def login( request: Request, - idp_name: Annotated[str, Path(min_length=1, pattern=ALPHANUMS_AND_UNDERSCORES)], + idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)], ) -> RedirectResponse: if not isinstance( oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client @@ -53,7 +44,7 @@ async def login( @router.get("/{idp_name}/tokens") async def create_tokens( request: Request, - idp_name: Annotated[str, Path(min_length=1, pattern=ALPHANUMS_AND_UNDERSCORES)], + idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)], ) -> RedirectResponse: assert isinstance(access_token_expiry := request.app.state.access_token_expiry, timedelta) assert isinstance(refresh_token_expiry := request.app.state.refresh_token_expiry, timedelta) @@ -63,10 +54,10 @@ async def create_tokens( ): return _redirect_to_login(error=f"Unknown IDP: {idp_name}.") try: - token = await oauth2_client.authorize_access_token(request) + token_data = await oauth2_client.authorize_access_token(request) except OAuthError as error: return _redirect_to_login(error=str(error)) - if (user_info := _get_user_info(token)) is None: + if (user_info := _get_user_info(token_data)) is None: return _redirect_to_login( error=f"OAuth2 IDP {idp_name} does not appear to support OpenID Connect." ) @@ -103,11 +94,14 @@ class UserInfo: profile_picture_url: Optional[str] -def _get_user_info(token: Dict[str, Any]) -> Optional[UserInfo]: - assert isinstance(token.get("access_token"), str) - assert isinstance(token_type := token.get("token_type"), str) +def _get_user_info(token_data: Dict[str, Any]) -> Optional[UserInfo]: + """ + Parses token data and extracts user info if available. + """ + assert isinstance(token_data.get("access_token"), str) + assert isinstance(token_type := token_data.get("token_type"), str) assert token_type.lower() == "bearer" - if (user_info := token.get("userinfo")) is None: + if (user_info := token_data.get("userinfo")) is None: return None assert isinstance(subject := user_info.get("sub"), (str, int)) idp_user_id = str(subject) @@ -135,7 +129,7 @@ async def _ensure_user_exists_and_is_up_to_date( ) if user is None: user = await _create_user(session, oauth2_client_id=oauth2_client_id, user_info=user_info) - elif _db_user_is_outdated(user=user, user_info=user_info): + elif not _user_is_up_to_date(user=user, user_info=user_info): user = await _update_user(session, user_id=user.id, user_info=user_info) return user @@ -143,6 +137,10 @@ async def _ensure_user_exists_and_is_up_to_date( async def _get_user( session: AsyncSession, /, *, oauth2_client_id: str, idp_user_id: str ) -> Optional[models.User]: + """ + Retrieves the user uniquely identified by the given OAuth2 client ID and IDP + user ID. + """ user = await session.scalar( select(models.User) .where( @@ -156,36 +154,6 @@ async def _get_user( return user -async def _ensure_email_and_username_are_not_in_use( - session: AsyncSession, /, *, email: str, username: Optional[str] -) -> None: - [(email_exists, username_exists)] = ( - await session.execute( - select( - cast( - func.coalesce( - func.max(case((models.User.email == email, 1), else_=0)), - 0, - ), - Boolean, - ).label("email_exists"), - cast( - func.coalesce( - func.max(case((models.User.username == username, 1), else_=0)), - 0, - ), - Boolean, - ).label("username_exists"), - ).where(or_(models.User.email == email, models.User.username == username)) - ) - ).all() - if email_exists: - raise EmailAlreadyInUse(f"An account for {email} is already in use.") - if username_exists: - raise UsernameAlreadyInUse(f'An account already exists with username "{username}".') - return None - - async def _create_user( session: AsyncSession, /, @@ -193,6 +161,9 @@ async def _create_user( oauth2_client_id: str, user_info: UserInfo, ) -> models.User: + """ + Creates a new user with the user info from the IDP. + """ await _ensure_email_and_username_are_not_in_use( session, email=user_info.email, @@ -213,15 +184,12 @@ async def _create_user( username=user_info.username, email=user_info.email, profile_picture_url=user_info.profile_picture_url, - password_hash=None, - password_salt=None, - reset_password=False, ) ) assert isinstance(user_id, int) user = await session.scalar( select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role)) - ) # query user for joined load + ) # query user again for joined load assert isinstance(user, models.User) return user @@ -229,6 +197,14 @@ async def _create_user( async def _update_user( session: AsyncSession, /, *, user_id: int, user_info: UserInfo ) -> models.User: + """ + Updates an existing user with user info from the IDP. + """ + await _ensure_email_and_username_are_not_in_use( + session, + email=user_info.email, + username=user_info.username, + ) await session.execute( update(models.User) .where(models.User.id == user_id) @@ -239,19 +215,54 @@ async def _update_user( ) .options(joinedload(models.User.role)) ) - assert isinstance(user_id, int) user = await session.scalar( select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role)) - ) # query user for joined load + ) # query user again for joined load assert isinstance(user, models.User) return user -def _db_user_is_outdated(*, user: models.User, user_info: UserInfo) -> bool: +async def _ensure_email_and_username_are_not_in_use( + session: AsyncSession, /, *, email: str, username: Optional[str] +) -> None: + """ + Raises an error if the email or username are already in use. + """ + [(email_exists, username_exists)] = ( + await session.execute( + select( + cast( + func.coalesce( + func.max(case((models.User.email == email, 1), else_=0)), + 0, + ), + Boolean, + ).label("email_exists"), + cast( + func.coalesce( + func.max(case((models.User.username == username, 1), else_=0)), + 0, + ), + Boolean, + ).label("username_exists"), + ).where(or_(models.User.email == email, models.User.username == username)) + ) + ).all() + if email_exists: + raise EmailAlreadyInUse(f"An account for {email} is already in use.") + if username_exists: + raise UsernameAlreadyInUse(f'An account already exists with username "{username}".') + + +def _user_is_up_to_date(*, user: models.User, user_info: UserInfo) -> bool: + """ + Determines whether the user's tuple in the database is up-to-date with the + IDP's user info. + """ return ( - user.email != user_info.email - or user.username != user_info.username - or user.profile_picture_url != user_info.profile_picture_url + user.email == user_info.email + and user.username == user_info.username + and user.profile_picture_url == user_info.profile_picture_url ) diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index 0b557eed42..72c01f1574 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -745,7 +745,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: web_manifest_path = SERVER_DIR / "static" / ".vite" / "manifest.json" if serve_ui and web_manifest_path.is_file(): oauth2_idps = [ - OAuth2Idp(name=config.idp_name, displayName=config.display_name) + OAuth2Idp(name=config.idp_name, displayName=config.idp_display_name) for config in oauth2_client_configs or [] ] app.mount( diff --git a/src/phoenix/server/bearer_auth.py b/src/phoenix/server/bearer_auth.py index 14ca32af51..3c47171c3e 100644 --- a/src/phoenix/server/bearer_auth.py +++ b/src/phoenix/server/bearer_auth.py @@ -137,9 +137,10 @@ async def create_access_and_refresh_tokens( refresh_token_expiry: timedelta, ) -> Tuple[AccessToken, RefreshToken]: issued_at = datetime.now(timezone.utc) + user_id = UserId(user.id) user_role = UserRole(user.role.name) refresh_token_claims = RefreshTokenClaims( - subject=UserId(user.id), + subject=user_id, issued_at=issued_at, expiration_time=issued_at + refresh_token_expiry, attributes=RefreshTokenAttributes( @@ -148,7 +149,7 @@ async def create_access_and_refresh_tokens( ) refresh_token, refresh_token_id = await token_store.create_refresh_token(refresh_token_claims) access_token_claims = AccessTokenClaims( - subject=UserId(user.id), + subject=user_id, issued_at=issued_at, expiration_time=issued_at + access_token_expiry, attributes=AccessTokenAttributes( diff --git a/src/phoenix/server/oauth2.py b/src/phoenix/server/oauth2.py index cc3256e2f9..8bff39929b 100644 --- a/src/phoenix/server/oauth2.py +++ b/src/phoenix/server/oauth2.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta -from typing import Any, Dict, Generic, List, Optional, Tuple +from typing import Any, Dict, Generic, Iterable, Optional, Tuple from authlib.integrations.starlette_client import OAuth from authlib.integrations.starlette_client import StarletteOAuth2App as OAuth2Client @@ -28,11 +28,11 @@ def add_client(self, config: OAuth2ClientConfig) -> None: def get_client(self, idp_name: str) -> OAuth2Client: if (client := self._clients.get(idp_name)) is None: - raise ValueError(f"unknown or unregistered oauth client: {idp_name}") + raise ValueError(f"unknown or unregistered OAuth2 client: {idp_name}") return client @classmethod - def from_configs(cls, configs: List[OAuth2ClientConfig]) -> "OAuth2Clients": + def from_configs(cls, configs: Iterable[OAuth2ClientConfig]) -> "OAuth2Clients": oauth2_clients = cls() for config in configs: oauth2_clients.add_client(config) @@ -51,7 +51,7 @@ class _OAuth2ClientTTLCache(Generic[_CacheKey, _CacheValue]): integration. Provides an alternative to starlette session middleware. """ - def __init__(self, cleanup_interval: timedelta = 10 * _MINUTE) -> None: + def __init__(self, cleanup_interval: timedelta = 1 * _MINUTE) -> None: self._data: Dict[_CacheKey, Tuple[_CacheValue, _Expiry]] = {} self._last_cleanup_time = datetime.now() self._cleanup_interval = cleanup_interval @@ -61,6 +61,7 @@ async def get(self, key: _CacheKey) -> Optional[_CacheValue]: Retrieves the value associated with the given key if it exists and has not expired, otherwise, returns None. """ + self._remove_expired_keys_if_cleanup_interval_exceeded() if (value_and_expiry := self._data.get(key)) is None: return None value, expiry = value_and_expiry From 448085febad401631f8844bbed0cae81d435d604 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Wed, 18 Sep 2024 17:27:35 -0700 Subject: [PATCH 16/29] store oauth2 state in cookies --- src/phoenix/auth.py | 60 +++++++++++++--- src/phoenix/server/api/routers/auth.py | 4 ++ src/phoenix/server/api/routers/oauth2.py | 85 +++++++++++++++++++---- src/phoenix/server/oauth2.py | 88 ++++++------------------ 4 files changed, 147 insertions(+), 90 deletions(-) diff --git a/src/phoenix/auth.py b/src/phoenix/auth.py index 72b65ecd83..8affa9613a 100644 --- a/src/phoenix/auth.py +++ b/src/phoenix/auth.py @@ -71,31 +71,53 @@ def validate_password_format(password: str) -> None: def set_access_token_cookie( *, response: ResponseType, access_token: str, max_age: timedelta ) -> ResponseType: - return _set_token_cookie( + return _set_cookie( response=response, cookie_name=PHOENIX_ACCESS_TOKEN_COOKIE_NAME, cookie_max_age=max_age, - token=access_token, + value=access_token, ) def set_refresh_token_cookie( *, response: ResponseType, refresh_token: str, max_age: timedelta ) -> ResponseType: - return _set_token_cookie( + return _set_cookie( response=response, cookie_name=PHOENIX_REFRESH_TOKEN_COOKIE_NAME, cookie_max_age=max_age, - token=refresh_token, + value=refresh_token, ) -def _set_token_cookie( - response: ResponseType, cookie_name: str, cookie_max_age: timedelta, token: str +def set_oauth2_state_cookie( + *, response: ResponseType, state: str, max_age: timedelta +) -> ResponseType: + return _set_cookie( + response=response, + cookie_name=PHOENIX_OAUTH2_STATE_COOKIE_NAME, + cookie_max_age=max_age, + value=state, + ) + + +def set_oauth2_nonce_cookie( + *, response: ResponseType, nonce: str, max_age: timedelta +) -> ResponseType: + return _set_cookie( + response=response, + cookie_name=PHOENIX_OAUTH2_NONCE_COOKIE_NAME, + cookie_max_age=max_age, + value=nonce, + ) + + +def _set_cookie( + response: ResponseType, cookie_name: str, cookie_max_age: timedelta, value: str ) -> ResponseType: response.set_cookie( key=cookie_name, - value=token, + value=value, secure=get_env_phoenix_use_secure_cookies(), httponly=True, samesite="strict", @@ -104,16 +126,26 @@ def _set_token_cookie( return response -def delete_access_token_cookie(response: Response) -> Response: +def delete_access_token_cookie(response: ResponseType) -> ResponseType: response.delete_cookie(key=PHOENIX_ACCESS_TOKEN_COOKIE_NAME) return response -def delete_refresh_token_cookie(response: Response) -> Response: +def delete_refresh_token_cookie(response: ResponseType) -> ResponseType: response.delete_cookie(key=PHOENIX_REFRESH_TOKEN_COOKIE_NAME) return response +def delete_oauth2_state_cookie(response: ResponseType) -> ResponseType: + response.delete_cookie(key=PHOENIX_OAUTH2_STATE_COOKIE_NAME) + return response + + +def delete_oauth2_nonce_cookie(response: ResponseType) -> ResponseType: + response.delete_cookie(key=PHOENIX_OAUTH2_NONCE_COOKIE_NAME) + return response + + @dataclass(frozen=True) class _PasswordRequirements: """ @@ -209,6 +241,16 @@ def validate( """The name of the cookie that stores the Phoenix access token.""" PHOENIX_REFRESH_TOKEN_COOKIE_NAME = "phoenix-refresh-token" """The name of the cookie that stores the Phoenix refresh token.""" +PHOENIX_OAUTH2_STATE_COOKIE_NAME = "phoenix-oauth2-state" +"""The name of the cookie that stores the state used for the OAuth2 authorization code flow.""" +PHOENIX_OAUTH2_NONCE_COOKIE_NAME = "phoenix-oauth2-nonce" +"""The name of the cookie that stores the nonce used for the OAuth2 authorization code flow.""" +DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES = 15 +""" +The default amount of time in minutes that can elapse between the initial +redirect to the IDP and the invocation of the callback URL during the OAuth2 +authorization code flow. +""" class Token(str): ... diff --git a/src/phoenix/server/api/routers/auth.py b/src/phoenix/server/api/routers/auth.py index c98cd5f015..be0656599f 100644 --- a/src/phoenix/server/api/routers/auth.py +++ b/src/phoenix/server/api/routers/auth.py @@ -22,6 +22,8 @@ Token, compute_password_hash, delete_access_token_cookie, + delete_oauth2_nonce_cookie, + delete_oauth2_state_cookie, delete_refresh_token_cookie, is_valid_password, set_access_token_cookie, @@ -122,6 +124,8 @@ async def logout( response = Response(status_code=HTTP_204_NO_CONTENT) response = delete_access_token_cookie(response) response = delete_refresh_token_cookie(response) + response = delete_oauth2_state_cookie(response) + response = delete_oauth2_nonce_cookie(response) return response diff --git a/src/phoenix/server/api/routers/oauth2.py b/src/phoenix/server/api/routers/oauth2.py index 33b2f3e1c6..1107069ba7 100644 --- a/src/phoenix/server/api/routers/oauth2.py +++ b/src/phoenix/server/api/routers/oauth2.py @@ -2,28 +2,36 @@ from datetime import timedelta from typing import Any, Dict, Optional +from authlib.common.security import generate_token from authlib.integrations.starlette_client import OAuthError -from authlib.integrations.starlette_client import StarletteOAuth2App as OAuth2Client -from fastapi import APIRouter, Path, Request +from fastapi import APIRouter, Cookie, Path, Query, Request from sqlalchemy import Boolean, and_, case, cast, func, insert, or_, select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload from starlette.datastructures import URL from starlette.responses import RedirectResponse +from starlette.status import HTTP_302_FOUND from typing_extensions import Annotated from phoenix.auth import ( + DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES, + PHOENIX_OAUTH2_NONCE_COOKIE_NAME, + PHOENIX_OAUTH2_STATE_COOKIE_NAME, + delete_oauth2_nonce_cookie, + delete_oauth2_state_cookie, set_access_token_cookie, + set_oauth2_nonce_cookie, + set_oauth2_state_cookie, set_refresh_token_cookie, ) from phoenix.db import models from phoenix.db.enums import UserRole from phoenix.server.bearer_auth import create_access_and_refresh_tokens from phoenix.server.jwt_store import JwtStore +from phoenix.server.oauth2 import OAuth2Client _LOWERCASE_ALPHANUMS_AND_UNDERSCORES = r"[a-z0-9_]+" - router = APIRouter(prefix="/oauth2", include_in_schema=False) @@ -36,8 +44,24 @@ async def login( oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client ): return _redirect_to_login(error=f"Unknown IDP: {idp_name}.") - redirect_uri = request.url_for("create_tokens", idp_name=idp_name) - response: RedirectResponse = await oauth2_client.authorize_redirect(request, redirect_uri) + authorization_url_data = await oauth2_client.create_authorization_url( + redirect_uri=_get_create_tokens_endpoint(request=request, idp_name=idp_name), + state=generate_token(), + ) + assert isinstance(authorization_url := authorization_url_data.get("url"), str) + assert isinstance(state := authorization_url_data.get("state"), str) + assert isinstance(nonce := authorization_url_data.get("nonce"), str) + response = RedirectResponse(url=authorization_url, status_code=HTTP_302_FOUND) + response = set_oauth2_state_cookie( + response=response, + state=state, + max_age=timedelta(minutes=DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES), + ) + response = set_oauth2_nonce_cookie( + response=response, + nonce=nonce, + max_age=timedelta(minutes=DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES), + ) return response @@ -45,7 +69,18 @@ async def login( async def create_tokens( request: Request, idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)], + state: str = Query(), + authorization_code: str = Query(alias="code"), + stored_state: str = Cookie(alias=PHOENIX_OAUTH2_STATE_COOKIE_NAME), + stored_nonce: str = Cookie(alias=PHOENIX_OAUTH2_NONCE_COOKIE_NAME), ) -> RedirectResponse: + if state != stored_state: + return _redirect_to_login( + error=( + "Received invalid state parameter during " + "OAuth2 authorization code flow for IDP {idp_name}." + ) + ) assert isinstance(access_token_expiry := request.app.state.access_token_expiry, timedelta) assert isinstance(refresh_token_expiry := request.app.state.refresh_token_expiry, timedelta) token_store: JwtStore = request.app.state.get_token_store() @@ -54,13 +89,20 @@ async def create_tokens( ): return _redirect_to_login(error=f"Unknown IDP: {idp_name}.") try: - token_data = await oauth2_client.authorize_access_token(request) + token_data = await oauth2_client.fetch_access_token( + state=state, + code=authorization_code, + redirect_uri=_get_create_tokens_endpoint(request=request, idp_name=idp_name), + ) except OAuthError as error: return _redirect_to_login(error=str(error)) - if (user_info := _get_user_info(token_data)) is None: + _validate_token_data(token_data) + if "id_token" not in token_data: return _redirect_to_login( error=f"OAuth2 IDP {idp_name} does not appear to support OpenID Connect." ) + user_info = await oauth2_client.parse_id_token(token_data, nonce=stored_nonce) + user_info = _parse_user_info(user_info) try: async with request.app.state.db() as session: user = await _ensure_user_exists_and_is_up_to_date( @@ -76,13 +118,15 @@ async def create_tokens( access_token_expiry=access_token_expiry, refresh_token_expiry=refresh_token_expiry, ) - response = RedirectResponse(url="/") # todo: sanitize a return url + response = RedirectResponse(url="/", status_code=HTTP_302_FOUND) # todo: sanitize a return url response = set_access_token_cookie( response=response, access_token=access_token, max_age=access_token_expiry ) response = set_refresh_token_cookie( response=response, refresh_token=refresh_token, max_age=refresh_token_expiry ) + response = delete_oauth2_state_cookie(response) + response = delete_oauth2_nonce_cookie(response) return response @@ -94,15 +138,19 @@ class UserInfo: profile_picture_url: Optional[str] -def _get_user_info(token_data: Dict[str, Any]) -> Optional[UserInfo]: +def _validate_token_data(token_data: Dict[str, Any]) -> None: """ - Parses token data and extracts user info if available. + Performs basic validations on the token data returned by the IDP. """ assert isinstance(token_data.get("access_token"), str) assert isinstance(token_type := token_data.get("token_type"), str) assert token_type.lower() == "bearer" - if (user_info := token_data.get("userinfo")) is None: - return None + + +def _parse_user_info(user_info: Dict[str, Any]) -> UserInfo: + """ + Parses user info from the IDP's ID token. + """ assert isinstance(subject := user_info.get("sub"), (str, int)) idp_user_id = str(subject) assert isinstance(email := user_info.get("email"), str) @@ -278,4 +326,15 @@ def _redirect_to_login(*, error: str) -> RedirectResponse: """ Creates a RedirectResponse to the login page to display an error message. """ - return RedirectResponse(url=URL("/login").include_query_params(error=error)) + url = URL("/login").include_query_params(error=error) + response = RedirectResponse(url=url) + response = delete_oauth2_state_cookie(response) + response = delete_oauth2_nonce_cookie(response) + return response + + +def _get_create_tokens_endpoint(*, request: Request, idp_name: str) -> str: + """ + Gets the endpoint for create tokens route. + """ + return str(request.url_for(create_tokens.__name__, idp_name=idp_name)) diff --git a/src/phoenix/server/oauth2.py b/src/phoenix/server/oauth2.py index 8bff39929b..3ab521a3e6 100644 --- a/src/phoenix/server/oauth2.py +++ b/src/phoenix/server/oauth2.py @@ -1,23 +1,35 @@ -from datetime import datetime, timedelta -from typing import Any, Dict, Generic, Iterable, Optional, Tuple +from typing import Any, Dict, Iterable -from authlib.integrations.starlette_client import OAuth -from authlib.integrations.starlette_client import StarletteOAuth2App as OAuth2Client -from typing_extensions import TypeAlias, TypeVar +from authlib.integrations.base_client import BaseApp +from authlib.integrations.base_client.async_app import AsyncOAuth2Mixin +from authlib.integrations.base_client.async_openid import AsyncOpenIDMixin +from authlib.integrations.httpx_client import AsyncOAuth2Client as AsyncHttpxOAuth2Client from phoenix.config import OAuth2ClientConfig +class OAuth2Client(AsyncOAuth2Mixin, AsyncOpenIDMixin, BaseApp): # type:ignore[misc] + """ + An OAuth2 client class that supports OpenID Connect. Adapted from authlib's + `StarletteOAuth2App` to be useable without integration with Starlette. + + https://github.com/lepture/authlib/blob/904d66bebd79bf39fb8814353a22bab7d3e092c4/authlib/integrations/starlette_client/apps.py#L58 + """ + + client_cls = AsyncHttpxOAuth2Client + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(framework=None, *args, **kwargs) + + class OAuth2Clients: def __init__(self) -> None: self._clients: Dict[str, OAuth2Client] = {} - self._oauth = OAuth(cache=_OAuth2ClientTTLCache[str, Any]()) def add_client(self, config: OAuth2ClientConfig) -> None: if (idp_name := config.idp_name) in self._clients: raise ValueError(f"oauth client already registered: {idp_name}") - client = self._oauth.register( - idp_name, + client = OAuth2Client( client_id=config.client_id, client_secret=config.client_secret, server_metadata_url=config.server_metadata_url, @@ -37,63 +49,3 @@ def from_configs(cls, configs: Iterable[OAuth2ClientConfig]) -> "OAuth2Clients": for config in configs: oauth2_clients.add_client(config) return oauth2_clients - - -_CacheKey = TypeVar("_CacheKey") -_CacheValue = TypeVar("_CacheValue") -_Expiry: TypeAlias = datetime -_MINUTE = timedelta(minutes=1) - - -class _OAuth2ClientTTLCache(Generic[_CacheKey, _CacheValue]): - """ - A TTL cache satisfying the interface required by the Authlib Starlette - integration. Provides an alternative to starlette session middleware. - """ - - def __init__(self, cleanup_interval: timedelta = 1 * _MINUTE) -> None: - self._data: Dict[_CacheKey, Tuple[_CacheValue, _Expiry]] = {} - self._last_cleanup_time = datetime.now() - self._cleanup_interval = cleanup_interval - - async def get(self, key: _CacheKey) -> Optional[_CacheValue]: - """ - Retrieves the value associated with the given key if it exists and has - not expired, otherwise, returns None. - """ - self._remove_expired_keys_if_cleanup_interval_exceeded() - if (value_and_expiry := self._data.get(key)) is None: - return None - value, expiry = value_and_expiry - if datetime.now() < expiry: - return value - self._data.pop(key, None) - return None - - async def set(self, key: _CacheKey, value: _CacheValue, expires: int) -> None: - """ - Sets the value associated with the given key to the provided value with - the given expiry time in seconds. - """ - self._remove_expired_keys_if_cleanup_interval_exceeded() - expiry = datetime.now() + timedelta(seconds=expires) - self._data[key] = (value, expiry) - - async def delete(self, key: _CacheKey) -> None: - """ - Removes the value associated with the given key if it exists. - """ - self._remove_expired_keys_if_cleanup_interval_exceeded() - self._data.pop(key, None) - - def _remove_expired_keys_if_cleanup_interval_exceeded(self) -> None: - time_since_last_cleanup = datetime.now() - self._last_cleanup_time - if time_since_last_cleanup > self._cleanup_interval: - self._remove_expired_keys() - - def _remove_expired_keys(self) -> None: - current_time = datetime.now() - delete_keys = [key for key, (_, expiry) in self._data.items() if expiry <= current_time] - for key in delete_keys: - self._data.pop(key, None) - self._last_cleanup_time = current_time From 0f1ee0258b5e7d0f8f1b5776295468fb802e55c6 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Wed, 18 Sep 2024 20:01:46 -0700 Subject: [PATCH 17/29] fix types --- src/phoenix/auth.py | 9 +++++++++ src/phoenix/server/api/mutations/user_mutations.py | 13 +++---------- src/phoenix/server/api/routers/auth.py | 7 ++++--- src/phoenix/server/api/types/AuthMethod.py | 1 + src/phoenix/server/api/types/User.py | 3 ++- 5 files changed, 19 insertions(+), 14 deletions(-) diff --git a/src/phoenix/auth.py b/src/phoenix/auth.py index 8affa9613a..8d43cf6d78 100644 --- a/src/phoenix/auth.py +++ b/src/phoenix/auth.py @@ -11,6 +11,7 @@ from typing_extensions import TypeVar from phoenix.config import get_env_phoenix_use_secure_cookies +from phoenix.db.models import User as OrmUser ResponseType = TypeVar("ResponseType", bound=Response) @@ -68,6 +69,14 @@ def validate_password_format(password: str) -> None: PASSWORD_REQUIREMENTS.validate(password) +def is_locally_authenticated(user: OrmUser) -> bool: + """ + Returns true if the user is authenticated locally, i.e., not through an + OAuth2 identity provider, and false otherwise. + """ + return user.oauth2_client_id is None and user.oauth2_user_id is None + + def set_access_token_cookie( *, response: ResponseType, access_token: str, max_age: timedelta ) -> ResponseType: diff --git a/src/phoenix/server/api/mutations/user_mutations.py b/src/phoenix/server/api/mutations/user_mutations.py index 76d48aab12..af0c3ee702 100644 --- a/src/phoenix/server/api/mutations/user_mutations.py +++ b/src/phoenix/server/api/mutations/user_mutations.py @@ -18,6 +18,7 @@ PASSWORD_REQUIREMENTS, PHOENIX_ACCESS_TOKEN_COOKIE_NAME, PHOENIX_REFRESH_TOKEN_COOKIE_NAME, + is_locally_authenticated, validate_email_format, validate_password_format, ) @@ -137,7 +138,7 @@ async def patch_user( raise NotFound(f"Role {input.new_role.value} not found") user.user_role_id = user_role_id if password := input.new_password: - if not _is_locally_authenticated_user(user): + if not is_locally_authenticated(user): raise Conflict("Cannot modify password for non-local user") validate_password_format(password) user.password_salt = secrets.token_bytes(DEFAULT_SECRET_LENGTH) @@ -170,7 +171,7 @@ async def patch_viewer( raise NotFound("User not found") stack.enter_context(session.no_autoflush) if password := input.new_password: - if not _is_locally_authenticated_user(user): + if not is_locally_authenticated(user): raise Conflict("Cannot modify password for non-local user") if not ( current_password := input.current_password @@ -325,14 +326,6 @@ def _select_user_by_id(user_id: int) -> Select[Tuple[models.User]]: ) -def _is_locally_authenticated_user(user: models.User) -> bool: - """ - Returns true if the user is authenticated locally, i.e., not through an - OAuth2 identity provider, and false otherwise. - """ - return user.oauth2_client_id is None and user.oauth2_user_id is None - - def _user_operation_error_message( error: IntegrityError, operation: Literal["create", "modify"] = "create", diff --git a/src/phoenix/server/api/routers/auth.py b/src/phoenix/server/api/routers/auth.py index be0656599f..f4666bd6ff 100644 --- a/src/phoenix/server/api/routers/auth.py +++ b/src/phoenix/server/api/routers/auth.py @@ -25,13 +25,14 @@ delete_oauth2_nonce_cookie, delete_oauth2_state_cookie, delete_refresh_token_cookie, + is_locally_authenticated, is_valid_password, set_access_token_cookie, set_refresh_token_cookie, validate_password_format, ) from phoenix.config import get_base_url -from phoenix.db import enums, models +from phoenix.db import models from phoenix.server.bearer_auth import PhoenixUser, create_access_and_refresh_tokens from phoenix.server.email.templates.types import PasswordResetTemplateBody from phoenix.server.email.types import EmailSender @@ -198,7 +199,7 @@ async def initiate_password_reset(request: Request) -> Response: joinedload(models.User.password_reset_token).load_only(models.PasswordResetToken.id) ) ) - if user is None or user.auth_method != enums.AuthMethod.LOCAL.value: + if user is None or not is_locally_authenticated(user): # Withold privileged information return Response(status_code=HTTP_204_NO_CONTENT) token_store: TokenStore = request.app.state.get_token_store() @@ -230,7 +231,7 @@ async def reset_password(request: Request) -> Response: assert (user_id := claims.subject) async with request.app.state.db() as session: user = await session.scalar(_select_active_user().filter_by(id=int(user_id))) - if user is None or user.auth_method != enums.AuthMethod.LOCAL.value: + if user is None or not is_locally_authenticated(user): # Withold privileged information return Response(status_code=HTTP_204_NO_CONTENT) validate_password_format(password) diff --git a/src/phoenix/server/api/types/AuthMethod.py b/src/phoenix/server/api/types/AuthMethod.py index 011140e035..f3c77e9b51 100644 --- a/src/phoenix/server/api/types/AuthMethod.py +++ b/src/phoenix/server/api/types/AuthMethod.py @@ -6,3 +6,4 @@ @strawberry.enum class AuthMethod(Enum): LOCAL = "LOCAL" + OAUTH2 = "OAUTH2" diff --git a/src/phoenix/server/api/types/User.py b/src/phoenix/server/api/types/User.py index eddf677e62..05adae4066 100644 --- a/src/phoenix/server/api/types/User.py +++ b/src/phoenix/server/api/types/User.py @@ -7,6 +7,7 @@ from strawberry.relay import Node, NodeID from strawberry.types import Info +from phoenix.auth import is_locally_authenticated from phoenix.db import models from phoenix.server.api.context import Context from phoenix.server.api.exceptions import NotFound @@ -53,5 +54,5 @@ def to_gql_user(user: models.User, api_keys: Optional[List[models.ApiKey]] = Non email=user.email, created_at=user.created_at, user_role_id=user.user_role_id, - auth_method=AuthMethod("MEMBER"), + auth_method=AuthMethod.LOCAL if is_locally_authenticated(user) else AuthMethod.OAUTH2, ) From 577a4de0715e0e4da0b4c052f08ea3ed96993ea8 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Wed, 18 Sep 2024 20:31:59 -0700 Subject: [PATCH 18/29] update graphql schema --- app/schema.graphql | 1 + 1 file changed, 1 insertion(+) diff --git a/app/schema.graphql b/app/schema.graphql index 42b5d3c4d2..9d0f97e034 100644 --- a/app/schema.graphql +++ b/app/schema.graphql @@ -60,6 +60,7 @@ interface ApiKey { enum AuthMethod { LOCAL + OAUTH2 } union Bin = NominalBin | IntervalBin | MissingValueBin From cdb2286044ae49b5546bc4a29445b3eb6161da38 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Wed, 18 Sep 2024 20:35:07 -0700 Subject: [PATCH 19/29] update relay --- .../pages/settings/__generated__/UsersTable_users.graphql.ts | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/app/src/pages/settings/__generated__/UsersTable_users.graphql.ts b/app/src/pages/settings/__generated__/UsersTable_users.graphql.ts index 51eca6a1f8..8745d33ded 100644 --- a/app/src/pages/settings/__generated__/UsersTable_users.graphql.ts +++ b/app/src/pages/settings/__generated__/UsersTable_users.graphql.ts @@ -1,5 +1,5 @@ /** - * @generated SignedSource<> + * @generated SignedSource<<1ecb9d526bace62adc934d30278eb434>> * @lightSyntaxTransform * @nogrep */ @@ -9,7 +9,7 @@ // @ts-nocheck import { ReaderFragment, RefetchableFragment } from 'relay-runtime'; -export type AuthMethod = "LOCAL"; +export type AuthMethod = "LOCAL" | "OAUTH2"; import { FragmentRefs } from "relay-runtime"; export type UsersTable_users$data = { readonly users: { From 636eadd221e012869bba646193826b7707861a5c Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Wed, 18 Sep 2024 21:05:28 -0700 Subject: [PATCH 20/29] support return urls --- app/src/pages/auth/LoginPage.tsx | 2 + app/src/pages/auth/OAuth2Login.tsx | 9 ++++- src/phoenix/server/api/routers/oauth2.py | 48 +++++++++++++++++++++++- 3 files changed, 55 insertions(+), 4 deletions(-) diff --git a/app/src/pages/auth/LoginPage.tsx b/app/src/pages/auth/LoginPage.tsx index 26928ac770..bb0c29b264 100644 --- a/app/src/pages/auth/LoginPage.tsx +++ b/app/src/pages/auth/LoginPage.tsx @@ -28,6 +28,7 @@ export function LoginPage() { const oAuth2Idps = window.Config.oAuth2Idps; const hasOAuth2Idps = oAuth2Idps.length > 0; const [searchParams, setSearchParams] = useSearchParams(); + const returnUrl = searchParams.get("returnUrl"); return ( @@ -49,6 +50,7 @@ export function LoginPage() { key={idp.name} idpName={idp.name} idpDisplayName={idp.displayName} + returnUrl={returnUrl} /> ))} diff --git a/app/src/pages/auth/OAuth2Login.tsx b/app/src/pages/auth/OAuth2Login.tsx index 1d70f8297d..8bc4fee0a0 100644 --- a/app/src/pages/auth/OAuth2Login.tsx +++ b/app/src/pages/auth/OAuth2Login.tsx @@ -6,6 +6,7 @@ import { Button } from "@arizeai/components"; type OAuth2LoginProps = { idpName: string; idpDisplayName: string; + returnUrl?: string | null; }; const loginCSS = css` @@ -30,10 +31,14 @@ const loginCSS = css` } `; -export function OAuth2Login({ idpName, idpDisplayName }: OAuth2LoginProps) { +export function OAuth2Login({ + idpName, + idpDisplayName, + returnUrl, +}: OAuth2LoginProps) { return (
RedirectResponse: + secret = request.app.state.get_secret() if not isinstance( oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client ): return _redirect_to_login(error=f"Unknown IDP: {idp_name}.") authorization_url_data = await oauth2_client.create_authorization_url( redirect_uri=_get_create_tokens_endpoint(request=request, idp_name=idp_name), - state=generate_token(), + state=_generate_state_for_oauth2_authorization_code_flow( + secret=secret, return_url=return_url + ), ) assert isinstance(authorization_url := authorization_url_data.get("url"), str) assert isinstance(state := authorization_url_data.get("state"), str) @@ -74,6 +80,7 @@ async def create_tokens( stored_state: str = Cookie(alias=PHOENIX_OAUTH2_STATE_COOKIE_NAME), stored_nonce: str = Cookie(alias=PHOENIX_OAUTH2_NONCE_COOKIE_NAME), ) -> RedirectResponse: + secret = request.app.state.get_secret() if state != stored_state: return _redirect_to_login( error=( @@ -118,7 +125,9 @@ async def create_tokens( access_token_expiry=access_token_expiry, refresh_token_expiry=refresh_token_expiry, ) - response = RedirectResponse(url="/", status_code=HTTP_302_FOUND) # todo: sanitize a return url + response = RedirectResponse( + url=_get_return_url(secret=secret, state=state) or "/", status_code=HTTP_302_FOUND + ) response = set_access_token_cookie( response=response, access_token=access_token, max_age=access_token_expiry ) @@ -338,3 +347,38 @@ def _get_create_tokens_endpoint(*, request: Request, idp_name: str) -> str: Gets the endpoint for create tokens route. """ return str(request.url_for(create_tokens.__name__, idp_name=idp_name)) + + +def _generate_state_for_oauth2_authorization_code_flow( + *, secret: str, return_url: Optional[str] +) -> str: + """ + Generates a JWT whose payload contains both an OAuth2 state (generated using + the `authlib` default algorithm) and a return URL. This allows us to pass + the return URL to the OAuth2 authorization server via the `state` query + parameter and have it returned to us in the callback without needing to + maintain state. + """ + header = {"alg": _JWT_ALGORITHM} + payload = {"state": generate_token()} + if return_url is not None: + payload[_RETURN_URL] = return_url + jwt_bytes: bytes = jwt.encode(header=header, payload=payload, key=secret) + return jwt_bytes.decode() + + +def _get_return_url(*, secret: str, state: str) -> Optional[str]: + """ + Parses the return URL from the OAuth2 state. + """ + try: + payload = jwt.decode(s=state, key=secret) + return_url = payload.get(_RETURN_URL) + assert isinstance(return_url, str) or return_url is None + return return_url + except BadSignatureError: + return None + + +_RETURN_URL = "return_url" +_JWT_ALGORITHM = "HS256" From 9821860b91a248de41c694a98a122de407ba5db9 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Wed, 18 Sep 2024 22:05:47 -0700 Subject: [PATCH 21/29] ensure that state tokens with invalid signature are rejected --- src/phoenix/server/api/routers/oauth2.py | 36 +++++++++++++++--------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/src/phoenix/server/api/routers/oauth2.py b/src/phoenix/server/api/routers/oauth2.py index 6b05c7cac8..642560b666 100644 --- a/src/phoenix/server/api/routers/oauth2.py +++ b/src/phoenix/server/api/routers/oauth2.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from datetime import timedelta -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple from authlib.common.security import generate_token from authlib.integrations.starlette_client import OAuthError @@ -82,12 +82,12 @@ async def create_tokens( ) -> RedirectResponse: secret = request.app.state.get_secret() if state != stored_state: - return _redirect_to_login( - error=( - "Received invalid state parameter during " - "OAuth2 authorization code flow for IDP {idp_name}." - ) - ) + return _redirect_to_login(error=_INVALID_OAUTH2_STATE_MESSAGE) + signature_is_valid, return_url = _validate_signature_and_parse_return_url( + secret=secret, state=state + ) + if not signature_is_valid: + return _redirect_to_login(error=_INVALID_OAUTH2_STATE_MESSAGE) assert isinstance(access_token_expiry := request.app.state.access_token_expiry, timedelta) assert isinstance(refresh_token_expiry := request.app.state.refresh_token_expiry, timedelta) token_store: JwtStore = request.app.state.get_token_store() @@ -125,9 +125,7 @@ async def create_tokens( access_token_expiry=access_token_expiry, refresh_token_expiry=refresh_token_expiry, ) - response = RedirectResponse( - url=_get_return_url(secret=secret, state=state) or "/", status_code=HTTP_302_FOUND - ) + response = RedirectResponse(url=return_url or "/", status_code=HTTP_302_FOUND) response = set_access_token_cookie( response=response, access_token=access_token, max_age=access_token_expiry ) @@ -367,18 +365,28 @@ def _generate_state_for_oauth2_authorization_code_flow( return jwt_bytes.decode() -def _get_return_url(*, secret: str, state: str) -> Optional[str]: +def _validate_signature_and_parse_return_url( + *, secret: str, state: str +) -> Tuple[bool, Optional[str]]: """ - Parses the return URL from the OAuth2 state. + Validates the JWT signature and parses the return URL from the OAuth2 state. """ + signature_is_valid: bool + return_url: Optional[str] try: payload = jwt.decode(s=state, key=secret) return_url = payload.get(_RETURN_URL) assert isinstance(return_url, str) or return_url is None - return return_url + signature_is_valid = True + return signature_is_valid, return_url except BadSignatureError: - return None + signature_is_valid = False + return_url = None + return signature_is_valid, return_url _RETURN_URL = "return_url" _JWT_ALGORITHM = "HS256" +_INVALID_OAUTH2_STATE_MESSAGE = ( + "Received invalid state parameter during OAuth2 authorization code flow for IDP {idp_name}." +) From 3f088fd94a5395fffc63947d5934a7b7446a899a Mon Sep 17 00:00:00 2001 From: Dustin Ngo Date: Thu, 19 Sep 2024 13:58:29 -0400 Subject: [PATCH 22/29] Add OAuth rate limiters --- src/phoenix/server/api/routers/auth.py | 14 ++++----- src/phoenix/server/api/routers/oauth2.py | 38 ++++++++++++++++++++++-- src/phoenix/server/rate_limiters.py | 24 +++++++++++++-- 3 files changed, 64 insertions(+), 12 deletions(-) diff --git a/src/phoenix/server/api/routers/auth.py b/src/phoenix/server/api/routers/auth.py index f4666bd6ff..31478933f8 100644 --- a/src/phoenix/server/api/routers/auth.py +++ b/src/phoenix/server/api/routers/auth.py @@ -37,7 +37,7 @@ from phoenix.server.email.templates.types import PasswordResetTemplateBody from phoenix.server.email.types import EmailSender from phoenix.server.jwt_store import JwtStore -from phoenix.server.rate_limiters import ServerRateLimiter, fastapi_rate_limiter +from phoenix.server.rate_limiters import ServerRateLimiter, fastapi_ip_rate_limiter from phoenix.server.types import ( AccessTokenClaims, PasswordResetTokenClaims, @@ -53,14 +53,14 @@ partition_seconds=60, active_partitions=2, ) -login_rate_limiter = fastapi_rate_limiter( +login_rate_limiter = fastapi_ip_rate_limiter( rate_limiter, paths=[ - "/login", - "/logout", - "/refresh", - "/password-reset-email", - "/password-reset", + "/auth/login", + "/auth/logout", + "/auth/refresh", + "/auth/password-reset-email", + "/auth/password-reset", ], ) router = APIRouter( diff --git a/src/phoenix/server/api/routers/oauth2.py b/src/phoenix/server/api/routers/oauth2.py index 642560b666..8a1988a33e 100644 --- a/src/phoenix/server/api/routers/oauth2.py +++ b/src/phoenix/server/api/routers/oauth2.py @@ -1,3 +1,4 @@ +import re from dataclasses import dataclass from datetime import timedelta from typing import Any, Dict, Optional, Tuple @@ -6,7 +7,7 @@ from authlib.integrations.starlette_client import OAuthError from authlib.jose import jwt from authlib.jose.errors import BadSignatureError -from fastapi import APIRouter, Cookie, Path, Query, Request +from fastapi import APIRouter, Cookie, Depends, Path, Query, Request from sqlalchemy import Boolean, and_, case, cast, func, insert, or_, select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload @@ -31,10 +32,43 @@ from phoenix.server.bearer_auth import create_access_and_refresh_tokens from phoenix.server.jwt_store import JwtStore from phoenix.server.oauth2 import OAuth2Client +from phoenix.server.rate_limiters import ( + ServerRateLimiter, + fastapi_ip_rate_limiter, + fastapi_route_rate_limiter, +) _LOWERCASE_ALPHANUMS_AND_UNDERSCORES = r"[a-z0-9_]+" -router = APIRouter(prefix="/oauth2", include_in_schema=False) +login_rate_limiter = ServerRateLimiter( + per_second_rate_limit=0.2, + enforcement_window_seconds=30, + partition_seconds=60, + active_partitions=2, +) +login_rate_limiter = fastapi_ip_rate_limiter( + login_rate_limiter, + paths=[ + "oauth2/login", + ], +) + +token_rate_limiter = ServerRateLimiter( + per_second_rate_limit=0.5, + enforcement_window_seconds=30, + partition_seconds=60, + active_partitions=2, +) +token_rate_limiter = fastapi_route_rate_limiter( + token_rate_limiter, + paths=[re.compile(r"/oauth2/[a-z0-9_]+/tokens")], +) + +router = APIRouter( + prefix="/oauth2", + include_in_schema=False, + dependencies=[Depends(login_rate_limiter), Depends(token_rate_limiter)], +) @router.post("/{idp_name}/login") diff --git a/src/phoenix/server/rate_limiters.py b/src/phoenix/server/rate_limiters.py index 9e6b8e2814..08957c0e93 100644 --- a/src/phoenix/server/rate_limiters.py +++ b/src/phoenix/server/rate_limiters.py @@ -1,7 +1,8 @@ +import re import time from collections import defaultdict from functools import partial -from typing import Any, Callable, Coroutine, DefaultDict, List, Optional +from typing import Any, Callable, Coroutine, DefaultDict, List, Optional, Union from fastapi import HTTPException, Request @@ -136,7 +137,7 @@ def make_request(self, key: str) -> None: rate_limiter.make_request_if_ready() -def fastapi_rate_limiter( +def fastapi_ip_rate_limiter( rate_limiter: ServerRateLimiter, paths: Optional[List[str]] = None ) -> Callable[[Request], Coroutine[Any, Any, Request]]: async def dependency(request: Request) -> Request: @@ -153,5 +154,22 @@ async def dependency(request: Request) -> Request: return dependency -def path_match(path: str, match_pattern: str) -> bool: +def fastapi_route_rate_limiter( + rate_limiter: ServerRateLimiter, paths: Optional[List[Union[str, re.Pattern]]] = None +) -> Callable[[Request], Coroutine[Any, Any, Request]]: + async def dependency(request: Request) -> Request: + for match_path in paths: + if path_match(request.url.path, match_path): + try: + rate_limiter.make_request(str(match_path)) + except UnavailableTokensError: + raise HTTPException(status_code=429, detail="Too Many Requests") + return request + + return dependency + + +def path_match(path: str, match_pattern: Union[str, re.Pattern]) -> bool: + if isinstance(match_pattern, re.Pattern): + return bool(match_pattern.match(path)) return path == match_pattern From b002378c5c732dfa025123b846cce3e8a104eb6d Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Thu, 19 Sep 2024 17:07:35 -0700 Subject: [PATCH 23/29] fix rate limiter type error --- src/phoenix/server/api/routers/oauth2.py | 38 ++++++++++-------------- src/phoenix/server/rate_limiters.py | 16 +++++----- 2 files changed, 22 insertions(+), 32 deletions(-) diff --git a/src/phoenix/server/api/routers/oauth2.py b/src/phoenix/server/api/routers/oauth2.py index 8a1988a33e..a8467472e8 100644 --- a/src/phoenix/server/api/routers/oauth2.py +++ b/src/phoenix/server/api/routers/oauth2.py @@ -1,4 +1,3 @@ -import re from dataclasses import dataclass from datetime import timedelta from typing import Any, Dict, Optional, Tuple @@ -40,38 +39,31 @@ _LOWERCASE_ALPHANUMS_AND_UNDERSCORES = r"[a-z0-9_]+" -login_rate_limiter = ServerRateLimiter( - per_second_rate_limit=0.2, - enforcement_window_seconds=30, - partition_seconds=60, - active_partitions=2, -) login_rate_limiter = fastapi_ip_rate_limiter( - login_rate_limiter, - paths=[ - "oauth2/login", - ], + ServerRateLimiter( + per_second_rate_limit=0.2, + enforcement_window_seconds=30, + partition_seconds=60, + active_partitions=2, + ), ) -token_rate_limiter = ServerRateLimiter( - per_second_rate_limit=0.5, - enforcement_window_seconds=30, - partition_seconds=60, - active_partitions=2, -) -token_rate_limiter = fastapi_route_rate_limiter( - token_rate_limiter, - paths=[re.compile(r"/oauth2/[a-z0-9_]+/tokens")], +create_tokens_rate_limiter = fastapi_route_rate_limiter( + ServerRateLimiter( + per_second_rate_limit=0.5, + enforcement_window_seconds=30, + partition_seconds=60, + active_partitions=2, + ) ) router = APIRouter( prefix="/oauth2", include_in_schema=False, - dependencies=[Depends(login_rate_limiter), Depends(token_rate_limiter)], ) -@router.post("/{idp_name}/login") +@router.post("/{idp_name}/login", dependencies=[Depends(login_rate_limiter)]) async def login( request: Request, idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)], @@ -105,7 +97,7 @@ async def login( return response -@router.get("/{idp_name}/tokens") +@router.get("/{idp_name}/tokens", dependencies=[Depends(create_tokens_rate_limiter)]) async def create_tokens( request: Request, idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)], diff --git a/src/phoenix/server/rate_limiters.py b/src/phoenix/server/rate_limiters.py index 08957c0e93..f50e3eefa7 100644 --- a/src/phoenix/server/rate_limiters.py +++ b/src/phoenix/server/rate_limiters.py @@ -138,7 +138,7 @@ def make_request(self, key: str) -> None: def fastapi_ip_rate_limiter( - rate_limiter: ServerRateLimiter, paths: Optional[List[str]] = None + rate_limiter: ServerRateLimiter, paths: Optional[List[Union[str, re.Pattern[str]]]] = None ) -> Callable[[Request], Coroutine[Any, Any, Request]]: async def dependency(request: Request) -> Request: if paths is None or any(path_match(request.url.path, path) for path in paths): @@ -155,21 +155,19 @@ async def dependency(request: Request) -> Request: def fastapi_route_rate_limiter( - rate_limiter: ServerRateLimiter, paths: Optional[List[Union[str, re.Pattern]]] = None + rate_limiter: ServerRateLimiter, ) -> Callable[[Request], Coroutine[Any, Any, Request]]: async def dependency(request: Request) -> Request: - for match_path in paths: - if path_match(request.url.path, match_path): - try: - rate_limiter.make_request(str(match_path)) - except UnavailableTokensError: - raise HTTPException(status_code=429, detail="Too Many Requests") + try: + rate_limiter.make_request(request.url.path) + except UnavailableTokensError: + raise HTTPException(status_code=429, detail="Too Many Requests") return request return dependency -def path_match(path: str, match_pattern: Union[str, re.Pattern]) -> bool: +def path_match(path: str, match_pattern: Union[str, re.Pattern[str]]) -> bool: if isinstance(match_pattern, re.Pattern): return bool(match_pattern.match(path)) return path == match_pattern From 80d3a8f62ef3132f983257a9441bd84fdb57f6ca Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Thu, 19 Sep 2024 17:11:13 -0700 Subject: [PATCH 24/29] explicitly reset password --- src/phoenix/server/api/routers/oauth2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/phoenix/server/api/routers/oauth2.py b/src/phoenix/server/api/routers/oauth2.py index a8467472e8..6d9f5f137b 100644 --- a/src/phoenix/server/api/routers/oauth2.py +++ b/src/phoenix/server/api/routers/oauth2.py @@ -265,6 +265,7 @@ async def _create_user( username=user_info.username, email=user_info.email, profile_picture_url=user_info.profile_picture_url, + reset_password=False, ) ) assert isinstance(user_id, int) From a31d6fedb4192d1f76a3662725493736ca965027 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Thu, 19 Sep 2024 17:21:23 -0700 Subject: [PATCH 25/29] use TokenStore interface --- src/phoenix/server/api/routers/auth.py | 3 +-- src/phoenix/server/api/routers/oauth2.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/phoenix/server/api/routers/auth.py b/src/phoenix/server/api/routers/auth.py index 31478933f8..d9aab97788 100644 --- a/src/phoenix/server/api/routers/auth.py +++ b/src/phoenix/server/api/routers/auth.py @@ -36,7 +36,6 @@ from phoenix.server.bearer_auth import PhoenixUser, create_access_and_refresh_tokens from phoenix.server.email.templates.types import PasswordResetTemplateBody from phoenix.server.email.types import EmailSender -from phoenix.server.jwt_store import JwtStore from phoenix.server.rate_limiters import ServerRateLimiter, fastapi_ip_rate_limiter from phoenix.server.types import ( AccessTokenClaims, @@ -72,7 +71,7 @@ async def login(request: Request) -> Response: assert isinstance(access_token_expiry := request.app.state.access_token_expiry, timedelta) assert isinstance(refresh_token_expiry := request.app.state.refresh_token_expiry, timedelta) - token_store: JwtStore = request.app.state.get_token_store() + token_store: TokenStore = request.app.state.get_token_store() data = await request.json() email = data.get("email") password = data.get("password") diff --git a/src/phoenix/server/api/routers/oauth2.py b/src/phoenix/server/api/routers/oauth2.py index 6d9f5f137b..dbc337c72d 100644 --- a/src/phoenix/server/api/routers/oauth2.py +++ b/src/phoenix/server/api/routers/oauth2.py @@ -29,13 +29,13 @@ from phoenix.db import models from phoenix.db.enums import UserRole from phoenix.server.bearer_auth import create_access_and_refresh_tokens -from phoenix.server.jwt_store import JwtStore from phoenix.server.oauth2 import OAuth2Client from phoenix.server.rate_limiters import ( ServerRateLimiter, fastapi_ip_rate_limiter, fastapi_route_rate_limiter, ) +from phoenix.server.types import TokenStore _LOWERCASE_ALPHANUMS_AND_UNDERSCORES = r"[a-z0-9_]+" @@ -116,7 +116,7 @@ async def create_tokens( return _redirect_to_login(error=_INVALID_OAUTH2_STATE_MESSAGE) assert isinstance(access_token_expiry := request.app.state.access_token_expiry, timedelta) assert isinstance(refresh_token_expiry := request.app.state.refresh_token_expiry, timedelta) - token_store: JwtStore = request.app.state.get_token_store() + token_store: TokenStore = request.app.state.get_token_store() if not isinstance( oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client ): From 69a5a9148478ce3676015d462dcb0be912e35e9d Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Thu, 19 Sep 2024 17:52:17 -0700 Subject: [PATCH 26/29] remove the explicit routes --- src/phoenix/server/api/routers/auth.py | 20 ++++++-------------- 1 file changed, 6 insertions(+), 14 deletions(-) diff --git a/src/phoenix/server/api/routers/auth.py b/src/phoenix/server/api/routers/auth.py index d9aab97788..a2edc31e5f 100644 --- a/src/phoenix/server/api/routers/auth.py +++ b/src/phoenix/server/api/routers/auth.py @@ -46,21 +46,13 @@ UserId, ) -rate_limiter = ServerRateLimiter( - per_second_rate_limit=0.2, - enforcement_window_seconds=30, - partition_seconds=60, - active_partitions=2, -) login_rate_limiter = fastapi_ip_rate_limiter( - rate_limiter, - paths=[ - "/auth/login", - "/auth/logout", - "/auth/refresh", - "/auth/password-reset-email", - "/auth/password-reset", - ], + ServerRateLimiter( + per_second_rate_limit=0.2, + enforcement_window_seconds=30, + partition_seconds=60, + active_partitions=2, + ) ) router = APIRouter( prefix="/auth", include_in_schema=False, dependencies=[Depends(login_rate_limiter)] From d959adee73e3d66349d3959312aa3561bab8ee02 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Thu, 19 Sep 2024 18:50:45 -0700 Subject: [PATCH 27/29] import pattern from typing to add support for 3.8 --- .../experiments/evaluators/code_evaluators.py | 12 +++++++++--- src/phoenix/server/rate_limiters.py | 15 ++++++++++++--- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/phoenix/experiments/evaluators/code_evaluators.py b/src/phoenix/experiments/evaluators/code_evaluators.py index 74d267dfc3..e99b2b921c 100644 --- a/src/phoenix/experiments/evaluators/code_evaluators.py +++ b/src/phoenix/experiments/evaluators/code_evaluators.py @@ -2,7 +2,13 @@ import json import re -from typing import Any, List, Optional, Union +from typing import ( + Any, + List, + Optional, + Pattern, # import from re module when we drop support for 3.8 + Union, +) from phoenix.experiments.evaluators.base import CodeEvaluator from phoenix.experiments.types import EvaluationResult, TaskOutput @@ -144,7 +150,7 @@ class MatchesRegex(CodeEvaluator): An experiment evaluator that checks if the output of an experiment run matches a regex pattern. Args: - pattern (Union[str, re.Pattern[str]]): The regex pattern to match the output against. + pattern (Union[str, Pattern[str]]): The regex pattern to match the output against. name (str, optional): An optional name for the evaluator. Defaults to "matches_({pattern})". Example: @@ -157,7 +163,7 @@ class MatchesRegex(CodeEvaluator): run_experiment(dataset, task, evaluators=[phone_number_evaluator]) """ - def __init__(self, pattern: Union[str, re.Pattern[str]], name: Optional[str] = None) -> None: + def __init__(self, pattern: Union[str, Pattern[str]], name: Optional[str] = None) -> None: if isinstance(pattern, str): pattern = re.compile(pattern) self.pattern = pattern diff --git a/src/phoenix/server/rate_limiters.py b/src/phoenix/server/rate_limiters.py index f50e3eefa7..d1eb3f380e 100644 --- a/src/phoenix/server/rate_limiters.py +++ b/src/phoenix/server/rate_limiters.py @@ -2,7 +2,16 @@ import time from collections import defaultdict from functools import partial -from typing import Any, Callable, Coroutine, DefaultDict, List, Optional, Union +from typing import ( + Any, + Callable, + Coroutine, + DefaultDict, + List, + Optional, + Pattern, # import from re module when we drop support for 3.8 + Union, +) from fastapi import HTTPException, Request @@ -138,7 +147,7 @@ def make_request(self, key: str) -> None: def fastapi_ip_rate_limiter( - rate_limiter: ServerRateLimiter, paths: Optional[List[Union[str, re.Pattern[str]]]] = None + rate_limiter: ServerRateLimiter, paths: Optional[List[Union[str, Pattern[str]]]] = None ) -> Callable[[Request], Coroutine[Any, Any, Request]]: async def dependency(request: Request) -> Request: if paths is None or any(path_match(request.url.path, path) for path in paths): @@ -167,7 +176,7 @@ async def dependency(request: Request) -> Request: return dependency -def path_match(path: str, match_pattern: Union[str, re.Pattern[str]]) -> bool: +def path_match(path: str, match_pattern: Union[str, Pattern[str]]) -> bool: if isinstance(match_pattern, re.Pattern): return bool(match_pattern.match(path)) return path == match_pattern From 7bfe1d207ae0fe6d4d080f36f50b7c2466328662 Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Thu, 19 Sep 2024 19:06:25 -0700 Subject: [PATCH 28/29] increase rate limit --- src/phoenix/server/api/routers/auth.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/phoenix/server/api/routers/auth.py b/src/phoenix/server/api/routers/auth.py index a2edc31e5f..79aa86c5cb 100644 --- a/src/phoenix/server/api/routers/auth.py +++ b/src/phoenix/server/api/routers/auth.py @@ -48,7 +48,7 @@ login_rate_limiter = fastapi_ip_rate_limiter( ServerRateLimiter( - per_second_rate_limit=0.2, + per_second_rate_limit=1.0, enforcement_window_seconds=30, partition_seconds=60, active_partitions=2, From 4e1834804f0c40c4f238a2657c4ead06ab8388bb Mon Sep 17 00:00:00 2001 From: Alexander Song Date: Thu, 19 Sep 2024 19:16:22 -0700 Subject: [PATCH 29/29] undo rate limiter fix --- src/phoenix/config.py | 3 +-- src/phoenix/server/api/routers/auth.py | 20 ++++++++++++++------ 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/phoenix/config.py b/src/phoenix/config.py index bec10aa91c..892a47184c 100644 --- a/src/phoenix/config.py +++ b/src/phoenix/config.py @@ -69,8 +69,7 @@ "PHOENIX_SERVER_INSTRUMENTATION_OTLP_TRACE_COLLECTOR_GRPC_ENDPOINT" ) -# Auth is under active development. Phoenix users are strongly advised not to -# set these environment variables until the feature is officially released. +# Authentication settings ENV_PHOENIX_ENABLE_AUTH = "PHOENIX_ENABLE_AUTH" ENV_PHOENIX_SECRET = "PHOENIX_SECRET" ENV_PHOENIX_API_KEY = "PHOENIX_API_KEY" diff --git a/src/phoenix/server/api/routers/auth.py b/src/phoenix/server/api/routers/auth.py index 79aa86c5cb..ed3dd3cb76 100644 --- a/src/phoenix/server/api/routers/auth.py +++ b/src/phoenix/server/api/routers/auth.py @@ -46,13 +46,21 @@ UserId, ) +rate_limiter = ServerRateLimiter( + per_second_rate_limit=0.2, + enforcement_window_seconds=30, + partition_seconds=60, + active_partitions=2, +) login_rate_limiter = fastapi_ip_rate_limiter( - ServerRateLimiter( - per_second_rate_limit=1.0, - enforcement_window_seconds=30, - partition_seconds=60, - active_partitions=2, - ) + rate_limiter, + paths=[ + "/login", + "/logout", + "/refresh", + "/password-reset-email", + "/password-reset", + ], ) router = APIRouter( prefix="/auth", include_in_schema=False, dependencies=[Depends(login_rate_limiter)]