From b7af851f42d04417d2851f62c7a3ee6ad99b7634 Mon Sep 17 00:00:00 2001 From: Xander Song Date: Thu, 12 Sep 2024 16:06:57 -0700 Subject: [PATCH] feat(auth): secure `/exports` when auth is enabled (#4589) --- integration_tests/_helpers.py | 9 +++++++ integration_tests/auth/test_auth.py | 19 ++++++++++++++ src/phoenix/server/api/routers/__init__.py | 7 +++++- src/phoenix/server/api/routers/embeddings.py | 26 ++++++++++++++++++++ src/phoenix/server/app.py | 16 ++---------- 5 files changed, 62 insertions(+), 15 deletions(-) create mode 100644 src/phoenix/server/api/routers/embeddings.py diff --git a/integration_tests/_helpers.py b/integration_tests/_helpers.py index f0dd4bb3c5..e026e08a7a 100644 --- a/integration_tests/_helpers.py +++ b/integration_tests/_helpers.py @@ -211,6 +211,9 @@ def create_api_key( def delete_api_key(self, api_key: _ApiKey, /) -> None: return _delete_api_key(api_key, self) + def export_embeddings(self, filename: str) -> None: + _export_embeddings(self, filename=filename) + _SYSTEM_USER_GID = _GqlId(GlobalID(type_name="User", node_id="1")) _DEFAULT_ADMIN = _User( @@ -828,6 +831,11 @@ def _log_out( resp.raise_for_status() +def _export_embeddings(auth: Optional[_SecurityArtifact] = None, /, *, filename: str) -> None: + resp = _httpx_client(auth).get("/exports", params={"filename": filename}) + resp.raise_for_status() + + def _json( resp: httpx.Response, ) -> Dict[str, Any]: @@ -853,3 +861,4 @@ def __exit__(self, *args: Any, **kwargs: Any) -> None: ... _EXPECTATION_401 = pytest.raises(HTTPStatusError, match="401 Unauthorized") _EXPECTATION_403 = pytest.raises(HTTPStatusError, match="403 Forbidden") +_EXPECTATION_404 = pytest.raises(HTTPStatusError, match="404 Not Found") diff --git a/integration_tests/auth/test_auth.py b/integration_tests/auth/test_auth.py index 3d5e0deb91..4f5444fc1f 100644 --- a/integration_tests/auth/test_auth.py +++ b/integration_tests/auth/test_auth.py @@ -36,6 +36,7 @@ _DEFAULT_ADMIN, _DENIED, _EXPECTATION_401, + _EXPECTATION_404, _MEMBER, _OK, _OK_OR_DENIED, @@ -45,6 +46,7 @@ _create_user, _DefaultAdminTokenSequestration, _Expectation, + _export_embeddings, _GetUser, _GqlId, _Headers, @@ -741,3 +743,20 @@ def test_api_key( if api_key and expected is SpanExportResult.SUCCESS: _DEFAULT_ADMIN.delete_api_key(api_key) assert export(_spans) is SpanExportResult.FAILURE + + +class TestEmbeddingsRestApi: + @pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN, _DEFAULT_ADMIN]) + def test_authenticated_users_can_access_route( + self, + role_or_user: _RoleOrUser, + _get_user: _GetUser, + ) -> None: + user = _get_user(role_or_user) + logged_in_user = user.log_in() + with _EXPECTATION_404: # no files have been exported + logged_in_user.export_embeddings("embeddings") + + def test_unauthenticated_requests_receive_401(self) -> None: + with _EXPECTATION_401: + _export_embeddings(None, filename="embeddings") diff --git a/src/phoenix/server/api/routers/__init__.py b/src/phoenix/server/api/routers/__init__.py index ac65d00c58..354d5d106d 100644 --- a/src/phoenix/server/api/routers/__init__.py +++ b/src/phoenix/server/api/routers/__init__.py @@ -1,4 +1,9 @@ from .auth import router as auth_router +from .embeddings import create_embeddings_router from .v1 import create_v1_router -__all__ = ["auth_router", "create_v1_router"] +__all__ = [ + "auth_router", + "create_embeddings_router", + "create_v1_router", +] diff --git a/src/phoenix/server/api/routers/embeddings.py b/src/phoenix/server/api/routers/embeddings.py new file mode 100644 index 0000000000..af80944fa8 --- /dev/null +++ b/src/phoenix/server/api/routers/embeddings.py @@ -0,0 +1,26 @@ +from fastapi import APIRouter, Depends +from fastapi.responses import FileResponse +from starlette.exceptions import HTTPException +from starlette.requests import Request + +from phoenix.server.bearer_auth import is_authenticated + + +def create_embeddings_router(authentication_enabled: bool) -> APIRouter: + """ + Instantiates a router for the embeddings API. + """ + router = APIRouter(dependencies=[Depends(is_authenticated)] if authentication_enabled else []) + + @router.get("/exports") + async def download_exported_file(request: Request, filename: str) -> FileResponse: + file = request.app.state.export_path / (filename + ".parquet") + if not file.is_file(): + raise HTTPException(status_code=404) + return FileResponse( + path=file, + filename=file.name, + media_type="application/x-octet-stream", + ) + + return router diff --git a/src/phoenix/server/app.py b/src/phoenix/server/app.py index fa8a9bf64e..df67cbb7a0 100644 --- a/src/phoenix/server/app.py +++ b/src/phoenix/server/app.py @@ -28,7 +28,6 @@ import strawberry from fastapi import APIRouter, Depends, FastAPI from fastapi.middleware.gzip import GZipMiddleware -from fastapi.responses import FileResponse from fastapi.utils import is_body_allowed_for_status_code from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker @@ -91,7 +90,7 @@ UserRolesDataLoader, UsersDataLoader, ) -from phoenix.server.api.routers import auth_router, create_v1_router +from phoenix.server.api.routers import auth_router, create_embeddings_router, create_v1_router from phoenix.server.api.routers.v1 import REST_API_VERSION from phoenix.server.api.schema import schema from phoenix.server.bearer_auth import BearerTokenAuthBackend, is_authenticated @@ -225,18 +224,6 @@ async def dispatch( ProjectRowId: TypeAlias = int -@router.get("/exports") -async def download_exported_file(request: Request, filename: str) -> FileResponse: - file = request.app.state.export_path / (filename + ".parquet") - if not file.is_file(): - raise HTTPException(status_code=404) - return FileResponse( - path=file, - filename=file.name, - media_type="application/x-octet-stream", - ) - - @router.get("/arize_phoenix_version") async def version() -> PlainTextResponse: return PlainTextResponse(f"{phoenix.__version__}") @@ -727,6 +714,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: }, ) app.include_router(create_v1_router(authentication_enabled)) + app.include_router(create_embeddings_router(authentication_enabled)) app.include_router(router) app.include_router(graphql_router) if authentication_enabled: