diff --git a/app/src/pages/auth/LoginPage.tsx b/app/src/pages/auth/LoginPage.tsx index b86711d565..26928ac770 100644 --- a/app/src/pages/auth/LoginPage.tsx +++ b/app/src/pages/auth/LoginPage.tsx @@ -6,7 +6,7 @@ import { Flex, View } from "@arizeai/components"; import { AuthLayout } from "./AuthLayout"; import { LoginForm } from "./LoginForm"; -import { OAuth2Login } from "./Oauth2Login"; +import { OAuth2Login } from "./OAuth2Login"; import { PhoenixLogo } from "./PhoenixLogo"; const separatorCSS = css` diff --git a/app/src/pages/auth/Oauth2Login.tsx b/app/src/pages/auth/OAuth2Login.tsx similarity index 100% rename from app/src/pages/auth/Oauth2Login.tsx rename to app/src/pages/auth/OAuth2Login.tsx diff --git a/app/src/pages/auth/oAuthCallbackLoader.ts b/app/src/pages/auth/oAuthCallbackLoader.ts deleted file mode 100644 index 251ed76915..0000000000 --- a/app/src/pages/auth/oAuthCallbackLoader.ts +++ /dev/null @@ -1,41 +0,0 @@ -// import { redirect } from "react-router"; -import { LoaderFunctionArgs } from "react-router-dom"; - -export async function oAuthCallbackLoader(args: LoaderFunctionArgs) { - const queryParameters = new URL(args.request.url).searchParams; - const authorizationCode = queryParameters.get("code"); - const state = queryParameters.get("state"); - const actualState = sessionStorage.getItem("oAuthState"); - sessionStorage.removeItem("oAuthState"); - if ( - authorizationCode == undefined || - state == undefined || - actualState == undefined || - state !== actualState - ) { - // todo: display error message - return null; - } - const origin = new URL(window.location.href).origin; - const redirectUri = `${origin}/oauth-callback`; - try { - const response = await fetch("/auth/oauth-tokens", { - method: "POST", - headers: { - "Content-Type": "application/json", - }, - body: JSON.stringify({ - authorization_code: authorizationCode, - redirect_uri: redirectUri, - }), - }); - if (!response.ok) { - // todo: parse response body and display error message - return null; - } - } catch (error) { - // todo: display error - } - // redirect("/"); - return null; -} diff --git a/src/phoenix/config.py b/src/phoenix/config.py index 48bc918df2..480404f4ac 100644 --- a/src/phoenix/config.py +++ b/src/phoenix/config.py @@ -10,13 +10,9 @@ from urllib.parse import urlparse import pandas as pd -from typing_extensions import TypeAlias from phoenix.utilities.re import parse_env_headers -EnvVarName: TypeAlias = str -EnvVarValue: TypeAlias = str - logger = getLogger(__name__) # Phoenix environment variables @@ -211,7 +207,7 @@ def get_env_refresh_token_expiry() -> timedelta: @dataclass(frozen=True) class OAuth2ClientConfig: idp_name: str - display_name: str + idp_display_name: str client_id: str client_secret: str server_metadata_url: str @@ -257,7 +253,7 @@ def from_env(cls, idp_name: str) -> "OAuth2ClientConfig": ) return cls( idp_name=idp_name, - display_name=os.getenv( + idp_display_name=os.getenv( f"PHOENIX_OAUTH2_{idp_name_upper}_DISPLAY_NAME", _get_default_idp_display_name(idp_name), ), @@ -473,6 +469,9 @@ class OAuth2Idp(Enum): def _get_default_idp_display_name(idp_name: str) -> str: + """ + Get the default display name for an OAuth2 IDP. + """ if idp_name == OAuth2Idp.AWS_COGNITO.value: return "AWS Cognito" if idp_name == OAuth2Idp.MICROSOFT_ENTRA_ID.value: @@ -481,6 +480,9 @@ def _get_default_idp_display_name(idp_name: str) -> str: def _get_default_server_metadata_url(idp_name: str) -> Optional[str]: + """ + Gets the default server metadata URL for an OAuth2 IDP. + """ if idp_name == OAuth2Idp.GOOGLE.value: return "https://accounts.google.com/.well-known/openid-configuration" return None diff --git a/src/phoenix/db/enums.py b/src/phoenix/db/enums.py index 12ae245942..8f99057750 100644 --- a/src/phoenix/db/enums.py +++ b/src/phoenix/db/enums.py @@ -12,10 +12,6 @@ class UserRole(Enum): MEMBER = "MEMBER" -class IdentityProviderName(Enum): - LOCAL = "local" - - COLUMN_ENUMS: Mapping[InstrumentedAttribute[str], Type[Enum]] = { models.UserRole.name: UserRole, } diff --git a/src/phoenix/server/api/routers/oauth2.py b/src/phoenix/server/api/routers/oauth2.py index 97f464cf67..33b2f3e1c6 100644 --- a/src/phoenix/server/api/routers/oauth2.py +++ b/src/phoenix/server/api/routers/oauth2.py @@ -4,7 +4,7 @@ from authlib.integrations.starlette_client import OAuthError from authlib.integrations.starlette_client import StarletteOAuth2App as OAuth2Client -from fastapi import APIRouter, Depends, Path, Request +from fastapi import APIRouter, Path, Request from sqlalchemy import Boolean, and_, case, cast, func, insert, or_, select, update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload @@ -20,26 +20,17 @@ from phoenix.db.enums import UserRole from phoenix.server.bearer_auth import create_access_and_refresh_tokens from phoenix.server.jwt_store import JwtStore -from phoenix.server.rate_limiters import ServerRateLimiter, fastapi_rate_limiter -ALPHANUMS_AND_UNDERSCORES = r"[a-z0-9_]+" +_LOWERCASE_ALPHANUMS_AND_UNDERSCORES = r"[a-z0-9_]+" -rate_limiter = ServerRateLimiter( - per_second_rate_limit=0.2, - enforcement_window_seconds=30, - partition_seconds=60, - active_partitions=2, -) -login_rate_limiter = fastapi_rate_limiter(rate_limiter, paths=["/login"]) -router = APIRouter( - prefix="/oauth2", include_in_schema=False, dependencies=[Depends(login_rate_limiter)] -) + +router = APIRouter(prefix="/oauth2", include_in_schema=False) @router.post("/{idp_name}/login") async def login( request: Request, - idp_name: Annotated[str, Path(min_length=1, pattern=ALPHANUMS_AND_UNDERSCORES)], + idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)], ) -> RedirectResponse: if not isinstance( oauth2_client := request.app.state.oauth2_clients.get_client(idp_name), OAuth2Client @@ -53,7 +44,7 @@ async def login( @router.get("/{idp_name}/tokens") async def create_tokens( request: Request, - idp_name: Annotated[str, Path(min_length=1, pattern=ALPHANUMS_AND_UNDERSCORES)], + idp_name: Annotated[str, Path(min_length=1, pattern=_LOWERCASE_ALPHANUMS_AND_UNDERSCORES)], ) -> RedirectResponse: assert isinstance(access_token_expiry := request.app.state.access_token_expiry, timedelta) assert isinstance(refresh_token_expiry := request.app.state.refresh_token_expiry, timedelta) @@ -63,10 +54,10 @@ async def create_tokens( ): return _redirect_to_login(error=f"Unknown IDP: {idp_name}.") try: - token = await oauth2_client.authorize_access_token(request) + token_data = await oauth2_client.authorize_access_token(request) except OAuthError as error: return _redirect_to_login(error=str(error)) - if (user_info := _get_user_info(token)) is None: + if (user_info := _get_user_info(token_data)) is None: return _redirect_to_login( error=f"OAuth2 IDP {idp_name} does not appear to support OpenID Connect." ) @@ -103,11 +94,14 @@ class UserInfo: profile_picture_url: Optional[str] -def _get_user_info(token: Dict[str, Any]) -> Optional[UserInfo]: - assert isinstance(token.get("access_token"), str) - assert isinstance(token_type := token.get("token_type"), str) +def _get_user_info(token_data: Dict[str, Any]) -> Optional[UserInfo]: + """ + Parses token data and extracts user info if available. + """ + assert isinstance(token_data.get("access_token"), str) + assert isinstance(token_type := token_data.get("token_type"), str) assert token_type.lower() == "bearer" - if (user_info := token.get("userinfo")) is None: + if (user_info := token_data.get("userinfo")) is None: return None assert isinstance(subject := user_info.get("sub"), (str, int)) idp_user_id = str(subject) @@ -135,7 +129,7 @@ async def _ensure_user_exists_and_is_up_to_date( ) if user is None: user = await _create_user(session, oauth2_client_id=oauth2_client_id, user_info=user_info) - elif _db_user_is_outdated(user=user, user_info=user_info): + elif not _user_is_up_to_date(user=user, user_info=user_info): user = await _update_user(session, user_id=user.id, user_info=user_info) return user @@ -143,6 +137,10 @@ async def _ensure_user_exists_and_is_up_to_date( async def _get_user( session: AsyncSession, /, *, oauth2_client_id: str, idp_user_id: str ) -> Optional[models.User]: + """ + Retrieves the user uniquely identified by the given OAuth2 client ID and IDP + user ID. + """ user = await session.scalar( select(models.User) .where( @@ -156,36 +154,6 @@ async def _get_user( return user -async def _ensure_email_and_username_are_not_in_use( - session: AsyncSession, /, *, email: str, username: Optional[str] -) -> None: - [(email_exists, username_exists)] = ( - await session.execute( - select( - cast( - func.coalesce( - func.max(case((models.User.email == email, 1), else_=0)), - 0, - ), - Boolean, - ).label("email_exists"), - cast( - func.coalesce( - func.max(case((models.User.username == username, 1), else_=0)), - 0, - ), - Boolean, - ).label("username_exists"), - ).where(or_(models.User.email == email, models.User.username == username)) - ) - ).all() - if email_exists: - raise EmailAlreadyInUse(f"An account for {email} is already in use.") - if username_exists: - raise UsernameAlreadyInUse(f'An account already exists with username "{username}".') - return None - - async def _create_user( session: AsyncSession, /, @@ -193,6 +161,9 @@ async def _create_user( oauth2_client_id: str, user_info: UserInfo, ) -> models.User: + """ + Creates a new user with the user info from the IDP. + """ await _ensure_email_and_username_are_not_in_use( session, email=user_info.email, @@ -213,15 +184,12 @@ async def _create_user( username=user_info.username, email=user_info.email, profile_picture_url=user_info.profile_picture_url, - password_hash=None, - password_salt=None, - reset_password=False, ) ) assert isinstance(user_id, int) user = await session.scalar( select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role)) - ) # query user for joined load + ) # query user again for joined load assert isinstance(user, models.User) return user @@ -229,6 +197,14 @@ async def _create_user( async def _update_user( session: AsyncSession, /, *, user_id: int, user_info: UserInfo ) -> models.User: + """ + Updates an existing user with user info from the IDP. + """ + await _ensure_email_and_username_are_not_in_use( + session, + email=user_info.email, + username=user_info.username, + ) await session.execute( update(models.User) .where(models.User.id == user_id) @@ -239,19 +215,54 @@ async def _update_user( ) .options(joinedload(models.User.role)) ) - assert isinstance(user_id, int) user = await session.scalar( select(models.User).where(models.User.id == user_id).options(joinedload(models.User.role)) - ) # query user for joined load + ) # query user again for joined load assert isinstance(user, models.User) return user -def _db_user_is_outdated(*, user: models.User, user_info: UserInfo) -> bool: +async def _ensure_email_and_username_are_not_in_use( + session: AsyncSession, /, *, email: str, username: Optional[str] +) -> None: + """ + Raises an error if the email or username are already in use. + """ + [(email_exists, username_exists)] = ( + await session.execute( + select( + cast( + func.coalesce( + func.max(case((models.User.email == email, 1), else_=0)), + 0, + ), + Boolean, + ).label("email_exists"), + cast( + func.coalesce( + func.max(case((models.User.username == username, 1), else_=0)), + 0, + ), + Boolean, + ).label("username_exists"), + ).where(or_(models.User.email == email, models.User.username == username)) + ) + ).all() + if email_exists: + raise EmailAlreadyInUse(f"An account for {email} is already in use.") + if username_exists: + raise UsernameAlreadyInUse(f'An account already exists with username "{username}".') + + +def _user_is_up_to_date(*, user: models.User, user_info: UserInfo) -> bool: + """ + Determines whether the user's tuple in the database is up-to-date with the + IDP's user info. + """ return ( - user.email != user_info.email - or user.username != user_info.username - or user.profile_picture_url != user_info.profile_picture_url + user.email == user_info.email + and user.username == user_info.username + and user.profile_picture_url == user_info.profile_picture_url ) diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index db9231c1c5..41d945b481 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -742,7 +742,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: web_manifest_path = SERVER_DIR / "static" / ".vite" / "manifest.json" if serve_ui and web_manifest_path.is_file(): oauth2_idps = [ - OAuth2Idp(name=config.idp_name, displayName=config.display_name) + OAuth2Idp(name=config.idp_name, displayName=config.idp_display_name) for config in oauth2_client_configs or [] ] app.mount( diff --git a/src/phoenix/server/bearer_auth.py b/src/phoenix/server/bearer_auth.py index 14ca32af51..3c47171c3e 100644 --- a/src/phoenix/server/bearer_auth.py +++ b/src/phoenix/server/bearer_auth.py @@ -137,9 +137,10 @@ async def create_access_and_refresh_tokens( refresh_token_expiry: timedelta, ) -> Tuple[AccessToken, RefreshToken]: issued_at = datetime.now(timezone.utc) + user_id = UserId(user.id) user_role = UserRole(user.role.name) refresh_token_claims = RefreshTokenClaims( - subject=UserId(user.id), + subject=user_id, issued_at=issued_at, expiration_time=issued_at + refresh_token_expiry, attributes=RefreshTokenAttributes( @@ -148,7 +149,7 @@ async def create_access_and_refresh_tokens( ) refresh_token, refresh_token_id = await token_store.create_refresh_token(refresh_token_claims) access_token_claims = AccessTokenClaims( - subject=UserId(user.id), + subject=user_id, issued_at=issued_at, expiration_time=issued_at + access_token_expiry, attributes=AccessTokenAttributes( diff --git a/src/phoenix/server/oauth2.py b/src/phoenix/server/oauth2.py index cc3256e2f9..8bff39929b 100644 --- a/src/phoenix/server/oauth2.py +++ b/src/phoenix/server/oauth2.py @@ -1,5 +1,5 @@ from datetime import datetime, timedelta -from typing import Any, Dict, Generic, List, Optional, Tuple +from typing import Any, Dict, Generic, Iterable, Optional, Tuple from authlib.integrations.starlette_client import OAuth from authlib.integrations.starlette_client import StarletteOAuth2App as OAuth2Client @@ -28,11 +28,11 @@ def add_client(self, config: OAuth2ClientConfig) -> None: def get_client(self, idp_name: str) -> OAuth2Client: if (client := self._clients.get(idp_name)) is None: - raise ValueError(f"unknown or unregistered oauth client: {idp_name}") + raise ValueError(f"unknown or unregistered OAuth2 client: {idp_name}") return client @classmethod - def from_configs(cls, configs: List[OAuth2ClientConfig]) -> "OAuth2Clients": + def from_configs(cls, configs: Iterable[OAuth2ClientConfig]) -> "OAuth2Clients": oauth2_clients = cls() for config in configs: oauth2_clients.add_client(config) @@ -51,7 +51,7 @@ class _OAuth2ClientTTLCache(Generic[_CacheKey, _CacheValue]): integration. Provides an alternative to starlette session middleware. """ - def __init__(self, cleanup_interval: timedelta = 10 * _MINUTE) -> None: + def __init__(self, cleanup_interval: timedelta = 1 * _MINUTE) -> None: self._data: Dict[_CacheKey, Tuple[_CacheValue, _Expiry]] = {} self._last_cleanup_time = datetime.now() self._cleanup_interval = cleanup_interval @@ -61,6 +61,7 @@ async def get(self, key: _CacheKey) -> Optional[_CacheValue]: Retrieves the value associated with the given key if it exists and has not expired, otherwise, returns None. """ + self._remove_expired_keys_if_cleanup_interval_exceeded() if (value_and_expiry := self._data.get(key)) is None: return None value, expiry = value_and_expiry