diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4a50a59d..4f37f571 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -38,12 +38,6 @@ jobs: # -v /tmp/dynamodb:/data # ports: # - 8000 - # redis-local: - # image: redis - # options: >- - # -p 0:6379 - # ports: - # - 6379 steps: - name: Check out repository diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 175e6fbf..efb68b1d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -9,7 +9,7 @@ To get started developing: 1. Clone the repository 2. Install the React dependencies and create a `.env.local` file 3. Install the FastAPI dependencies -4. Start the Redis and DynamoDB databases +4. Start the DynamoDB databases 5. Initialize the test databases 6. Serve the FastAPI application 7. Serve the React frontend @@ -61,22 +61,6 @@ To run, **source the same environment variables that you use for FastAPI** and t DYNAMO_ENDPOINT=http://127.0.0.1:4566 dynamodb-admin ``` -### Redis - -For Redis, use the `redis` Docker image: - -```bash -docker pull redis # If you haven't already -docker run --name store-redis -d -p 6379:6379 redis # Start the container in the background -``` - -Then, if you need to kill the database, you can run: - -```bash -docker kill store-redis || true -docker rm store-redis || true -``` - ## FastAPI Create a Python virtual environment using either [uv](https://astral.sh/blog/uv) or [virtualenv](https://virtualenv.pypa.io/en/latest/) with at least Python 3.11. This should look something like this: @@ -119,8 +103,6 @@ export ROBOLIST_SMTP_SENDER_EMAIL= export ROBOLIST_SMTP_PASSWORD= export ROBOLIST_SMTP_SENDER_NAME= export ROBOLIST_SMTP_USERNAME= -export ROBOLIST_REDIS_HOST= -export ROBOLIST_REDIS_PASSWORD= export GITHUB_CLIENT_ID= export GITHUB_CLIENT_SECRET= ``` diff --git a/Makefile b/Makefile index ade262a1..18ab2876 100644 --- a/Makefile +++ b/Makefile @@ -37,18 +37,13 @@ start-docker-dynamodb: @docker rm store-db || true @docker run --name store-db -d -p 8000:8000 amazon/dynamodb-local -start-docker-redis: - @docker kill store-redis || true - @docker rm store-redis || true - @docker run --name store-redis -d -p 6379:6379 redis - # ------------------------ # # Code Formatting # # ------------------------ # format-backend: - @black store - @ruff format store + @black store tests + @ruff format store tests .PHONY: format format-frontend: @@ -63,9 +58,9 @@ format: format-backend format-frontend # ------------------------ # static-checks-backend: - @black --diff --check store - @ruff check store - @mypy --install-types --non-interactive store + @black --diff --check store tests + @ruff check store tests + @mypy --install-types --non-interactive store tests .PHONY: lint static-checks-frontend: diff --git a/frontend/src/App.tsx b/frontend/src/App.tsx index 516309f7..17d475a7 100644 --- a/frontend/src/App.tsx +++ b/frontend/src/App.tsx @@ -6,10 +6,8 @@ import { AlertQueue, AlertQueueProvider } from "hooks/alerts"; import { AuthenticationProvider } from "hooks/auth"; import { ThemeProvider } from "hooks/theme"; import About from "pages/About"; -import ChangeEmail from "pages/ChangeEmail"; import EditPartForm from "pages/EditPartForm"; import EditRobotForm from "pages/EditRobotForm"; -import Forgot from "pages/Forgot"; import Home from "pages/Home"; import Login from "pages/Login"; import Logout from "pages/Logout"; @@ -18,9 +16,6 @@ import NewRobot from "pages/NewRobot"; import NotFound from "pages/NotFound"; import PartDetails from "pages/PartDetails"; import Parts from "pages/Parts"; -import Register from "pages/Register"; -import RegistrationEmail from "pages/RegistrationEmail"; -import ResetPassword from "pages/ResetPassword"; import RobotDetails from "pages/RobotDetails"; import Robots from "pages/Robots"; import TestImages from "pages/TestImages"; @@ -44,18 +39,7 @@ const App = () => { } /> } /> } /> - } /> - } /> } /> - } /> - } - /> - } - /> } /> } /> } /> diff --git a/frontend/src/components/RobotForm.tsx b/frontend/src/components/RobotForm.tsx index 88790e29..1d927ee6 100644 --- a/frontend/src/components/RobotForm.tsx +++ b/frontend/src/components/RobotForm.tsx @@ -295,7 +295,7 @@ const RobotForm: React.FC = ({ Select a Part {parts.map((part, index) => ( - ))} diff --git a/frontend/src/components/files/ViewImage.tsx b/frontend/src/components/files/ViewImage.tsx index 3e4ccc2d..721e8645 100644 --- a/frontend/src/components/files/ViewImage.tsx +++ b/frontend/src/components/files/ViewImage.tsx @@ -1,16 +1,17 @@ -import { BACKEND_URL } from "constants/backend"; +import { S3_URL } from "constants/backend"; import React from "react"; interface ImageProps { imageId: string; + caption: string; } -const ImageComponent: React.FC = ({ imageId }) => { +const ImageComponent: React.FC = ({ imageId, caption }) => { return (
Robot { - const { addAlert } = useAlertQueue(); - - const auth = useAuthentication(); - const auth_api = new api(auth.api); - - const [newEmail, setNewEmail] = useState(""); - const [changeEmailSuccess, setChangeEmailSuccess] = useState(false); - const [oldPassword, setOldPassword] = useState(""); - const [newPassword, setNewPassword] = useState(""); - const [changePasswordSuccess, setChangePasswordSuccess] = - useState(false); - - const sendChangeEmail = async (event: FormEvent) => { - event.preventDefault(); - try { - await auth_api.send_change_email(newEmail); - setChangeEmailSuccess(true); - } catch (error) { - if (error instanceof Error) { - addAlert(error.message, "error"); - } else { - addAlert("Unexpected error when trying to change email", "error"); - } - } - }; - - const changePassword = async (event: FormEvent) => { - event.preventDefault(); - try { - await auth_api.change_password(oldPassword, newPassword); - setChangePasswordSuccess(true); - } catch (error) { - if (error instanceof Error) { - addAlert(error.message, "error"); - } else { - addAlert("Unexpected error when trying to change password", "error"); - } - } - }; - return ( @@ -65,78 +20,6 @@ const Sidebar = ({ show, onHide }: Props) => { height: "100%", }} > - -

- Change Email -

- {auth.email == "dummy@kscale.dev" ? ( -

- No email address associated with this account. (This is because - you registered via OAuth.) -

- ) : ( -

Current email: {auth.email}

- )} - {changeEmailSuccess ? ( -

An email has been sent to your new email address.

- ) : ( -
- - { - setNewEmail(e.target.value); - }} - value={newEmail} - required - /> - Change Email - - )} -
- -

- Change Password -

-

- You may only change your password if you have a previous password. - If not, log out and reset your password. -

- {changePasswordSuccess ? ( -

Your password has been changed.

- ) : ( -
- - { - setOldPassword(e.target.value); - }} - value={oldPassword} - required - /> - - { - setNewPassword(e.target.value); - }} - value={newPassword} - required - /> - Change Password - - )} -
About diff --git a/frontend/src/components/nav/TopNavbar.tsx b/frontend/src/components/nav/TopNavbar.tsx index 76719185..41b99033 100644 --- a/frontend/src/components/nav/TopNavbar.tsx +++ b/frontend/src/components/nav/TopNavbar.tsx @@ -66,9 +66,6 @@ const TopNavbar = () => { Login - - Register - )}
diff --git a/frontend/src/constants/backend.ts b/frontend/src/constants/backend.ts index 86e060de..097b173d 100644 --- a/frontend/src/constants/backend.ts +++ b/frontend/src/constants/backend.ts @@ -1,7 +1,9 @@ import { AxiosError, isAxiosError } from "axios"; export const BACKEND_URL = - process.env.REACT_APP_BACKEND_URL || "http://localhost:8080"; + process.env.REACT_APP_BACKEND_URL || "http://127.0.0.1:8080"; + +export const S3_URL = process.env.S3_URL || "http://127.0.0.1:4566"; // eslint-disable-next-line export const humanReadableError = (error: any | undefined) => { diff --git a/frontend/src/hooks/api.tsx b/frontend/src/hooks/api.tsx index 3e9ff158..8015fbc2 100644 --- a/frontend/src/hooks/api.tsx +++ b/frontend/src/hooks/api.tsx @@ -4,7 +4,7 @@ export interface Part { description: string; owner: string; images: Image[]; - part_id: string; + id: string; name: string; } @@ -24,7 +24,7 @@ export interface Package { } export interface Robot { - robot_id: string; + id: string; name: string; description: string; owner: string; @@ -38,7 +38,7 @@ export interface Robot { } interface MeResponse { - user_id: string; + id: string; email: string; username: string; admin: boolean; @@ -51,45 +51,6 @@ export class api { this.api = api; } - public async send_register_email(email: string): Promise { - try { - await this.api.post("/users/send-register-email", { email }); - } catch (error) { - if (axios.isAxiosError(error)) { - console.error( - "Error sending registration email:", - error.response?.data, - ); - throw new Error( - error.response?.data?.detail || "Error sending verification email", - ); - } else { - console.error("Unexpected error:", error); - throw new Error("Unexpected error"); - } - } - } - - public async get_registration_email(token: string): Promise { - try { - const res = await this.api.get("/users/registration-email/" + token); - return res.data; - } catch (error) { - if (axios.isAxiosError(error)) { - console.error( - "Error fetching registration email:", - error.response?.data, - ); - throw new Error( - error.response?.data?.detail || "Error fetching registration email", - ); - } else { - console.error("Unexpected error:", error); - throw new Error("Unexpected error"); - } - } - } - public async send_register_github(): Promise { try { const res = await this.api.get("/users/github-login"); @@ -107,99 +68,6 @@ export class api { } } - public async register( - token: string, - username: string, - password: string, - ): Promise { - try { - await this.api.post("/users/register/", { token, username, password }); - } catch (error) { - if (axios.isAxiosError(error)) { - console.error("Error registering:", error.response?.data); - throw new Error( - error.response?.data?.detail || - "Error registering with token " + token, - ); - } else { - console.error("Unexpected error:", error); - throw new Error("Unexpected error"); - } - } - } - - public async change_email(code: string): Promise { - try { - await this.api.post("/users/change-email/" + code); - } catch (error) { - if (axios.isAxiosError(error)) { - console.error("Error changing email:", error.response?.data); - throw new Error( - error.response?.data?.detail || - "Error changing email with code " + code, - ); - } else { - console.error("Unexpected error:", error); - throw new Error("Unexpected error"); - } - } - } - - public async send_change_email(new_email: string): Promise { - try { - await this.api.post("/users/change-email", { new_email }); - } catch (error) { - if (axios.isAxiosError(error)) { - console.error("Error sending change email:", error.response?.data); - throw new Error( - error.response?.data?.detail || "Error sending change email", - ); - } else { - console.error("Unexpected error:", error); - throw new Error("Unexpected error"); - } - } - } - - public async change_password( - old_password: string, - new_password: string, - ): Promise { - try { - await this.api.post("/users/change-password", { - old_password, - new_password, - }); - } catch (error) { - if (axios.isAxiosError(error)) { - console.error("Error changing password:", error.response?.data); - throw new Error( - error.response?.data?.detail || "Error changing password", - ); - } else { - console.error("Unexpected error:", error); - throw new Error("Unexpected error"); - } - } - } - - public async login(email: string, password: string): Promise { - try { - await this.api.post("/users/login/", { email, password }); - } catch (error) { - if (axios.isAxiosError(error)) { - console.error("Error logging in:", error.response?.data); - throw new Error( - error.response?.data?.detail || - "Error logging in with email " + email, - ); - } else { - console.error("Unexpected error:", error); - throw new Error("Unexpected error"); - } - } - } - public async login_github(code: string): Promise { try { const res = await this.api.get(`/users/github-code/${code}`); @@ -231,38 +99,6 @@ export class api { } } - public async forgot(email: string): Promise { - try { - await this.api.post("/users/forgot-password/", { email }); - } catch (error) { - if (axios.isAxiosError(error)) { - console.error("Error sending forgot password:", error.response?.data); - throw new Error( - error.response?.data?.detail || "Error sending forgot password", - ); - } else { - console.error("Unexpected error:", error); - throw new Error("Unexpected error"); - } - } - } - - public async reset_password(token: string, password: string): Promise { - try { - await this.api.post("/users/reset-password/" + token, { password }); - } catch (error) { - if (axios.isAxiosError(error)) { - console.error("Error resetting password:", error.response?.data); - throw new Error( - error.response?.data?.detail || "Error resetting password", - ); - } else { - console.error("Unexpected error:", error); - throw new Error("Unexpected error"); - } - } - } - public async me(): Promise { try { const res = await this.api.get("/users/me/"); @@ -282,7 +118,7 @@ export class api { public async getUserById(userId: string | undefined): Promise { const response = await this.api.get(`/users/${userId}`); - return response.data.username; + return response.data.email; } public async getRobots( @@ -309,13 +145,13 @@ export class api { public async getUserBatch(userIds: string[]): Promise> { const params = new URLSearchParams(); - userIds.forEach((id) => params.append("user_ids", id)); + userIds.forEach((id) => params.append("ids", id)); const response = await this.api.get("/users/batch/", { params, }); const map = new Map(); for (const index in response.data) { - map.set(response.data[index].user_id, response.data[index].username); + map.set(response.data[index].id, response.data[index].email); } return map; } @@ -368,8 +204,8 @@ export class api { } public async currentUser(): Promise { try { - const response = await this.api.get("/robots/user/"); - return response.data; + const response = await this.api.get("/users/me"); + return response.data.id; } catch (error) { if (axios.isAxiosError(error)) { console.error("Error fetching current user:", error.response?.data); @@ -417,7 +253,7 @@ export class api { public async editRobot(robot: Robot): Promise { const s = robot.name; try { - await this.api.post(`robots/edit-robot/${robot.robot_id}/`, robot); + await this.api.post(`robots/edit-robot/${robot.id}/`, robot); } catch (error) { if (axios.isAxiosError(error)) { console.error("Error editing robot:", error.response?.data); @@ -514,7 +350,7 @@ export class api { public async editPart(part: Part): Promise { const s = part.name; try { - await this.api.post(`parts/edit-part/${part.part_id}/`, part); + await this.api.post(`parts/edit-part/${part.id}/`, part); } catch (error) { if (axios.isAxiosError(error)) { console.error("Error editing part:", error.response?.data); diff --git a/frontend/src/hooks/auth.tsx b/frontend/src/hooks/auth.tsx index 81680cf5..6798b617 100644 --- a/frontend/src/hooks/auth.tsx +++ b/frontend/src/hooks/auth.tsx @@ -30,8 +30,8 @@ interface AuthenticationContextProps { setIsAuthenticated: React.Dispatch>; id: string | null; api: AxiosInstance; - email: string; - setEmail: React.Dispatch>; + email: string | null; + setEmail: React.Dispatch>; } const AuthenticationContext = createContext< @@ -50,7 +50,7 @@ export const AuthenticationProvider = (props: AuthenticationProviderProps) => { const [isAuthenticated, setIsAuthenticated] = useState( getLocalStorageAuth() !== null, ); - const [email, setEmail] = useState("dummy@kscale.dev"); + const [email, setEmail] = useState(null); const id = getLocalStorageAuth(); const api = axios.create({ diff --git a/frontend/src/pages/ChangeEmail.tsx b/frontend/src/pages/ChangeEmail.tsx deleted file mode 100644 index bbfcb1ba..00000000 --- a/frontend/src/pages/ChangeEmail.tsx +++ /dev/null @@ -1,36 +0,0 @@ -import { api } from "hooks/api"; -import { useAuthentication } from "hooks/auth"; -import { useEffect, useState } from "react"; -import { useParams } from "react-router-dom"; - -const ChangeEmail = () => { - const auth = useAuthentication(); - const auth_api = new api(auth.api); - const { token } = useParams(); - const [needToSend, setNeedToSend] = useState(true); - const [message, setMessage] = useState(""); - useEffect(() => { - (async () => { - if (needToSend) { - setNeedToSend(false); - if (token !== undefined) { - try { - await auth_api.change_email(token); - setMessage("Successfully changed email."); - } catch (error) { - setMessage("Verification token invalid."); - } - } else { - setMessage("No token provided"); - } - } - })(); - }, [auth_api]); - return ( - <> -

Email Verification

-

{message}

- - ); -}; -export default ChangeEmail; diff --git a/frontend/src/pages/EditPartForm.tsx b/frontend/src/pages/EditPartForm.tsx index 49852e15..30f8aab2 100644 --- a/frontend/src/pages/EditPartForm.tsx +++ b/frontend/src/pages/EditPartForm.tsx @@ -31,7 +31,7 @@ const EditPartForm: React.FC = () => { setName(PartData.name); setDescription(PartData.description); setImages(PartData.images); - setPartId(PartData.part_id); + setPartId(PartData.id); } catch (err) { addAlert(humanReadableError(err), "error"); } @@ -47,7 +47,7 @@ const EditPartForm: React.FC = () => { return; } const newFormData: Part = { - part_id: Part_id, + id: Part_id, name: name, description: Part_description, owner: "", diff --git a/frontend/src/pages/EditRobotForm.tsx b/frontend/src/pages/EditRobotForm.tsx index 07e7eb05..339d1205 100644 --- a/frontend/src/pages/EditRobotForm.tsx +++ b/frontend/src/pages/EditRobotForm.tsx @@ -40,7 +40,7 @@ const EditRobotForm: React.FC = () => { setBom(robotData.bom); setURDF(robotData.urdf); setImages(robotData.images); - setRobotId(robotData.robot_id); + setRobotId(robotData.id); setHeight(robotData.height); setWeight(robotData.weight); setDof(robotData.degrees_of_freedom); @@ -61,7 +61,7 @@ const EditRobotForm: React.FC = () => { return; } const newFormData: Robot = { - robot_id: robot_id, + id: robot_id, name: robot_name, description: robot_description, owner: "", diff --git a/frontend/src/pages/Forgot.tsx b/frontend/src/pages/Forgot.tsx deleted file mode 100644 index 7c36ef70..00000000 --- a/frontend/src/pages/Forgot.tsx +++ /dev/null @@ -1,54 +0,0 @@ -import TCButton from "components/files/TCButton"; -import { api } from "hooks/api"; -import { useAuthentication } from "hooks/auth"; -import { FormEvent, useState } from "react"; -import { Form } from "react-bootstrap"; - -const Forgot = () => { - const auth = useAuthentication(); - const auth_api = new api(auth.api); - - const [email, setEmail] = useState(""); - const [success, setSuccess] = useState(false); - - const handleSubmit = async (event: FormEvent) => { - event.preventDefault(); - try { - await auth_api.forgot(email); - setSuccess(true); - } catch (err) { - console.error(err); - } - }; - - return ( -
-

Forgot Password

- {success ? ( -

- If your account exists, an email with a password reset link will be - sent. -

- ) : ( -
- - { - setEmail(e.target.value); - }} - value={email} - required - /> - - Send - - - )} -
- ); -}; - -export default Forgot; diff --git a/frontend/src/pages/Login.tsx b/frontend/src/pages/Login.tsx index 33295c74..343bc149 100644 --- a/frontend/src/pages/Login.tsx +++ b/frontend/src/pages/Login.tsx @@ -3,10 +3,9 @@ import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; import TCButton from "components/files/TCButton"; import { useAlertQueue } from "hooks/alerts"; import { api } from "hooks/api"; -import { setLocalStorageAuth, useAuthentication } from "hooks/auth"; -import { FormEvent, useState } from "react"; +import { useAuthentication } from "hooks/auth"; +import { FormEvent } from "react"; import { Form } from "react-bootstrap"; -import { Link, useNavigate } from "react-router-dom"; const Login = () => { const auth = useAuthentication(); @@ -14,32 +13,11 @@ const Login = () => { const { addAlert } = useAlertQueue(); - const navigate = useNavigate(); - - const [password, setPassword] = useState(""); - const [email, setEmail] = useState(""); - - const handleSubmit = async (event: FormEvent) => { - event.preventDefault(); - try { - await auth_api.login(email, password); - setLocalStorageAuth(email); - navigate("/"); - } catch (err) { - if (err instanceof Error) { - addAlert(err.message, "error"); - } else { - addAlert("Unexpected error when trying to log in", "error"); - } - } - }; - const handleGithubSubmit = async (event: FormEvent) => { event.preventDefault(); try { const redirectUrl = await auth_api.send_register_github(); window.location.href = redirectUrl; - // setSuccess(true); } catch (err) { if (err instanceof Error) { addAlert(err.message, "error"); @@ -52,38 +30,11 @@ const Login = () => { return (

Login

-
- - { - setEmail(e.target.value); - }} - value={email} - required - /> - - { - setPassword(e.target.value); - }} - value={password} - required - /> -
- Forgot your password? -
- Login - +

+ If you do not already have an account, authenticating will automatically + create an account for you. +

-
OR
Login with Github diff --git a/frontend/src/pages/NewPart.tsx b/frontend/src/pages/NewPart.tsx index f7dda3a3..ff6e5806 100644 --- a/frontend/src/pages/NewPart.tsx +++ b/frontend/src/pages/NewPart.tsx @@ -22,7 +22,7 @@ const NewPart: React.FC = () => { return; } const newFormData: Part = { - part_id: "", + id: "", name: name, description: part_description, owner: "Bob", diff --git a/frontend/src/pages/NewRobot.tsx b/frontend/src/pages/NewRobot.tsx index ee755dd0..208bc763 100644 --- a/frontend/src/pages/NewRobot.tsx +++ b/frontend/src/pages/NewRobot.tsx @@ -29,7 +29,7 @@ const NewRobot: React.FC = () => { return; } const newFormData: Robot = { - robot_id: "", + id: "", name: robot_name, description: robot_description, owner: "", diff --git a/frontend/src/pages/PartDetails.tsx b/frontend/src/pages/PartDetails.tsx index 916617ac..1223ae6f 100644 --- a/frontend/src/pages/PartDetails.tsx +++ b/frontend/src/pages/PartDetails.tsx @@ -32,7 +32,7 @@ const PartDetails = () => { const [userId, setUserId] = useState(null); const { id } = useParams(); const [show, setShow] = useState(false); - const [ownerUsername, setOwnerUsername] = useState(null); + const [ownerEmail, setOwnerEmail] = useState(null); const [part, setPart] = useState(null); const [imageIndex, setImageIndex] = useState(0); const [error, setError] = useState(null); @@ -49,8 +49,8 @@ const PartDetails = () => { try { const partData = await auth_api.getPartById(id); setPart(partData); - const ownerUsername = await auth_api.getUserById(partData.owner); - setOwnerUsername(ownerUsername); + const ownerEmail = await auth_api.getUserById(partData.owner); + setOwnerEmail(ownerEmail); } catch (err) { if (err instanceof Error) { setError(err.message); @@ -68,8 +68,8 @@ const PartDetails = () => { if (auth.isAuthenticated) { try { const fetchUserId = async () => { - const user_id = await auth_api.currentUser(); - setUserId(user_id); + const id = await auth_api.currentUser(); + setUserId(id); }; fetchUserId(); } catch (err) { @@ -129,7 +129,10 @@ const PartDetails = () => {

{name}

ID: {id}
- {ownerUsername} + + This listing is maintained by{" "} + {ownerEmail}. +
@@ -250,7 +253,10 @@ const PartDetails = () => { handleShow(); }} > - +
{
- +
diff --git a/frontend/src/pages/Parts.tsx b/frontend/src/pages/Parts.tsx index ebefca74..4c77fe97 100644 --- a/frontend/src/pages/Parts.tsx +++ b/frontend/src/pages/Parts.tsx @@ -100,8 +100,8 @@ const Parts = () => { {partsData.map((part) => ( - - navigate(`/part/${part.part_id}`)}> + + navigate(`/part/${part.id}`)}> {part.images[0] && (
{ borderTopRightRadius: ".25rem", }} > - +
)} diff --git a/frontend/src/pages/Register.tsx b/frontend/src/pages/Register.tsx deleted file mode 100644 index a69f5eb8..00000000 --- a/frontend/src/pages/Register.tsx +++ /dev/null @@ -1,115 +0,0 @@ -import TCButton from "components/files/TCButton"; -import { useAlertQueue } from "hooks/alerts"; -import { api } from "hooks/api"; -import { setLocalStorageAuth, useAuthentication } from "hooks/auth"; -import { FormEvent, useEffect, useState } from "react"; -import { Col, Container, Form, Row, Spinner } from "react-bootstrap"; -import { Link, useNavigate, useParams } from "react-router-dom"; - -const Register = () => { - const auth = useAuthentication(); - const auth_api = new api(auth.api); - const { addAlert } = useAlertQueue(); - const { token } = useParams(); - - const navigate = useNavigate(); - - const [password, setPassword] = useState(""); - const [username, setUsername] = useState(""); - const [email, setEmail] = useState(""); - const [waiting, setWaiting] = useState(true); - - const handleSubmit = async (event: FormEvent) => { - event.preventDefault(); - try { - await auth_api.register(token || "", username, password); - setLocalStorageAuth(email); - navigate("/"); - } catch (err) { - if (err instanceof Error) { - addAlert(err.message, "error"); - } else { - addAlert("Unexpected error.", "error"); - } - } - }; - - useEffect(() => { - (async () => { - if (!email) { - try { - setEmail(await auth_api.get_registration_email(token || "")); - setWaiting(false); - } catch (error) { - setWaiting(false); - - if (error instanceof Error) { - addAlert(error.message, "error"); - } else { - addAlert("Unexpected error.", "error"); - } - } - } - })(); - }, []); - - if (waiting) { - return ( - - - - - - - - ); - } - - return ( -
-

Register

- {email ? ( - <> -

You are registering with email {email}.

- - - { - setUsername(e.target.value); - }} - value={username} - required - /> - - { - setPassword(e.target.value); - }} - value={password} - required - /> - Register - - - ) : ( -

- Your registration link is invalid. Try{" "} - sending yourself a new one. -

- )} -
- ); -}; - -export default Register; diff --git a/frontend/src/pages/RegistrationEmail.tsx b/frontend/src/pages/RegistrationEmail.tsx deleted file mode 100644 index 9f0d8519..00000000 --- a/frontend/src/pages/RegistrationEmail.tsx +++ /dev/null @@ -1,106 +0,0 @@ -import { faGithub } from "@fortawesome/free-brands-svg-icons"; -import { FontAwesomeIcon } from "@fortawesome/react-fontawesome"; -import TCButton from "components/files/TCButton"; -import { useAlertQueue } from "hooks/alerts"; -import { api } from "hooks/api"; -import { useAuthentication } from "hooks/auth"; -import { FormEvent, useState } from "react"; -import { Col, Container, Form, Row, Spinner } from "react-bootstrap"; - -const RegistrationEmail = () => { - const auth = useAuthentication(); - const auth_api = new api(auth.api); - const { addAlert } = useAlertQueue(); - - const [email, setEmail] = useState(""); - const [submitted, setSubmitted] = useState(false); - const [success, setSuccess] = useState(false); - const handleSubmit = async (event: FormEvent) => { - event.preventDefault(); - try { - setSubmitted(true); - await auth_api.send_register_email(email); - setSuccess(true); - } catch (err) { - setSubmitted(false); - if (err instanceof Error) { - addAlert(err.message, "error"); - } else { - addAlert("Unexpected error.", "error"); - } - } - }; - - const handleGithubSubmit = async (event: FormEvent) => { - event.preventDefault(); - try { - setSubmitted(true); - const redirectUrl = await auth_api.send_register_github(); - window.location.href = redirectUrl; - // setSuccess(true); - } catch (err) { - setSubmitted(false); - if (err instanceof Error) { - addAlert(err.message, "error"); - } else { - addAlert("Unexpected error.", "error"); - } - } - }; - - if (success) { - return ( -
-

Register

-

Check your email for a registration code.

-
- ); - } else if (submitted) { - return ( - - - - - - - - ); - } - return ( -
-

Register

-

- To create an account, enter your email address. You will then be sent an - email containing a registration link. (This helps to avoid part of the - song and dance with email verification.) -

-
- - { - setEmail(e.target.value); - }} - value={email} - required - /> - Send Code - -
-
OR
- - - Register with Github - -
-
- ); -}; - -export default RegistrationEmail; diff --git a/frontend/src/pages/ResetPassword.tsx b/frontend/src/pages/ResetPassword.tsx deleted file mode 100644 index 2eaaa665..00000000 --- a/frontend/src/pages/ResetPassword.tsx +++ /dev/null @@ -1,64 +0,0 @@ -import TCButton from "components/files/TCButton"; -import { useAlertQueue } from "hooks/alerts"; -import { api } from "hooks/api"; -import { useAuthentication } from "hooks/auth"; -import { FormEvent, useState } from "react"; -import { Form } from "react-bootstrap"; -import { useParams } from "react-router-dom"; - -const ResetPassword = () => { - const auth = useAuthentication(); - const auth_api = new api(auth.api); - const { token } = useParams(); - - const [password, setPassword] = useState(""); - const [success, setSuccess] = useState(false); - const { addAlert } = useAlertQueue(); - - const handleSubmit = async (event: FormEvent) => { - event.preventDefault(); - if (token === undefined) { - addAlert("No token provided", "error"); - return; - } - try { - await auth_api.reset_password(token, password); - setSuccess(true); - } catch (err) { - if (err instanceof Error) { - addAlert(err.message, "error"); - } else { - addAlert("Unexpected error resetting password.", "error"); - } - } - }; - - return ( -
-

Reset Password

- {success ? ( -

Your password has been reset.

- ) : ( -
- - { - setPassword(e.target.value); - }} - value={password} - required - /> - - Send - - - )}{" "} -
- ); -}; - -export default ResetPassword; diff --git a/frontend/src/pages/RobotDetails.tsx b/frontend/src/pages/RobotDetails.tsx index a71c1318..1b9114f1 100644 --- a/frontend/src/pages/RobotDetails.tsx +++ b/frontend/src/pages/RobotDetails.tsx @@ -45,7 +45,7 @@ const RobotDetails = () => { const [userId, setUserId] = useState(null); const { id } = useParams(); const [show, setShow] = useState(false); - const [ownerUsername, setOwnerUsername] = useState(null); + const [ownerEmail, setOwnerEmail] = useState(null); const [robot, setRobot] = useState(null); const [parts, setParts] = useState([]); const [package_urls, setPackages] = useState([]); @@ -65,8 +65,8 @@ const RobotDetails = () => { try { const robotData = await auth_api.getRobotById(id); setRobot(robotData); - const ownerUsername = await auth_api.getUserById(robotData.owner); - setOwnerUsername(ownerUsername); + const ownerEmail = await auth_api.getUserById(robotData.owner); + setOwnerEmail(ownerEmail); const curPackages = []; for (let i = 0; i < robotData.packages.length; i++) { const package_id = robotData.packages[i].name; @@ -77,8 +77,8 @@ const RobotDetails = () => { setPackages(curPackages); const parts = robotData.bom.map(async (part) => { return { - name: (await auth_api.getPartById(part.part_id)).name, - part_id: part.part_id, + name: (await auth_api.getPartById(id)).name, + part_id: id, quantity: part.quantity, }; }); @@ -120,8 +120,8 @@ const RobotDetails = () => { if (auth.isAuthenticated) { try { const fetchUserId = async () => { - const user_id = await auth_api.currentUser(); - setUserId(user_id); + const id = await auth_api.currentUser(); + setUserId(id); }; fetchUserId(); } catch (err) { @@ -189,7 +189,10 @@ const RobotDetails = () => {

{name}

ID: {id}
- {ownerUsername} + + This listing is maintained by{" "} + {ownerEmail}. +
{((response.height && response.height !== "") || @@ -251,7 +254,7 @@ const RobotDetails = () => { {parts.map((part, key) => ( - {part.name} + {part.name} {part.quantity} @@ -375,7 +378,10 @@ const RobotDetails = () => { handleShow(); }} > - + { position: "relative", }} > - + diff --git a/frontend/src/pages/Robots.tsx b/frontend/src/pages/Robots.tsx index f6999323..ea9b8c97 100644 --- a/frontend/src/pages/Robots.tsx +++ b/frontend/src/pages/Robots.tsx @@ -100,8 +100,8 @@ const Robots = () => { {robotsData.map((robot) => ( - - navigate(`/robot/${robot.robot_id}`)}> + + navigate(`/robot/${robot.id}`)}> {robot.images[0] && (
{ borderTopRightRadius: ".25rem", }} > - +
)} diff --git a/frontend/src/pages/YourParts.tsx b/frontend/src/pages/YourParts.tsx index b97e1553..a7d606d6 100644 --- a/frontend/src/pages/YourParts.tsx +++ b/frontend/src/pages/YourParts.tsx @@ -75,8 +75,8 @@ const YourParts = () => { {partsData.map((part) => ( - - navigate(`/part/${part.part_id}`)}> + + navigate(`/part/${part.id}`)}> {part.images[0] && (
{ borderTopRightRadius: ".25rem", }} > - +
)} diff --git a/frontend/src/pages/YourRobots.tsx b/frontend/src/pages/YourRobots.tsx index 6b9ffa03..fa64c7ab 100644 --- a/frontend/src/pages/YourRobots.tsx +++ b/frontend/src/pages/YourRobots.tsx @@ -75,8 +75,8 @@ const YourRobots = () => { {robotsData.map((robot) => ( - - navigate(`/robot/${robot.robot_id}`)}> + + navigate(`/robot/${robot.id}`)}> {robot.images[0] && (
{ borderTopRightRadius: ".25rem", }} > - +
)} diff --git a/pyproject.toml b/pyproject.toml index 79a2da4e..94900354 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ namespace_packages = false module = [ "boto3.*", + "moto.*", ] ignore_missing_imports = true diff --git a/store/app/crud/base.py b/store/app/crud/base.py index 75b41d5a..9b3c3ff4 100644 --- a/store/app/crud/base.py +++ b/store/app/crud/base.py @@ -2,15 +2,29 @@ import itertools import logging -from typing import Any, AsyncContextManager, Literal, Self +from typing import Any, AsyncContextManager, Callable, Literal, Self, TypeVar, overload import aioboto3 +from boto3.dynamodb.conditions import Key from botocore.exceptions import ClientError from types_aiobotocore_dynamodb.service_resource import DynamoDBServiceResource from types_aiobotocore_s3.service_resource import S3ServiceResource +from store.app.model import RobolistBaseModel + +TABLE_NAME = "Robolist" + logger = logging.getLogger(__name__) +T = TypeVar("T", bound=RobolistBaseModel) + +DEFAULT_CHUNK_SIZE = 100 +DEFAULT_SCAN_LIMIT = 1000 +ITEMS_PER_PAGE = 12 + +TableKey = tuple[str, Literal["S", "N", "B"], Literal["HASH", "RANGE"]] +GlobalSecondaryIndex = tuple[str, str, Literal["S", "N", "B"], Literal["HASH", "RANGE"]] + class BaseCrud(AsyncContextManager["BaseCrud"]): def __init__(self) -> None: @@ -30,6 +44,12 @@ def s3(self) -> S3ServiceResource: raise RuntimeError("Must call __aenter__ first!") return self.__s3 + @classmethod + def get_gsis(cls) -> list[GlobalSecondaryIndex]: + return [ + ("typeIndex", "type", "S", "HASH"), + ] + async def __aenter__(self) -> Self: session = aioboto3.Session() db = session.resource("dynamodb") @@ -48,11 +68,215 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # if self.__s3 is not None: await self.__s3.__aexit__(exc_type, exc_val, exc_tb) + async def _add_item(self, item: RobolistBaseModel) -> None: + table = await self.db.Table(TABLE_NAME) + item_data = item.model_dump() + if "type" in item_data: + raise ValueError("Cannot add item with 'type' attribute") + item_data["type"] = item.__class__.__name__ + await table.put_item(Item=item_data) + + async def _delete_item(self, item: RobolistBaseModel | str) -> None: + table = await self.db.Table(TABLE_NAME) + if isinstance(item, str): + await table.delete_item(Key={"id": item}) + else: + await table.delete_item(Key={"id": item.id}) + + async def _list_items( + self, + item_class: type[T], + expression_attribute_names: dict[str, str] | None = None, + expression_attribute_values: dict[str, Any] | None = None, + filter_expression: str | None = None, + offset: int | None = None, + limit: int = DEFAULT_SCAN_LIMIT, + ) -> list[T]: + table = await self.db.Table(TABLE_NAME) + + query_params = { + "IndexName": "typeIndex", + "KeyConditionExpression": Key("type").eq(item_class.__name__), + "Limit": limit, + } + + if expression_attribute_names: + query_params["ExpressionAttributeNames"] = expression_attribute_names + if expression_attribute_values: + query_params["ExpressionAttributeValues"] = expression_attribute_values + if filter_expression: + query_params["FilterExpression"] = filter_expression + if offset: + query_params["ExclusiveStartKey"] = {"id": offset} + + items = (await table.query(**query_params))["Items"] + return [self._validate_item(item, item_class) for item in items] + + async def _list( + self, item_class: type[T], page: int, sort_key: Callable[[T], int], search_query: str | None = None + ) -> tuple[list[T], bool]: + if search_query: + response = await self._list_items( + item_class, + filter_expression="contains(#part_name, :query) OR contains(description, :query)", + expression_attribute_names={"#part_name": "name"}, + expression_attribute_values={":query": search_query}, + ) + else: + response = await self._list_items(item_class) + sorted_items = sorted(response, key=sort_key, reverse=True) + return sorted_items[(page - 1) * ITEMS_PER_PAGE : page * ITEMS_PER_PAGE], page * ITEMS_PER_PAGE < len(response) + + async def _list_your( + self, + item_class: type[T], + user_id: str, + page: int, + sort_key: Callable[[T], int], + search_query: str | None = None, + ) -> tuple[list[T], bool]: + if search_query: + response = await self._list_items( + item_class, + filter_expression="(contains(#p_name, :query) OR contains(description, :query)) AND #p_owner=:user_id", + expression_attribute_names={"#p_name": "name", "#p_owner": "owner"}, + expression_attribute_values={":query": search_query, ":user_id": user_id}, + ) + else: + response = await self._list_items( + item_class, + filter_expression="#p_owner=:user_id", + expression_attribute_values={":user_id": user_id}, + expression_attribute_names={"#p_owner": "owner"}, + ) + sorted_items = sorted(response, key=sort_key, reverse=True) + return sorted_items[(page - 1) * ITEMS_PER_PAGE : page * ITEMS_PER_PAGE], page * ITEMS_PER_PAGE < len(response) + + async def _count_items(self, item_class: type[T]) -> int: + table = await self.db.Table(TABLE_NAME) + item_dict = await table.scan( + IndexName="typeIndex", + Select="COUNT", + FilterExpression=Key("type").eq(item_class.__name__), + ) + return item_dict["Count"] + + def _validate_item(self, data: dict[str, Any], item_class: type[T]) -> T: + if (item_type := data.pop("type")) != item_class.__name__: + raise ValueError(f"Item type {str(item_type)} is not a {item_class.__name__}") + return item_class.model_validate(data) + + @overload + async def _get_item(self, item_id: str, item_class: type[T], throw_if_missing: Literal[True]) -> T: ... + + @overload + async def _get_item(self, item_id: str, item_class: type[T], throw_if_missing: Literal[False]) -> T | None: ... + + async def _get_item(self, item_id: str, item_class: type[T], throw_if_missing: bool = False) -> T | None: + table = await self.db.Table(TABLE_NAME) + item_dict = await table.get_item(Key={"id": item_id}) + if "Item" not in item_dict: + if throw_if_missing: + raise ValueError(f"Item {item_id} not found") + return None + item_data = item_dict["Item"] + return self._validate_item(item_data, item_class) + + async def _item_exists(self, item_id: str) -> bool: + table = await self.db.Table(TABLE_NAME) + item_dict = await table.get_item(Key={"id": item_id}) + return "Item" in item_dict + + async def _get_item_batch( + self, + item_ids: list[str], + item_class: type[T], + chunk_size: int = DEFAULT_CHUNK_SIZE, + ) -> list[T]: + items: list[T] = [] + for i in range(0, len(item_ids), chunk_size): + chunk = item_ids[i : i + chunk_size] + keys = [{"id": item_id} for item_id in chunk] + response = await self.db.batch_get_item(RequestItems={TABLE_NAME: {"Keys": keys}}) + for item in response["Responses"][TABLE_NAME]: + items.append(self._validate_item(item, item_class)) + return items + + async def _get_items_from_secondary_index( + self, + secondary_index: str, + secondary_index_name: str, + secondary_index_value: str, + item_class: type[T], + ) -> list[T]: + table = await self.db.Table(TABLE_NAME) + item_dict = await table.query( + IndexName=secondary_index, + KeyConditionExpression=Key(secondary_index_name).eq(secondary_index_value), + ) + items = item_dict["Items"] + return [self._validate_item(item, item_class) for item in items] + + @overload + async def _get_unique_item_from_secondary_index( + self, + secondary_index: str, + secondary_index_name: str, + secondary_index_value: str, + item_class: type[T], + throw_if_missing: Literal[True], + ) -> T: ... + + @overload + async def _get_unique_item_from_secondary_index( + self, + secondary_index: str, + secondary_index_name: str, + secondary_index_value: str, + item_class: type[T], + throw_if_missing: Literal[False] = False, + ) -> T | None: ... + + async def _get_unique_item_from_secondary_index( + self, + secondary_index: str, + secondary_index_name: str, + secondary_index_value: str, + item_class: type[T], + throw_if_missing: bool = False, + ) -> T | None: + items = await self._get_items_from_secondary_index( + secondary_index, + secondary_index_name, + secondary_index_value, + item_class, + ) + if len(items) == 0: + if throw_if_missing: + raise ValueError(f"No items found with {secondary_index_name} {secondary_index_value}") + return None + if len(items) > 1: + raise ValueError(f"Multiple items found with {secondary_index_name} {secondary_index_value}") + return items[0] + + async def _update_item(self, item_id: str, item_class: type[T], new_values: dict[str, Any]) -> None: # noqa: ANN401 + # Validates the new values. + for field_name, field_value in new_values.items(): + if item_class.model_fields.get(field_name) is None: + raise ValueError(f"Field {field_name} not in model {item_class.__name__}") + + # Updates the table. + table = await self.db.Table(TABLE_NAME) + await table.update_item( + Key={"id": item_id}, + AttributeUpdates={k: {"Value": v, "Action": "PUT"} for k, v in new_values.items() if k != "id"}, + ) + async def _create_dynamodb_table( self, name: str, - keys: list[tuple[str, Literal["S", "N", "B"], Literal["HASH", "RANGE"]]], - gsis: list[tuple[str, str, Literal["S", "N", "B"], Literal["HASH", "RANGE"]]] | None = None, + keys: list[TableKey], + gsis: list[GlobalSecondaryIndex] | None = None, deletion_protection: bool = False, ) -> None: """Creates a table in the Dynamo database if a table of that name does not already exist. @@ -73,18 +297,7 @@ async def _create_dynamodb_table( except ClientError: logger.info("Creating %s table", name) - if gsis is None: - table = await self.db.create_table( - AttributeDefinitions=[ - {"AttributeName": n, "AttributeType": t} for n, t in ((n, t) for (n, t, _) in keys) - ], - TableName=name, - KeySchema=[{"AttributeName": n, "KeyType": t} for n, _, t in keys], - DeletionProtectionEnabled=deletion_protection, - BillingMode="PAY_PER_REQUEST", - ) - - else: + if gsis: table = await self.db.create_table( AttributeDefinitions=[ {"AttributeName": n, "AttributeType": t} @@ -106,6 +319,17 @@ async def _create_dynamodb_table( BillingMode="PAY_PER_REQUEST", ) + else: + table = await self.db.create_table( + AttributeDefinitions=[ + {"AttributeName": n, "AttributeType": t} for n, t in ((n, t) for (n, t, _) in keys) + ], + TableName=name, + KeySchema=[{"AttributeName": n, "KeyType": t} for n, _, t in keys], + DeletionProtectionEnabled=deletion_protection, + BillingMode="PAY_PER_REQUEST", + ) + await table.wait_until_exists() async def _delete_dynamodb_table(self, name: str) -> None: diff --git a/store/app/crud/robots.py b/store/app/crud/robots.py index fc96ea2c..559215d1 100644 --- a/store/app/crud/robots.py +++ b/store/app/crud/robots.py @@ -1,238 +1,48 @@ """Defines CRUD interface for robot API.""" import logging -from typing import Any, Dict, List, Optional -from boto3.dynamodb.conditions import Key from fastapi import UploadFile -from fastapi.responses import RedirectResponse -from pydantic import BaseModel from store.app.crud.base import BaseCrud -from store.app.model import Bom, Image, Package, Part, Robot -from store.settings import settings +from store.app.model import Part, Robot logger = logging.getLogger(__name__) -class EditPart(BaseModel): - name: str - description: str - images: List[Image] - - -class EditRobot(BaseModel): - name: str - description: str - bom: List[Bom] - images: List[Image] - height: Optional[str] - weight: Optional[str] - degrees_of_freedom: Optional[str] - urdf: Optional[str] - packages: List[Package] - - -def serialize_bom(bom: Bom) -> dict: - return {"part_id": bom.part_id, "quantity": str(bom.quantity)} - - -def serialize_bom_list(bom_list: List[Bom]) -> List[dict]: - return [serialize_bom(bom) for bom in bom_list] - - -def serialize_image(image: Image) -> dict: - return {"caption": image.caption, "url": image.url} - - -def serialize_image_list(image_list: List[Image]) -> List[dict]: - return [serialize_image(image) for image in image_list] - - -def serialize_package(package: Package) -> dict: - return {"name": package.name, "url": package.url} - - -def serialize_package_list(package_list: List[Package]) -> List[dict]: - return [serialize_package(package) for package in package_list] - - -def get_timestamp(item: Dict[str, Any]) -> int: - return item["timestamp"] - - class RobotCrud(BaseCrud): async def add_robot(self, robot: Robot) -> None: - table = await self.db.Table("Robots") - await table.put_item(Item=robot.model_dump()) - - async def add_part(self, part: Part) -> None: - table = await self.db.Table("Parts") - await table.put_item(Item=part.model_dump()) - - async def list_robots( - self, page: int = 1, items_per_page: int = 12, search_query: Optional[str] = None - ) -> tuple[list[Robot], bool]: - table = await self.db.Table("Robots") - if search_query: - response = await table.scan( - FilterExpression="contains(#robot_name, :query) OR contains(description, :query)", - ExpressionAttributeValues={":query": search_query}, - ExpressionAttributeNames={ - "#robot_name": "name" - }, # Define the placeholder since "name" is a dynamodb reserved keyword - ) - else: - response = await table.scan() - # This is O(n log n). Look into better ways to architect the schema. - sorted_items = sorted(response["Items"], key=get_timestamp, reverse=True) - return [ - Robot.model_validate(item) for item in sorted_items[(page - 1) * items_per_page : page * items_per_page] - ], page * items_per_page < response["Count"] - - async def list_your_robots(self, user_id: str, page: int = 1, items_per_page: int = 12) -> tuple[list[Robot], bool]: - table = await self.db.Table("Robots") - response = await table.query(IndexName="ownerIndex", KeyConditionExpression=Key("owner").eq(user_id)) - sorted_items = sorted(response["Items"], key=get_timestamp, reverse=True) - return [ - Robot.model_validate(item) for item in sorted_items[(page - 1) * items_per_page : page * items_per_page] - ], page * items_per_page < response["Count"] + await self._add_item(robot) async def get_robot(self, robot_id: str) -> Robot | None: - table = await self.db.Table("Robots") - robot_dict = await table.get_item(Key={"robot_id": robot_id}) - if "Item" not in robot_dict: - return None - return Robot.model_validate(robot_dict["Item"]) - - async def list_parts( - self, page: int = 1, items_per_page: int = 12, search_query: Optional[str] = None - ) -> tuple[list[Part], bool]: - table = await self.db.Table("Parts") - if search_query: - response = await table.scan( - FilterExpression="contains(#part_name, :query) OR contains(description, :query)", - ExpressionAttributeValues={":query": search_query}, - ExpressionAttributeNames={ - "#part_name": "name" - }, # Define the placeholder since "name" is a dynamodb reserved keyword - ) - else: - response = await table.scan() - # This is O(n log n). Look into better ways to architect the schema. - sorted_items = sorted(response["Items"], key=get_timestamp, reverse=True) - return [ - Part.model_validate(item) for item in sorted_items[(page - 1) * items_per_page : page * items_per_page] - ], page * items_per_page < response["Count"] + return await self._get_item(robot_id, Robot, throw_if_missing=False) - async def dump_parts(self) -> list[Part]: - table = await self.db.Table("Parts") - response = await table.scan() - return [Part.model_validate(item) for item in response["Items"]] - - async def list_your_parts(self, user_id: str, page: int = 1, items_per_page: int = 12) -> tuple[list[Part], bool]: - table = await self.db.Table("Parts") - response = await table.query(IndexName="ownerIndex", KeyConditionExpression=Key("owner").eq(user_id)) - sorted_items = sorted(response["Items"], key=get_timestamp, reverse=True) - return [ - Part.model_validate(item) for item in sorted_items[(page - 1) * items_per_page : page * items_per_page] - ], page * items_per_page < response["Count"] + async def delete_robot(self, robot_id: str) -> None: + await self._delete_item(robot_id) + + async def add_part(self, part: Part) -> None: + await self._add_item(part) async def get_part(self, part_id: str) -> Part | None: - table = await self.db.Table("Parts") - part_dict = await table.get_item(Key={"part_id": part_id}) - if "Item" not in part_dict: - return None - return Part.model_validate(part_dict["Item"]) + return await self._get_item(part_id, Part, throw_if_missing=False) async def delete_part(self, part_id: str) -> None: - table = await self.db.Table("Parts") - await table.delete_item(Key={"part_id": part_id}) + await self._delete_item(part_id) - async def delete_robot(self, robot_id: str) -> None: - table = await self.db.Table("Robots") - await table.delete_item(Key={"robot_id": robot_id}) - - async def update_part(self, id: str, part: EditPart) -> None: - table = await self.db.Table("Parts") - update_expression = "SET #name = :name, \ - #description = :description, \ - #images = :images, " - - expression_attribute_names = { - "#name": "name", - "#description": "description", - "#images": "images", - } - - expression_attribute_values = { - ":name": part.name, - ":description": part.description, - ":images": serialize_image_list(part.images), - } - - await table.update_item( - Key={"part_id": id}, - UpdateExpression=update_expression[:-2], - ExpressionAttributeValues=expression_attribute_values, - ExpressionAttributeNames=expression_attribute_names, - ReturnValues="NONE", - ) - - async def update_robot(self, id: str, robot: EditRobot) -> None: - table = await self.db.Table("Robots") - update_expression = "SET #name = :name, \ - #description = :description, \ - #bom = :bom, \ - #images = :images, \ - #packages = :packages, " - - expression_attribute_names = { - "#name": "name", - "#description": "description", - "#bom": "bom", - "#images": "images", - "#packages": "packages", - } - - expression_attribute_values = { - ":name": robot.name, - ":description": robot.description, - ":bom": serialize_bom_list(robot.bom), - ":images": serialize_image_list(robot.images), - ":packages": serialize_package_list(robot.packages), - } - - if robot.urdf is not None: - update_expression += "#urdf = :urdf, " - expression_attribute_names["#urdf"] = "urdf" - expression_attribute_values[":urdf"] = robot.urdf or "" - - if robot.height is not None: - update_expression += "#height = :height, " - expression_attribute_names["#height"] = "height" - expression_attribute_values[":height"] = robot.height or "" - - if robot.weight is not None: - update_expression += "#weight = :weight, " - expression_attribute_names["#weight"] = "weight" - expression_attribute_values[":weight"] = robot.weight or "" - - if robot.degrees_of_freedom is not None: - update_expression += "#degrees_of_freedom = :degrees_of_freedom, " - expression_attribute_names["#degrees_of_freedom"] = "degrees_of_freedom" - expression_attribute_values[":degrees_of_freedom"] = robot.degrees_of_freedom or "" - - await table.update_item( - Key={"robot_id": id}, - UpdateExpression=update_expression[:-2], - ExpressionAttributeValues=expression_attribute_values, - ExpressionAttributeNames=expression_attribute_names, - ReturnValues="NONE", - ) - - async def get_image(self, url: str) -> RedirectResponse: - return RedirectResponse(url=f"{settings.site.image_url}/{url}") + async def dump_parts(self) -> list[Part]: + return await self._list_items(Part) + + async def list_robots(self, page: int, search_query: str | None = None) -> tuple[list[Robot], bool]: + return await self._list(Robot, page, lambda x: x.timestamp, search_query) + + async def list_your_robots(self, user_id: str, page: int, search_query: str) -> tuple[list[Robot], bool]: + return await self._list_your(Robot, user_id, page, lambda x: x.timestamp, search_query) + + async def list_parts(self, page: int, search_query: str | None = None) -> tuple[list[Part], bool]: + return await self._list(Part, page, lambda x: x.timestamp, search_query) + + async def list_your_parts(self, user_id: str, page: int, search_query: str) -> tuple[list[Part], bool]: + return await self._list_your(Part, user_id, page, lambda x: x.timestamp, search_query) async def upload_image(self, file: UploadFile) -> None: await (await self.s3.Bucket("images")).upload_fileobj(file.file, file.filename or "") diff --git a/store/app/crud/users.py b/store/app/crud/users.py index dd05d275..8379ae78 100644 --- a/store/app/crud/users.py +++ b/store/app/crud/users.py @@ -1,223 +1,132 @@ """Defines CRUD interface for user API.""" import asyncio -import json import warnings -from typing import Any, Self +from datetime import datetime -from boto3.dynamodb.conditions import Key -from redis.asyncio import Redis - -from store.app.crud.base import BaseCrud -from store.app.crypto import hash_password, hash_token -from store.app.model import User +from store.app.crud.base import BaseCrud, GlobalSecondaryIndex +from store.app.model import APIKey, OAuthKey, User from store.settings import settings +from store.utils import LRUCache + +# This dictionary is used to locally cache the last time a token was validated +# against the database. We give the tokens some buffer time to avoid hitting +# the database too often. +LAST_API_KEY_VALIDATION = LRUCache[str, tuple[datetime, bool]](2**20) + + +def github_auth_key(github_id: str) -> str: + return f"github:{github_id}" + + +def google_auth_key(google_id: str) -> str: + return f"google:{google_id}" + + +class UserNotFoundError(Exception): + def __init__(self, message: str) -> None: + super().__init__(message) class UserCrud(BaseCrud): def __init__(self) -> None: super().__init__() - self.__session_kv: Redis | None = None - self.__register_kv: Redis | None = None - self.__reset_password_kv: Redis | None = None - self.__change_email_kv: Redis | None = None - - async def __aenter__(self) -> Self: - self, sessions = await asyncio.gather( - super().__aenter__(), - asyncio.gather( - *( - Redis( - host=settings.redis.host, - password=settings.redis.password if settings.redis.password else None, - port=settings.redis.port, - db=db, - ).__aenter__() - for db in ( - settings.redis.session_db, - settings.redis.verify_email_db, - settings.redis.reset_password_db, - settings.redis.change_email_db, - ) - ) - ), - ) - - self.__session_kv, self.__register_kv, self.__reset_password_kv, self.__change_email_kv = sessions - return self - - async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # noqa: ANN401 - await asyncio.gather( - super().__aexit__(exc_type, exc_val, exc_tb), - asyncio.gather( - *( - kv.__aexit__(exc_type, exc_val, exc_tb) - for kv in ( - self.session_kv, - self.register_kv, - self.reset_password_kv, - self.change_email_kv, - ) - ) - ), - ) - - @property - def session_kv(self) -> Redis: - if self.__session_kv is None: - raise RuntimeError("Must call __aenter__ first!") - return self.__session_kv - - @property - def register_kv(self) -> Redis: - if self.__register_kv is None: - raise RuntimeError("Must call __aenter__ first!") - return self.__register_kv - - @property - def reset_password_kv(self) -> Redis: - if self.__reset_password_kv is None: - raise RuntimeError("Must call __aenter__ first!") - return self.__reset_password_kv - - @property - def change_email_kv(self) -> Redis: - if self.__change_email_kv is None: - raise RuntimeError("Must call __aenter__ first!") - return self.__change_email_kv - - async def add_user(self, user: User) -> None: - # Then, add the user object to the Users table. - table = await self.db.Table("Users") - await table.put_item( - Item=user.model_dump(), - ConditionExpression="attribute_not_exists(oauth_id) AND attribute_not_exists(email) AND \ - attribute_not_exists(username)", - ) - - async def get_user(self, user_id: str) -> User | None: - table = await self.db.Table("Users") - user_dict = await table.get_item(Key={"user_id": user_id}) - if "Item" not in user_dict: + @classmethod + def get_gsis(cls) -> list[GlobalSecondaryIndex]: + return super().get_gsis() + [ + ("emailIndex", "email", "S", "HASH"), + ] + + async def get_user(self, id: str) -> User | None: + return await self._get_item(id, User, throw_if_missing=False) + + async def create_user_from_token(self, token: str, email: str) -> User: + user = User.create(email=email) + await self._add_item(user) + key = OAuthKey.create(token, user.id) + await self._add_item(key) + return user + + async def get_user_from_token(self, token: str) -> User | None: + key = await self._get_item(token, OAuthKey, throw_if_missing=False) + if key is None: return None - return User.model_validate(user_dict["Item"]) - - async def get_user_batch(self, user_ids: list[str]) -> list[User]: - users: list[User] = [] - chunk_size = 100 - for i in range(0, len(user_ids), chunk_size): - chunk = user_ids[i : i + chunk_size] - keys = [{"user_id": user_id} for user_id in chunk] - response = await self.db.batch_get_item(RequestItems={"Users": {"Keys": keys}}) - users.extend(User.model_validate(user) for user in response["Responses"]["Users"]) - return users + return await self.get_user(key.user_id) + + async def create_user_from_github_token(self, github_id: str, email: str) -> User: + return await self.create_user_from_token(github_auth_key(github_id), email) + + async def create_user_from_google_token(self, google_id: str, email: str) -> User: + return await self.create_user_from_token(google_auth_key(google_id), email) + + async def get_user_from_github_token(self, token: str) -> User | None: + return await self.get_user_from_token(github_auth_key(token)) + + async def get_user_from_google_token(self, token: str) -> User | None: + return await self.get_user_from_token(google_auth_key(token)) async def get_user_from_email(self, email: str) -> User | None: - table = await self.db.Table("Users") - user_dict = await table.query( - IndexName="emailIndex", - KeyConditionExpression=Key("email").eq(email), - ) - items = user_dict["Items"] - if len(items) == 0: - return None - if len(items) > 1: - raise ValueError(f"Multiple users found with email {email}") - return User.model_validate(items[0]) - - async def get_user_from_oauth_id(self, oauth_id: str) -> User | None: - table = await self.db.Table("Users") - user_dict = await table.query( - IndexName="oauthIdIndex", - KeyConditionExpression=Key("oauth_id").eq(oauth_id), - ) - items = user_dict["Items"] - if len(items) == 0: - return None - if len(items) > 1: - raise ValueError(f"Multiple users found with oauth id {oauth_id}") - return User.model_validate(items[0]) + return await self._get_unique_item_from_secondary_index("emailIndex", "email", email, User) - async def get_user_id_from_session_token(self, session_token: str) -> str | None: - user_id = await self.session_kv.get(hash_token(session_token)) - if user_id is None: - return None - return user_id.decode("utf-8") + async def create_user_from_email(self, email: str) -> User: + user = User.create(email=email) + await self._add_item(user) + return user - async def delete_user(self, user_id: str) -> None: - # Then, delete the user object from the Users table. - table = await self.db.Table("Users") - await table.delete_item(Key={"user_id": user_id}) + async def get_user_batch(self, ids: list[str]) -> list[User]: + return await self._get_item_batch(ids, User) + + async def get_user_from_api_key(self, key: str) -> User: + key = await self.get_api_key(key) + return await self._get_item(key.user_id, User, throw_if_missing=True) + + async def delete_user(self, id: str) -> None: + await self._delete_item(id) async def list_users(self) -> list[User]: warnings.warn("`list_users` probably shouldn't be called in production", ResourceWarning) - table = await self.db.Table("Users") - return [User.model_validate(user) for user in await table.scan()] + return await self._list_items(User) async def get_user_count(self) -> int: - table = await self.db.Table("Users") - return await table.item_count - - async def add_session_token(self, token: str, user_id: str, lifetime: int) -> None: - await self.session_kv.setex(hash_token(token), lifetime, user_id) - - async def delete_session_token(self, token: str) -> None: - await self.session_kv.delete(hash_token(token)) - - async def add_register_token(self, token: str, email: str, lifetime: int) -> None: - await self.register_kv.setex(hash_token(token), lifetime, email) - - async def delete_register_token(self, token: str) -> None: - await self.register_kv.delete(hash_token(token)) - - async def check_register_token(self, token: str) -> str: - email = await self.register_kv.get(hash_token(token)) - if email is None: - raise ValueError("Provided token is invalid") - return email.decode("utf-8") - - async def change_password(self, user_id: str, new_password: str) -> None: - await (await self.db.Table("Users")).update_item( - Key={"user_id": user_id}, - AttributeUpdates={"password_hash": {"Value": hash_password(new_password), "Action": "PUT"}}, - ) - - async def add_reset_password_token(self, token: str, user_id: str, lifetime: int) -> None: - await self.reset_password_kv.setex(hash_token(token), lifetime, user_id) - - async def delete_reset_password_token(self, token: str) -> None: - await self.reset_password_kv.delete(hash_token(token)) - - async def use_reset_password_token(self, token: str, new_password: str) -> None: - id = await self.reset_password_kv.get(hash_token(token)) - if id is None: - raise ValueError("Provided token is invalid") - await self.change_password(id.decode("utf-8"), new_password) - await self.delete_reset_password_token(token) - - async def add_change_email_token(self, token: str, user_id: str, new_email: str, lifetime: int) -> None: - await self.change_email_kv.setex( - hash_token(token), lifetime, json.dumps({"user_id": user_id, "new_email": new_email}) - ) - - async def use_change_email_token(self, token: str) -> None: - data = await self.change_email_kv.get(hash_token(token)) - if data is None: - raise ValueError("Provided token is invalid") - data = json.loads(data) - await (await self.db.Table("Users")).update_item( - Key={"user_id": data["user_id"]}, - AttributeUpdates={ - "email": {"Value": data["new_email"], "Action": "PUT"}, - }, - ) - await self.change_email_kv.delete(hash_token(token)) + return await self._count_items(User) + + async def get_api_key(self, id: str) -> APIKey: + return await self._get_item(id, APIKey, throw_if_missing=True) + + async def add_api_key(self, id: str) -> APIKey: + token = APIKey.create(id=id) + await self._add_item(token) + return token + + async def delete_api_key(self, token: APIKey | str) -> None: + await self._delete_item(token) + + async def api_key_is_valid(self, token: str) -> bool: + """Validates a token against the database, with caching. + + In order to reduce the number of database queries, we locally cache + whether or not a token is valid for some amount of time. + + Args: + token: The token to validate. + + Returns: + If the token is valid, meaning, if it exists in the database. + """ + cur_time = datetime.now() + if token in LAST_API_KEY_VALIDATION: + last_time, is_valid = LAST_API_KEY_VALIDATION[token] + if (cur_time - last_time).seconds < settings.crypto.cache_token_db_result_seconds: + return is_valid + is_valid = await self._item_exists(token) + LAST_API_KEY_VALIDATION[token] = (cur_time, is_valid) + return is_valid async def test_adhoc() -> None: async with UserCrud() as crud: - await crud.add_user(User.create(username="ben", email="ben@kscale.dev", password="password")) + await crud.create_user_from_email(email="ben@kscale.dev") if __name__ == "__main__": diff --git a/store/app/crypto.py b/store/app/crypto.py index a1548063..7394b5dd 100644 --- a/store/app/crypto.py +++ b/store/app/crypto.py @@ -8,10 +8,6 @@ from argon2 import PasswordHasher -def new_uuid() -> uuid.UUID: - return uuid.uuid4() - - def new_token(length: int = 64) -> str: """Generates a cryptographically secure random 64 character alphanumeric token.""" return "".join(secrets.choice(string.ascii_letters + string.digits) for _ in range(length)) @@ -22,11 +18,11 @@ def hash_token(token: str) -> str: def check_hash(token: str, hash: str) -> bool: - return hashlib.sha256(token.encode()).hexdigest() == hash + return hash_token(token) == hash -def hash_password(password: str) -> str: - return PasswordHasher().hash(password) +def new_uuid() -> uuid.UUID: + return uuid.uuid4() def check_password(password: str, hash: str) -> bool: diff --git a/store/app/db.py b/store/app/db.py index ce41a3cb..dbf65c0c 100644 --- a/store/app/db.py +++ b/store/app/db.py @@ -5,7 +5,7 @@ import logging from typing import AsyncGenerator, Self -from store.app.crud.base import BaseCrud +from store.app.crud.base import TABLE_NAME, BaseCrud from store.app.crud.robots import RobotCrud from store.app.crud.users import UserCrud @@ -37,43 +37,13 @@ async def create_tables(crud: Crud | None = None, deletion_protection: bool = Fa await create_tables(new_crud) else: - await asyncio.gather( - crud._create_dynamodb_table( - name="Users", - keys=[ - ("user_id", "S", "HASH"), - ], - gsis=[ - ("emailIndex", "email", "S", "HASH"), - ("usernameIndex", "username", "S", "HASH"), - ("oauthIdIndex", "oauth_id", "S", "HASH"), - ], - deletion_protection=deletion_protection, - ), - crud._create_dynamodb_table( - name="Robots", - keys=[ - ("robot_id", "S", "HASH"), - ], - gsis=[ - ("ownerIndex", "owner", "S", "HASH"), - ("nameIndex", "name", "S", "HASH"), - ("timestampIndex", "timestamp", "N", "HASH"), - ], - deletion_protection=deletion_protection, - ), - crud._create_dynamodb_table( - name="Parts", - keys=[ - ("part_id", "S", "HASH"), - ], - gsis=[ - ("ownerIndex", "owner", "S", "HASH"), - ("nameIndex", "name", "S", "HASH"), - ("timestampIndex", "timestamp", "N", "HASH"), - ], - deletion_protection=deletion_protection, - ), + await crud._create_dynamodb_table( + name=TABLE_NAME, + keys=[ + ("id", "S", "HASH"), + ], + gsis=crud.get_gsis(), + deletion_protection=deletion_protection, ) @@ -90,11 +60,7 @@ async def delete_tables(crud: Crud | None = None) -> None: await delete_tables(new_crud) else: - await asyncio.gather( - crud._delete_dynamodb_table("Users"), - crud._delete_dynamodb_table("Robots"), - crud._delete_dynamodb_table("Parts"), - ) + await crud._delete_dynamodb_table(TABLE_NAME) async def populate_with_dummy_data(crud: Crud | None = None) -> None: diff --git a/store/app/model.py b/store/app/model.py index 1252dd11..82a3ef83 100644 --- a/store/app/model.py +++ b/store/app/model.py @@ -5,43 +5,113 @@ expects (for example, converting a UUID into a string). """ -import uuid -from typing import Optional +from typing import Self +import jwt from pydantic import BaseModel -from store.app.crypto import hash_password +from store.app.crypto import new_uuid +from store.settings import settings -class User(BaseModel): - user_id: str # Primary key - username: str +class RobolistBaseModel(BaseModel): + """Defines the base model for Robolist database rows. + + Our database architecture uses a single table with a single primary key + (the `id` field). This class provides a common interface for all models + that are stored in the database. + """ + + id: str + + +class UserPermissions(BaseModel): + is_admin: bool = False + + +class User(RobolistBaseModel): + """Defines the user model for the API. + + Users are defined by their email, username and password hash. This is the + simplest form of authentication, and is used for users who sign up with + their email and password. + """ + email: str - password_hash: str - oauth_id: str - admin: bool + permissions: UserPermissions = UserPermissions() + + @classmethod + def create(cls, email: str) -> Self: + return cls(id=str(new_uuid()), email=email) + + +class OAuthKey(RobolistBaseModel): + """Keys for OAuth providers which identify users.""" + + user_id: str @classmethod - def create(cls, email: str, username: str, password: str) -> "User": + def create(cls, token: str, user_id: str) -> Self: + return cls(id=token, user_id=user_id) + + +class APIKey(RobolistBaseModel): + """The API key is used for querying the API. + + Downstream users keep the JWT locally, and it is used to authenticate + requests to the API. The key is stored in the database, and can be + revoked by the user at any time. + """ + + user_id: str + + @classmethod + def create(cls, id: str) -> Self: return cls( - user_id=str(uuid.uuid4()), - email=email, - username=username, - password_hash=hash_password(password), - oauth_id="dummy_oauth", - admin=False, + id=str(new_uuid()), + user_id=id, + ) + + def to_jwt(self) -> str: + return jwt.encode( + payload={"token": self.id, "user_id": self.user_id}, + key=settings.crypto.jwt_secret, ) @classmethod - def create_oauth(cls, username: str, oauth_id: str) -> "User": + def from_jwt(cls, jwt_token: str) -> Self: + data = jwt.decode( + jwt=jwt_token, + key=settings.crypto.jwt_secret, + ) + return cls(id=data["token"], user_id=data["user_id"]) + + +class RegisterToken(RobolistBaseModel): + """Stores a token for registering a new user.""" + + email: str + + @classmethod + def create(cls, email: str) -> Self: return cls( - user_id=str(uuid.uuid4()), - username=username, - email="dummy@kscale.dev", - oauth_id=oauth_id, - admin=False, - password_hash="", + id=str(new_uuid()), + email=email, + ) + + def to_jwt(self) -> str: + return jwt.encode( + payload={"token": self.id, "email": self.email}, + key=settings.crypto.jwt_secret, + ) + + @classmethod + def from_jwt(cls, jwt_token: str) -> Self: + data = jwt.decode( + jwt=jwt_token, + key=settings.crypto.jwt_secret, ) + return cls(id=data["token"], email=data["email"]) class Bom(BaseModel): @@ -59,23 +129,21 @@ class Package(BaseModel): url: str -class Robot(BaseModel): - robot_id: str # Primary key +class Robot(RobolistBaseModel): owner: str name: str description: str bom: list[Bom] images: list[Image] - height: Optional[str] = "" - weight: Optional[str] = "" - degrees_of_freedom: Optional[str] = "" + height: str | None = None + weight: str | None = None + degrees_of_freedom: int | None = None timestamp: int urdf: str packages: list[Package] -class Part(BaseModel): - part_id: str # Primary key +class Part(RobolistBaseModel): name: str owner: str description: str diff --git a/store/app/routers/image.py b/store/app/routers/image.py index ee71a708..8d06999a 100644 --- a/store/app/routers/image.py +++ b/store/app/routers/image.py @@ -5,7 +5,7 @@ from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, UploadFile -from fastapi.responses import JSONResponse, RedirectResponse +from fastapi.responses import JSONResponse from PIL import Image from store.app.crypto import new_uuid @@ -16,14 +16,6 @@ logger = logging.getLogger(__name__) -@image_router.get("/{url}/") -async def get_image(url: str, crud: Annotated[Crud, Depends(Crud.get)]) -> RedirectResponse: - try: - return await crud.get_image(url + ".png") - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) - - @image_router.post("/upload/") async def upload_image(crud: Annotated[Crud, Depends(Crud.get)], file: UploadFile) -> JSONResponse: try: diff --git a/store/app/routers/part.py b/store/app/routers/part.py index 75c794fe..c5e0323d 100644 --- a/store/app/routers/part.py +++ b/store/app/routers/part.py @@ -2,12 +2,11 @@ import logging import time -from typing import Annotated, List +from typing import Annotated, Any from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel -from store.app.crud.robots import EditPart from store.app.crypto import new_uuid from store.app.db import Crud from store.app.model import Image, Part @@ -23,12 +22,12 @@ async def list_parts( crud: Annotated[Crud, Depends(Crud.get)], page: int = Query(description="Page number for pagination"), search_query: str = Query(None, description="Search query string"), -) -> tuple[List[Part], bool]: +) -> tuple[list[Part], bool]: return await crud.list_parts(page, search_query=search_query) @parts_router.get("/dump/") -async def dump_parts(crud: Annotated[Crud, Depends(Crud.get)]) -> List[Part]: +async def dump_parts(crud: Annotated[Crud, Depends(Crud.get)]) -> list[Part]: return await crud.dump_parts() @@ -37,11 +36,10 @@ async def list_your_parts( crud: Annotated[Crud, Depends(Crud.get)], token: Annotated[str, Depends(get_session_token)], page: int = Query(description="Page number for pagination"), -) -> tuple[List[Part], bool]: - user_id = await crud.get_user_id_from_session_token(token) - if user_id is None: - raise HTTPException(status_code=401, detail="Must be logged in to view your parts") - return await crud.list_your_parts(user_id, page) + search_query: str = Query(None, description="Search query string"), +) -> tuple[list[Part], bool]: + user = await crud.get_user_from_api_key(token) + return await crud.list_your_parts(user.id, page, search_query=search_query) @parts_router.get("/{part_id}") @@ -54,14 +52,14 @@ async def current_user( crud: Annotated[Crud, Depends(Crud.get)], token: Annotated[str, Depends(get_session_token)], ) -> str | None: - user_id = await crud.get_user_id_from_session_token(token) - return str(user_id) + user = await crud.get_user_from_api_key(token) + return user.id class NewPart(BaseModel): name: str description: str - images: List[Image] + images: list[Image] @parts_router.post("/add/") @@ -70,16 +68,14 @@ async def add_part( token: Annotated[str, Depends(get_session_token)], crud: Annotated[Crud, Depends(Crud.get)], ) -> bool: - user_id = await crud.get_user_id_from_session_token(token) - if user_id is None: - raise HTTPException(status_code=401, detail="Must be logged in to add a part") + user = await crud.get_user_from_api_key(token) await crud.add_part( Part( name=part.name, description=part.description, images=part.images, - owner=str(user_id), - part_id=str(new_uuid()), + owner=user.id, + id=str(new_uuid()), timestamp=int(time.time()), ) ) @@ -95,8 +91,8 @@ async def delete_part( part = await crud.get_part(part_id) if part is None: raise HTTPException(status_code=404, detail="Part not found") - user_id = await crud.get_user_id_from_session_token(token) - if part.owner != user_id: + user = await crud.get_user_from_api_key(token) + if part.owner != user.id: raise HTTPException(status_code=403, detail="You do not own this part") await crud.delete_part(part_id) return True @@ -105,12 +101,18 @@ async def delete_part( @parts_router.post("/edit-part/{part_id}/") async def edit_part( part_id: str, - part: EditPart, + part: dict[ + str, Any + ], # There has got to be a better type annotation than this (possibly the deleted) EditPart class token: Annotated[str, Depends(get_session_token)], crud: Annotated[Crud, Depends(Crud.get)], ) -> bool: - user_id = await crud.get_user_id_from_session_token(token) - if user_id is None: - raise HTTPException(status_code=401, detail="Must be logged in to edit a part") - await crud.update_part(part_id, part) + user = await crud.get_user_from_api_key(token) + part_info = await crud.get_part(part_id) + if part_info is None: + raise HTTPException(status_code=404, detail="Part not found") + if user.id != part_info.owner: + raise HTTPException(status_code=403, detail="You do not own this part") + part["owner"] = user.id + await crud._update_item(part_id, Part, part) return True diff --git a/store/app/routers/robot.py b/store/app/routers/robot.py index b6626e83..50f8e925 100644 --- a/store/app/routers/robot.py +++ b/store/app/routers/robot.py @@ -2,12 +2,11 @@ import logging import time -from typing import Annotated, List, Optional +from typing import Annotated, Any, List, Optional from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel -from store.app.crud.robots import EditRobot from store.app.crypto import new_uuid from store.app.db import Crud from store.app.model import Bom, Image, Package, Robot @@ -18,48 +17,11 @@ logger = logging.getLogger(__name__) -@robots_router.get("/") -async def list_robots( - crud: Annotated[Crud, Depends(Crud.get)], - page: int = Query(description="Page number for pagination"), - search_query: str = Query(None, description="Search query string"), -) -> tuple[List[Robot], bool]: - """Lists the robots in the database. - - The function is paginated. The page size is 12. - - Returns the robots on the page and a boolean indicating if there are more pages. - """ - return await crud.list_robots(page, search_query=search_query) - - -@robots_router.get("/your/") -async def list_your_robots( - crud: Annotated[Crud, Depends(Crud.get)], - token: Annotated[str, Depends(get_session_token)], - page: int = Query(description="Page number for pagination"), -) -> tuple[List[Robot], bool]: - """Lists the robots that you own.""" - user_id = await crud.get_user_id_from_session_token(token) - if user_id is None: - raise HTTPException(status_code=401, detail="Must be logged in to view your parts") - - return await crud.list_your_robots(user_id, page) - - @robots_router.get("/{robot_id}") async def get_robot(robot_id: str, crud: Annotated[Crud, Depends(Crud.get)]) -> Robot | None: return await crud.get_robot(robot_id) -@robots_router.get("/user/") -async def current_user( - crud: Annotated[Crud, Depends(Crud.get)], token: Annotated[str, Depends(get_session_token)] -) -> str | None: - user_id = await crud.get_user_id_from_session_token(token) - return str(user_id) - - class NewRobot(BaseModel): name: str description: str @@ -67,23 +29,42 @@ class NewRobot(BaseModel): images: List[Image] height: Optional[str] weight: Optional[str] - degrees_of_freedom: Optional[str] + degrees_of_freedom: Optional[int] urdf: str packages: List[Package] +@robots_router.get("/") +async def list_robots( + crud: Annotated[Crud, Depends(Crud.get)], + page: int = Query(description="Page number for pagination"), + search_query: str = Query(None, description="Search query string"), +) -> tuple[list[Robot], bool]: + return await crud.list_robots(page, search_query=search_query) + + +@robots_router.get("/your/") +async def list_your_robots( + crud: Annotated[Crud, Depends(Crud.get)], + token: Annotated[str, Depends(get_session_token)], + page: int = Query(description="Page number for pagination"), + search_query: str = Query(None, description="Search query string"), +) -> tuple[list[Robot], bool]: + user = await crud.get_user_from_api_key(token) + return await crud.list_your_robots(user.id, page, search_query=search_query) + + @robots_router.post("/add/") async def add_robot( new_robot: NewRobot, token: Annotated[str, Depends(get_session_token)], crud: Annotated[Crud, Depends(Crud.get)], ) -> bool: - user_id = await crud.get_user_id_from_session_token(token) - if user_id is None: - raise HTTPException(status_code=401, detail="Must be logged in to add a robot") + user = await crud.get_user_from_api_key(token) await crud.add_robot( Robot( + id=str(new_uuid()), name=new_robot.name, description=new_robot.description, bom=new_robot.bom, @@ -91,8 +72,7 @@ async def add_robot( height=new_robot.height, weight=new_robot.weight, degrees_of_freedom=new_robot.degrees_of_freedom, - owner=str(user_id), - robot_id=str(new_uuid()), + owner=user.id, timestamp=int(time.time()), urdf=new_robot.urdf, packages=new_robot.packages, @@ -110,8 +90,8 @@ async def delete_robot( robot = await crud.get_robot(robot_id) if robot is None: raise HTTPException(status_code=404, detail="Robot not found") - user_id = await crud.get_user_id_from_session_token(token) - if str(robot.owner) != str(user_id): + user = await crud.get_user_from_api_key(token) + if robot.owner != user.id: raise HTTPException(status_code=403, detail="You do not own this robot") await crud.delete_robot(robot_id) return True @@ -120,12 +100,16 @@ async def delete_robot( @robots_router.post("/edit-robot/{id}/") async def edit_robot( id: str, - robot: EditRobot, + robot: dict[str, Any], token: Annotated[str, Depends(get_session_token)], crud: Annotated[Crud, Depends(Crud.get)], ) -> bool: - user_id = await crud.get_user_id_from_session_token(token) - if user_id is None: - raise HTTPException(status_code=401, detail="Must be logged in to edit a robot") - await crud.update_robot(id, robot) + robot_info = await crud.get_robot(id) + if robot_info is None: + raise HTTPException(status_code=404, detail="Robot not found") + user = await crud.get_user_from_api_key(token) + if robot_info.owner != user.id: + raise HTTPException(status_code=403, detail="You do not own this robot") + robot["owner"] = user.id + await crud._update_item(id, Robot, robot) return True diff --git a/store/app/routers/users.py b/store/app/routers/users.py index 51b740af..a24fe100 100644 --- a/store/app/routers/users.py +++ b/store/app/routers/users.py @@ -6,13 +6,12 @@ from fastapi import APIRouter, Depends, HTTPException, Query, Request, Response, status from fastapi.security.utils import get_authorization_scheme_param -from httpx import AsyncClient +from httpx import AsyncClient, Response as HttpxResponse from pydantic.main import BaseModel as PydanticBaseModel -from store.app.crypto import check_password, new_token from store.app.db import Crud -from store.app.model import User -from store.app.utils.email import send_change_email, send_delete_email, send_register_email, send_reset_password_email +from store.app.model import UserPermissions +from store.app.utils.email import send_delete_email from store.settings import settings logger = logging.getLogger(__name__) @@ -73,207 +72,28 @@ class SendRegister(BaseModel): email: str -@users_router.post("/send-register-email") -async def send_register_email_endpoint( - data: SendRegister, - crud: Annotated[Crud, Depends(Crud.get)], -) -> bool: - """Sends a verification email to the new email address.""" - email = validate_email(data.email) - verify_email_token = new_token() - # Magic number: 7 days - await crud.add_register_token(verify_email_token, email, 60 * 60 * 24 * 7) - await send_register_email(email, verify_email_token) - return True - - -@users_router.get("/registration-email/{token}") -async def get_registration_email_endpoint( - token: str, - crud: Annotated[Crud, Depends(Crud.get)], -) -> str: - """Gets the email address associated with a registration token.""" - return await crud.check_register_token(token) - - class UserRegister(BaseModel): token: str - username: str - password: str - - -@users_router.post("/register") -async def register_user_endpoint( - data: UserRegister, - crud: Annotated[Crud, Depends(Crud.get)], -) -> bool: - """Registers a new user with the given email and password.""" - email = await crud.check_register_token(data.token) - user = await crud.get_user_from_email(email) - if user is not None: - raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail="User already exists") - user = User.create(username=data.username, email=email, password=data.password) - await crud.add_user(user) - return True - - -class UserForgotPassword(BaseModel): - email: str - - -@users_router.post("/forgot-password") -async def forgot_password_user_endpoint( - data: UserForgotPassword, - crud: Annotated[Crud, Depends(Crud.get)], -) -> bool: - """Sends a reset password email to the user.""" - email = validate_email(data.email) - user = await crud.get_user_from_email(email) - if user is None: - return True - reset_password_token = new_token() - # Magic number: 1 hour - await crud.add_reset_password_token(reset_password_token, user.user_id, 60 * 60) - await send_reset_password_email(email, reset_password_token) - return True - - -class ResetPassword(BaseModel): - password: str - - -@users_router.post("/reset-password/{token}") -async def reset_password_user_endpoint( - token: str, - data: ResetPassword, - crud: Annotated[Crud, Depends(Crud.get)], -) -> bool: - """Resets a user's password.""" - await crud.use_reset_password_token(token, data.password) - return True - - -class NewEmail(BaseModel): - new_email: str - - -@users_router.post("/change-email") -async def send_change_email_user_endpoint( - data: NewEmail, - crud: Annotated[Crud, Depends(Crud.get)], - token: Annotated[str, Depends(get_session_token)], -) -> bool: - user_id = await crud.get_user_id_from_session_token(token) - if user_id is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") - user = await crud.get_user(user_id) - if user is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") - change_email_token = new_token() - """Sends a verification email to the new email address.""" - # Magic number: 1 hour - await crud.add_change_email_token(change_email_token, user.user_id, data.new_email, 60 * 60) - await send_change_email(data.new_email, change_email_token) - return True - - -@users_router.post("/change-email/{token}") -async def change_email_user_endpoint( - token: str, - crud: Annotated[Crud, Depends(Crud.get)], -) -> bool: - """Changes the user's email address.""" - await crud.use_change_email_token(token) - return True - - -class ChangePassword(BaseModel): - old_password: str - new_password: str - - -@users_router.post("/change-password") -async def change_password_user_endpoint( - data: ChangePassword, - token: Annotated[str, Depends(get_session_token)], - crud: Annotated[Crud, Depends(Crud.get)], -) -> bool: - """Changes the user's password.""" - user_id = await crud.get_user_id_from_session_token(token) - if user_id is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") - user = await crud.get_user(user_id) - if user is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") - if not check_password(data.old_password, user.password_hash): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid password") - await crud.change_password(user_id, data.new_password) - return True - - -class UserLogin(BaseModel): - email: str - password: str - - -@users_router.post("/login") -async def login_user_endpoint( - data: UserLogin, - crud: Annotated[Crud, Depends(Crud.get)], - response: Response, -) -> bool: - """Gives the user a session token if they present the correct credentials. - - Args: - data: User email and password. - crud: The CRUD object. - response: The response object. - - Returns: - True if the credentials are correct. - """ - user = await crud.get_user_from_email(data.email) - if user is None: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email or password") - if not check_password(data.password, user.password_hash): - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email or password") - token = new_token() - response.set_cookie( - key="session_token", - value=token, - httponly=True, - ) - await crud.add_session_token(token, user.user_id, 60 * 60 * 24 * 7) - - return True class UserInfoResponse(BaseModel): - email: str - username: str - user_id: str - admin: bool + id: str + permissions: UserPermissions @users_router.get("/me", response_model=UserInfoResponse) async def get_user_info_endpoint( token: Annotated[str, Depends(get_session_token)], crud: Annotated[Crud, Depends(Crud.get)], -) -> UserInfoResponse: - user_id = await crud.get_user_id_from_session_token(token) - if user_id is None: - print("executed 1") - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") - user_obj = await crud.get_user(user_id) - if user_obj is None: - print("executed 2") - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") - return UserInfoResponse( - email=user_obj.email, - username=user_obj.username, - user_id=user_obj.user_id, - admin=user_obj.admin, - ) +) -> UserInfoResponse | None: + try: + user = await crud.get_user_from_api_key(token) + return UserInfoResponse( + id=user.id, + permissions=user.permissions, + ) + except ValueError: + return None @users_router.delete("/me") @@ -281,14 +101,9 @@ async def delete_user_endpoint( token: Annotated[str, Depends(get_session_token)], crud: Annotated[Crud, Depends(Crud.get)], ) -> bool: - user_id = await crud.get_user_id_from_session_token(token) - if user_id is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") - user_obj = await crud.get_user(user_id) - if user_obj is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") - await crud.delete_user(user_id) - await send_delete_email(user_obj.email) + user = await crud.get_user_from_api_key(token) + await crud.delete_user(user.id) + await send_delete_email(user.email) return True @@ -298,33 +113,23 @@ async def logout_user_endpoint( crud: Annotated[Crud, Depends(Crud.get)], response: Response, ) -> bool: - await crud.delete_session_token(token) + await crud.delete_api_key(token) response.delete_cookie("session_token") return True class PublicUserInfoResponse(BaseModel): - username: str - user_id: str + id: str + email: str @users_router.get("/batch", response_model=list[PublicUserInfoResponse]) async def get_users_batch_endpoint( crud: Annotated[Crud, Depends(Crud.get)], - user_ids: list[str] = Query(...), + ids: list[str] = Query(...), ) -> list[PublicUserInfoResponse]: - user_objs = await crud.get_user_batch(user_ids) - return [ - PublicUserInfoResponse( - username=user_obj.username, - user_id=user_obj.user_id, - ) - for user_obj in user_objs - ] - - -class SessionData(BaseModel): - username: str + users = await crud.get_user_batch(ids) + return [PublicUserInfoResponse(id=user.id, email=user.email) for user in users] @users_router.get("/github-login") @@ -334,7 +139,26 @@ async def github_login() -> str: Returns: Github oauth redirect url. """ - return f"https://github.com/login/oauth/authorize?client_id={settings.oauth.github_client_id}" + return f"https://github.com/login/oauth/authorize?scope=user:email&client_id={settings.oauth.github_client_id}" + + +async def github_access_token_req(params: dict[str, str], headers: dict[str, str]) -> HttpxResponse: + async with AsyncClient() as client: + return await client.post( + url="https://github.com/login/oauth/access_token", + params=params, + headers=headers, + ) + + +async def github_req(headers: dict[str, str]) -> HttpxResponse: + async with AsyncClient() as client: + return await client.get("https://api.github.com/user", headers=headers) + + +async def github_email_req(headers: dict[str, str]) -> HttpxResponse: + async with AsyncClient() as client: + return await client.get("https://api.github.com/user/emails", headers=headers) @users_router.get("/github-code/{code}", response_model=UserInfoResponse) @@ -359,60 +183,39 @@ async def github_code( "code": code, } headers = {"Accept": "application/json"} - async with AsyncClient() as client: - oauth_response = await client.post( - url="https://github.com/login/oauth/access_token", params=params, headers=headers - ) + oauth_response = await github_access_token_req(params, headers) response_json = oauth_response.json() - print("\n\n", response_json, "\n\n") # access token is used to retrieve user oauth details access_token = response_json["access_token"] - async with AsyncClient() as client: - headers.update({"Authorization": f"Bearer {access_token}"}) - oauth_response = await client.get("https://api.github.com/user", headers=headers) + headers.update({"Authorization": f"Bearer {access_token}"}) + oauth_response = await github_req(headers) + oauth_email_response = await github_email_req(headers) github_id = oauth_response.json()["html_url"] - github_username = oauth_response.json()["login"] - - user = await crud.get_user_from_oauth_id(github_id) + email = next(entry["email"] for entry in oauth_email_response.json() if entry["primary"]) - # create a user if it doesn't exist, with dummy email since email is required for secondary indexing + user = await crud.get_user_from_github_token(github_id) + # Exception occurs when user does not exist. + # Create a user if this is the case. if user is None: - user = User.create_oauth(username=github_username, oauth_id=github_id) - await crud.add_user(user) - - token = new_token() - - await crud.add_session_token(token, user.user_id, 60 * 60 * 24 * 7) - - response.set_cookie( - key="session_token", - value=token, - httponly=True, - ) + user = await crud.create_user_from_github_token( + email=email, + github_id=github_id, + ) + # This is solely so mypy stops complaining. + assert user is not None - user_obj = await crud.get_user(user.user_id) + api_key = await crud.add_api_key(user.id) - if user_obj is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") + response.set_cookie(key="session_token", value=api_key.id, httponly=True, samesite="lax") - return UserInfoResponse( - email=user_obj.email, - username=user_obj.username, - user_id=user_obj.user_id, - admin=user_obj.admin, - ) + return UserInfoResponse(id=user.id, permissions=user.permissions) -@users_router.get("/{user_id}", response_model=PublicUserInfoResponse) -async def get_user_info_by_id_endpoint( - user_id: str, crud: Annotated[Crud, Depends(Crud.get)] -) -> PublicUserInfoResponse: - user_obj = await crud.get_user(user_id) - if user_obj is None: +@users_router.get("/{id}", response_model=PublicUserInfoResponse) +async def get_user_info_by_id_endpoint(id: str, crud: Annotated[Crud, Depends(Crud.get)]) -> PublicUserInfoResponse: + user = await crud.get_user(id) + if user is None: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found") - return PublicUserInfoResponse( - username=user_obj.username, - user_id=user_obj.user_id, - ) + return PublicUserInfoResponse(id=user.id, email=user.email) diff --git a/store/requirements-dev.txt b/store/requirements-dev.txt index c5ebe370..19e8fca8 100644 --- a/store/requirements-dev.txt +++ b/store/requirements-dev.txt @@ -8,7 +8,6 @@ ruff # Testing moto[dynamodb] -fakeredis pytest pytest-aiohttp pytest-aiomoto diff --git a/store/requirements.txt b/store/requirements.txt index deb4dcc4..8aba6cad 100644 --- a/store/requirements.txt +++ b/store/requirements.txt @@ -7,10 +7,10 @@ pydantic # AWS dependencies. aioboto3 -redis # Crypto dependencies argon2-cffi +pyjwt[asyncio] # FastAPI dependencies. aiohttp diff --git a/store/settings/environment.py b/store/settings/environment.py index 62a7e628..a23fd27f 100644 --- a/store/settings/environment.py +++ b/store/settings/environment.py @@ -11,20 +11,9 @@ class OauthSettings: github_client_secret: str = field(default=II("oc.env:GITHUB_CLIENT_SECRET")) -@dataclass -class RedisSettings: - host: str = field(default=II("oc.env:ROBOLIST_REDIS_HOST,127.0.0.1")) - password: str = field(default=II("oc.env:ROBOLIST_REDIS_PASSWORD,''")) - port: int = field(default=6379) - session_db: int = field(default=0) - verify_email_db: int = field(default=1) - reset_password_db: int = field(default=2) - change_email_db: int = field(default=3) - - @dataclass class CryptoSettings: - expire_token_minutes: int = field(default=10) + cache_token_db_result_seconds: int = field(default=30) expire_otp_minutes: int = field(default=10) jwt_secret: str = field(default=MISSING) algorithm: str = field(default="HS256") @@ -55,7 +44,6 @@ class SiteSettings: @dataclass class EnvironmentSettings: oauth: OauthSettings = field(default_factory=OauthSettings) - redis: RedisSettings = field(default_factory=RedisSettings) user: UserSettings = field(default_factory=UserSettings) crypto: CryptoSettings = field(default_factory=CryptoSettings) email: EmailSettings = field(default_factory=EmailSettings) diff --git a/store/utils.py b/store/utils.py index 2dd73adf..1e32018d 100644 --- a/store/utils.py +++ b/store/utils.py @@ -1,6 +1,52 @@ """Defines package-wide utility functions.""" import datetime +from collections import OrderedDict +from typing import Generic, Hashable, TypeVar, overload + +Tk = TypeVar("Tk", bound=Hashable) +Tv = TypeVar("Tv") + + +class LRUCache(Generic[Tk, Tv]): + def __init__(self, capacity: int) -> None: + super().__init__() + + self.cache: OrderedDict[Tk, Tv] = OrderedDict() + self.capacity = capacity + + @overload + def get(self, key: Tk) -> Tv | None: ... + + @overload + def get(self, key: Tk, default: Tv) -> Tv: ... + + def get(self, key: Tk, default: Tv | None = None) -> Tv | None: + if key not in self.cache: + return None + else: + self.cache.move_to_end(key) + return self.cache[key] + + def __contains__(self, key: Tk) -> bool: + return key in self.cache + + def __len__(self) -> int: + return len(self.cache) + + def put(self, key: Tk, value: Tv) -> None: + self.cache[key] = value + self.cache.move_to_end(key) + if len(self.cache) > self.capacity: + self.cache.popitem(last=False) + + def __getitem__(self, key: Tk) -> Tv: + if (item := self.get(key)) is None: + raise KeyError(key) + return item + + def __setitem__(self, key: Tk, value: Tv) -> None: + self.put(key, value) def server_time() -> datetime.datetime: diff --git a/tests/conftest.py b/tests/conftest.py index 522d2267..1a5da0cf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,12 +1,12 @@ """Pytest configuration file.""" import os -from typing import Generator +from typing import AsyncGenerator, Generator, cast -import fakeredis import pytest from _pytest.python import Function -from httpx import ASGITransport, AsyncClient +from httpx import ASGITransport, AsyncClient, Response +from httpx._transports.asgi import _ASGIApp from moto.dynamodb import mock_dynamodb from moto.server import ThreadedMotoServer from pytest_mock.plugin import MockerFixture, MockType @@ -38,6 +38,8 @@ def mock_aws() -> Generator[None, None, None]: os.environ["AWS_SECRET_ACCESS_KEY"] = "test" os.environ["AWS_ACCESS_KEY_ID"] = "test" os.environ["AWS_DEFAULT_REGION"] = os.environ["AWS_REGION"] = "us-east-1" + os.environ["GITHUB_CLIENT_ID"] = "test" + os.environ["GITHUB_CLIENT_SECRET"] = "test" # Starts a local AWS server. server = ThreadedMotoServer(port=0) @@ -59,18 +61,11 @@ def mock_aws() -> Generator[None, None, None]: os.environ[k] = v -@pytest.fixture(autouse=True) -def mock_redis(mocker: MockerFixture) -> None: - os.environ["ROBOLIST_REDIS_HOST"] = "localhost" - os.environ["ROBOLIST_REDIS_PASSWORD"] = "" - fake_redis = fakeredis.aioredis.FakeRedis() - mocker.patch("store.app.crud.users.Redis", return_value=fake_redis) - - @pytest.fixture() -async def app_client() -> AsyncClient: +async def app_client() -> AsyncGenerator[AsyncClient, None]: from store.app.main import app - transport = ASGITransport(app) + + transport = ASGITransport(cast(_ASGIApp, app)) async with AsyncClient(transport=transport, base_url="http://test") as app_client: yield app_client @@ -81,3 +76,24 @@ def mock_send_email(mocker: MockerFixture) -> MockType: mock = mocker.patch("store.app.utils.email.send_email") mock.return_value = None return mock + + +@pytest.fixture(autouse=True) +def mock_github_access_token(mocker: MockerFixture) -> MockType: + mock = mocker.patch("store.app.routers.users.github_access_token_req") + mock.return_value = Response(status_code=200, json={"access_token": ""}) + return mock + + +@pytest.fixture(autouse=True) +def mock_github(mocker: MockerFixture) -> MockType: + mock = mocker.patch("store.app.routers.users.github_req") + mock.return_value = Response(status_code=200, json={"html_url": "https://github.com/chennisden"}) + return mock + + +@pytest.fixture(autouse=True) +def mock_github_email(mocker: MockerFixture) -> MockType: + mock = mocker.patch("store.app.routers.users.github_email_req") + mock.return_value = Response(status_code=200, json=[{"email": "dchen@kscale.dev", "primary": True}]) + return mock diff --git a/tests/test_data_structures.py b/tests/test_data_structures.py new file mode 100644 index 00000000..d4a6f923 --- /dev/null +++ b/tests/test_data_structures.py @@ -0,0 +1,20 @@ +"""Tests some common shared data structures.""" + +from store.utils import LRUCache + + +def test_lru_cache() -> None: + cache = LRUCache[int, str](3) + cache.put(1, "one") + cache.put(2, "two") + cache.put(3, "three") + assert len(cache) == 3 + assert cache.get(1) == "one" + cache.put(4, "four") + assert len(cache) == 3 + assert 2 not in cache + assert cache.get(2) is None + assert 1 in cache + assert cache.get(1) == "one" + assert cache.get(3) == "three" + assert cache.get(4) == "four" diff --git a/tests/test_robots.py b/tests/test_robots.py index 18884bdc..63913c14 100644 --- a/tests/test_robots.py +++ b/tests/test_robots.py @@ -1,56 +1,40 @@ """Runs tests on the robot APIs.""" + from httpx import AsyncClient -from store.app.crud.users import UserCrud from store.app.db import create_tables async def test_robots(app_client: AsyncClient) -> None: - crud = UserCrud() await create_tables() - await crud.__aenter__() - - test_username = "test" - test_email = "test@example.com" - test_password = "test password" - test_token = "test_token" - - await crud.add_register_token(test_token, test_email, 3600) - # Register. - response = await app_client.post("/users/register", json={ - "username": test_username, - "token": test_token, - "password": test_password, - }) - assert response.status_code == 200 - - # Log in. - response = await app_client.post("/users/login", json={ - "email": test_email, - "password": test_password, - }) - assert response.status_code == 200 + response = await app_client.get("/users/github-code/doesnt-matter") + assert response.status_code == 200, response.json() assert "session_token" in response.cookies # Create a part. - response = await app_client.post("/parts/add", json={ - "name": "test part", - "description": "test description", - "images": [{ - "url": "", - "caption": "", - }], - }) + response = await app_client.post( + "/parts/add", + json={ + "name": "test part", + "description": "test description", + "images": [ + { + "url": "", + "caption": "", + } + ], + }, + ) # Create a robot. - response = await app_client.post("/robots/add", json={ - "name": "test robot", - "description": "test description", - "bom": [], - "images": [{ - "url": "", - "caption": "" - }], - }) + response = await app_client.post( + "/robots/add", + json={ + "name": "test robot", + "description": "test description", + "bom": [], + "images": [{"url": "", "caption": ""}], + }, + ) diff --git a/tests/test_users.py b/tests/test_users.py index 77b051ae..385188c3 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -1,38 +1,12 @@ """Runs tests on the user APIs.""" - - from httpx import AsyncClient -from pytest_mock.plugin import MockType -from store.app.crud.users import UserCrud from store.app.db import create_tables -async def test_user_auth_functions(app_client: AsyncClient, mock_send_email: MockType) -> None: - crud = UserCrud() +async def test_user_auth_functions(app_client: AsyncClient) -> None: await create_tables() - await crud.__aenter__() - - test_username = "test" - test_email = "test@example.com" - test_password = "test password" - test_token = "test_token" - - await crud.add_register_token(test_token, test_email, 3600) - - # Send registration email. - response = await app_client.post("/users/send-register-email", json={"email": test_email}) - assert response.status_code == 200 - assert mock_send_email.call_count == 1 - - # Register. - response = await app_client.post("/users/register", json={ - "username": test_username, - "token": test_token, - "password": test_password - }) - assert response.status_code == 200 # Checks that without the session token we get a 401 response. response = await app_client.get("/users/me") @@ -43,28 +17,26 @@ async def test_user_auth_functions(app_client: AsyncClient, mock_send_email: Moc response = await app_client.delete("/users/logout") assert response.status_code == 401, response.json() - # Log in. - response = await app_client.post("/users/login", json={ - "email": test_email, - "password": test_password, - }) - assert response.status_code == 200 + # Because of the way we patched GitHub functions for mocking, it doesn't matter what token we pass in. + response = await app_client.get("/users/github-code/doesnt-matter") + assert response.status_code == 200, response.json() assert "session_token" in response.cookies token = response.cookies["session_token"] + user_id = response.json()["id"] + # Checks that with the session token we get a 200 response. response = await app_client.get("/users/me") assert response.status_code == 200, response.json() - assert response.json()["email"] == test_email + # Check the id of the user we are authenticated as matches the id of the user we created. + assert response.json()["id"] == user_id # Use the Authorization header instead of the cookie. response = await app_client.get( - "/users/me", - cookies={"session_token": ""}, - headers={"Authorization": f"Bearer {token}"} + "/users/me", cookies={"session_token": ""}, headers={"Authorization": f"Bearer {token}"} ) assert response.status_code == 200, response.json() - assert response.json()["email"] == test_email + assert response.json()["id"] == user_id # Log the user out, which deletes the session token. response = await app_client.delete("/users/logout") @@ -77,13 +49,9 @@ async def test_user_auth_functions(app_client: AsyncClient, mock_send_email: Moc assert response.json()["detail"] == "Not authenticated" # Log the user back in, getting new session token. - response = await app_client.post("/users/login", json={ - "email": test_email, - "password": test_password, - }) - assert response.status_code == 200 + response = await app_client.get("/users/github-code/doesnt-matter") + assert response.status_code == 200, response.json() assert "session_token" in response.cookies - token = response.cookies["session_token"] # Delete the user using the new session token. response = await app_client.delete("/users/me") @@ -92,5 +60,5 @@ async def test_user_auth_functions(app_client: AsyncClient, mock_send_email: Moc # Tries deleting the user again, which should fail. response = await app_client.delete("/users/me") - assert response.status_code == 404, response.json() - assert response.json()["detail"] == "User not found" + assert response.status_code == 400, response.json() + assert response.json()["detail"] == "Item " + user_id + " not found"