Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

convert from using tokens to using api keys #39

Merged
merged 9 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 23 additions & 19 deletions frontend/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import "bootstrap/dist/css/bootstrap.min.css";
import TopNavbar from "components/nav/TopNavbar";
import NotFoundRedirect from "components/NotFoundRedirect";
import { AlertQueue, AlertQueueProvider } from "hooks/alerts";
import { AuthenticationProvider } from "hooks/auth";
import { AuthenticationProvider, OneTimePasswordWrapper } from "hooks/auth";
import { ThemeProvider } from "hooks/theme";
import Home from "pages/Home";
import NotFound from "pages/NotFound";
Expand All @@ -21,26 +21,30 @@ const App = () => {
<AuthenticationProvider>
<AlertQueueProvider>
<AlertQueue>
<TopNavbar />
<OneTimePasswordWrapper>
<TopNavbar />

<Container className="content">
<Routes>
<Route path="/" element={<Home />} />
<Route path="/robots/" element={<Robots />} />
<Route path="/robot/:id" element={<RobotDetails />} />
<Route path="/parts/" element={<Parts />} />
<Route path="/part/:id" element={<PartDetails />} />
<Route path="/404" element={<NotFound />} />
<Route path="*" element={<NotFoundRedirect />} />
</Routes>
</Container>
<Container className="content">
<Routes>
<Route path="/" element={<Home />} />
<Route path="/robots/" element={<Robots />} />
<Route path="/robot/:id" element={<RobotDetails />} />
<Route path="/parts/" element={<Parts />} />
<Route path="/part/:id" element={<PartDetails />} />
<Route path="/404" element={<NotFound />} />
<Route path="*" element={<NotFoundRedirect />} />
</Routes>
</Container>

<footer className="fixed-bottom">
{/* Solid background */}
<div className="text-center bg-body-tertiary p-2">
<a href="mailto:[email protected]">[email protected]</a>
</div>
</footer>
<footer className="fixed-bottom">
{/* Solid background */}
<div className="text-center bg-body-tertiary p-2">
<a href="mailto:[email protected]">
[email protected]
</a>
</div>
</footer>
</OneTimePasswordWrapper>
</AlertQueue>
</AlertQueueProvider>
</AuthenticationProvider>
Expand Down
6 changes: 3 additions & 3 deletions frontend/src/components/auth/GoogleAuthComponent.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ const GoogleAuthComponentInner = () => {
const [credential, setCredential] = useState<string | null>(null);
const [disableButton, setDisableButton] = useState(false);

const { setRefreshToken, api } = useAuthentication();
const { setApiKey, api } = useAuthentication();
const { addAlert } = useAlertQueue();

useEffect(() => {
Expand All @@ -28,15 +28,15 @@ const GoogleAuthComponentInner = () => {
const response = await api.post<UserLoginResponse>("/users/google", {
token: credential,
});
setRefreshToken(response.data.token);
setApiKey(response.data.token);
} catch (error) {
addAlert(humanReadableError(error), "error");
} finally {
setCredential(null);
}
}
})();
}, [credential, setRefreshToken, api, addAlert]);
}, [credential, setApiKey, api, addAlert]);

const login = useGoogleLogin({
onSuccess: (tokenResponse) => {
Expand Down
149 changes: 33 additions & 116 deletions frontend/src/hooks/auth.tsx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import axios, { AxiosError, AxiosInstance, isAxiosError } from "axios";
import axios, { AxiosInstance } from "axios";
import { BACKEND_URL } from "constants/backend";
import {
createContext,
Expand All @@ -10,49 +10,28 @@ import {
} from "react";
import { useNavigate, useSearchParams } from "react-router-dom";

const REFRESH_TOKEN_KEY = "__REFRESH_TOKEN";
const SESSION_TOKEN_KEY = "__SESSION_TOKEN";
const API_KEY_ID = "__API_KEY";

type TokenType = "refresh" | "session";

const getLocalStorageValueKey = (tokenType: TokenType) => {
switch (tokenType) {
case "refresh":
return REFRESH_TOKEN_KEY;
case "session":
return SESSION_TOKEN_KEY;
default:
throw new Error("Invalid token type");
}
};

const getLocalStorageToken = (tokenType: TokenType): string | null => {
return localStorage.getItem(getLocalStorageValueKey(tokenType));
const getLocalStorageApiKey = (): string | null => {
return localStorage.getItem(API_KEY_ID);
};

const setLocalStorageToken = (token: string, tokenType: TokenType) => {
localStorage.setItem(getLocalStorageValueKey(tokenType), token);
const setLocalStorageApiKey = (token: string) => {
localStorage.setItem(API_KEY_ID, token);
};

const deleteLocalStorageToken = (tokenType: TokenType) => {
localStorage.removeItem(getLocalStorageValueKey(tokenType));
const deleteLocalStorageApiKey = () => {
localStorage.removeItem(API_KEY_ID);
};

interface AuthenticationContextProps {
sessionToken: string | null;
setSessionToken: (token: string) => void;
refreshToken: string | null;
setRefreshToken: (token: string) => void;
apiKey: string | null;
setApiKey: (token: string) => void;
logout: () => void;
isAuthenticated: boolean;
api: AxiosInstance;
}

interface RefreshTokenResponse {
token: string;
token_type: string;
}

const AuthenticationContext = createContext<
AuthenticationContextProps | undefined
>(undefined);
Expand All @@ -64,52 +43,42 @@ interface AuthenticationProviderProps {
export const AuthenticationProvider = (props: AuthenticationProviderProps) => {
const { children } = props;

const [sessionToken, setSessionToken] = useState<string | null>(
getLocalStorageToken("session"),
);
const [refreshToken, setRefreshToken] = useState<string | null>(
getLocalStorageToken("refresh"),
);
const [apiKey, setApiKey] = useState<string | null>(getLocalStorageApiKey());

const navigate = useNavigate();

const isAuthenticated = refreshToken !== null;
const isAuthenticated = apiKey !== null;

const api = axios.create({
baseURL: BACKEND_URL,
withCredentials: true,
});

const baseApi = axios.create({
baseURL: BACKEND_URL,
});

useEffect(() => {
if (sessionToken === null) {
deleteLocalStorageToken("session");
if (apiKey === null) {
deleteLocalStorageApiKey();
} else {
setLocalStorageToken(sessionToken, "session");
setLocalStorageApiKey(apiKey);
}
}, [sessionToken]);

useEffect(() => {
if (refreshToken === null) {
deleteLocalStorageToken("refresh");
} else {
setLocalStorageToken(refreshToken, "refresh");
}
}, [refreshToken]);
}, [apiKey]);

const logout = useCallback(() => {
setSessionToken(null);
setRefreshToken(null);
navigate("/");
(async () => {
try {
await api.delete<boolean>("/users/logout");
setApiKey(null);
navigate("/");
} catch (error) {
// Do nothing
}
})();
}, [navigate]);

// Adds the API key to the request header, if it is set.
api.interceptors.request.use(
(config) => {
if (sessionToken !== null) {
config.headers.Authorization = `Bearer ${sessionToken}`;
if (apiKey !== null) {
config.headers.Authorization = `Bearer ${apiKey}`;
config.headers["Access-Control-Allow-Origin"] = "*";
}
return config;
Expand All @@ -119,63 +88,11 @@ export const AuthenticationProvider = (props: AuthenticationProviderProps) => {
},
);

api.interceptors.response.use(
(response) => response,
async (error) => {
const originalRequest = error.config;
if (error.response.status === 401 && !originalRequest._retry) {
originalRequest._retry = true;
if (refreshToken === null) {
return Promise.reject(error);
}

let localSessionToken;
try {
// Gets a new session token and try the request again.
const response = await baseApi.post<RefreshTokenResponse>(
"/users/refresh",
{},
{
headers: {
Authorization: `Bearer ${refreshToken}`,
"Access-Control-Allow-Origin": "*",
},
},
);
localSessionToken = response.data.token;
} catch (refreshError) {
if (isAxiosError(refreshError)) {
const axiosError = refreshError as AxiosError;
if (axiosError?.response?.status === 401) {
logout();
}
}
return Promise.reject(refreshError);
}

// Retry the request with the new session token.
setSessionToken(localSessionToken);
const updatedRequest = {
...originalRequest,
headers: {
Authorization: `Bearer ${localSessionToken}`,
"Access-Control-Allow-Origin": "*",
},
};
return await baseApi(updatedRequest);
}

return Promise.reject(error);
},
);

return (
<AuthenticationContext.Provider
value={{
sessionToken,
setSessionToken,
refreshToken,
setRefreshToken,
apiKey,
setApiKey,
logout,
isAuthenticated,
api,
Expand Down Expand Up @@ -210,7 +127,7 @@ export const OneTimePasswordWrapper = ({
}: OneTimePasswordWrapperProps) => {
const [searchParams] = useSearchParams();
const navigate = useNavigate();
const { setRefreshToken, api } = useAuthentication();
const { setApiKey, api } = useAuthentication();

useEffect(() => {
(async () => {
Expand All @@ -220,7 +137,7 @@ export const OneTimePasswordWrapper = ({
const response = await api.post<UserLoginResponse>("/users/otp", {
payload,
});
setRefreshToken(response.data.token);
setApiKey(response.data.token);
navigate("/");
} catch (error) {
// Do nothing
Expand All @@ -229,7 +146,7 @@ export const OneTimePasswordWrapper = ({
}
}
})();
}, [searchParams, navigate, setRefreshToken, api]);
}, [searchParams, navigate, setApiKey, api]);

return <>{children}</>;
};
49 changes: 29 additions & 20 deletions store/app/api/crud/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,23 @@
import asyncio
import uuid
import warnings
from typing import cast

from boto3.dynamodb.conditions import Key as KeyCondition

from store.app.api.crud.base import BaseCrud
from store.app.api.model import Token, User
from store.app.api.crypto import hash_api_key
from store.app.api.model import ApiKey, User


class UserCrud(BaseCrud):
async def add_user(self, user: User) -> None:
table = await self.db.Table("Users")
await table.put_item(Item=user.model_dump())

async def get_user(self, user_id: str) -> User | None:
async def get_user(self, user_id: uuid.UUID) -> User | None:
table = await self.db.Table("Users")
user_dict = await table.get_item(Key={"user_id": user_id})
user_dict = await table.get_item(Key={"user_id": str(user_id)})
if "Item" not in user_dict:
return None
user = User.model_validate(user_dict["Item"])
Expand All @@ -34,6 +36,15 @@ async def get_user_from_email(self, email: str) -> User | None:
user = User.model_validate(items[0])
return user

async def get_user_id_from_api_key(self, api_key: uuid.UUID) -> uuid.UUID | None:
table = await self.db.Table("ApiKeys")
api_key_hash = hash_api_key(api_key)
row = await table.get_item(Key={"api_key_hash": api_key_hash})
if "Item" not in row:
return None
user_id = cast(str, row["Item"]["user_id"])
return uuid.UUID(user_id)

async def delete_user(self, user: User) -> None:
table = await self.db.Table("Users")
await table.delete_item(Key={"user_id": user.user_id})
Expand All @@ -48,23 +59,21 @@ async def get_user_count(self) -> int:
table = await self.db.Table("Users")
return await table.item_count

async def add_token(self, token: Token) -> None:
table = await self.db.Table("Tokens")
await table.put_item(Item=token.model_dump())

async def get_token(self, token_id: str) -> Token | None:
table = await self.db.Table("Tokens")
token_dict = await table.get_item(Key={"token_id": token_id})
if "Item" not in token_dict:
return None
token = Token.model_validate(token_dict["Item"])
return token

async def get_user_tokens(self, user_id: str) -> list[Token]:
table = await self.db.Table("Tokens")
tokens = table.query(IndexName="userIdIndex", KeyConditionExpression=KeyCondition("user_id").eq(user_id))
tokens = [Token.model_validate(token) for token in await tokens]
return tokens
async def add_api_key(self, api_key: uuid.UUID, user_id: uuid.UUID) -> None:
row = ApiKey.from_api_key(api_key, user_id)
table = await self.db.Table("ApiKeys")
await table.put_item(Item=row.model_dump())

async def check_api_key(self, api_key: uuid.UUID, user_id: uuid.UUID) -> bool:
table = await self.db.Table("ApiKeys")
row = await table.get_item(Key={"api_key_hash": hash_api_key(api_key)})
if "Item" not in row:
return False
return row["Item"]["user_id"] == str(user_id)

async def delete_api_key(self, api_key: uuid.UUID) -> None:
table = await self.db.Table("ApiKeys")
await table.delete_item(Key={"api_key_hash": hash_api_key(api_key)})


async def test_adhoc() -> None:
Expand Down
Loading
Loading