From 0558d8f6b12f0a2f5429babb24b304f08735c942 Mon Sep 17 00:00:00 2001 From: Alejandro Herrera <149527975+sfc-gh-alherrera@users.noreply.github.com> Date: Thu, 5 Dec 2024 11:02:16 -0500 Subject: [PATCH] refactor: eliminate DSPy dependency (#76) * refactor: eliminiating dspy dependency * chore: updating uv.lock file * refactor: filter cleanup * refactor: filter cleanup --- agent_gateway/gateway/planner.py | 2 +- agent_gateway/tools/snowflake_tools.py | 162 ++----------------------- pyproject.toml | 3 +- uv.lock | 44 ------- 4 files changed, 15 insertions(+), 196 deletions(-) diff --git a/agent_gateway/gateway/planner.py b/agent_gateway/gateway/planner.py index aef5ab2..8183cc7 100644 --- a/agent_gateway/gateway/planner.py +++ b/agent_gateway/gateway/planner.py @@ -208,7 +208,7 @@ class Planner: def __init__( self, session: object, - llm: str, # point to dspy + llm: str, example_prompt: str, example_prompt_replan: str, tools: Sequence[Union[Tool, StructuredTool]], diff --git a/agent_gateway/tools/snowflake_tools.py b/agent_gateway/tools/snowflake_tools.py index 62e27f4..fe9d988 100644 --- a/agent_gateway/tools/snowflake_tools.py +++ b/agent_gateway/tools/snowflake_tools.py @@ -11,15 +11,13 @@ # limitations under the License. import asyncio -import contextlib import inspect import json import logging import re from typing import Any, Type, Union -import dspy -from pydantic import BaseModel, Field, ValidationError +from pydantic import BaseModel from snowflake.connector.connection import SnowflakeConnection from snowflake.snowpark import Session from snowflake.snowpark.functions import col @@ -46,8 +44,6 @@ class CortexSearchTool(Tool): retrieval_columns: list = [] service_name: str = "" connection: Union[Session, SnowflakeConnection] = None - auto_filter: bool = False - filter_generator: object = None def __init__( self, @@ -56,7 +52,6 @@ def __init__( data_description, retrieval_columns, snowflake_connection, - auto_filter=False, k=5, ): """Parameters @@ -67,7 +62,6 @@ def __init__( data_description (str): description of the source data that has been indexed. retrieval_columns (list): list of columns to include in Cortex Search results. snowflake_connection (object): snowpark connection object - auto_filter (bool): automatically generate filter based on user's query or not. k: number of records to include in results """ tool_name = f"{service_name.lower()}_cortexsearch" @@ -79,13 +73,7 @@ def __init__( super().__init__( name=tool_name, description=tool_description, func=self.asearch ) - self.auto_filter = auto_filter self.connection = _get_connection(snowflake_connection) - if self.auto_filter: - self.filter_generator = SmartSearch() - lm = dspy.Snowflake(session=self.session, model="mixtral-8x7b") - dspy.settings.configure(lm=lm) - self.k = k self.retrieval_columns = retrieval_columns self.service_name = service_name @@ -113,27 +101,11 @@ def _prepare_request(self, query): self.connection.schema, self.service_name, ) - if self.auto_filter: - search_attributes, sample_vals = self._get_sample_values( - snowflake_connection=Session.builder.config( - "connection", self.connection - ), - cortex_search_service=self.service_name, - ) - raw_filter = self.filter_generator( - query=query, - attributes=str(search_attributes), - sample_values=str(sample_vals), - )["answer"] - filter = json.loads(raw_filter) - else: - filter = None data = { "query": query, "columns": self.retrieval_columns, "limit": self.k, - "filter": filter, } return headers, url, data @@ -209,82 +181,6 @@ def get_min_length(model: Type[BaseModel]): return min_length -class JSONFilter(BaseModel): - answer: str = Field(description="The filter_query in valid JSON format") - - @classmethod - def model_validate_json( - cls, - json_data: str, - *, - strict: bool | None = None, - context: dict[str, Any] | None = None, - ): - __tracebackhide__ = True - try: - return cls.__pydantic_validator__.validate_json( - json_data, strict=strict, context=context - ) - except ValidationError: - min_length = get_min_length(cls) - for substring_length in range(len(json_data), min_length - 1, -1): - for start in range(len(json_data) - substring_length + 1): - substring = json_data[start : start + substring_length] - with contextlib.suppress(ValidationError): - return cls.__pydantic_validator__.validate_json( - substring, strict=strict, context=context - ) - raise ValueError("Could not find valid json") - - -class GenerateFilter(dspy.Signature): - """Given a query, attributes in the data, and example values of each attribute, generate a filter in valid JSON format. - Ensure the filter only uses valid operators: @eq, @contains,@and,@or,@not - Ensure only the valid JSON is output with no other reasoning. - - --- - Query: What was the sentiment of CEOs between 2021 and 2024? - Attributes: industry,hq,date - Sample Values: {"industry":["biotechnology","healthcare","agriculture"],"HQ":["NY, US","CA,US","FL,US"],"date":["01/01,1999","01/01/2024"]} - Answer: {"@or":[{"@eq":{"year":"2021"}},{"@eq":{"year":"2022"}},{"@eq":{"year":"2023"}},{"@eq":{"year":"2024"}}]} - - Query: What is the sentiment of Biotech CEOs of companies based in New York? - Attributes: industry,hq,date - Sample Values: {"industry":["biotechnology","healthcare","agriculture"],"HQ":["NY, US","CA,US","FL,US"],"date":["01/01,1999","01/01/2024"]} - Answer: {"@and":[{ "@eq": { "industry": "biotechnology" } },{"@not":{"@eq":{"HQ":"CA,US"}}}]} - - Query: What is the sentiment of Biotech CEOs outside of California? - Attributes: industry,hq,date - Sample Values: {"industry":["biotechnology","healthcare","agriculture"],"HQ":["NY, US","CA,US","FL,US"],"date":["01/01,1999","01/01/2024"]} - Answer: {"@and":[{ "@eq": { "industry": "biotechnology" } },{"@not":{"@eq":{"HQ":"CA,US"}}}]} - - Query: What is sentiment towards ag and biotech companies based outside of the US? - Attributes: industry,hq,date - Sample Values: {"industry":["biotechnology","healthcare","agriculture"],"COUNTRY":["United States","Ireland","Russia","Georgia","Spain"],"month":["01","02","03","06","11","12"],"year":["2022","2023","2024"]} - Answer: {"@and": [{ "@or": [{"@eq":{ "industry": "biotechnology" } },{"@eq":{"industry":"agriculture"}}]},{ "@not": {"@eq": { "COUNTRY": "United States" } }}]} - """ - - query = dspy.InputField(desc="user query") - attributes = dspy.InputField(desc="attributes to filter on") - sample_values = dspy.InputField(desc="examples of values per attribute") - answer: JSONFilter = dspy.OutputField( - desc="filter query in valid JSON format. ONLY output the filter query in JSON, no reasoning" - ) - - -class SmartSearch(dspy.Module): - def __init__(self): - super().__init__() - self.filter_gen = dspy.ChainOfThought(GenerateFilter) - - def forward(self, query, attributes, sample_values): - filter_query = self.filter_gen( - query=query, attributes=attributes, sample_values=sample_values - ) - - return filter_query - - class CortexAnalystTool(Tool): """""Cortex Analyst tool for use with Snowflake Agent Gateway""" "" @@ -329,42 +225,21 @@ def __call__(self, prompt) -> Any: async def asearch(self, query): gateway_logger.log(logging.DEBUG, f"Cortex Analyst Prompt:{query}") - for _ in range(3): - current_query = query - url, headers, data = self._prepare_analyst_request(prompt=query) + url, headers, data = self._prepare_analyst_request(prompt=query) - response_text = await post_cortex_request( - url=url, headers=headers, data=data - ) - json_response = json.loads(response_text) + response_text = await post_cortex_request(url=url, headers=headers, data=data) + json_response = json.loads(response_text) - gateway_logger.log( - logging.DEBUG, f"Cortex Analyst Raw Response:{json_response}" - ) + gateway_logger.log( + logging.DEBUG, f"Cortex Analyst Raw Response:{json_response}" + ) - try: - query_response = self._process_analyst_message( - json_response["message"]["content"] - ) - - if "Unable to generate valid SQL Query" in query_response: - lm = dspy.Snowflake( - session=Session.builder.config( - "connection", self.connection - ).getOrCreate(), - model="llama3.2-1b", - ) - dspy.settings.configure(lm=lm) - rephrase_prompt = dspy.ChainOfThought(PromptRephrase) - prompt = f"Original Query: {current_query}. Previous Response Context: {query_response}" - current_query = rephrase_prompt(user_prompt=prompt)[ - "rephrased_prompt" - ] - else: - break - - except Exception: - raise SnowflakeError(message=json_response["message"]) + try: + query_response = self._process_analyst_message( + json_response["message"]["content"] + ) + except Exception: + raise SnowflakeError(message=json_response["message"]) return query_response @@ -426,17 +301,6 @@ def _prepare_analyst_description( return base_analyst_description -class PromptRephrase(dspy.Signature): - """Takes in a prompt and rephrases it using context into to a single concise, and specific question. - If there are references to entities that are not clear or consistent with the question being asked, make the references more appropriate. - """ - - user_prompt = dspy.InputField(desc="original user prompt") - rephrased_prompt = dspy.OutputField( - desc="rephrased prompt with more clear and specific intent" - ) - - class PythonTool(Tool): python_callable: object = None diff --git a/pyproject.toml b/pyproject.toml index 432c964..40eb85e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,13 +4,12 @@ version = "0.1.0" requires-python = ">=3.9" description = "Multi-agent framework for Snowflake" authors = [ - { name = "Alejandro Ferrera", email = "alejandro.herrera@snowflake.com" }, + { name = "Alejandro Herrera", email = "alejandro.herrera@snowflake.com" }, ] readme = "README.md" dependencies = [ "snowflake-snowpark-python>=1.22.1", - "dspy-ai>=2.5.3", "langchain>=0.3.2", "asyncio>=3.4.3", "aiohttp>=3.10.9", diff --git a/uv.lock b/uv.lock index 952a141..de5035c 100644 --- a/uv.lock +++ b/uv.lock @@ -564,48 +564,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/12/b3/231ffd4ab1fc9d679809f356cebee130ac7daa00d6d6f3206dd4fd137e9e/distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2", size = 20277 }, ] -[[package]] -name = "dspy" -version = "2.5.41" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "anyio" }, - { name = "asyncer" }, - { name = "backoff" }, - { name = "datasets" }, - { name = "diskcache" }, - { name = "httpx" }, - { name = "joblib" }, - { name = "json-repair" }, - { name = "litellm" }, - { name = "magicattr" }, - { name = "openai" }, - { name = "optuna" }, - { name = "pandas" }, - { name = "pydantic" }, - { name = "regex" }, - { name = "requests" }, - { name = "tenacity" }, - { name = "tqdm" }, - { name = "ujson" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/dc/44/4ff003b2b0efb01cc0e8de59f6c079ae8cc7e8511cb30bb8df1dae2a195e/dspy-2.5.41.tar.gz", hash = "sha256:d94b74d39f5c57346ff4090b99688660c9e5fd385cf29d139e4e0ed104c8477b", size = 258727 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/72/82/a86a37f654da2789dbb910e4061749a4865f3fa2087f13d32263b09037b5/dspy-2.5.41-py3-none-any.whl", hash = "sha256:28732210ddc2a47dd3d887d444fd8edf38bdc535c420357fd5d9fbe88d1df286", size = 340372 }, -] - -[[package]] -name = "dspy-ai" -version = "2.5.41" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "dspy" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/d6/25/94659c581e69fe11237294aeadf55785abec09361aa2d3a63bd0abc39d82/dspy-ai-2.5.41.tar.gz", hash = "sha256:a2804434eb3354d6e41f59a448938886914fed0c502eb42503569776a3f2f002", size = 258174 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/e5/50/f715b62acfe29718c3575dde16eb35e3ab6a5642652b13fb229bd758c040/dspy_ai-2.5.41-py3-none-any.whl", hash = "sha256:4f366a6b610ee93fd04b92b0dff2f78b24e23ee74c71d20bcf9fedcbf0d7aa97", size = 339737 }, -] - [[package]] name = "exceptiongroup" version = "1.2.2" @@ -1460,7 +1418,6 @@ source = { editable = "." } dependencies = [ { name = "aiohttp" }, { name = "asyncio" }, - { name = "dspy-ai" }, { name = "langchain" }, { name = "pydantic" }, { name = "snowflake-snowpark-python" }, @@ -1480,7 +1437,6 @@ streamlit = [ requires-dist = [ { name = "aiohttp", specifier = ">=3.10.9" }, { name = "asyncio", specifier = ">=3.4.3" }, - { name = "dspy-ai", specifier = ">=2.5.3" }, { name = "langchain", specifier = ">=0.3.2" }, { name = "pre-commit", marker = "extra == 'dev'", specifier = ">=4.0.1" }, { name = "pydantic", specifier = ">=2.9.2" },