Skip to content

Commit

Permalink
refactor: better output handling for cortex analyst
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-alherrera authored and sfc-gh-twhite committed Dec 3, 2024
1 parent d154f09 commit 9f7a01a
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 42 deletions.
6 changes: 2 additions & 4 deletions agent_gateway/executors/agent_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from __future__ import annotations

import asyncio
import logging
import time
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union
Expand All @@ -36,8 +35,6 @@
from agent_gateway.chains.chain import Chain
from agent_gateway.tools.base import BaseTool

logger = logging.getLogger(__name__)


class ExceptionTool(BaseTool):
"""Tool that just returns the query."""
Expand Down Expand Up @@ -439,7 +436,8 @@ async def _aperform_agent_action(

# Use asyncio.gather to run multiple tool.arun() calls concurrently
result = await asyncio.gather(
*[_aperform_agent_action(agent_action) for agent_action in actions]
*[_aperform_agent_action(agent_action) for agent_action in actions],
return_exceptions=True,
)

return list(result)
Expand Down
60 changes: 44 additions & 16 deletions agent_gateway/gateway/gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ def _parse_fusion_output(self, raw_answer: str) -> str:
answer = self._extract_answer(raw_answer)
is_replan = FUSION_REPLAN in answer

if is_replan:
answer = "We couldn't find the information you're looking for. You can try rephrasing your request or validate that the provided tools contain sufficient information."

return thought, answer, is_replan

def _extract_answer(self, raw_answer):
Expand All @@ -247,7 +250,11 @@ def _extract_answer(self, raw_answer):
return answer
else:
if replan_index != 1:
print("....replanning...")
gateway_logger.log(
logging.INFO,
"Unable to answer the request. Replanning....",
block=True,
)
return "Replan required. Consider rephrasing your question."
else:
return None
Expand Down Expand Up @@ -322,28 +329,49 @@ def __call__(self, input: str):
input (str): user's natural language request
"""
result = []
thread = threading.Thread(target=self.run_async, args=(input, result))
error = []
thread = threading.Thread(target=self.run_async, args=(input, result, error))
thread.start()
thread.join()
try:
return result[0]["output"]
except IndexError:
raise AgentGatewayError(
message="Unable to retrieve response. Please check each of your Cortex tools and ensure all connections are valid."
)

if error:
raise error[0]

if not result:
raise AgentGatewayError("Unable to retrieve response. Result is empty.")

return result[0]

def handle_exception(self, loop, context):
exception = context.get("exception")
if exception:
print(f"Caught unhandled exception: {exception}")
loop.default_exception_handler(context)
loop.stop()
loop.default_exception_handler(context)
loop.stop()

def run_async(self, input, result):
def run_async(self, input, result, error):
loop = asyncio.new_event_loop()
loop.set_exception_handler(self.handle_exception)
asyncio.set_event_loop(loop)
result.append(loop.run_until_complete(self.acall(input)))
try:
task = loop.run_until_complete(self.acall(input))
result.append(task)
except asyncio.CancelledError:
error.append(AgentGatewayError("Task was cancelled"))
except RuntimeError as e:
error.append(AgentGatewayError(f"RuntimeError: {str(e)}"))
except Exception as e:
error.append(AgentGatewayError(f"Gateway Execution Error: {str(e)}"))

finally:
try:
# Cancel any pending tasks
pending = asyncio.all_tasks(loop)
for task in pending:
task.cancel()
# Wait for all tasks to be cancelled
loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
finally:
loop.close()

async def acall(
self,
Expand Down Expand Up @@ -416,4 +444,4 @@ async def acall(
formatted_contexts = self._format_contexts(contexts)
inputs["context"] = formatted_contexts

return {self.output_key: answer}
return answer
13 changes: 8 additions & 5 deletions agent_gateway/gateway/task_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
from typing import Any, Callable, Dict, List, Optional

from agent_gateway.tools.logger import gateway_logger
from agent_gateway.tools.snowflake_tools import SnowflakeError

SCHEDULING_INTERVAL = 0.01 # seconds


class agent_gatewayError(Exception):
class AgentGatewayError(Exception):
def __init__(self, message):
self.message = message
super().__init__(self.message)
Expand Down Expand Up @@ -76,10 +77,10 @@ async def __call__(self) -> Any:
x = await self.tool(*self.args)
gateway_logger.log(logging.DEBUG, "task successfully completed")
return x
except SnowflakeError as e:
return f"Unexpected error during Cortex Gateway Tool request: {str(e)}"
except Exception as e:
raise agent_gatewayError(
f"Unexpected error during Cortex gateway Tool request: {str(e)}"
) from e
return f"Unexpected error during Cortex Gateway Tool request: {str(e)}"

def get_thought_action_observation(
self, include_action=True, include_thought=True, include_action_idx=False
Expand Down Expand Up @@ -145,8 +146,10 @@ async def _run_task(self, task: Task):
try:
observation = await task()
task.observation = observation
except SnowflakeError as e:
return f"SnowflakeError in task: {str(e)}"
except Exception as e:
raise agent_gatewayError(f"{str(e)}") from e
return f"Unexpected Error in task: {str(e)}"
self.tasks_done[task.idx].set()

async def schedule(self):
Expand Down
57 changes: 40 additions & 17 deletions agent_gateway/tools/snowflake_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,12 +338,16 @@ async def asearch(self, query):
)
json_response = json.loads(response_text)

gateway_logger.log(
logging.DEBUG, f"Cortex Analyst Raw Response:{json_response}"
)

try:
query_response = self._process_message(
query_response = self._process_analyst_message(
json_response["message"]["content"]
)

if query_response == "Invalid Query":
if "Unable to generate valid SQL Query" in query_response:
lm = dspy.Snowflake(
session=Session.builder.config(
"connection", self.connection
Expand All @@ -352,7 +356,8 @@ async def asearch(self, query):
)
dspy.settings.configure(lm=lm)
rephrase_prompt = dspy.ChainOfThought(PromptRephrase)
current_query = rephrase_prompt(user_prompt=current_query)[
prompt = f"Original Query: {current_query}. Previous Response Context: {query_response}"
current_query = rephrase_prompt(user_prompt=prompt)[
"rephrased_prompt"
]
else:
Expand All @@ -361,7 +366,6 @@ async def asearch(self, query):
except Exception:
raise SnowflakeError(message=json_response["message"])

gateway_logger.log(logging.DEBUG, f"Cortex Analyst Response:{query_response}")
return query_response

def _prepare_analyst_request(self, prompt):
Expand All @@ -378,20 +382,39 @@ def _prepare_analyst_request(self, prompt):

return url, headers, data

def _process_message(self, response):
# ensure valid sql query is present in response
if response[1].get("type") != "sql":
return "Invalid Query"

# execute sql query
sql_query = response[1]["statement"]
gateway_logger.log(logging.DEBUG, f"Cortex Analyst SQL Query:{sql_query}")
table = self.connection.cursor().execute(sql_query).fetch_arrow_all()
def _process_analyst_message(self, response):
if isinstance(response, list) and len(response) > 0:
first_item = response[0]

if "type" in first_item:
if first_item["type"] == "text":
_ = None
for item in response:
_ = item
if item["type"] == "suggestions":
raise SnowflakeError(
message=f"Your request is unclear. Consider rephrasing your request to one of the following suggestions:{item['suggestions']}"
)
elif item["type"] == "sql":
sql_query = item["statement"]
table = (
self.connection.cursor()
.execute(sql_query)
.fetch_arrow_all()
)

if table is not None:
return str(table.to_pydict())
else:
raise SnowflakeError(
message="No results found. Consider rephrasing your request"
)

raise SnowflakeError(
message=f"Unable to generate a valid SQL Query. {_['text']}"
)

if table is not None:
return str(table.to_pydict())
else:
return "No Results Found"
return SnowflakeError(message="Invalid Cortex Analyst Response")

def _prepare_analyst_description(
self, name, service_topic, data_source_description
Expand Down

0 comments on commit 9f7a01a

Please sign in to comment.