From 72cdadabcf24e0f04290c5460e93f76a9776a92c Mon Sep 17 00:00:00 2001 From: Alejandro Herrera <149527975+sfc-gh-alherrera@users.noreply.github.com> Date: Fri, 6 Dec 2024 17:29:41 -0500 Subject: [PATCH] test: add memory tests and resolve initial implementation bug (#84) * adding memory test * fix: only access memory_context if memory arg is True * style(ruff): sort imports --------- Co-authored-by: Tyler White --- agent_gateway/gateway/gateway.py | 6 ++-- tests/test_quickstart.py | 55 +++++++++++++++++++++++++++++++- 2 files changed, 57 insertions(+), 4 deletions(-) diff --git a/agent_gateway/gateway/gateway.py b/agent_gateway/gateway/gateway.py index 485a02f..ebc704f 100644 --- a/agent_gateway/gateway/gateway.py +++ b/agent_gateway/gateway/gateway.py @@ -468,8 +468,8 @@ async def acall( inputs["context"] = formatted_contexts max_memory = 3 # TODO consider exposing this to users - - if len(self.memory_context) <= max_memory: - self.memory_context.append({"Question:": input, "Answer": answer}) + if self.memory: + if len(self.memory_context) <= max_memory: + self.memory_context.append({"Question:": input, "Answer": answer}) return answer diff --git a/tests/test_quickstart.py b/tests/test_quickstart.py index d455ae9..4cab29c 100644 --- a/tests/test_quickstart.py +++ b/tests/test_quickstart.py @@ -16,7 +16,7 @@ import pytest from agent_gateway import Agent -from agent_gateway.tools import CortexSearchTool, CortexAnalystTool, PythonTool +from agent_gateway.tools import CortexAnalystTool, CortexSearchTool, PythonTool @pytest.mark.parametrize( @@ -142,3 +142,56 @@ def get_news(_) -> dict: ) response = agent(question) assert answer_contains in response + + +@pytest.mark.parametrize( + "question, answer_contains", + [ + pytest.param( + "What is the market cap of Apple?", + "$3,019,131,060,224", + id="market_cap", + ), + pytest.param( + "When is Apple releasing a new chip?", + "May 7", + id="product_revenue", + ), + ], +) +def test_gateway_agent_without_memory(session, question, answer_contains): + search_config = { + "service_name": "SEC_SEARCH_SERVICE", + "service_topic": "Snowflake's business,product offerings,and performance", + "data_description": "Snowflake annual reports", + "retrieval_columns": ["CHUNK"], + "snowflake_connection": session, + } + analyst_config = { + "semantic_model": "sp500_semantic_model.yaml", + "stage": "ANALYST", + "service_topic": "S&P500 company and stock metrics", + "data_description": "a table with stock and financial metrics about S&P500 companies ", + "snowflake_connection": session, + } + + def get_news(_) -> dict: + with open("tests/data/response.json") as f: + d = json.load(f) + return d + + python_config = { + "tool_description": "searches for relevant news based on user query", + "output_description": "relevant articles", + "python_func": get_news, + } + annual_reports = CortexSearchTool(**search_config) + sp500 = CortexAnalystTool(**analyst_config) + news_search = PythonTool(**python_config) + agent = Agent( + snowflake_connection=session, + tools=[annual_reports, sp500, news_search], + memory=False, + ) + response = agent(question) + assert answer_contains in response