Skip to content

Commit

Permalink
Improvements to DB tools (#26)
Browse files Browse the repository at this point in the history
* added slack tools
improved database tools and separated into db_tools.py
minor updates to prompt

* bump black version

* removed debug msg

* improved prompt for REACT
  • Loading branch information
ofermend authored Nov 1, 2024
1 parent 70c8dd2 commit 868aa8e
Show file tree
Hide file tree
Showing 10 changed files with 124 additions and 51 deletions.
4 changes: 2 additions & 2 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@ disable =
dangerous-default-value,
too-many-return-statements,
import-outside-toplevel,
eval-used

eval-used,
too-few-public-methods

13 changes: 11 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,21 @@ print(response)
- `stock_news`: provides news about a company
- `stock_analyst_recommendations`: provides stock analyst recommendations for a company.

1. **Database tools**: providing tools to inspect and query a database
4. **Database tools**: providing tools to inspect and query a database
- `list_tables`: list all tables in the database
- `describe_tables`: describe the schema of tables in the database
- `load_data`: returns data based on a SQL query
- `load_sample_data`: returns the first 25 rows of a table
- `load_unique_values`: returns the top unique values for a given column

More tools coming soon...
In addition, we include various other tools from LlamaIndex ToolSpecs:
* Tavily search
* arxiv
* neo4j
* Google tools (including gmail, calendar, and search)
* Slack

Note that some of these tools may require API keys as environment variables

You can create your own tool directly from a Python function using the `create_tool()` method of the `ToolsFactory` class:

Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
mypy==1.11.0
pylint==3.2.6
flake8==7.1.0
black==23.11.0
black==24.10.0
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ llama-index-tools-database==0.2.0
llama-index-tools-google==0.2.0
llama-index-tools-tavily_research==0.2.0
llama-index-tools-neo4j==0.2.0
llama-index-tools-slack==0.2.0
tavily-python==0.5.0
yahoo-finance==1.4.0
openinference-instrumentation-llama-index==3.0.2
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def read_requirements():

setup(
name="vectara_agentic",
version="0.1.17",
version="0.1.18",
author="Ofer Mendelevitch",
author_email="[email protected]",
description="A Python package for creating AI Assistants and AI Agents with Vectara",
Expand Down
2 changes: 1 addition & 1 deletion vectara_agentic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

# Define the package version
__version__ = "0.1.17"
__version__ = "0.1.18"

# Import classes and functions from modules
# from .module1 import Class1, function1
Expand Down
10 changes: 7 additions & 3 deletions vectara_agentic/_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
- If you are provided with database tools use them for analytical queries (such as counting, calculating max, min, average, sum, or other statistics).
For each database, the database tools include: x_list_tables, x_load_data, x_describe_tables, and x_load_sample_data, where 'x' in the database name.
The x_list_tables tool provides a list of available tables in the x database.
Always use the x_describe_tables tool to understand the schema of each table, before you access data from that table.
Always use the x_load_sample_data tool to understand the column names, and the unique values in each column, so you can use them in your queries.
Some times the user may ask for a specific column value, but the actual value in the table may be different, and you will need to use the correct value.
Before issuing a SQL query, always:
- Use the x_describe_tables tool to understand the schema of each table
- Use the x_load_unique_values tool to understand the unique values in each column.
Sometimes the user may ask for a specific column value, but the actual value in the table may be different, and you will need to use the correct value.
- Use the x_load_sample_data tool to understand the column names, and typical values in each column.
- Never call x_load_data to retrieve values from each row in the table.
- Do not mention table names or database names in your response.
"""
Expand Down Expand Up @@ -89,6 +91,8 @@
Please use a valid JSON format for the Action Input. Do NOT do this {{'input': 'hello world', 'num_beams': 5}}.
Do not include the Action Input in a wrapper dictionary 'properties' like this: {{'properties': {{'input': 'hello world', 'num_beams': 5}} }}.
If this format is used, the user will respond in the following format:
```
Expand Down
74 changes: 74 additions & 0 deletions vectara_agentic/db_tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
"""
This module contains the code to extend and improve DatabaseToolSpec
Specifically adding load_sample_data and load_unique_values methods, as well as
making sure the load_data method returns a list of text values from the database, not Document[] objects.
"""
from abc import ABC
from typing import Callable, Any

#
# Additional database tool
#
class DBTool(ABC):
"""
A base class for vectara-agentic database tools extensions
"""
def __init__(self, load_data_fn: Callable):
self.load_data_fn = load_data_fn

class DBLoadData(DBTool):
"""
A tool to Run SQL query on the database and return the result.
"""
def __call__(self, query: str) -> Any:
"""Query and load data from the Database, returning a list of Documents.
Args:
query (str): an SQL query to filter tables and rows.
Returns:
List[text]: a list of text values from the database.
"""
res = self.load_data_fn(query)
return [d.text for d in res]

class DBLoadSampleData(DBTool):
"""
A tool to load a sample of data from the specified database table.
This tool fetches the first num_rows (default 25) rows from the given table
using a provided database query function.
"""
def __call__(self, table_name: str, num_rows: int = 25) -> Any:
"""
Fetches the first num_rows rows from the specified database table.
Args:
table_name (str): The name of the database table.
Returns:
Any: The result of the database query.
"""
return self.load_data_fn(f"SELECT * FROM {table_name} LIMIT {num_rows}")

class DBLoadUniqueValues(DBTool):
"""
A tool to list all unique values for each column in a set of columns of a database table.
"""
def __call__(self, table_name: str, columns: list[str], num_vals: int = 200) -> dict:
"""
Fetches the first num_vals unique values from the specified columns of the database table.
Args:
table_name (str): The name of the database table.
columns (list[str]): The list of columns to fetch unique values from.
num_vals (int): The number of unique values to fetch for each column. Default is 200.
Returns:
dict: A dictionary containing the unique values for each column.
"""
res = {}
for column in columns:
unique_vals = self.load_data_fn(f'SELECT DISTINCT "{column}" FROM {table_name} LIMIT {num_vals}')
res[column] = [d.text for d in unique_vals]
return res
29 changes: 24 additions & 5 deletions vectara_agentic/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@


from .types import ToolType
from .tools_catalog import summarize_text, rephrase_text, critique_text, get_bad_topics, DBLoadSampleData
from .tools_catalog import summarize_text, rephrase_text, critique_text, get_bad_topics
from .db_tools import DBLoadSampleData, DBLoadUniqueValues, DBLoadData

LI_packages = {
"yahoo_finance": ToolType.QUERY,
Expand All @@ -42,9 +43,15 @@
},
"GoogleSearchToolSpec": {"google_search": ToolType.QUERY},
},
"slack": {
"SlackToolSpec": {
"load_data": ToolType.QUERY,
"send_message": ToolType.ACTION,
"fetch_channel": ToolType.QUERY,
}
}
}


class VectaraToolMetadata(ToolMetadata):
"""
A subclass of ToolMetadata adding the tool_type attribute.
Expand Down Expand Up @@ -495,10 +502,22 @@ def database_tools(
tool.metadata.description + f"The database tables include data about {content_description}."
)

# Update load_data_tool to return only text instead of "Document" objects (to save on space)
# Add two new tools: load_sample_data and load_unique_values
load_data_tool_index = next(i for i, t in enumerate(tools) if t.metadata.name.endswith("load_data"))
sample_data_fn = DBLoadSampleData(tools[load_data_tool_index])
load_data_fn_original = tools[load_data_tool_index].fn

load_data_fn = DBLoadData(load_data_fn_original)
load_data_fn.__name__ = f"{tool_name_prefix}_load_data"
load_data_tool = self.create_tool(load_data_fn, ToolType.QUERY)

sample_data_fn = DBLoadSampleData(load_data_fn_original)
sample_data_fn.__name__ = f"{tool_name_prefix}_load_sample_data"
sample_data_tool = self.create_tool(sample_data_fn, ToolType.QUERY)
tools.append(sample_data_tool)

load_unique_values_fn = DBLoadUniqueValues(load_data_fn_original)
load_unique_values_fn.__name__ = f"{tool_name_prefix}_load_unique_values"
load_unique_values_tool = self.create_tool(load_unique_values_fn, ToolType.QUERY)

tools[load_data_tool_index] = load_data_tool
tools.extend([sample_data_tool, load_unique_values_tool])
return tools
38 changes: 2 additions & 36 deletions vectara_agentic/tools_catalog.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
"""
This module contains the tools catalog for the Vectara Agentic.
"""

from typing import Callable, Any, List
from typing import List
from functools import lru_cache
from pydantic import Field
import requests
Expand Down Expand Up @@ -114,7 +113,7 @@ def critique_text(


#
# Guardrails tools
# Guardrails tool: returns list of topics to avoid
#
def get_bad_topics() -> List[str]:
"""
Expand All @@ -128,36 +127,3 @@ def get_bad_topics() -> List[str]:
"adult content",
"illegal activities",
]


#
# Additional database tool
#
class DBLoadSampleData:
"""
A tool to load a sample of data from the specified database table.
This tool fetches the first num_rows (default 25) rows from the given table
using a provided database query function.
"""

def __init__(self, load_data_tool: Callable):
"""
Initializes the DBLoadSampleData object with the provided load_data_tool function.
Args:
load_data_tool (Callable): A function to execute the SQL query.
"""
self.load_data_tool = load_data_tool

def __call__(self, table_name: str, num_rows: int = 25) -> Any:
"""
Fetches the first num_rows rows from the specified database table.
Args:
table_name (str): The name of the database table.
Returns:
Any: The result of the database query.
"""
return self.load_data_tool(f"SELECT * FROM {table_name} LIMIT {num_rows}")

0 comments on commit 868aa8e

Please sign in to comment.