Skip to content

Commit

Permalink
test: add memory tests and resolve initial implementation bug (#84)
Browse files Browse the repository at this point in the history
* adding memory test

* fix: only access memory_context if memory arg is True

* style(ruff): sort imports

---------

Co-authored-by: Tyler White <[email protected]>
  • Loading branch information
sfc-gh-alherrera and sfc-gh-twhite authored Dec 6, 2024
1 parent eb33f2a commit 72cdada
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 4 deletions.
6 changes: 3 additions & 3 deletions agent_gateway/gateway/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,8 +468,8 @@ async def acall(
inputs["context"] = formatted_contexts

max_memory = 3 # TODO consider exposing this to users

if len(self.memory_context) <= max_memory:
self.memory_context.append({"Question:": input, "Answer": answer})
if self.memory:
if len(self.memory_context) <= max_memory:
self.memory_context.append({"Question:": input, "Answer": answer})

return answer
55 changes: 54 additions & 1 deletion tests/test_quickstart.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import pytest

from agent_gateway import Agent
from agent_gateway.tools import CortexSearchTool, CortexAnalystTool, PythonTool
from agent_gateway.tools import CortexAnalystTool, CortexSearchTool, PythonTool


@pytest.mark.parametrize(
Expand Down Expand Up @@ -142,3 +142,56 @@ def get_news(_) -> dict:
)
response = agent(question)
assert answer_contains in response


@pytest.mark.parametrize(
"question, answer_contains",
[
pytest.param(
"What is the market cap of Apple?",
"$3,019,131,060,224",
id="market_cap",
),
pytest.param(
"When is Apple releasing a new chip?",
"May 7",
id="product_revenue",
),
],
)
def test_gateway_agent_without_memory(session, question, answer_contains):
search_config = {
"service_name": "SEC_SEARCH_SERVICE",
"service_topic": "Snowflake's business,product offerings,and performance",
"data_description": "Snowflake annual reports",
"retrieval_columns": ["CHUNK"],
"snowflake_connection": session,
}
analyst_config = {
"semantic_model": "sp500_semantic_model.yaml",
"stage": "ANALYST",
"service_topic": "S&P500 company and stock metrics",
"data_description": "a table with stock and financial metrics about S&P500 companies ",
"snowflake_connection": session,
}

def get_news(_) -> dict:
with open("tests/data/response.json") as f:
d = json.load(f)
return d

python_config = {
"tool_description": "searches for relevant news based on user query",
"output_description": "relevant articles",
"python_func": get_news,
}
annual_reports = CortexSearchTool(**search_config)
sp500 = CortexAnalystTool(**analyst_config)
news_search = PythonTool(**python_config)
agent = Agent(
snowflake_connection=session,
tools=[annual_reports, sp500, news_search],
memory=False,
)
response = agent(question)
assert answer_contains in response

0 comments on commit 72cdada

Please sign in to comment.