Skip to content

Commit

Permalink
✨ feat: Add 'entity_id' Parameter (#107)
Browse files Browse the repository at this point in the history
* ✨ feat: Add envFile configuration to VSCode launch settings

* 🛠️ fix: Improve debug_mode environment variable handling for consistency

* feat: add entity_id as alternate user_id param
  • Loading branch information
danny-avila authored Dec 18, 2024
1 parent 95a0cd0 commit 9e4bb52
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 21 deletions.
3 changes: 2 additions & 1 deletion .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
"8000",
"--reload"
],
"jinja": true
"jinja": true,
"envFile": "${workspaceFolder}/.env"
}
]
}
8 changes: 7 additions & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,13 @@ def get_env_variable(

logger = logging.getLogger()

debug_mode = get_env_variable("DEBUG_RAG_API", "False").lower() == "true"
debug_mode = os.getenv("DEBUG_RAG_API", "False").lower() in (
"true",
"1",
"yes",
"y",
"t",
)
console_json = get_env_variable("CONSOLE_JSON", "False").lower() == "true"

if debug_mode:
Expand Down
83 changes: 65 additions & 18 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from langchain.schema import Document
from contextlib import asynccontextmanager
from dotenv import find_dotenv, load_dotenv
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.exceptions import RequestValidationError
from langchain_core.runnables.config import run_in_executor
from langchain.text_splitter import RecursiveCharacterTextSplitter
from fastapi import (
Expand Down Expand Up @@ -84,7 +86,7 @@ async def lifespan(app: FastAPI):
yield


app = FastAPI(lifespan=lifespan)
app = FastAPI(lifespan=lifespan, debug=debug_mode)

app.add_middleware(
CORSMiddleware,
Expand Down Expand Up @@ -213,10 +215,17 @@ async def delete_documents(document_ids: List[str] = Body(...)):


@app.post("/query")
async def query_embeddings_by_file_id(body: QueryRequestBody, request: Request):
user_authorized = (
"public" if not hasattr(request.state, "user") else request.state.user.get("id")
)
async def query_embeddings_by_file_id(
body: QueryRequestBody,
request: Request,
):
if not hasattr(request.state, "user"):
user_authorized = body.entity_id if body.entity_id else "public"
else:
user_authorized = (
body.entity_id if body.entity_id else request.state.user.get("id")
)

authorized_documents = []

try:
Expand Down Expand Up @@ -245,9 +254,24 @@ async def query_embeddings_by_file_id(body: QueryRequestBody, request: Request):
if doc_user_id is None or doc_user_id == user_authorized:
authorized_documents = documents
else:
logger.warn(
f"Unauthorized access attempt by user {user_authorized} to a document with user_id {doc_user_id}"
)
# If using entity_id and access denied, try again with user's actual ID
if body.entity_id and hasattr(request.state, "user"):
user_authorized = request.state.user.get("id")
if doc_user_id == user_authorized:
authorized_documents = documents
else:
if body.entity_id == doc_user_id:
logger.warning(
f"Entity ID {body.entity_id} matches document user_id but user {user_authorized} is not authorized"
)
else:
logger.warning(
f"Access denied for both entity ID {body.entity_id} and user {user_authorized} to document with user_id {doc_user_id}"
)
else:
logger.warning(
f"Unauthorized access attempt by user {user_authorized} to a document with user_id {doc_user_id}"
)

return authorized_documents

Expand Down Expand Up @@ -361,8 +385,9 @@ def get_loader(filename: str, file_content_type: str, filepath: str):


@app.post("/local/embed")
async def embed_local_file(document: StoreDocument, request: Request):

async def embed_local_file(
document: StoreDocument, request: Request, entity_id: str = None
):
# Check if the file exists
if not os.path.exists(document.filepath):
raise HTTPException(
Expand All @@ -371,9 +396,9 @@ async def embed_local_file(document: StoreDocument, request: Request):
)

if not hasattr(request.state, "user"):
user_id = "public"
user_id = entity_id if entity_id else "public"
else:
user_id = request.state.user.get("id")
user_id = entity_id if entity_id else request.state.user.get("id")

try:
loader, known_type = get_loader(
Expand Down Expand Up @@ -410,15 +435,18 @@ async def embed_local_file(document: StoreDocument, request: Request):

@app.post("/embed")
async def embed_file(
request: Request, file_id: str = Form(...), file: UploadFile = File(...)
request: Request,
file_id: str = Form(...),
file: UploadFile = File(...),
entity_id: str = Form(None),
):
response_status = True
response_message = "File processed successfully."
known_type = None
if not hasattr(request.state, "user"):
user_id = "public"
user_id = entity_id if entity_id else "public"
else:
user_id = request.state.user.get("id")
user_id = entity_id if entity_id else request.state.user.get("id")

temp_base_path = os.path.join(RAG_UPLOAD_DIR, user_id)
os.makedirs(temp_base_path, exist_ok=True)
Expand Down Expand Up @@ -538,14 +566,17 @@ async def load_document_context(id: str):

@app.post("/embed-upload")
async def embed_file_upload(
request: Request, file_id: str = Form(...), uploaded_file: UploadFile = File(...)
request: Request,
file_id: str = Form(...),
uploaded_file: UploadFile = File(...),
entity_id: str = Form(None),
):
temp_file_path = os.path.join(RAG_UPLOAD_DIR, uploaded_file.filename)

if not hasattr(request.state, "user"):
user_id = "public"
user_id = entity_id if entity_id else "public"
else:
user_id = request.state.user.get("id")
user_id = entity_id if entity_id else request.state.user.get("id")

try:
with open(temp_file_path, "wb") as temp_file:
Expand Down Expand Up @@ -624,6 +655,22 @@ async def query_embeddings_by_file_ids(body: QueryMultipleBody):
raise HTTPException(status_code=500, detail=str(e))


@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
body = await request.body()
logger.debug(f"Validation error occurred")
logger.debug(f"Raw request body: {body.decode()}")
logger.debug(f"Validation errors: {exc.errors()}")
return JSONResponse(
status_code=422,
content={
"detail": exc.errors(),
"body": body.decode(),
"message": "Request validation failed",
},
)


if debug_mode:
app.include_router(router=pgvector_router)

Expand Down
3 changes: 2 additions & 1 deletion models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ class StoreDocument(BaseModel):


class QueryRequestBody(BaseModel):
file_id: str
query: str
file_id: str
k: int = 4
entity_id: Optional[str] = None


class CleanupMethod(str, Enum):
Expand Down

0 comments on commit 9e4bb52

Please sign in to comment.