Skip to content

Commit

Permalink
💽 feat: New File Uploading Strategy & Additional Security (#9)
Browse files Browse the repository at this point in the history
* feat: new file strategy using upload

* chore: typing and user authorization for queries/embedding
  • Loading branch information
danny-avila authored Mar 22, 2024
1 parent 90970a7 commit 5190969
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 18 deletions.
109 changes: 91 additions & 18 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import os
import hashlib
import aiofiles
import aiofiles.os
from typing import Iterable
from shutil import copyfileobj
from langchain.schema import Document
from contextlib import asynccontextmanager
from fastapi.responses import JSONResponse
from dotenv import find_dotenv, load_dotenv
from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI, File, Form, UploadFile, HTTPException, status
from langchain_core.runnables.config import run_in_executor
from langchain.text_splitter import RecursiveCharacterTextSplitter
from fastapi import FastAPI, File, Form, UploadFile, HTTPException, status, Request
from langchain_community.document_loaders import (
WebBaseLoader,
TextLoader,
Expand All @@ -32,6 +36,7 @@
load_dotenv(find_dotenv())

from config import (
logger,
debug_mode,
CHUNK_SIZE,
CHUNK_OVERLAP,
Expand Down Expand Up @@ -122,7 +127,13 @@ async def delete_documents(ids: list[str]):
raise HTTPException(status_code=500, detail=str(e))

@app.post("/query")
async def query_embeddings_by_file_id(body: QueryRequestBody):
async def query_embeddings_by_file_id(body: QueryRequestBody, request: Request):
if not hasattr(request.state, 'user'):
user_authorized = "public"
else:
user_authorized = request.state.user.get('id');

authorized_documents = []
try:
embedding = vector_store.embedding_function.embed_query(body.query)

Expand All @@ -141,15 +152,24 @@ async def query_embeddings_by_file_id(body: QueryRequestBody):
filter={"file_id": body.file_id}
)

return documents
document, score = documents[0]
doc_metadata = document.metadata
doc_user_id = doc_metadata.get('user_id')

if doc_user_id is None or doc_user_id == user_authorized:
authorized_documents = documents
else:
logger.warn(f"Unauthorized access attempt by user {user_authorized} to a document with user_id {doc_user_id}")

return authorized_documents
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

def generate_digest(page_content: str):
hash_obj = hashlib.md5(page_content.encode())
return hash_obj.hexdigest()

async def store_data_in_vector_db(data, file_id, overwrite: bool = False) -> bool:
async def store_data_in_vector_db(data: Iterable[Document], file_id: str, user_id: str = '') -> bool:
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=app.state.CHUNK_SIZE, chunk_overlap=app.state.CHUNK_OVERLAP
)
Expand All @@ -162,6 +182,7 @@ async def store_data_in_vector_db(data, file_id, overwrite: bool = False) -> boo
page_content=doc.page_content,
metadata={
"file_id": file_id,
"user_id": user_id,
"digest": generate_digest(doc.page_content),
**(doc.metadata or {}),
},
Expand All @@ -178,13 +199,7 @@ async def store_data_in_vector_db(data, file_id, overwrite: bool = False) -> boo
return {"message": "Documents added successfully", "ids": ids}

except Exception as e:
print(e)
# Checking if a unique constraint error occurred, to handle overwrite logic if needed.
if e.__class__.__name__ == "UniqueConstraintError" and overwrite:
# Functionality to overwrite existing documents.
# This might require fetching existing document IDs, deleting them, and then re-inserting the documents.
return {"message": "Documents exist. Overwrite not implemented.", "error": str(e)}

logger.error(e)
return {"message": "An error occurred while adding documents.", "error": str(e)}

def get_loader(filename: str, file_content_type: str, filepath: str):
Expand Down Expand Up @@ -224,20 +239,25 @@ def get_loader(filename: str, file_content_type: str, filepath: str):

return loader, known_type

@app.post("/embed")
async def embed_file(document: StoreDocument):
@app.post("/local/embed")
async def embed_local_file(document: StoreDocument, request: Request):

# Check if the file exists
if not os.path.exists(document.filepath):
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=ERROR_MESSAGES.FILE_NOT_FOUND,
)

if not hasattr(request.state, 'user'):
user_id = "public"
else:
user_id = request.state.user.get('id');

try:
loader, known_type = get_loader(document.filename, document.file_content_type, document.filepath)
data = loader.load()
result = await store_data_in_vector_db(data, document.file_id)
result = await store_data_in_vector_db(data, document.file_id, user_id)

if result:
return {
Expand All @@ -252,7 +272,7 @@ async def embed_file(document: StoreDocument):
detail=ERROR_MESSAGES.DEFAULT(),
)
except Exception as e:
print(e)
logger.error(e)
if "No pandoc was found" in str(e):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
Expand All @@ -263,6 +283,54 @@ async def embed_file(document: StoreDocument):
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)

@app.post("/embed")
async def embed_file(request: Request, file_id: str = Form(...), file: UploadFile = File(...)):
known_type = None
if not hasattr(request.state, 'user'):
user_id = "public"
else:
user_id = request.state.user.get('id');

temp_base_path = os.path.join(RAG_UPLOAD_DIR, user_id)
os.makedirs(temp_base_path, exist_ok=True)
temp_file_path = os.path.join(RAG_UPLOAD_DIR, user_id, file.filename)

try:
async with aiofiles.open(temp_file_path, 'wb') as temp_file:
chunk_size = 64 * 1024 # 64 KB
while content := await file.read(chunk_size):
await temp_file.write(content)
except Exception as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to save the uploaded file. Error: {str(e)}")

try:
loader, known_type = get_loader(file.filename, file.content_type, temp_file_path)
data = loader.load()
result = await store_data_in_vector_db(data, file_id, user_id)

if not result:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to process/store the file data.",
)
except Exception as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Error during file processing: {str(e)}")
finally:
try:
await aiofiles.os.remove(temp_file_path)
except Exception as e:
logger.info(f"Failed to remove temporary file: {str(e)}")

return {
"status": True,
"message": "File processed successfully.",
"file_id": file_id,
"filename": file.filename,
"known_type": known_type,
}

@app.get("/documents/{id}/context")
async def load_document_context(id: str):
Expand All @@ -280,16 +348,21 @@ async def load_document_context(id: str):

return process_documents(documents)
except Exception as e:
print(e)
logger.error(e)
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=ERROR_MESSAGES.DEFAULT(e),
)

@app.post("/embed-upload")
async def embed_file_upload(file_id: str = Form(...), uploaded_file: UploadFile = File(...)):
async def embed_file_upload(request: Request, file_id: str = Form(...), uploaded_file: UploadFile = File(...)):
temp_file_path = os.path.join(RAG_UPLOAD_DIR, uploaded_file.filename)

if not hasattr(request.state, 'user'):
user_id = "public"
else:
user_id = request.state.user.get('id');

try:
with open(temp_file_path, 'wb') as temp_file:
copyfileobj(uploaded_file.file, temp_file)
Expand All @@ -301,7 +374,7 @@ async def embed_file_upload(file_id: str = Form(...), uploaded_file: UploadFile
loader, known_type = get_loader(uploaded_file.filename, uploaded_file.content_type, temp_file_path)

data = loader.load()
result = await store_data_in_vector_db(data, file_id)
result = await store_data_in_vector_db(data, file_id, user_id)

if not result:
raise HTTPException(
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ python-jose==3.3.0
asyncpg==0.29.0
python-multipart==0.0.9
sentence_transformers==2.5.1
aiofiles==23.2.1

0 comments on commit 5190969

Please sign in to comment.