Skip to content

Commit

Permalink
Setup scan route
Browse files Browse the repository at this point in the history
  • Loading branch information
sanders41 committed Nov 16, 2024
1 parent 83d4246 commit 92c0aab
Show file tree
Hide file tree
Showing 12 changed files with 1,124 additions and 3 deletions.
3 changes: 2 additions & 1 deletion backend/app/api/router.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from app.api.routes import health, login, users
from app.api.routes import health, login, scan, users
from app.core.utils import APIRouter

api_router = APIRouter()
api_router.include_router(health.router)
api_router.include_router(login.router)
api_router.include_router(scan.router)
api_router.include_router(users.router)
29 changes: 29 additions & 0 deletions backend/app/api/routes/scan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from fastapi import HTTPException
from loguru import logger
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR

from app.api.deps import CurrentUser
from app.core.config import settings
from app.core.utils import APIRouter
from app.models.scan import AnalysisReport, Topic
from app.scan.graph import CustomGraph

router = APIRouter(tags=["SCAN"], prefix=f"{settings.API_V1_PREFIX}/scan")


@router.post("/")
async def ask_question(*, topic: Topic, _: CurrentUser) -> AnalysisReport:
"""Ask SCAN for help with a questions."""

graph = CustomGraph(topic.topic)
try:
logger.debug("Preparing analysis")
analysis = await graph.execute()
except Exception as e:
logger.error(f"An error occurred while answering question: {e}")
raise HTTPException(
status_code=HTTP_500_INTERNAL_SERVER_ERROR,
detail="An error occurred when getting an answer",
) from e

return analysis
13 changes: 13 additions & 0 deletions backend/app/models/scan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from camel_converter.pydantic_base import CamelBase


class AnalysisReport(CamelBase):
dlpfc_analysis: str | None = None
vmpfc_analysis: str | None = None
ofc_analysis: str | None = None
acc_analysis: str | None = None
mpfc_analysis: str | None = None


class Topic(CamelBase):
topic: str
Empty file added backend/app/scan/__init__.py
Empty file.
86 changes: 86 additions & 0 deletions backend/app/scan/graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from collections.abc import Hashable
from typing import cast

from langgraph.graph import END, StateGraph

from app.models.scan import AnalysisReport
from app.scan.pfc_agents import AgentState, PFCAgents


class CustomGraph:
def __init__(self, topic: str) -> None:
self.topic = topic
self.agents = PFCAgents(topic=self.topic)
self.workflow = StateGraph(AgentState)

def build_graph(self) -> None:
for role, agent_function in self.agents.agents.items():
self.workflow.add_node(role, agent_function)

def router(state: AgentState) -> str | None:
next_agent = state["next"]
if next_agent == END:
return None

if next_agent is not None:
return next_agent
return END

dlpfc_edges: dict[str, str] = {
"VMPFC": "VMPFC",
"OFC": "OFC",
"ACC": "ACC",
"MPFC": "MPFC",
END: END,
}

self.workflow.add_conditional_edges("DLPFC", router, cast(dict[Hashable, str], dlpfc_edges))

for role in ("VMPFC", "OFC", "ACC", "MPFC"):
self.workflow.add_conditional_edges(
role,
router,
)

self.workflow.set_entry_point("DLPFC")
self.graph = self.workflow.compile()

async def execute(self) -> AnalysisReport:
self.build_graph()

initial_state = AgentState(
input=self.topic,
history=[],
next="DLPFC",
current_role="DLPFC",
)

final_state = await self.graph.ainvoke(initial_state)
return self.prepare_output(cast(AgentState, final_state))

def prepare_output(self, state: AgentState) -> AnalysisReport:
dlpfc_analysis = None
vmpfc_analysis = None
ofc_analysis = None
acc_analysis = None
mpfc_analysis = None

for role, analysis in state["history"]:
if role == "DLPFC":
dlpfc_analysis = " ".join(analysis) if analysis else None
elif role == "VMPFC":
vmpfc_analysis = " ".join(analysis) if analysis else None
elif role == "OFC":
ofc_analysis = " ".join(analysis) if analysis else None
elif role == "ACC":
acc_analysis = " ".join(analysis) if analysis else None
elif role == "MPFC":
mpfc_analysis = " ".join(analysis) if analysis else None

return AnalysisReport(
dlpfc_analysis=dlpfc_analysis,
vmpfc_analysis=vmpfc_analysis,
ofc_analysis=ofc_analysis,
acc_analysis=acc_analysis,
mpfc_analysis=mpfc_analysis,
)
34 changes: 34 additions & 0 deletions backend/app/scan/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from collections.abc import Iterable

from openai import AsyncOpenAI
from openai.types.chat import ChatCompletionMessageParam

from app.core.config import settings


class OpenAIWrapper:
"""Wrapper for OpenAI API interactions."""

def __init__(self, model_name: str):
self.model_name = model_name
self.client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)

async def create_chat_completion(
self,
messages: Iterable[ChatCompletionMessageParam],
max_tokens: int = settings.MAX_TOKENS,
temperature: float = settings.TEMPERATURE,
) -> str | None:
"""Create a chat completion with error handling."""

response = await self.client.chat.completions.create(
model=self.model_name,
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
)

if not response.choices[0].message.content:
return None

return response.choices[0].message.content.strip()
Loading

0 comments on commit 92c0aab

Please sign in to comment.