Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add scan route and handling #4

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ ENV \
RUN : \
&& apt-get update \
&& apt-get install -y --no-install-recommends \
build-essential \
curl \
ca-certificates \
&& apt-get clean \
Expand Down
3 changes: 2 additions & 1 deletion backend/app/api/router.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from app.api.routes import health, login, users
from app.api.routes import health, login, scan, users
from app.core.utils import APIRouter

api_router = APIRouter()
api_router.include_router(health.router)
api_router.include_router(login.router)
api_router.include_router(scan.router)
api_router.include_router(users.router)
29 changes: 29 additions & 0 deletions backend/app/api/routes/scan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from fastapi import HTTPException
from loguru import logger
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR

from app.api.deps import CurrentUser
from app.core.config import settings
from app.core.utils import APIRouter
from app.models.scan import AnalysisReport, Topic
from app.scan.graph import CustomGraph

router = APIRouter(tags=["SCAN"], prefix=f"{settings.API_V1_PREFIX}/scan")


@router.post("/")
async def ask_question(*, topic: Topic, _: CurrentUser) -> AnalysisReport:
"""Ask SCAN for help with a questions."""

graph = CustomGraph(topic.topic)
try:
logger.debug("Preparing analysis")
analysis = await graph.execute()
except Exception as e:
logger.error(f"An error occurred while answering question: {e}")
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="An error occurred when getting an answer",
) from e

return analysis
13 changes: 13 additions & 0 deletions backend/app/models/scan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from camel_converter.pydantic_base import CamelBase


class AnalysisReport(CamelBase):
dlpfc_analysis: str | None = None
vmpfc_analysis: str | None = None
ofc_analysis: str | None = None
acc_analysis: str | None = None
mpfc_analysis: str | None = None


class Topic(CamelBase):
topic: str
Empty file added backend/app/scan/__init__.py
Empty file.
86 changes: 86 additions & 0 deletions backend/app/scan/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from collections.abc import Hashable
from typing import cast

from langgraph.graph import END, StateGraph

from app.models.scan import AnalysisReport
from app.scan.pfc_agents import AgentState, PFCAgents


class CustomGraph:
def __init__(self, topic: str) -> None:
self.topic = topic
self.agents = PFCAgents(topic=self.topic)
self.workflow = StateGraph(AgentState)

def build_graph(self) -> None:
for role, agent_function in self.agents.agents.items():
self.workflow.add_node(role, agent_function)

def router(state: AgentState) -> str | None:
next_agent = state["next"]
if next_agent == END:
return None

if next_agent is not None:
return next_agent
return END

dlpfc_edges: dict[str, str] = {
"VMPFC": "VMPFC",
"OFC": "OFC",
"ACC": "ACC",
"MPFC": "MPFC",
END: END,
}

self.workflow.add_conditional_edges("DLPFC", router, cast(dict[Hashable, str], dlpfc_edges))

for role in ("VMPFC", "OFC", "ACC", "MPFC"):
self.workflow.add_conditional_edges(
role,
router,
)

self.workflow.set_entry_point("DLPFC")
self.graph = self.workflow.compile()

async def execute(self) -> AnalysisReport:
self.build_graph()

initial_state = AgentState(
input=self.topic,
history=[],
next="DLPFC",
current_role="DLPFC",
)

final_state = await self.graph.ainvoke(initial_state)
return self.prepare_output(cast(AgentState, final_state))

def prepare_output(self, state: AgentState) -> AnalysisReport:
dlpfc_analysis = None
vmpfc_analysis = None
ofc_analysis = None
acc_analysis = None
mpfc_analysis = None

for role, analysis in state["history"]:
if role == "DLPFC":
dlpfc_analysis = analysis if analysis else None
elif role == "VMPFC":
vmpfc_analysis = analysis if analysis else None
elif role == "OFC":
ofc_analysis = analysis if analysis else None
elif role == "ACC":
acc_analysis = analysis if analysis else None
elif role == "MPFC":
mpfc_analysis = analysis if analysis else None

return AnalysisReport(
dlpfc_analysis=dlpfc_analysis,
vmpfc_analysis=vmpfc_analysis,
ofc_analysis=ofc_analysis,
acc_analysis=acc_analysis,
mpfc_analysis=mpfc_analysis,
)
34 changes: 34 additions & 0 deletions backend/app/scan/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from collections.abc import Iterable

from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam

from app.core.config import settings


class OpenAIWrapper:
"""Wrapper for OpenAI API interactions."""

def __init__(self, model_name: str):
self.model_name = model_name
self.client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)

async def create_chat_completion(
self,
messages: Iterable[ChatCompletionMessageParam],
max_tokens: int = settings.MAX_TOKENS,
temperature: float = settings.TEMPERATURE,
) -> str | None:
"""Create a chat completion with error handling."""

response = await self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
)

if not response.choices[0].message.content:
return None

return response.choices[0].message.content.strip()
Loading