Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(auth): add support for oauth2 with openid connect discovery #4618

Merged
merged 29 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
980d808
implement authorization code flow
axiomofjoy Sep 15, 2024
4e5e8e7
update database
axiomofjoy Sep 15, 2024
54f6cc7
working end-to-end
axiomofjoy Sep 16, 2024
30b6cd1
running on google, azure and aws
axiomofjoy Sep 17, 2024
ce2c528
refactor idp id to idp name
axiomofjoy Sep 17, 2024
b816062
oauth to oauth2
axiomofjoy Sep 17, 2024
d6b9c8d
rename files oauth to oath2
axiomofjoy Sep 17, 2024
17e1afc
redirect to login
axiomofjoy Sep 17, 2024
6f967c3
optimize query
axiomofjoy Sep 17, 2024
fee4b13
feat: style sso buttons - mikeldking
mikeldking Sep 17, 2024
7a0e84e
display error messages from query parameters
axiomofjoy Sep 17, 2024
e0f2f61
remove idp table
axiomofjoy Sep 17, 2024
aaae0b7
add oauth2 client id
axiomofjoy Sep 17, 2024
d1314ef
update azure ad to microsoft entra id
axiomofjoy Sep 17, 2024
230d52e
clean up
axiomofjoy Sep 18, 2024
448085f
store oauth2 state in cookies
axiomofjoy Sep 19, 2024
0f1ee02
fix types
axiomofjoy Sep 19, 2024
577a4de
update graphql schema
axiomofjoy Sep 19, 2024
cdb2286
update relay
axiomofjoy Sep 19, 2024
636eadd
support return urls
axiomofjoy Sep 19, 2024
9821860
ensure that state tokens with invalid signature are rejected
axiomofjoy Sep 19, 2024
3f088fd
Add OAuth rate limiters
anticorrelator Sep 19, 2024
b002378
fix rate limiter type error
axiomofjoy Sep 20, 2024
80d3a8f
explicitly reset password
axiomofjoy Sep 20, 2024
a31d6fe
use TokenStore interface
axiomofjoy Sep 20, 2024
69a5a91
remove the explicit routes
axiomofjoy Sep 20, 2024
d959ade
import pattern from typing to add support for 3.8
axiomofjoy Sep 20, 2024
7bfe1d2
increase rate limit
axiomofjoy Sep 20, 2024
4e18348
undo rate limiter fix
axiomofjoy Sep 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions app/schema.graphql
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ interface ApiKey {

enum AuthMethod {
LOCAL
OAUTH2
}

union Bin = NominalBin | IntervalBin | MissingValueBin
Expand Down
15 changes: 12 additions & 3 deletions app/src/pages/auth/LoginForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,21 @@ type LoginFormParams = {
password: string;
};

export function LoginForm() {
type LoginFormProps = {
initialError: string | null;
/**
* Callback function called when the form is submitted
*/
onSubmit?: () => void;
};
export function LoginForm(props: LoginFormProps) {
const navigate = useNavigate();
const [error, setError] = useState<string | null>(null);
const { initialError, onSubmit: propsOnSubmit } = props;
const [error, setError] = useState<string | null>(initialError);
const [isLoading, setIsLoading] = useState<boolean>(false);
const onSubmit = useCallback(
async (params: LoginFormParams) => {
propsOnSubmit?.();
setError(null);
setIsLoading(true);
try {
Expand All @@ -42,7 +51,7 @@ export function LoginForm() {
const returnUrl = getReturnUrl();
navigate(returnUrl);
},
[navigate, setError]
[navigate, propsOnSubmit, setError]
);
const { control, handleSubmit } = useForm<LoginFormParams>({
defaultValues: { email: "", password: "" },
Expand Down
44 changes: 43 additions & 1 deletion app/src/pages/auth/LoginPage.tsx
Original file line number Diff line number Diff line change
@@ -1,20 +1,62 @@
import React from "react";
import { useSearchParams } from "react-router-dom";
import { css } from "@emotion/react";

import { Flex, View } from "@arizeai/components";

import { AuthLayout } from "./AuthLayout";
import { LoginForm } from "./LoginForm";
import { OAuth2Login } from "./OAuth2Login";
import { PhoenixLogo } from "./PhoenixLogo";

const separatorCSS = css`
text-align: center;
margin-top: var(--ac-global-dimension-size-200);
margin-bottom: var(--ac-global-dimension-size-200);
color: var(--ac-global-text-color-700);
`;

const oAuthLoginButtonListCSS = css`
display: flex;
flex-direction: column;
gap: var(--ac-global-dimension-size-100);
flex-wrap: wrap;
justify-content: center;
`;

export function LoginPage() {
const oAuth2Idps = window.Config.oAuth2Idps;
const hasOAuth2Idps = oAuth2Idps.length > 0;
const [searchParams, setSearchParams] = useSearchParams();
const returnUrl = searchParams.get("returnUrl");
return (
<AuthLayout>
<Flex direction="column" gap="size-200" alignItems="center">
<View paddingBottom="size-200">
<PhoenixLogo />
</View>
</Flex>
<LoginForm />
<LoginForm
initialError={searchParams.get("error")}
onSubmit={() => setSearchParams({})}
/>
{hasOAuth2Idps && (
<>
<div css={separatorCSS}>or</div>
<ul css={oAuthLoginButtonListCSS}>
{oAuth2Idps.map((idp) => (
<li key={idp.name}>
<OAuth2Login
key={idp.name}
idpName={idp.name}
idpDisplayName={idp.displayName}
returnUrl={returnUrl}
/>
</li>
))}
</ul>
</>
)}
</AuthLayout>
);
}
95 changes: 95 additions & 0 deletions app/src/pages/auth/OAuth2Login.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import React, { ReactNode } from "react";
import { css } from "@emotion/react";

import { Button } from "@arizeai/components";

type OAuth2LoginProps = {
idpName: string;
idpDisplayName: string;
returnUrl?: string | null;
};

const loginCSS = css`
button {
width: 100%;
}
i {
display: block;
width: 20px;
height: 20px;
padding-right: var(--ac-global-dimension-size-50);
}
&[data-provider^="aws"],
&[data-provider^="google"] {
button {
background-color: white;
color: black;
&:hover {
background-color: #ececec !important;
}
}
}
`;

export function OAuth2Login({
idpName,
idpDisplayName,
returnUrl,
}: OAuth2LoginProps) {
return (
<form
action={`/oauth2/${idpName}/login${returnUrl ? `?returnUrl=${returnUrl}` : ""}`}
method="post"
css={loginCSS}
data-provider={idpName}
>
<Button
variant="default"
type="submit"
icon={<IDPIcon idpName={idpName} />}
>
Login with {idpDisplayName}
</Button>
</form>
);
}

function IDPIcon({ idpName }: { idpName: string }): ReactNode {
const hasIcon =
idpName === "github" ||
idpName === "google" ||
idpName === "microsoft_entra_id" ||
idpName.startsWith("aws");
if (!hasIcon) {
return null;
}
return (
<i>
<div
css={css`
display: inline-block;
width: 20px;
height: 20px;
position: relative;
background-size: contain;
background-repeat: no-repeat;
background-position: 50%;
&[data-provider^="github"] {
background-image: url("data:image/svg+xml;charset=utf-8,%3Csvg width='20' height='20' xmlns='http://www.w3.org/2000/svg'%3E%3Cpath d='M10 0C4.477 0 0 4.36 0 9.74c0 4.304 2.865 7.955 6.839 9.243.5.09.682-.211.682-.47 0-.23-.008-.843-.013-1.656-2.782.588-3.369-1.306-3.369-1.306-.454-1.125-1.11-1.425-1.11-1.425-.908-.604.069-.592.069-.592 1.003.069 1.531 1.004 1.531 1.004.892 1.488 2.341 1.059 2.91.81.092-.63.35-1.06.636-1.303-2.22-.245-4.555-1.081-4.555-4.814 0-1.063.39-1.933 1.029-2.613-.103-.247-.446-1.238.098-2.578 0 0 .84-.262 2.75.998A9.818 9.818 0 0 1 10 4.71c.85.004 1.705.112 2.504.328 1.909-1.26 2.747-.998 2.747-.998.546 1.34.203 2.331.1 2.578.64.68 1.028 1.55 1.028 2.613 0 3.742-2.339 4.566-4.566 4.807.359.3.678.895.678 1.804 0 1.301-.012 2.352-.012 2.671 0 .261.18.564.688.47C17.137 17.69 20 14.042 20 9.74 20 4.36 15.522 0 10 0z' fill='%23161514' fill-rule='evenodd'/%3E%3C/svg%3E");
}

&[data-provider^="google"] {
background-image: url("data:image/svg+xml;charset=utf-8,%3Csvg xmlns='http://www.w3.org/2000/svg' xmlns:xlink='http://www.w3.org/1999/xlink' viewBox='0 0 48 48'%3E%3Cdefs%3E%3Cpath id='a' d='M44.5 20H24v8.5h11.8C34.7 33.9 30.1 37 24 37c-7.2 0-13-5.8-13-13s5.8-13 13-13c3.1 0 5.9 1.1 8.1 2.9l6.4-6.4C34.6 4.1 29.6 2 24 2 11.8 2 2 11.8 2 24s9.8 22 22 22c11 0 21-8 21-22 0-1.3-.2-2.7-.5-4z'/%3E%3C/defs%3E%3CclipPath id='b'%3E%3Cuse xlink:href='%23a' overflow='visible'/%3E%3C/clipPath%3E%3Cpath clip-path='url(%23b)' fill='%23FBBC05' d='M0 37V11l17 13z'/%3E%3Cpath clip-path='url(%23b)' fill='%23EA4335' d='M0 11l17 13 7-6.1L48 14V0H0z'/%3E%3Cpath clip-path='url(%23b)' fill='%2334A853' d='M0 37l30-23 7.9 1L48 0v48H0z'/%3E%3Cpath clip-path='url(%23b)' fill='%234285F4' d='M48 48L17 24l-4-3 35-10z'/%3E%3C/svg%3E");
}
&[data-provider^="microsoft"] {
background-image: url("data:image/svg+xml;charset=utf-8,%3Csvg xmlns='http://www.w3.org/2000/svg' width='21' height='21'%3E%3Cpath fill='%23f25022' d='M1 1h9v9H1z'/%3E%3Cpath fill='%2300a4ef' d='M1 11h9v9H1z'/%3E%3Cpath fill='%237fba00' d='M11 1h9v9h-9z'/%3E%3Cpath fill='%23ffb900' d='M11 11h9v9h-9z'/%3E%3C/svg%3E");
}
&[data-provider^="aws"] {
background-image: url("data:image/svg+xml;charset=utf-8,%3Csvg width='400' height='334' xmlns='http://www.w3.org/2000/svg'%3E%3Cpath d='M236.578 94.824c-9.683.765-20.854 1.502-32.021 3.006-17.12 2.211-34.24 5.219-48.386 11.907-27.544 11.208-46.163 35.053-46.163 70.114 0 44.018 28.298 66.354 64.026 66.354 11.93 0 21.606-1.466 30.522-3.71 14.156-4.481 26.07-12.67 40.209-27.596 8.192 11.205 10.413 16.428 24.575 28.338 3.725 1.502 7.448 1.502 10.413-.742 8.932-7.458 24.561-20.873 32.773-28.33 3.71-3.012 2.955-7.463.739-11.204-8.198-10.435-16.381-19.401-16.381-39.506V96.324c0-28.359 2.214-54.453-18.614-73.822C261.147 6.815 234.34.86 213.5.86h-8.947c-37.965 2.247-78.169 18.635-87.122 65.629-1.462 5.955 3.012 8.198 5.989 8.962l41.677 5.224c4.471-.773 6.691-4.491 7.432-8.233 3.74-16.388 17.136-24.583 32.024-26.087h2.998c8.905 0 18.586 3.743 23.813 11.168 5.932 8.965 5.21 20.904 5.21 31.339v5.961h.004v.001zm0 43.278c0 17.162.723 30.571-8.195 45.461-5.208 10.437-14.141 17.154-23.827 19.4-1.481 0-3.698.766-5.947.766-16.371 0-26.077-12.673-26.077-31.334 0-23.856 14.159-35.056 32.023-40.277 9.687-2.241 20.86-2.982 32.021-2.982v8.966h.002z'/%3E%3Cpath d='M373.71 315.303c18.201-15.398 25.89-43.349 26.29-57.939v-2.44c0-3.255-.803-5.661-1.6-6.88-3.646-4.445-30.369-8.523-53.402-1.627-6.468 2.045-12.146 4.865-17.396 8.507-4.051 2.854-3.238 6.464.802 6.08 4.447-.823 10.126-1.208 16.594-2.048 14.159-1.18 30.742-1.592 34.784 3.662 5.642 6.87-6.468 36.868-11.749 49.838-1.593 4.065 2.03 5.696 5.677 2.847z' fill='%23FE9900'/%3E%3Cpath d='M2.008 257.364c52.17 47.404 120.925 75.775 197.791 75.775 47.725 0 102.727-13.381 145.199-38.899 5.676-3.27 11.316-6.912 16.565-10.952 7.286-5.25.817-13.38-6.463-10.147-3.229 1.215-6.873 2.857-10.103 4.066-46.539 18.248-95.441 26.76-140.762 26.76-72.008 0-141.56-19.87-197.786-52.684-5.259-2.822-8.907 2.428-4.441 6.081z' fill='%23FE9900'/%3E%3C/svg%3E");
}
`}
data-provider={idpName}
/>
</i>
);
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions app/src/window.d.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
export {};

type OAuth2Idp = {
name: string;
displayName: string;
};

declare global {
interface Window {
Config: {
Expand All @@ -15,6 +20,7 @@ declare global {
nSamples: number;
};
authenticationEnabled: boolean;
oAuth2Idps: OAuth2Idp[];
};
}
}
1 change: 1 addition & 0 deletions cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"graphiql",
"HDBSCAN",
"httpx",
"Idps",
"Instrumentor",
"instrumentors",
"langchain",
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ dependencies = [
"fastapi-mail",
"pydantic>=1.0,!=2.0.*,<3", # exclude 2.0.* since it does not support the `json_encoders` configuration setting
"pyjwt",
"authlib",
]
dynamic = ["version"]

Expand Down Expand Up @@ -407,6 +408,7 @@ module = [
"grpc.*",
"py_grpc_prometheus.*",
"orjson", # suppress fastapi internal type errors
"authlib.*",
]
ignore_missing_imports = true

Expand Down
84 changes: 69 additions & 15 deletions src/phoenix/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@
from hashlib import pbkdf2_hmac
from typing import Any, Literal, Optional, Protocol

from fastapi import Response
from starlette.responses import Response
from typing_extensions import TypeVar

from phoenix.config import get_env_phoenix_use_secure_cookies
from phoenix.db.models import User as OrmUser

ResponseType = TypeVar("ResponseType", bound=Response)


def compute_password_hash(*, password: str, salt: bytes) -> bytes:
Expand Down Expand Up @@ -65,34 +69,64 @@ def validate_password_format(password: str) -> None:
PASSWORD_REQUIREMENTS.validate(password)


def is_locally_authenticated(user: OrmUser) -> bool:
"""
Returns true if the user is authenticated locally, i.e., not through an
OAuth2 identity provider, and false otherwise.
"""
return user.oauth2_client_id is None and user.oauth2_user_id is None


def set_access_token_cookie(
*, response: Response, access_token: str, max_age: timedelta
) -> Response:
return _set_token_cookie(
*, response: ResponseType, access_token: str, max_age: timedelta
) -> ResponseType:
return _set_cookie(
response=response,
cookie_name=PHOENIX_ACCESS_TOKEN_COOKIE_NAME,
cookie_max_age=max_age,
token=access_token,
value=access_token,
)


def set_refresh_token_cookie(
*, response: Response, refresh_token: str, max_age: timedelta
) -> Response:
return _set_token_cookie(
*, response: ResponseType, refresh_token: str, max_age: timedelta
) -> ResponseType:
return _set_cookie(
response=response,
cookie_name=PHOENIX_REFRESH_TOKEN_COOKIE_NAME,
cookie_max_age=max_age,
token=refresh_token,
value=refresh_token,
)


def _set_token_cookie(
response: Response, cookie_name: str, cookie_max_age: timedelta, token: str
) -> Response:
def set_oauth2_state_cookie(
*, response: ResponseType, state: str, max_age: timedelta
) -> ResponseType:
return _set_cookie(
response=response,
cookie_name=PHOENIX_OAUTH2_STATE_COOKIE_NAME,
cookie_max_age=max_age,
value=state,
)


def set_oauth2_nonce_cookie(
*, response: ResponseType, nonce: str, max_age: timedelta
) -> ResponseType:
return _set_cookie(
response=response,
cookie_name=PHOENIX_OAUTH2_NONCE_COOKIE_NAME,
cookie_max_age=max_age,
value=nonce,
)


def _set_cookie(
response: ResponseType, cookie_name: str, cookie_max_age: timedelta, value: str
) -> ResponseType:
response.set_cookie(
key=cookie_name,
value=token,
value=value,
secure=get_env_phoenix_use_secure_cookies(),
httponly=True,
samesite="strict",
Expand All @@ -101,16 +135,26 @@ def _set_token_cookie(
return response


def delete_access_token_cookie(response: Response) -> Response:
def delete_access_token_cookie(response: ResponseType) -> ResponseType:
response.delete_cookie(key=PHOENIX_ACCESS_TOKEN_COOKIE_NAME)
return response


def delete_refresh_token_cookie(response: Response) -> Response:
def delete_refresh_token_cookie(response: ResponseType) -> ResponseType:
response.delete_cookie(key=PHOENIX_REFRESH_TOKEN_COOKIE_NAME)
return response


def delete_oauth2_state_cookie(response: ResponseType) -> ResponseType:
response.delete_cookie(key=PHOENIX_OAUTH2_STATE_COOKIE_NAME)
return response


def delete_oauth2_nonce_cookie(response: ResponseType) -> ResponseType:
response.delete_cookie(key=PHOENIX_OAUTH2_NONCE_COOKIE_NAME)
return response


@dataclass(frozen=True)
class _PasswordRequirements:
"""
Expand Down Expand Up @@ -206,6 +250,16 @@ def validate(
"""The name of the cookie that stores the Phoenix access token."""
PHOENIX_REFRESH_TOKEN_COOKIE_NAME = "phoenix-refresh-token"
"""The name of the cookie that stores the Phoenix refresh token."""
PHOENIX_OAUTH2_STATE_COOKIE_NAME = "phoenix-oauth2-state"
"""The name of the cookie that stores the state used for the OAuth2 authorization code flow."""
PHOENIX_OAUTH2_NONCE_COOKIE_NAME = "phoenix-oauth2-nonce"
"""The name of the cookie that stores the nonce used for the OAuth2 authorization code flow."""
DEFAULT_OAUTH2_LOGIN_EXPIRY_MINUTES = 15
"""
The default amount of time in minutes that can elapse between the initial
redirect to the IDP and the invocation of the callback URL during the OAuth2
authorization code flow.
"""


class Token(str): ...
Expand Down
Loading
Loading