diff --git a/integration_tests/_helpers.py b/integration_tests/_helpers.py index bfcc16704b..016a89812a 100644 --- a/integration_tests/_helpers.py +++ b/integration_tests/_helpers.py @@ -53,6 +53,8 @@ DEFAULT_ADMIN_PASSWORD, DEFAULT_ADMIN_USERNAME, PHOENIX_ACCESS_TOKEN_COOKIE_NAME, + PHOENIX_OAUTH2_NONCE_COOKIE_NAME, + PHOENIX_OAUTH2_STATE_COOKIE_NAME, PHOENIX_REFRESH_TOKEN_COOKIE_NAME, ) from phoenix.config import ( @@ -284,7 +286,9 @@ def log_out(self) -> None: _log_out(self) -class _RefreshToken(_Token): ... +class _RefreshToken(_Token, _CanLogOut[None]): + def log_out(self) -> None: + _log_out(self) @dataclass(frozen=True) @@ -885,6 +889,9 @@ def _log_out( ) -> None: resp = _httpx_client(auth).post("auth/logout") resp.raise_for_status() + tokens = _extract_tokens(resp.headers, "set-cookie") + for k in _COOKIE_NAMES: + assert tokens[k] == '""' def _initiate_password_reset( @@ -958,12 +965,7 @@ def _extract_tokens( if not (cookies := headers.get(key)): return {} parts = re.split(r"[ ,;=]", cookies) - return { - k: v - for k, v in zip(parts[:-1], parts[1:]) - if v.strip('"') - and k in (PHOENIX_ACCESS_TOKEN_COOKIE_NAME, PHOENIX_REFRESH_TOKEN_COOKIE_NAME) - } + return {k: v for k, v in zip(parts[:-1], parts[1:]) if k in _COOKIE_NAMES} def _decode_token_ids( @@ -973,6 +975,7 @@ def _decode_token_ids( return [ jwt.decode(v, options={"verify_signature": False})["jti"] for v in _extract_tokens(headers, key).values() + if v != '""' ] @@ -1000,3 +1003,11 @@ def _extract_html(msg: Message) -> Optional[bs4.BeautifulSoup]: content = payload.decode(part.get_content_charset() or "utf-8") return bs4.BeautifulSoup(content, "html.parser") return None + + +_COOKIE_NAMES = ( + PHOENIX_ACCESS_TOKEN_COOKIE_NAME, + PHOENIX_REFRESH_TOKEN_COOKIE_NAME, + PHOENIX_OAUTH2_STATE_COOKIE_NAME, + PHOENIX_OAUTH2_NONCE_COOKIE_NAME, +) diff --git a/integration_tests/auth/test_auth.py b/integration_tests/auth/test_auth.py index 4d3d654d1e..2a33c691db 100644 --- a/integration_tests/auth/test_auth.py +++ b/integration_tests/auth/test_auth.py @@ -54,6 +54,7 @@ _http_span_exporter, _initiate_password_reset, _log_in, + _log_out, _LoggedInUser, _Password, _patch_user, @@ -288,6 +289,19 @@ def test_can_log_out( with _EXPECTATION_401: logged_in_user.create_api_key() + @pytest.mark.parametrize("role_or_user", [_MEMBER, _ADMIN]) + def test_can_log_out_with_only_refresh_token( + self, + role_or_user: _RoleOrUser, + _get_user: _GetUser, + ) -> None: + u = _get_user(role_or_user) + refresh_token = u.log_in().tokens.refresh_token + refresh_token.log_out() + + def test_log_out_does_not_raise_exception(self) -> None: + _log_out() + class TestLoggedInTokens: class _JtiSet(Generic[_TokenT]): diff --git a/src/phoenix/server/api/routers/auth.py b/src/phoenix/server/api/routers/auth.py index 83f7ba0438..05834ddb36 100644 --- a/src/phoenix/server/api/routers/auth.py +++ b/src/phoenix/server/api/routers/auth.py @@ -116,9 +116,19 @@ async def logout( request: Request, ) -> Response: token_store: TokenStore = request.app.state.get_token_store() - if not isinstance(user := request.user, PhoenixUser): - raise HTTPException(status_code=HTTP_401_UNAUTHORIZED) - await token_store.log_out(user.identity) + user_id = None + if isinstance(user := request.user, PhoenixUser): + user_id = user.identity + elif (refresh_token := request.cookies.get(PHOENIX_REFRESH_TOKEN_COOKIE_NAME)) and ( + isinstance( + refresh_token_claims := await token_store.read(Token(refresh_token)), + RefreshTokenClaims, + ) + and isinstance(subject := refresh_token_claims.subject, UserId) + ): + user_id = subject + if user_id: + await token_store.log_out(user_id) response = Response(status_code=HTTP_204_NO_CONTENT) response = delete_access_token_cookie(response) response = delete_refresh_token_cookie(response)