diff --git a/.gitignore b/.gitignore index 0beb9d9..354364d 100644 --- a/.gitignore +++ b/.gitignore @@ -178,3 +178,9 @@ pyrightconfig.json # IDE files .idea/ .vscode/ + +# poetry version lock +poetry.lock + +# personal config +Configs/config.json \ No newline at end of file diff --git a/Configs/config.py b/Configs/config.py index cc39e83..a909d9f 100644 --- a/Configs/config.py +++ b/Configs/config.py @@ -1,17 +1,21 @@ import json -from typing import Any class Config: def __init__(self, config_path: str): with open(config_path, "r") as f: - self.config = json.load(f) + self.config: dict = json.load(f) - def get_db_config(self, key: str) -> Any | None: - try: - return self.config["database"][key] - except: - return None + def get_config(self, *keys) -> object | None: + keys = [*keys] + result: object | None = self.config + while len(keys) > 0: + key = keys.pop(0) + if (isinstance(result, list) and key < len(result)) or (isinstance(result, dict) and key in result.keys()): + result = result[key] + else: + return None + return result config = Config("Configs/config.json") diff --git a/Models/response.py b/Models/response.py index 77dcd60..9e23a3d 100644 --- a/Models/response.py +++ b/Models/response.py @@ -1,5 +1,6 @@ -from fastapi import status, HTTPException +from fastapi import status, HTTPException, Request from pydantic import BaseModel +from slowapi.errors import RateLimitExceeded class ExceptionResponse: @@ -18,6 +19,14 @@ def not_found(self) -> HTTPException: detail="Not found", ) + @staticmethod + def limit_exceeded(request: Request, exc: RateLimitExceeded): + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail="Too many requests", + headers={"Detail": exc.detail}, + ) + class StandardResponse(BaseModel): status_code: int diff --git a/Routers/user.py b/Routers/user.py index d6da7f8..782296e 100644 --- a/Routers/user.py +++ b/Routers/user.py @@ -1,12 +1,12 @@ -from fastapi import APIRouter, Response, Depends, HTTPException, Form +from fastapi import APIRouter, Response, Depends, HTTPException, Form, Request from fastapi.security import OAuth2PasswordRequestForm from passlib.context import CryptContext -from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from uuid import uuid4 from datetime import datetime, timedelta from Services.Database.database import get_db +from Services.Limiter.limiter import limiter from Services.Security.user import ( ACCESS_TOKEN_EXPIRE_MINUTES, create_access_token, get_current_user, encrypt_src_password, decrypt_src_password) @@ -20,23 +20,26 @@ @user_router.post("/reg") -async def user_reg(email: str = Form(), username: str = Form(), password: str = Form(), db: Session = Depends(get_db)): - try: - db.add(UserDb( - uid=uuid4().hex, - email=email, - username=username, - hashed_password=pwd_ctx.hash(password), - created_at=datetime.now(), - )) - db.commit() - return Response(status_code=201) - except IntegrityError: +async def user_reg(email: str = Form(), username: str = Form(), password: str = Form(), + db: Session = Depends(get_db)): + if db.query(UserDb).filter( + UserDb.email == email or UserDb.username == username).first() is not None: # type: ignore raise HTTPException(status_code=409, detail="User already exists") + db.add(UserDb( + uid=uuid4().hex, + email=email, + username=username, + hashed_password=pwd_ctx.hash(password), + created_at=datetime.now(), + )) + db.commit() + return Response(status_code=201) + @user_router.post("/login") -async def user_login(body: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)): +@limiter.limit("5/minute") +async def user_login(request: Request, body: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)): user: UserDb = db.query(UserDb).filter(UserDb.username == body.username).first() # type: ignore if user is not None and pwd_ctx.verify(body.password, user.hashed_password): token = create_access_token( diff --git a/Services/Database/database.py b/Services/Database/database.py index c7610fb..77cb067 100644 --- a/Services/Database/database.py +++ b/Services/Database/database.py @@ -4,11 +4,11 @@ from Configs.config import config -host = config.get_db_config("host") -port = config.get_db_config("port") -db = config.get_db_config("database") -user = config.get_db_config("user") -pwd = config.get_db_config("password") +host = config.get_config("database", "host") +port = config.get_config("database", "port") +db = config.get_config("database", "database") +user = config.get_config("database", "user") +pwd = config.get_config("database", "password") if not host or not port or not db or not user or not pwd: raise ValueError("Please complete the database configuration") diff --git a/Services/Limiter/limiter.py b/Services/Limiter/limiter.py new file mode 100644 index 0000000..0b6689a --- /dev/null +++ b/Services/Limiter/limiter.py @@ -0,0 +1,5 @@ +from slowapi import Limiter +from slowapi.util import get_remote_address + + +limiter = Limiter(key_func=get_remote_address, default_limits=["1000/minute"]) diff --git a/main.py b/main.py index eaf0c0d..0782ac1 100644 --- a/main.py +++ b/main.py @@ -1,9 +1,18 @@ +import logging from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware +from rich.logging import RichHandler +from slowapi.errors import RateLimitExceeded from Routers.user import user_router +from Models.response import ExceptionResponse +from Services.Limiter.limiter import limiter + +logging.basicConfig(level="INFO", format="%(message)s", datefmt="[%X]", handlers=[RichHandler()]) app = FastAPI() +app.state.limiter = limiter +app.add_exception_handler(RateLimitExceeded, ExceptionResponse.limit_exceeded) # type: ignore app.add_middleware( CORSMiddleware, # type: ignore allow_origins=["*"], diff --git a/pyproject.toml b/pyproject.toml index bd0cdca..a0a3be8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,8 @@ psycopg2 = "^2.9.9" uvicorn = "^0.26.0" python-multipart = "^0.0.6" pycryptodome = "^3.20.0" +rich = "^13.7.0" +slowapi = "^0.1.8" [build-system]