-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
1,124 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,8 @@ | ||
from app.api.routes import health, login, users | ||
from app.api.routes import health, login, scan, users | ||
from app.core.utils import APIRouter | ||
|
||
api_router = APIRouter() | ||
api_router.include_router(health.router) | ||
api_router.include_router(login.router) | ||
api_router.include_router(scan.router) | ||
api_router.include_router(users.router) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from fastapi import HTTPException | ||
from loguru import logger | ||
from starlette.status import HTTP_500_INTERNAL_SERVER_ERROR | ||
|
||
from app.api.deps import CurrentUser | ||
from app.core.config import settings | ||
from app.core.utils import APIRouter | ||
from app.models.scan import AnalysisReport, Topic | ||
from app.scan.graph import CustomGraph | ||
|
||
router = APIRouter(tags=["SCAN"], prefix=f"{settings.API_V1_PREFIX}/scan") | ||
|
||
|
||
@router.post("/") | ||
async def ask_question(*, topic: Topic, _: CurrentUser) -> AnalysisReport: | ||
"""Ask SCAN for help with a questions.""" | ||
|
||
graph = CustomGraph(topic.topic) | ||
try: | ||
logger.debug("Preparing analysis") | ||
analysis = await graph.execute() | ||
except Exception as e: | ||
logger.error(f"An error occurred while answering question: {e}") | ||
raise HTTPException( | ||
status_code=HTTP_500_INTERNAL_SERVER_ERROR, | ||
detail="An error occurred when getting an answer", | ||
) from e | ||
|
||
return analysis |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from camel_converter.pydantic_base import CamelBase | ||
|
||
|
||
class AnalysisReport(CamelBase): | ||
dlpfc_analysis: str | None = None | ||
vmpfc_analysis: str | None = None | ||
ofc_analysis: str | None = None | ||
acc_analysis: str | None = None | ||
mpfc_analysis: str | None = None | ||
|
||
|
||
class Topic(CamelBase): | ||
topic: str |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from collections.abc import Hashable | ||
from typing import cast | ||
|
||
from langgraph.graph import END, StateGraph | ||
|
||
from app.models.scan import AnalysisReport | ||
from app.scan.pfc_agents import AgentState, PFCAgents | ||
|
||
|
||
class CustomGraph: | ||
def __init__(self, topic: str) -> None: | ||
self.topic = topic | ||
self.agents = PFCAgents(topic=self.topic) | ||
self.workflow = StateGraph(AgentState) | ||
|
||
def build_graph(self) -> None: | ||
for role, agent_function in self.agents.agents.items(): | ||
self.workflow.add_node(role, agent_function) | ||
|
||
def router(state: AgentState) -> str | None: | ||
next_agent = state["next"] | ||
if next_agent == END: | ||
return None | ||
|
||
if next_agent is not None: | ||
return next_agent | ||
return END | ||
|
||
dlpfc_edges: dict[str, str] = { | ||
"VMPFC": "VMPFC", | ||
"OFC": "OFC", | ||
"ACC": "ACC", | ||
"MPFC": "MPFC", | ||
END: END, | ||
} | ||
|
||
self.workflow.add_conditional_edges("DLPFC", router, cast(dict[Hashable, str], dlpfc_edges)) | ||
|
||
for role in ("VMPFC", "OFC", "ACC", "MPFC"): | ||
self.workflow.add_conditional_edges( | ||
role, | ||
router, | ||
) | ||
|
||
self.workflow.set_entry_point("DLPFC") | ||
self.graph = self.workflow.compile() | ||
|
||
async def execute(self) -> AnalysisReport: | ||
self.build_graph() | ||
|
||
initial_state = AgentState( | ||
input=self.topic, | ||
history=[], | ||
next="DLPFC", | ||
current_role="DLPFC", | ||
) | ||
|
||
final_state = await self.graph.ainvoke(initial_state) | ||
return self.prepare_output(cast(AgentState, final_state)) | ||
|
||
def prepare_output(self, state: AgentState) -> AnalysisReport: | ||
dlpfc_analysis = None | ||
vmpfc_analysis = None | ||
ofc_analysis = None | ||
acc_analysis = None | ||
mpfc_analysis = None | ||
|
||
for role, analysis in state["history"]: | ||
if role == "DLPFC": | ||
dlpfc_analysis = " ".join(analysis) if analysis else None | ||
elif role == "VMPFC": | ||
vmpfc_analysis = " ".join(analysis) if analysis else None | ||
elif role == "OFC": | ||
ofc_analysis = " ".join(analysis) if analysis else None | ||
elif role == "ACC": | ||
acc_analysis = " ".join(analysis) if analysis else None | ||
elif role == "MPFC": | ||
mpfc_analysis = " ".join(analysis) if analysis else None | ||
|
||
return AnalysisReport( | ||
dlpfc_analysis=dlpfc_analysis, | ||
vmpfc_analysis=vmpfc_analysis, | ||
ofc_analysis=ofc_analysis, | ||
acc_analysis=acc_analysis, | ||
mpfc_analysis=mpfc_analysis, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from collections.abc import Iterable | ||
|
||
from openai import AsyncOpenAI | ||
from openai.types.chat import ChatCompletionMessageParam | ||
|
||
from app.core.config import settings | ||
|
||
|
||
class OpenAIWrapper: | ||
"""Wrapper for OpenAI API interactions.""" | ||
|
||
def __init__(self, model_name: str): | ||
self.model_name = model_name | ||
self.client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY) | ||
|
||
async def create_chat_completion( | ||
self, | ||
messages: Iterable[ChatCompletionMessageParam], | ||
max_tokens: int = settings.MAX_TOKENS, | ||
temperature: float = settings.TEMPERATURE, | ||
) -> str | None: | ||
"""Create a chat completion with error handling.""" | ||
|
||
response = await self.client.chat.completions.create( | ||
model=self.model_name, | ||
messages=messages, | ||
max_tokens=max_tokens, | ||
temperature=temperature, | ||
) | ||
|
||
if not response.choices[0].message.content: | ||
return None | ||
|
||
return response.choices[0].message.content.strip() |
Oops, something went wrong.