diff --git a/agent_gateway/executors/agent_executor.py b/agent_gateway/executors/agent_executor.py index d211b2d..4a998ef 100644 --- a/agent_gateway/executors/agent_executor.py +++ b/agent_gateway/executors/agent_executor.py @@ -13,7 +13,6 @@ from __future__ import annotations import asyncio -import logging import time from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union @@ -36,8 +35,6 @@ from agent_gateway.chains.chain import Chain from agent_gateway.tools.base import BaseTool -logger = logging.getLogger(__name__) - class ExceptionTool(BaseTool): """Tool that just returns the query.""" @@ -439,7 +436,8 @@ async def _aperform_agent_action( # Use asyncio.gather to run multiple tool.arun() calls concurrently result = await asyncio.gather( - *[_aperform_agent_action(agent_action) for agent_action in actions] + *[_aperform_agent_action(agent_action) for agent_action in actions], + return_exceptions=True, ) return list(result) diff --git a/agent_gateway/gateway/gateway.py b/agent_gateway/gateway/gateway.py index fbc6196..d19fec3 100644 --- a/agent_gateway/gateway/gateway.py +++ b/agent_gateway/gateway/gateway.py @@ -224,6 +224,9 @@ def _parse_fusion_output(self, raw_answer: str) -> str: answer = self._extract_answer(raw_answer) is_replan = FUSION_REPLAN in answer + if is_replan: + answer = "We couldn't find the information you're looking for. You can try rephrasing your request or validate that the provided tools contain sufficient information." + return thought, answer, is_replan def _extract_answer(self, raw_answer): @@ -247,7 +250,11 @@ def _extract_answer(self, raw_answer): return answer else: if replan_index != 1: - print("....replanning...") + gateway_logger.log( + logging.INFO, + "Unable to answer the request. Replanning....", + block=True, + ) return "Replan required. Consider rephrasing your question." else: return None @@ -322,28 +329,49 @@ def __call__(self, input: str): input (str): user's natural language request """ result = [] - thread = threading.Thread(target=self.run_async, args=(input, result)) + error = [] + thread = threading.Thread(target=self.run_async, args=(input, result, error)) thread.start() thread.join() - try: - return result[0]["output"] - except IndexError: - raise AgentGatewayError( - message="Unable to retrieve response. Please check each of your Cortex tools and ensure all connections are valid." - ) + + if error: + raise error[0] + + if not result: + raise AgentGatewayError("Unable to retrieve response. Result is empty.") + + return result[0] def handle_exception(self, loop, context): - exception = context.get("exception") - if exception: - print(f"Caught unhandled exception: {exception}") - loop.default_exception_handler(context) - loop.stop() + loop.default_exception_handler(context) + loop.stop() - def run_async(self, input, result): + def run_async(self, input, result, error): loop = asyncio.new_event_loop() loop.set_exception_handler(self.handle_exception) asyncio.set_event_loop(loop) - result.append(loop.run_until_complete(self.acall(input))) + try: + task = loop.run_until_complete(self.acall(input)) + result.append(task) + except asyncio.CancelledError: + error.append(AgentGatewayError("Task was cancelled")) + except RuntimeError as e: + error.append(AgentGatewayError(f"RuntimeError: {str(e)}")) + except Exception as e: + error.append(AgentGatewayError(f"Gateway Execution Error: {str(e)}")) + + finally: + try: + # Cancel any pending tasks + pending = asyncio.all_tasks(loop) + for task in pending: + task.cancel() + # Wait for all tasks to be cancelled + loop.run_until_complete( + asyncio.gather(*pending, return_exceptions=True) + ) + finally: + loop.close() async def acall( self, @@ -416,4 +444,4 @@ async def acall( formatted_contexts = self._format_contexts(contexts) inputs["context"] = formatted_contexts - return {self.output_key: answer} + return answer diff --git a/agent_gateway/gateway/task_processor.py b/agent_gateway/gateway/task_processor.py index 79dd677..42dd168 100644 --- a/agent_gateway/gateway/task_processor.py +++ b/agent_gateway/gateway/task_processor.py @@ -19,11 +19,12 @@ from typing import Any, Callable, Dict, List, Optional from agent_gateway.tools.logger import gateway_logger +from agent_gateway.tools.snowflake_tools import SnowflakeError SCHEDULING_INTERVAL = 0.01 # seconds -class agent_gatewayError(Exception): +class AgentGatewayError(Exception): def __init__(self, message): self.message = message super().__init__(self.message) @@ -76,10 +77,10 @@ async def __call__(self) -> Any: x = await self.tool(*self.args) gateway_logger.log(logging.DEBUG, "task successfully completed") return x + except SnowflakeError as e: + return f"Unexpected error during Cortex Gateway Tool request: {str(e)}" except Exception as e: - raise agent_gatewayError( - f"Unexpected error during Cortex gateway Tool request: {str(e)}" - ) from e + return f"Unexpected error during Cortex Gateway Tool request: {str(e)}" def get_thought_action_observation( self, include_action=True, include_thought=True, include_action_idx=False @@ -145,8 +146,10 @@ async def _run_task(self, task: Task): try: observation = await task() task.observation = observation + except SnowflakeError as e: + return f"SnowflakeError in task: {str(e)}" except Exception as e: - raise agent_gatewayError(f"{str(e)}") from e + return f"Unexpected Error in task: {str(e)}" self.tasks_done[task.idx].set() async def schedule(self): diff --git a/agent_gateway/tools/snowflake_tools.py b/agent_gateway/tools/snowflake_tools.py index 46a844f..62e27f4 100644 --- a/agent_gateway/tools/snowflake_tools.py +++ b/agent_gateway/tools/snowflake_tools.py @@ -338,12 +338,16 @@ async def asearch(self, query): ) json_response = json.loads(response_text) + gateway_logger.log( + logging.DEBUG, f"Cortex Analyst Raw Response:{json_response}" + ) + try: - query_response = self._process_message( + query_response = self._process_analyst_message( json_response["message"]["content"] ) - if query_response == "Invalid Query": + if "Unable to generate valid SQL Query" in query_response: lm = dspy.Snowflake( session=Session.builder.config( "connection", self.connection @@ -352,7 +356,8 @@ async def asearch(self, query): ) dspy.settings.configure(lm=lm) rephrase_prompt = dspy.ChainOfThought(PromptRephrase) - current_query = rephrase_prompt(user_prompt=current_query)[ + prompt = f"Original Query: {current_query}. Previous Response Context: {query_response}" + current_query = rephrase_prompt(user_prompt=prompt)[ "rephrased_prompt" ] else: @@ -361,7 +366,6 @@ async def asearch(self, query): except Exception: raise SnowflakeError(message=json_response["message"]) - gateway_logger.log(logging.DEBUG, f"Cortex Analyst Response:{query_response}") return query_response def _prepare_analyst_request(self, prompt): @@ -378,20 +382,39 @@ def _prepare_analyst_request(self, prompt): return url, headers, data - def _process_message(self, response): - # ensure valid sql query is present in response - if response[1].get("type") != "sql": - return "Invalid Query" - - # execute sql query - sql_query = response[1]["statement"] - gateway_logger.log(logging.DEBUG, f"Cortex Analyst SQL Query:{sql_query}") - table = self.connection.cursor().execute(sql_query).fetch_arrow_all() + def _process_analyst_message(self, response): + if isinstance(response, list) and len(response) > 0: + first_item = response[0] + + if "type" in first_item: + if first_item["type"] == "text": + _ = None + for item in response: + _ = item + if item["type"] == "suggestions": + raise SnowflakeError( + message=f"Your request is unclear. Consider rephrasing your request to one of the following suggestions:{item['suggestions']}" + ) + elif item["type"] == "sql": + sql_query = item["statement"] + table = ( + self.connection.cursor() + .execute(sql_query) + .fetch_arrow_all() + ) + + if table is not None: + return str(table.to_pydict()) + else: + raise SnowflakeError( + message="No results found. Consider rephrasing your request" + ) + + raise SnowflakeError( + message=f"Unable to generate a valid SQL Query. {_['text']}" + ) - if table is not None: - return str(table.to_pydict()) - else: - return "No Results Found" + return SnowflakeError(message="Invalid Cortex Analyst Response") def _prepare_analyst_description( self, name, service_topic, data_source_description