From 51909693fd173373f921010cd5616cfd10a71542 Mon Sep 17 00:00:00 2001 From: Danny Avila Date: Fri, 22 Mar 2024 18:53:10 -0400 Subject: [PATCH] =?UTF-8?q?=F0=9F=92=BD=20feat:=20New=20File=20Uploading?= =?UTF-8?q?=20Strategy=20&=20Additional=20Security=20(#9)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: new file strategy using upload * chore: typing and user authorization for queries/embedding --- main.py | 109 +++++++++++++++++++++++++++++++++++++++-------- requirements.txt | 1 + 2 files changed, 92 insertions(+), 18 deletions(-) diff --git a/main.py b/main.py index 750eddac..8a3bba09 100644 --- a/main.py +++ b/main.py @@ -1,13 +1,17 @@ import os import hashlib +import aiofiles +import aiofiles.os +from typing import Iterable from shutil import copyfileobj from langchain.schema import Document from contextlib import asynccontextmanager +from fastapi.responses import JSONResponse from dotenv import find_dotenv, load_dotenv from fastapi.middleware.cors import CORSMiddleware -from fastapi import FastAPI, File, Form, UploadFile, HTTPException, status from langchain_core.runnables.config import run_in_executor from langchain.text_splitter import RecursiveCharacterTextSplitter +from fastapi import FastAPI, File, Form, UploadFile, HTTPException, status, Request from langchain_community.document_loaders import ( WebBaseLoader, TextLoader, @@ -32,6 +36,7 @@ load_dotenv(find_dotenv()) from config import ( + logger, debug_mode, CHUNK_SIZE, CHUNK_OVERLAP, @@ -122,7 +127,13 @@ async def delete_documents(ids: list[str]): raise HTTPException(status_code=500, detail=str(e)) @app.post("/query") -async def query_embeddings_by_file_id(body: QueryRequestBody): +async def query_embeddings_by_file_id(body: QueryRequestBody, request: Request): + if not hasattr(request.state, 'user'): + user_authorized = "public" + else: + user_authorized = request.state.user.get('id'); + + authorized_documents = [] try: embedding = vector_store.embedding_function.embed_query(body.query) @@ -141,7 +152,16 @@ async def query_embeddings_by_file_id(body: QueryRequestBody): filter={"file_id": body.file_id} ) - return documents + document, score = documents[0] + doc_metadata = document.metadata + doc_user_id = doc_metadata.get('user_id') + + if doc_user_id is None or doc_user_id == user_authorized: + authorized_documents = documents + else: + logger.warn(f"Unauthorized access attempt by user {user_authorized} to a document with user_id {doc_user_id}") + + return authorized_documents except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -149,7 +169,7 @@ def generate_digest(page_content: str): hash_obj = hashlib.md5(page_content.encode()) return hash_obj.hexdigest() -async def store_data_in_vector_db(data, file_id, overwrite: bool = False) -> bool: +async def store_data_in_vector_db(data: Iterable[Document], file_id: str, user_id: str = '') -> bool: text_splitter = RecursiveCharacterTextSplitter( chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP ) @@ -162,6 +182,7 @@ async def store_data_in_vector_db(data, file_id, overwrite: bool = False) -> boo page_content=doc.page_content, metadata={ "file_id": file_id, + "user_id": user_id, "digest": generate_digest(doc.page_content), **(doc.metadata or {}), }, @@ -178,13 +199,7 @@ async def store_data_in_vector_db(data, file_id, overwrite: bool = False) -> boo return {"message": "Documents added successfully", "ids": ids} except Exception as e: - print(e) - # Checking if a unique constraint error occurred, to handle overwrite logic if needed. - if e.__class__.__name__ == "UniqueConstraintError" and overwrite: - # Functionality to overwrite existing documents. - # This might require fetching existing document IDs, deleting them, and then re-inserting the documents. - return {"message": "Documents exist. Overwrite not implemented.", "error": str(e)} - + logger.error(e) return {"message": "An error occurred while adding documents.", "error": str(e)} def get_loader(filename: str, file_content_type: str, filepath: str): @@ -224,8 +239,8 @@ def get_loader(filename: str, file_content_type: str, filepath: str): return loader, known_type -@app.post("/embed") -async def embed_file(document: StoreDocument): +@app.post("/local/embed") +async def embed_local_file(document: StoreDocument, request: Request): # Check if the file exists if not os.path.exists(document.filepath): @@ -233,11 +248,16 @@ async def embed_file(document: StoreDocument): status_code=status.HTTP_404_NOT_FOUND, detail=ERROR_MESSAGES.FILE_NOT_FOUND, ) + + if not hasattr(request.state, 'user'): + user_id = "public" + else: + user_id = request.state.user.get('id'); try: loader, known_type = get_loader(document.filename, document.file_content_type, document.filepath) data = loader.load() - result = await store_data_in_vector_db(data, document.file_id) + result = await store_data_in_vector_db(data, document.file_id, user_id) if result: return { @@ -252,7 +272,7 @@ async def embed_file(document: StoreDocument): detail=ERROR_MESSAGES.DEFAULT(), ) except Exception as e: - print(e) + logger.error(e) if "No pandoc was found" in str(e): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, @@ -263,6 +283,54 @@ async def embed_file(document: StoreDocument): status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), ) + +@app.post("/embed") +async def embed_file(request: Request, file_id: str = Form(...), file: UploadFile = File(...)): + known_type = None + if not hasattr(request.state, 'user'): + user_id = "public" + else: + user_id = request.state.user.get('id'); + + temp_base_path = os.path.join(RAG_UPLOAD_DIR, user_id) + os.makedirs(temp_base_path, exist_ok=True) + temp_file_path = os.path.join(RAG_UPLOAD_DIR, user_id, file.filename) + + try: + async with aiofiles.open(temp_file_path, 'wb') as temp_file: + chunk_size = 64 * 1024 # 64 KB + while content := await file.read(chunk_size): + await temp_file.write(content) + except Exception as e: + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to save the uploaded file. Error: {str(e)}") + + try: + loader, known_type = get_loader(file.filename, file.content_type, temp_file_path) + data = loader.load() + result = await store_data_in_vector_db(data, file_id, user_id) + + if not result: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to process/store the file data.", + ) + except Exception as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Error during file processing: {str(e)}") + finally: + try: + await aiofiles.os.remove(temp_file_path) + except Exception as e: + logger.info(f"Failed to remove temporary file: {str(e)}") + + return { + "status": True, + "message": "File processed successfully.", + "file_id": file_id, + "filename": file.filename, + "known_type": known_type, + } @app.get("/documents/{id}/context") async def load_document_context(id: str): @@ -280,16 +348,21 @@ async def load_document_context(id: str): return process_documents(documents) except Exception as e: - print(e) + logger.error(e) raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=ERROR_MESSAGES.DEFAULT(e), ) @app.post("/embed-upload") -async def embed_file_upload(file_id: str = Form(...), uploaded_file: UploadFile = File(...)): +async def embed_file_upload(request: Request, file_id: str = Form(...), uploaded_file: UploadFile = File(...)): temp_file_path = os.path.join(RAG_UPLOAD_DIR, uploaded_file.filename) + if not hasattr(request.state, 'user'): + user_id = "public" + else: + user_id = request.state.user.get('id'); + try: with open(temp_file_path, 'wb') as temp_file: copyfileobj(uploaded_file.file, temp_file) @@ -301,7 +374,7 @@ async def embed_file_upload(file_id: str = Form(...), uploaded_file: UploadFile loader, known_type = get_loader(uploaded_file.filename, uploaded_file.content_type, temp_file_path) data = loader.load() - result = await store_data_in_vector_db(data, file_id) + result = await store_data_in_vector_db(data, file_id, user_id) if not result: raise HTTPException( diff --git a/requirements.txt b/requirements.txt index a21c6da0..a2c049d1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,3 +20,4 @@ python-jose==3.3.0 asyncpg==0.29.0 python-multipart==0.0.9 sentence_transformers==2.5.1 +aiofiles==23.2.1