diff --git a/config.py b/config.py index fc20e0d4..e4263612 100644 --- a/config.py +++ b/config.py @@ -1,9 +1,9 @@ # config.py -import json import os +import json import logging +from enum import Enum from datetime import datetime - from dotenv import find_dotenv, load_dotenv from langchain_community.embeddings import ( HuggingFaceEmbeddings, @@ -12,12 +12,24 @@ ) from langchain_openai import AzureOpenAIEmbeddings, OpenAIEmbeddings from starlette.middleware.base import BaseHTTPMiddleware - from store_factory import get_vector_store load_dotenv(find_dotenv()) +class VectorDBType(Enum): + PGVECTOR = "pgvector" + ATLAS_MONGO = "atlas-mongo" + + +class EmbeddingsProvider(Enum): + OPENAI = "openai" + AZURE = "azure" + HUGGINGFACE = "huggingface" + HUGGINGFACETEI = "huggingfacetei" + OLLAMA = "ollama" + + def get_env_variable( var_name: str, default_value: str = None, required: bool = False ) -> str: @@ -36,7 +48,9 @@ def get_env_variable( if not os.path.exists(RAG_UPLOAD_DIR): os.makedirs(RAG_UPLOAD_DIR, exist_ok=True) -VECTOR_DB_TYPE = get_env_variable("VECTOR_DB_TYPE", "pgvector") +VECTOR_DB_TYPE = VectorDBType( + get_env_variable("VECTOR_DB_TYPE", VectorDBType.PGVECTOR.value) +) POSTGRES_DB = get_env_variable("POSTGRES_DB", "mydatabase") POSTGRES_USER = get_env_variable("POSTGRES_USER", "myuser") POSTGRES_PASSWORD = get_env_variable("POSTGRES_PASSWORD", "mypassword") @@ -140,7 +154,6 @@ async def dispatch(self, request, call_next): logging.getLogger("uvicorn.access").disabled = True - ## Credentials OPENAI_API_KEY = get_env_variable("OPENAI_API_KEY", "") @@ -163,51 +176,49 @@ async def dispatch(self, request, call_next): def init_embeddings(provider, model): - if provider == "openai": + if provider == EmbeddingsProvider.OPENAI: return OpenAIEmbeddings( model=model, api_key=RAG_OPENAI_API_KEY, openai_api_base=RAG_OPENAI_BASEURL, openai_proxy=RAG_OPENAI_PROXY, ) - elif provider == "azure": + elif provider == EmbeddingsProvider.AZURE: return AzureOpenAIEmbeddings( azure_deployment=model, api_key=RAG_AZURE_OPENAI_API_KEY, azure_endpoint=RAG_AZURE_OPENAI_ENDPOINT, api_version=RAG_AZURE_OPENAI_API_VERSION, ) - elif provider == "huggingface": + elif provider == EmbeddingsProvider.HUGGINGFACE: return HuggingFaceEmbeddings( model_name=model, encode_kwargs={"normalize_embeddings": True} ) - elif provider == "huggingfacetei": + elif provider == EmbeddingsProvider.HUGGINGFACETEI: return HuggingFaceHubEmbeddings(model=model) - elif provider == "ollama": + elif provider == EmbeddingsProvider.OLLAMA: return OllamaEmbeddings(model=model, base_url=OLLAMA_BASE_URL) else: raise ValueError(f"Unsupported embeddings provider: {provider}") -EMBEDDINGS_PROVIDER = get_env_variable("EMBEDDINGS_PROVIDER", "openai").lower() +EMBEDDINGS_PROVIDER = EmbeddingsProvider( + get_env_variable("EMBEDDINGS_PROVIDER", EmbeddingsProvider.OPENAI.value).lower() +) -if EMBEDDINGS_PROVIDER == "openai": +if EMBEDDINGS_PROVIDER == EmbeddingsProvider.OPENAI: EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "text-embedding-3-small") - -elif EMBEDDINGS_PROVIDER == "azure": +elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.AZURE: EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "text-embedding-3-small") - -elif EMBEDDINGS_PROVIDER == "huggingface": +elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.HUGGINGFACE: EMBEDDINGS_MODEL = get_env_variable( "EMBEDDINGS_MODEL", "sentence-transformers/all-MiniLM-L6-v2" ) - -elif EMBEDDINGS_PROVIDER == "huggingfacetei": +elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.HUGGINGFACETEI: EMBEDDINGS_MODEL = get_env_variable( "EMBEDDINGS_MODEL", "http://huggingfacetei:3000" ) - -elif EMBEDDINGS_PROVIDER == "ollama": +elif EMBEDDINGS_PROVIDER == EmbeddingsProvider.OLLAMA: EMBEDDINGS_MODEL = get_env_variable("EMBEDDINGS_MODEL", "nomic-embed-text") else: raise ValueError(f"Unsupported embeddings provider: {EMBEDDINGS_PROVIDER}") @@ -217,15 +228,15 @@ def init_embeddings(provider, model): logger.info(f"Initialized embeddings of type: {type(embeddings)}") # Vector store -if VECTOR_DB_TYPE == "pgvector": +if VECTOR_DB_TYPE == VectorDBType.PGVECTOR: vector_store = get_vector_store( connection_string=CONNECTION_STRING, embeddings=embeddings, collection_name=COLLECTION_NAME, mode="async", ) -elif VECTOR_DB_TYPE == "atlas-mongo": - # atlas-mongo vector: +elif VECTOR_DB_TYPE == VectorDBType.ATLAS_MONGO: + logger.warning("Using Atlas MongoDB as vector store is not fully supported yet.") vector_store = get_vector_store( connection_string=ATLAS_MONGO_DB_URI, embeddings=embeddings, diff --git a/main.py b/main.py index d4c288b7..ce5b880d 100644 --- a/main.py +++ b/main.py @@ -2,7 +2,7 @@ import hashlib import aiofiles import aiofiles.os -from typing import Iterable +from typing import Iterable, List from shutil import copyfileobj import uvicorn @@ -13,14 +13,15 @@ from langchain_core.runnables.config import run_in_executor from langchain.text_splitter import RecursiveCharacterTextSplitter from fastapi import ( - FastAPI, File, Form, + Body, Query, - UploadFile, - HTTPException, status, + FastAPI, Request, + UploadFile, + HTTPException, ) from langchain_community.document_loaders import ( WebBaseLoader, @@ -35,11 +36,17 @@ UnstructuredExcelLoader, ) -from models import DocumentResponse, StoreDocument, QueryRequestBody, QueryMultipleBody +from models import ( + StoreDocument, + QueryRequestBody, + DocumentResponse, + QueryMultipleBody, +) from psql import PSQLDatabase, ensure_custom_id_index_on_embedding, pg_health_check -from middleware import security_middleware from pgvector_routes import router as pgvector_router from parsers import process_documents, clean_text +from middleware import security_middleware +from mongo import mongo_health_check from constants import ERROR_MESSAGES from store import AsyncPgVector @@ -57,6 +64,7 @@ LogMiddleware, RAG_HOST, RAG_PORT, + VectorDBType, # RAG_EMBEDDING_MODEL, # RAG_EMBEDDING_MODEL_DEVICE_TYPE, # RAG_TEMPLATE, @@ -107,8 +115,10 @@ async def get_all_ids(): def isHealthOK(): - if VECTOR_DB_TYPE == "pgvector": + if VECTOR_DB_TYPE == VectorDBType.PGVECTOR: return pg_health_check() + if VECTOR_DB_TYPE == VectorDBType.ATLAS_MONGO: + return mongo_health_check() else: return True @@ -131,9 +141,16 @@ async def get_documents_by_ids(ids: list[str] = Query(...)): existing_ids = vector_store.get_all_ids() documents = vector_store.get_documents_by_ids(ids) + # Ensure all requested ids exist if not all(id in existing_ids for id in ids): raise HTTPException(status_code=404, detail="One or more IDs not found") + # Ensure documents list is not empty + if not documents: + raise HTTPException( + status_code=404, detail="No documents found for the given IDs" + ) + return documents except HTTPException as http_exc: raise http_exc @@ -142,19 +159,19 @@ async def get_documents_by_ids(ids: list[str] = Query(...)): @app.delete("/documents") -async def delete_documents(ids: list[str] = Query(...)): +async def delete_documents(document_ids: List[str] = Body(...)): try: if isinstance(vector_store, AsyncPgVector): existing_ids = await vector_store.get_all_ids() - await vector_store.delete(ids=ids) + await vector_store.delete(ids=document_ids) else: existing_ids = vector_store.get_all_ids() - vector_store.delete(ids=ids) + vector_store.delete(ids=document_ids) - if not all(id in existing_ids for id in ids): + if not all(id in existing_ids for id in document_ids): raise HTTPException(status_code=404, detail="One or more IDs not found") - file_count = len(ids) + file_count = len(document_ids) return { "message": f"Documents for {file_count} file{'s' if file_count > 1 else ''} deleted successfully" } @@ -164,12 +181,11 @@ async def delete_documents(ids: list[str] = Query(...)): @app.post("/query") async def query_embeddings_by_file_id(body: QueryRequestBody, request: Request): - if not hasattr(request.state, "user"): - user_authorized = "public" - else: - user_authorized = request.state.user.get("id") - + user_authorized = ( + "public" if not hasattr(request.state, "user") else request.state.user.get("id") + ) authorized_documents = [] + try: embedding = vector_store.embedding_function.embed_query(body.query) @@ -186,6 +202,9 @@ async def query_embeddings_by_file_id(body: QueryRequestBody, request: Request): embedding, k=body.k, filter={"file_id": body.file_id} ) + if not documents: + return authorized_documents + document, score = documents[0] doc_metadata = document.metadata doc_user_id = doc_metadata.get("user_id") @@ -198,6 +217,7 @@ async def query_embeddings_by_file_id(body: QueryRequestBody, request: Request): ) return authorized_documents + except Exception as e: logger.error(e) raise HTTPException(status_code=500, detail=str(e)) @@ -427,11 +447,18 @@ async def load_document_context(id: str): existing_ids = vector_store.get_all_ids() documents = vector_store.get_documents_by_ids(ids) + # Ensure the requested id exists if not all(id in existing_ids for id in ids): raise HTTPException( status_code=404, detail="The specified file_id was not found" ) + # Ensure documents list is not empty + if not documents: + raise HTTPException( + status_code=404, detail="No document found for the given ID" + ) + return process_documents(documents) except Exception as e: logger.error(e) @@ -511,6 +538,12 @@ async def query_embeddings_by_file_ids(body: QueryMultipleBody): embedding, k=body.k, filter={"file_id": {"$in": body.file_ids}} ) + # Ensure documents list is not empty + if not documents: + raise HTTPException( + status_code=404, detail="No documents found for the given query" + ) + return documents except Exception as e: raise HTTPException(status_code=500, detail=str(e)) diff --git a/middleware.py b/middleware.py index 30dec8d1..76794478 100644 --- a/middleware.py +++ b/middleware.py @@ -1,48 +1,57 @@ import os -from datetime import datetime, timezone from fastapi import Request +from datetime import datetime, timezone from fastapi.responses import JSONResponse -from jose import jwt, JWTError from config import logger +import jwt +from jwt import PyJWTError + async def security_middleware(request: Request, call_next): - async def next(): - response = await call_next(request) - return response - - if (request.url.path == "/docs" or - request.url.path == "/openapi.json" or - request.url.path == "/health"): - return await next() - - jwt_secret = os.getenv('JWT_SECRET') - - if jwt_secret: - authorization = request.headers.get('Authorization') - if not authorization or not authorization.startswith('Bearer '): - logger.info(f"Unauthorized request with missing or invalid Authorization header to: {request.url.path}") - return JSONResponse(status_code=401, content = { "detail" : "Missing or invalid Authorization header" }) - - token = authorization.split(' ')[1] - try: - payload = jwt.decode(token, jwt_secret, algorithms=['HS256']) - - # Check if the token has expired - exp_timestamp = payload.get('exp') - if exp_timestamp: - exp_datetime = datetime.fromtimestamp(exp_timestamp, tz=timezone.utc) - current_datetime = datetime.now(tz=timezone.utc) - if current_datetime > exp_datetime: - logger.info(f"Unauthorized request with expired token to: {request.url.path}") - return JSONResponse(status_code=401, content = { "detail" : "Token has expired" }) - - request.state.user = payload - logger.debug(f"{request.url.path} - {payload}") - except JWTError as e: - logger.info(f"Unauthorized request with invalid token to: {request.url.path}, reason: {str(e)}") - return JSONResponse(status_code=401, content = { "detail" : f"Invalid token: {str(e)}" }) - else: - logger.warn("JWT_SECRET not found in environment variables") + async def next_middleware_call(): + return await call_next(request) + + if request.url.path in {"/docs", "/openapi.json", "/health"}: + return await next_middleware_call() - return await next() + jwt_secret = os.getenv("JWT_SECRET") + if not jwt_secret: + logger.warn("JWT_SECRET not found in environment variables") + return await next_middleware_call() + + authorization = request.headers.get("Authorization") + if not authorization or not authorization.startswith("Bearer "): + logger.info( + f"Unauthorized request with missing or invalid Authorization header to: {request.url.path}" + ) + return JSONResponse( + status_code=401, + content={"detail": "Missing or invalid Authorization header"}, + ) + + token = authorization.split(" ")[1] + try: + payload = jwt.decode(token, jwt_secret, algorithms=["HS256"]) + exp_timestamp = payload.get("exp") + if exp_timestamp and datetime.now(tz=timezone.utc) > datetime.fromtimestamp( + exp_timestamp, tz=timezone.utc + ): + logger.info( + f"Unauthorized request with expired token to: {request.url.path}" + ) + return JSONResponse( + status_code=401, content={"detail": "Token has expired"} + ) + + request.state.user = payload + logger.debug(f"{request.url.path} - {payload}") + except PyJWTError as e: + logger.info( + f"Unauthorized request with invalid token to: {request.url.path}, reason: {str(e)}" + ) + return JSONResponse( + status_code=401, content={"detail": f"Invalid token: {str(e)}"} + ) + + return await next_middleware_call() diff --git a/models.py b/models.py index a8dd04cc..b584c992 100644 --- a/models.py +++ b/models.py @@ -3,6 +3,7 @@ from pydantic import BaseModel from typing import Optional, List + class DocumentResponse(BaseModel): page_content: str metadata: dict @@ -15,23 +16,27 @@ class DocumentModel(BaseModel): def generate_digest(self): hash_obj = hashlib.md5(self.page_content.encode()) return hash_obj.hexdigest() - + + class StoreDocument(BaseModel): filepath: str filename: str file_content_type: str file_id: str + class QueryRequestBody(BaseModel): file_id: str query: str k: int = 4 + class CleanupMethod(str, Enum): incremental = "incremental" full = "full" + class QueryMultipleBody(BaseModel): query: str file_ids: List[str] - k: int = 4 \ No newline at end of file + k: int = 4 diff --git a/mongo.py b/mongo.py new file mode 100644 index 00000000..93c78ea3 --- /dev/null +++ b/mongo.py @@ -0,0 +1,16 @@ +import logging +from pymongo import MongoClient +from pymongo.errors import PyMongoError +from config import ATLAS_MONGO_DB_URI + +logger = logging.getLogger(__name__) + + +async def mongo_health_check() -> bool: + try: + client = MongoClient(ATLAS_MONGO_DB_URI) + client.admin.command("ping") + return True + except PyMongoError as e: + logger.error(f"MongoDB health check failed: {e}") + return False diff --git a/requirements.lite.txt b/requirements.lite.txt index 6070ccf3..61083069 100644 --- a/requirements.lite.txt +++ b/requirements.lite.txt @@ -16,7 +16,7 @@ pandas==2.2.1 openpyxl==3.1.2 docx2txt==0.8 pypandoc==1.13 -python-jose==3.3.0 +PyJWT==2.8.0 asyncpg==0.29.0 python-multipart==0.0.9 aiofiles==23.2.1 diff --git a/requirements.txt b/requirements.txt index fc569822..37cdaa2f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -16,7 +16,7 @@ pandas==2.2.1 openpyxl==3.1.2 docx2txt==0.8 pypandoc==1.13 -python-jose==3.3.0 +PyJWT==2.8.0 asyncpg==0.29.0 python-multipart==0.0.9 sentence_transformers==2.5.1