diff --git a/integration_tests/_helpers.py b/integration_tests/_helpers.py index e026e08a7a..c115a8e551 100644 --- a/integration_tests/_helpers.py +++ b/integration_tests/_helpers.py @@ -35,7 +35,7 @@ import httpx import pytest -from httpx import HTTPStatusError +from httpx import Headers, HTTPStatusError from openinference.semconv.resource import ResourceAttributes from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import ReadableSpan, TracerProvider @@ -135,7 +135,7 @@ def gql( self, query: str, variables: Optional[Mapping[str, Any]] = None, - ) -> Dict[str, Any]: + ) -> Tuple[Dict[str, Any], Headers]: return _gql(self, query=query, variables=variables) def create_user( @@ -641,10 +641,10 @@ def _gql( *, query: str, variables: Optional[Mapping[str, Any]] = None, -) -> Dict[str, Any]: +) -> Tuple[Dict[str, Any], Headers]: json_ = dict(query=query, variables=dict(variables or {})) resp = _httpx_client(auth).post("graphql", json=json_) - return _json(resp) + return _json(resp), resp.headers def _get_gql_spans( @@ -654,8 +654,9 @@ def _get_gql_spans( ) -> Dict[_ProjectName, List[Dict[str, Any]]]: out = "name spans{edges{node{" + " ".join(fields) + "}}}" query = "query{projects{edges{node{" + out + "}}}}" - resp_dict = _gql(auth, query=query) + resp_dict, headers = _gql(auth, query=query) assert not resp_dict.get("errors") + assert not headers.get("set-cookie") return { project["node"]["name"]: [span["node"] for span in project["node"]["spans"]["edges"]] for project in resp_dict["data"]["projects"]["edges"] @@ -677,10 +678,11 @@ def _create_user( args.append(f'username:"{username}"') out = "user{id email role{name}}" query = "mutation{createUser(input:{" + ",".join(args) + "}){" + out + "}}" - resp_dict = _gql(auth, query=query) + resp_dict, headers = _gql(auth, query=query) assert (user := resp_dict["data"]["createUser"]["user"]) assert user["email"] == email assert user["role"]["name"] == role.value + assert not headers.get("set-cookie") return _User(_GqlId(user["id"]), role, profile) @@ -692,7 +694,8 @@ def _delete_users( ) -> None: user_ids = [u.gid if isinstance(u, _User) else u for u in users] query = "mutation($userIds:[GlobalID!]!){deleteUsers(input:{userIds:$userIds})}" - _gql(auth, query=query, variables=dict(userIds=user_ids)) + _, headers = _gql(auth, query=query, variables=dict(userIds=user_ids)) + assert not headers.get("set-cookie") def _patch_user_gid( @@ -713,7 +716,7 @@ def _patch_user_gid( args.append(f"newRole:{new_role.value}") out = "user{id username role{name}}" query = "mutation{patchUser(input:{" + ",".join(args) + "}){" + out + "}}" - resp_dict = _gql(auth, query=query) + resp_dict, headers = _gql(auth, query=query) assert (data := resp_dict["data"]["patchUser"]) assert (result := data["user"]) assert result["id"] == gid @@ -721,6 +724,7 @@ def _patch_user_gid( assert result["username"] == new_username if new_role: assert result["role"]["name"] == new_role.value + assert not headers.get("set-cookie") def _patch_user( @@ -765,11 +769,15 @@ def _patch_viewer( args.append(f'newUsername:"{new_username}"') out = "user{username}" query = "mutation{patchViewer(input:{" + ",".join(args) + "}){" + out + "}}" - resp_dict = _gql(auth, query=query) + resp_dict, headers = _gql(auth, query=query) assert (data := resp_dict["data"]["patchViewer"]) assert (user := data["user"]) if new_username: assert user["username"] == new_username + if new_password: + assert headers.get("set-cookie") + else: + assert not headers.get("set-cookie") def _create_api_key( @@ -786,12 +794,13 @@ def _create_api_key( args, out = (f'name:"{name}"' + exp), "jwt apiKey{id name expiresAt}" field = f"create{kind}ApiKey" query = "mutation{" + field + "(input:{" + args + "}){" + out + "}}" - resp_dict = _gql(auth, query=query) + resp_dict, headers = _gql(auth, query=query) assert (data := resp_dict["data"][field]) assert (key := data["apiKey"]) assert key["name"] == name exp_t = datetime.fromisoformat(key["expiresAt"]) if key["expiresAt"] else None assert exp_t == expires_at + assert not headers.get("set-cookie") return _ApiKey(data["jwt"], _GqlId(key["id"]), kind) @@ -805,8 +814,9 @@ def _delete_api_key( gid = api_key.gid args, out = f'id:"{gid}"', "apiKeyId" query = "mutation{" + field + "(input:{" + args + "}){" + out + "}}" - resp_dict = _gql(auth, query=query) + resp_dict, headers = _gql(auth, query=query) assert resp_dict["data"][field]["apiKeyId"] == gid + assert not headers.get("set-cookie") def _log_in( diff --git a/src/phoenix/server/api/context.py b/src/phoenix/server/api/context.py index 8ed113e57d..ed39375614 100644 --- a/src/phoenix/server/api/context.py +++ b/src/phoenix/server/api/context.py @@ -9,8 +9,6 @@ from strawberry.fastapi import BaseContext from phoenix.auth import ( - PHOENIX_ACCESS_TOKEN_COOKIE_NAME, - PHOENIX_REFRESH_TOKEN_COOKIE_NAME, compute_password_hash, ) from phoenix.core.model_schema import Model @@ -145,9 +143,6 @@ async def hash_password(password: str, salt: bytes) -> bytes: async def log_out(self, user_id: int) -> None: assert self.token_store is not None await self.token_store.log_out(UserId(user_id)) - response = self.get_response() - response.delete_cookie(PHOENIX_REFRESH_TOKEN_COOKIE_NAME) - response.delete_cookie(PHOENIX_ACCESS_TOKEN_COOKIE_NAME) @cached_property def user(self) -> PhoenixUser: diff --git a/src/phoenix/server/api/mutations/user_mutations.py b/src/phoenix/server/api/mutations/user_mutations.py index 541eb2a4e9..718207af03 100644 --- a/src/phoenix/server/api/mutations/user_mutations.py +++ b/src/phoenix/server/api/mutations/user_mutations.py @@ -16,6 +16,8 @@ DEFAULT_ADMIN_USERNAME, DEFAULT_SECRET_LENGTH, PASSWORD_REQUIREMENTS, + PHOENIX_ACCESS_TOKEN_COOKIE_NAME, + PHOENIX_REFRESH_TOKEN_COOKIE_NAME, validate_email_format, validate_password_format, ) @@ -188,6 +190,9 @@ async def patch_viewer( assert user if input.new_password: await info.context.log_out(user.id) + response = info.context.get_response() + response.delete_cookie(PHOENIX_REFRESH_TOKEN_COOKIE_NAME) + response.delete_cookie(PHOENIX_ACCESS_TOKEN_COOKIE_NAME) return UserMutationPayload(user=to_gql_user(user)) @strawberry.mutation(permission_classes=[IsNotReadOnly, IsAdmin]) # type: ignore