Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Querying vectordb with AgentAI #20

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions agentai/api.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""
API functions for the agentai package
"""
import inspect
import json
from typing import Any, Callable, Tuple, Union

from loguru import logger
from openai import ChatCompletion
from openai.error import RateLimitError
from pydantic import BaseModel
from tenacity import (
retry,
retry_if_exception_type,
Expand All @@ -21,6 +23,18 @@
logger.disable(__name__)


def validate_function_args(func, args):
parameters = inspect.signature(func).parameters
validated_args = {}
for param_name, param_value in args.items():
if param_name in parameters:
param_type = parameters[param_name].annotation
if issubclass(param_type, BaseModel):
param_value = param_type(**param_value)
validated_args[param_name] = param_value
return validated_args


@retry(
retry=retry_if_exception_type((ValueError, RateLimitError)),
stop=stop_after_attempt(5),
Expand Down Expand Up @@ -79,16 +93,16 @@ def chat_complete_execute_fn(
conversation=conversation,
tool_registry=tool_registry,
model=model,
function_call=True,
function_call="auto",
)
message = completion.choices[0].message
function_call = message["function_call"]
function_arguments = json.loads(function_call["arguments"])
logger.info(f"function_arguments: {function_arguments}")
callable_function = tool_registry.get(function_call["name"])
logger.info(f"callable_function: {callable_function}")
callable_function.validate(**function_arguments)
validated_args = validate_function_args(callable_function, function_arguments)
logger.info("Validated function arguments")
results = callable_function(**function_arguments)
results = callable_function(**validated_args)
logger.info(f"results: {results}")
return results, function_arguments, callable_function
return results, function_arguments, callable_function
131 changes: 126 additions & 5 deletions agentai/parsers.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,25 @@
from typing import List
from typing import Any, Dict, List

import pandas as pd
from azure.ai.formrecognizer import DocumentAnalysisClient
from azure.core.credentials import AzureKeyCredential
from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.tools.azure_cognitive_services.utils import detect_file_src_type
from unstructured.partition.pdf import partition_pdf
from typing import List, Optional


class Parser:
def __init__(self):
self.filename = ""
self.file = ""
self.docs: List[Document] = list()


class UnstructuredPdfParser(Parser):
def __init__(self):
super().__init__()

def process_element(self, element):
metadata = {}
if hasattr(element, "metadata"):
Expand All @@ -35,9 +45,120 @@ def process_element(self, element):
# if element is not a table, then it is a text
self.docs.append(Document(page_content=str(element), metadata=metadata))

def parse_pdf(self, filename: str):
self.filename = filename
elements = partition_pdf(self.filename)
def parse_pdf(self, file: str):
self.file = file
elements = partition_pdf(self.file)
for element in elements:
self.process_element(element)
return self.docs


class AzureDocumentIntelligencePdfParser(Parser):
def __init__(self, endpoint: str = None, key: str = None):
super().__init__()
self.endpoint = endpoint
self.key = key

def parse_tables(self, tables: List[Any]) -> List[Document]:
all_row_data = []

# Goal: Rewrite the above table code using using pandas
for table in tables:
metadata = {}
# metadata["filename"] = filename
# metadata["filetype"] = filetype

json_data = table.to_dict()
# Extract column headers
column_headers = []
for cell in json_data["cells"]:
if cell["kind"] == "columnHeader":
column_headers.append(cell["content"])

# Initialize an empty DataFrame with column headers
df = pd.DataFrame(columns=column_headers)

# Fill in the DataFrame with cell content
for row_index in range(json_data["row_count"]):
row_data = []
for col_index in range(json_data["column_count"]):
content = next(
cell["content"]
for cell in json_data["cells"]
if cell["row_index"] == row_index and cell["column_index"] == col_index
)
row_data.append(content)
df.loc[row_index] = row_data

# Drop the first row since it contains column headers repeating
df = df.drop(df.index[0])

for _, row in df.iterrows():
# go through each row of the table and create a document
# with the row data dictionary as page content
metadata["category"] = "Table"
metadata["page_number"] = json_data["bounding_regions"][0]["page_number"]
page_content = row.to_dict()

all_row_data.append(Document(page_content=str(page_content), metadata=metadata))

return all_row_data

def parse_kv_pairs(self, kv_pairs: List[Any]) -> List[Document]:
result = []
for kv_pair in kv_pairs:
key = kv_pair.key.content if kv_pair.key else ""
value = kv_pair.value.content if kv_pair.value else ""
# result.append((key, value))
page_content = {key: value}
metadata = {}
metadata["category"] = "Key Value Pair"
result.append(Document(page_content=str(page_content), metadata=metadata))
return result

def format_document_analysis_result(self, doc_dictionary: Dict) -> List[Document]:
formatted_result = []
if "content" in doc_dictionary:
# split the content into chunks of 300 characters with 30 character overlap
splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(chunk_size=300, chunk_overlap=30)
# metadatas is a list of dictionaries with key "category" and value "Text".
# The length of metadatas is the same as the length of document_analysis_result["content"]
metadatas = [{"category": "Text"}] * len(doc_dictionary["content"])

splits = splitter.create_documents(texts=[doc_dictionary["content"]], metadatas=metadatas)
formatted_result = splits
if "tables" in doc_dictionary:
formatted_result.extend(doc_dictionary["tables"])
if "key_value_pairs" in doc_dictionary:
formatted_result.extend(doc_dictionary["key_value_pairs"])

print("formatted_result: ", formatted_result)
return formatted_result

def parse_pdf(self, file: str, pages: Optional[str] = None) -> List[Document]:
document_analysis_client = DocumentAnalysisClient(endpoint=self.endpoint, credential=AzureKeyCredential(self.key))
document_src_type = detect_file_src_type(file)
if document_src_type == "local":
with open(file, "rb") as document:
poller = document_analysis_client.begin_analyze_document("prebuilt-layout", document, pages=pages)
elif document_src_type == "remote":
poller = document_analysis_client.begin_analyze_document_from_url("prebuilt-layout", file, pages=pages)
else:
raise ValueError(f"Invalid document path: {file}")

result = poller.result()
print("result from azure: ", result)

res_dict = {}

if result.content is not None:
res_dict["content"] = result.content

if result.tables is not None:
print("result.tables: ", result.tables)
res_dict["tables"] = self.parse_tables(result.tables)

if result.key_value_pairs is not None:
res_dict["key_value_pairs"] = self.parse_kv_pairs(result.key_value_pairs)

return self.format_document_analysis_result(res_dict)
43 changes: 31 additions & 12 deletions agentai/vectordb.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,49 @@
import uuid
from typing import List, Optional
from typing import List, Optional, Union

import chromadb
from .parsers import Parser
from pydantic import BaseModel, Field
from typing_extensions import Literal

from agentai.annotations import ToolRegistry, tool
from agentai.parsers import Parser, UnstructuredPdfParser

Embedding = List[float]

Include = List[
Union[
Literal["documents"],
Literal["embeddings"],
Literal["metadatas"],
Literal["distances"],
]
]


class Query(BaseModel):
query_embedding: Optional[List[int]] = Field(..., description="Embedding for the query to search")
query_text: Optional[str] = Field(..., description="Simplified query from the user to search")
"""Query Model to search the vector database. If query_embeddings is provided, query_texts will be ignored."""

query_embeddings: Optional[List[Embedding]] = Field(None, description="Embedding for the query to search")
query_texts: Optional[List[str]] = Field(None, description="Simplified query from the user to search")
k: int = Field(..., description="The number of results requested")
include: Include = Field(
["documents", "embeddings", "metadatas", "distances"], description="Data to include in results"
)


class VectorDB(BaseModel):
class VectorDB:
def __init__(self):
self.client = None


class ChromaDB(VectorDB):
def __init__(self):
self.client = chromadb.Client()
self.collection = self.client.create_collection(name="my_collection")
self.chroma_client = chromadb.Client()
self.collection = self.chroma_client.create_collection(name="my_collection")

def doc_loader(self, filename: str):
parser = Parser()
docs = parser.parse_pdf(filename)
def doc_loader(self, filename: str, parser: Parser = None):
if parser is None:
parser = UnstructuredPdfParser()
docs = parser.parse_pdf(file=filename)
# Using chroma db as an example
for doc in docs:
self.collection.add(
Expand All @@ -37,7 +54,9 @@ def doc_loader(self, filename: str):

def get_docs(self, query: Query):
results = self.collection.query(
query_embeddings=[query.query_embedding], query_texts=[query.query_text], n_results=query.k
query_texts=query.query_texts,
n_results=query.k,
include=query.include,
)
return results

Expand Down
262 changes: 262 additions & 0 deletions docs/06_Querying_VectorDB_AzureDocIntel.ipynb

Large diffs are not rendered by default.

Loading