Skip to content

Commit

Permalink
Add test
Browse files Browse the repository at this point in the history
  • Loading branch information
sanders41 committed Nov 17, 2024
1 parent 95122eb commit 0076da9
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 5 deletions.
10 changes: 5 additions & 5 deletions backend/app/scan/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,15 @@ def prepare_output(self, state: AgentState) -> AnalysisReport:

for role, analysis in state["history"]:
if role == "DLPFC":
dlpfc_analysis = " ".join(analysis) if analysis else None
dlpfc_analysis = analysis if analysis else None
elif role == "VMPFC":
vmpfc_analysis = " ".join(analysis) if analysis else None
vmpfc_analysis = analysis if analysis else None
elif role == "OFC":
ofc_analysis = " ".join(analysis) if analysis else None
ofc_analysis = analysis if analysis else None
elif role == "ACC":
acc_analysis = " ".join(analysis) if analysis else None
acc_analysis = analysis if analysis else None
elif role == "MPFC":
mpfc_analysis = " ".join(analysis) if analysis else None
mpfc_analysis = analysis if analysis else None

return AnalysisReport(
dlpfc_analysis=dlpfc_analysis,
Expand Down
69 changes: 69 additions & 0 deletions backend/tests/scan/test_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pytest

from app.models.scan import AnalysisReport
from app.scan.graph import CustomGraph
from app.scan.pfc_agents import AgentState


@pytest.mark.parametrize(
"state, expected",
(
(
AgentState(
current_role="DLPFC",
input="some input",
history=[
("DLPFC", "dlpfc"),
("VMPFC", "vmpfc"),
("OFC", "ofc"),
("ACC", "acc"),
("MPFC", "mpfc"),
],
next="VMPFC",
),
AnalysisReport(
dlpfc_analysis="dlpfc",
vmpfc_analysis="vmpfc",
ofc_analysis="ofc",
acc_analysis="acc",
mpfc_analysis="mpfc",
),
),
(
AgentState(
current_role="DLPFC",
input="some input",
history=[
("DLPFC", "dlpfc"),
],
next="VMPFC",
),
AnalysisReport(
dlpfc_analysis="dlpfc",
vmpfc_analysis=None,
ofc_analysis=None,
acc_analysis=None,
mpfc_analysis=None,
),
),
(
AgentState(
current_role="DLPFC",
input="some input",
history=[],
next="VMPFC",
),
AnalysisReport(
dlpfc_analysis=None,
vmpfc_analysis=None,
ofc_analysis=None,
acc_analysis=None,
mpfc_analysis=None,
),
),
),
)
def test_prepare_output(state, expected):
graph = CustomGraph("some topic")

assert graph.prepare_output(state) == expected

0 comments on commit 0076da9

Please sign in to comment.