Skip to content

Commit

Permalink
feat(auth): secure /exports when auth is enabled (#4589)
Browse files Browse the repository at this point in the history
  • Loading branch information
axiomofjoy authored and RogerHYang committed Sep 21, 2024
1 parent 0ce3fae commit b7af851
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 15 deletions.
9 changes: 9 additions & 0 deletions integration_tests/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ def create_api_key(
def delete_api_key(self, api_key: _ApiKey, /) -> None:
return _delete_api_key(api_key, self)

def export_embeddings(self, filename: str) -> None:
_export_embeddings(self, filename=filename)


_SYSTEM_USER_GID = _GqlId(GlobalID(type_name="User", node_id="1"))
_DEFAULT_ADMIN = _User(
Expand Down Expand Up @@ -828,6 +831,11 @@ def _log_out(
resp.raise_for_status()


def _export_embeddings(auth: Optional[_SecurityArtifact] = None, /, *, filename: str) -> None:
resp = _httpx_client(auth).get("/exports", params={"filename": filename})
resp.raise_for_status()


def _json(
resp: httpx.Response,
) -> Dict[str, Any]:
Expand All @@ -853,3 +861,4 @@ def __exit__(self, *args: Any, **kwargs: Any) -> None: ...

_EXPECTATION_401 = pytest.raises(HTTPStatusError, match="401 Unauthorized")
_EXPECTATION_403 = pytest.raises(HTTPStatusError, match="403 Forbidden")
_EXPECTATION_404 = pytest.raises(HTTPStatusError, match="404 Not Found")
19 changes: 19 additions & 0 deletions integration_tests/auth/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
_DEFAULT_ADMIN,
_DENIED,
_EXPECTATION_401,
_EXPECTATION_404,
_MEMBER,
_OK,
_OK_OR_DENIED,
Expand All @@ -45,6 +46,7 @@
_create_user,
_DefaultAdminTokenSequestration,
_Expectation,
_export_embeddings,
_GetUser,
_GqlId,
_Headers,
Expand Down Expand Up @@ -741,3 +743,20 @@ def test_api_key(
if api_key and expected is SpanExportResult.SUCCESS:
_DEFAULT_ADMIN.delete_api_key(api_key)
assert export(_spans) is SpanExportResult.FAILURE


class TestEmbeddingsRestApi:
@pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN, _DEFAULT_ADMIN])
def test_authenticated_users_can_access_route(
self,
role_or_user: _RoleOrUser,
_get_user: _GetUser,
) -> None:
user = _get_user(role_or_user)
logged_in_user = user.log_in()
with _EXPECTATION_404: # no files have been exported
logged_in_user.export_embeddings("embeddings")

def test_unauthenticated_requests_receive_401(self) -> None:
with _EXPECTATION_401:
_export_embeddings(None, filename="embeddings")
7 changes: 6 additions & 1 deletion src/phoenix/server/api/routers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from .auth import router as auth_router
from .embeddings import create_embeddings_router
from .v1 import create_v1_router

__all__ = ["auth_router", "create_v1_router"]
__all__ = [
"auth_router",
"create_embeddings_router",
"create_v1_router",
]
26 changes: 26 additions & 0 deletions src/phoenix/server/api/routers/embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from fastapi import APIRouter, Depends
from fastapi.responses import FileResponse
from starlette.exceptions import HTTPException
from starlette.requests import Request

from phoenix.server.bearer_auth import is_authenticated


def create_embeddings_router(authentication_enabled: bool) -> APIRouter:
"""
Instantiates a router for the embeddings API.
"""
router = APIRouter(dependencies=[Depends(is_authenticated)] if authentication_enabled else [])

@router.get("/exports")
async def download_exported_file(request: Request, filename: str) -> FileResponse:
file = request.app.state.export_path / (filename + ".parquet")
if not file.is_file():
raise HTTPException(status_code=404)
return FileResponse(
path=file,
filename=file.name,
media_type="application/x-octet-stream",
)

return router
16 changes: 2 additions & 14 deletions src/phoenix/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import strawberry
from fastapi import APIRouter, Depends, FastAPI
from fastapi.middleware.gzip import GZipMiddleware
from fastapi.responses import FileResponse
from fastapi.utils import is_body_allowed_for_status_code
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
Expand Down Expand Up @@ -91,7 +90,7 @@
UserRolesDataLoader,
UsersDataLoader,
)
from phoenix.server.api.routers import auth_router, create_v1_router
from phoenix.server.api.routers import auth_router, create_embeddings_router, create_v1_router
from phoenix.server.api.routers.v1 import REST_API_VERSION
from phoenix.server.api.schema import schema
from phoenix.server.bearer_auth import BearerTokenAuthBackend, is_authenticated
Expand Down Expand Up @@ -225,18 +224,6 @@ async def dispatch(
ProjectRowId: TypeAlias = int


@router.get("/exports")
async def download_exported_file(request: Request, filename: str) -> FileResponse:
file = request.app.state.export_path / (filename + ".parquet")
if not file.is_file():
raise HTTPException(status_code=404)
return FileResponse(
path=file,
filename=file.name,
media_type="application/x-octet-stream",
)


@router.get("/arize_phoenix_version")
async def version() -> PlainTextResponse:
return PlainTextResponse(f"{phoenix.__version__}")
Expand Down Expand Up @@ -727,6 +714,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
},
)
app.include_router(create_v1_router(authentication_enabled))
app.include_router(create_embeddings_router(authentication_enabled))
app.include_router(router)
app.include_router(graphql_router)
if authentication_enabled:
Expand Down

0 comments on commit b7af851

Please sign in to comment.