Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft: feat: Add types for function metadata and improve names #4

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ from agentai.api import chat_complete, chat_complete_execute_fn
from agentai.openai_function import tool, ToolRegistry
from agentai.conversation import Conversation
from enum import Enum
weather_registry = ToolRegistry()
tool_registry = ToolRegistry()
```

2. **Define a function with `@tool` decorator**
Expand All @@ -59,7 +59,7 @@ class TemperatureUnit(Enum):
fahrenheit = "fahrenheit"


@tool(regsitry=weather_registry)
@tool(regsitry=tool_registry)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@tool(regsitry=tool_registry)
@tool(registery=tool_registry)

Ohh there's a typo as well. Will fix.

def get_current_weather(location: str, format: TemperatureUnit) -> str:
"""
Get the current weather
Expand Down Expand Up @@ -87,7 +87,7 @@ conversation.add_message("user", "what is the weather like today?")
4. **Use the `chat_complete` function to get a response from the model**

```python
chat_response = chat_complete(conversation.conversation_history, function_registry=weather_registry, model=GPT_MODEL)
chat_response = chat_complete(conversation.conversation_history, tool_registry=tool_registry, model=GPT_MODEL)
```

Output:
Expand All @@ -103,7 +103,7 @@ Once the user provides the required information, the model can generate the func

```python
conversation.add_message("user", "I'm in Bengaluru, India")
chat_response = chat_complete(conversation.conversation_history, function_registry=weather_registry, model=GPT_MODEL)
chat_response = chat_complete(conversation.conversation_history, tool_registry=tool_registry, model=GPT_MODEL)

eval(chat_response.json()["choices"][0]["message"]["function_call"]["arguments"])
```
Expand Down Expand Up @@ -142,7 +142,7 @@ def ask_database(query: str) -> List[Tuple[str, str]]:
2. **Registering the function and using it**

```python
agentai_functions = [json.loads(func.json_info) for func in [ask_database]]
agentai_functions = [json.loads(func.metadata) for func in [ask_database]]

from agentai.api import chat_complete_execute_fn
agent_system_message = """You are ChinookGPT, a helpful assistant who gets answers to user questions from the Chinook Music Database.
Expand Down
22 changes: 11 additions & 11 deletions agentai/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class InvalidInputError(Exception):

@retry(retry=retry_unless_exception_type(InvalidInputError), stop=stop_after_attempt(3))
def chat_complete(
conversation: Conversation, model, function_registry: ToolRegistry = None, return_function_params: bool = False
conversation: Conversation, model, tool_registry: ToolRegistry = None, return_function_params: bool = False
):
messages = conversation.history
if openai.api_key is None:
Expand All @@ -32,8 +32,8 @@ def chat_complete(
"Authorization": "Bearer " + openai.api_key,
}
json_data = {"model": model, "messages": messages}
if function_registry is not None:
functions = function_registry.get_all_function_information()
if tool_registry is not None:
functions = tool_registry.get_metadata()
logger.debug(f"functions: {functions}")
json_data.update({"functions": functions})

Expand All @@ -56,39 +56,39 @@ def chat_complete(
return response


def get_function_arguments(message, conversation: Conversation, function_registry: ToolRegistry, model: str):
def get_function_arguments(message: dict, conversation: Conversation, tool_registry: ToolRegistry, model: str):
function_arguments = {}
if message["finish_reason"] == "function_call":
arguments = message["message"]["function_call"]["arguments"]
try:
function_arguments = eval(arguments)
except SyntaxError:
print("Syntax error, trying again")
response = chat_complete(conversation.history, function_registry=function_registry, model=model)
response = chat_complete(conversation.history, tool_registry=tool_registry, model=model)
message = response.json()["choices"][0]
function_arguments = get_function_arguments(
message, conversation, function_registry=function_registry, model=model
)
message, conversation, tool_registry=tool_registry, model=model
) # FIXME: This can become an infinite loop
return function_arguments
raise ValueError(f"Unexpected message: {message}")


@retry(retry=retry_unless_exception_type(InvalidInputError), stop=stop_after_attempt(3))
def chat_complete_execute_fn(
conversation: Conversation,
function_registry: ToolRegistry,
tool_registry: ToolRegistry,
callable_function: Callable,
model: str,
):
response = chat_complete(
conversation=conversation,
function_registry=function_registry,
tool_registry=tool_registry,
model=model,
return_function_params=True,
)
message = response.json()["choices"][0]
function_arguments = get_function_arguments(
message=message, conversation=conversation, function_registry=function_registry, model=model
message=message, conversation=conversation, tool_registry=tool_registry, model=model
)
logger.debug(f"function_arguments: {function_arguments}")
results = callable_function(**function_arguments)
Expand All @@ -97,7 +97,7 @@ def chat_complete_execute_fn(

response = chat_complete(
conversation=conversation,
function_registry=function_registry,
tool_registry=tool_registry,
model=model,
return_function_params=False,
)
Expand Down
79 changes: 50 additions & 29 deletions agentai/openai_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,43 @@
"""
import enum
import inspect
import typing
from typing import Any, Callable
from typing import Any, Callable, TypedDict, Optional, Literal, Union

from docstring_parser import parse


ArgName = str
JsonPrimitive = Literal["string", "integer", "number", "boolean", "null", "any", "object", "array"]

class ArgProperties(TypedDict):
type: JsonPrimitive
description: Optional[str]
enum: Optional[list[str]]

class FunctionParameters(TypedDict):
type: Literal["object"]
properties: dict[ArgName, ArgProperties] # function arg name -> arg properties
required: list[ArgName]

class FunctionMetadata(TypedDict):
name: str
description: str
parameters: FunctionParameters


def parse_annotation(annotation):
if getattr(annotation, "__origin__", None) == typing.Union:
if getattr(annotation, "__origin__", None) == Union:
types = [t.__name__ if t.__name__ != "NoneType" else "None" for t in annotation.__args__]
return to_json_schema_type(types[0])
return to_json_primitive(types[0])
elif issubclass(annotation, enum.Enum): # If the annotation is an Enum type
return "enum", [item.name for item in annotation] # Return 'enum' and a list of the names of the enum members
elif getattr(annotation, "__origin__", None) is not None:
if annotation._name is not None:
return f"{to_json_schema_type(annotation._name)}[{','.join([to_json_schema_type(i.__name__) for i in annotation.__args__])}]"
return f"{to_json_primitive(annotation._name)}[{','.join([to_json_primitive(i.__name__) for i in annotation.__args__])}]"
else:
return f"{to_json_schema_type(annotation.__origin__.__name__)}[{','.join([to_json_schema_type(i.__name__) for i in annotation.__args__])}]"
return f"{to_json_primitive(annotation.__origin__.__name__)}[{','.join([to_json_primitive(i.__name__) for i in annotation.__args__])}]"
else:
return to_json_schema_type(annotation.__name__)
return to_json_primitive(annotation.__name__)


class ToolRegistry:
Expand All @@ -42,48 +60,48 @@ def add(self, func: Callable[..., Any]):
"""
self.functions[func.__name__] = func

def get_function_info(self, func: Callable) -> dict:
def get_function_metadata(self, func: Callable) -> FunctionMetadata:
signature = inspect.signature(func)
docstring = inspect.getdoc(func)
docstring_parsed = parse(docstring)

parameters = dict()
args_map: dict[ArgName, ArgProperties] = dict()
required = []

for name, param in signature.parameters.items():
json_type = parse_annotation(param.annotation)
for arg_name, arg in signature.parameters.items():
json_type = parse_annotation(arg.annotation)

if isinstance(json_type, tuple) and json_type[0] == "enum": # If the type is an Enum
param_info = {
arg_properties: ArgProperties = {
"type": "string",
"enum": json_type[1], # Add an 'enum' field with the names of the enum members
"description": "",
}
else:
param_info = {"type": json_type, "description": ""}
arg_properties: ArgProperties = {"type": json_type, "description": ""}

if json_type != "any" and name != "self" and param.default == inspect.Parameter.empty:
required.append(name)
if json_type != "any" and arg_name != "self" and arg.default == inspect.Parameter.empty:
required.append(arg_name)

for doc_param in docstring_parsed.params:
if doc_param.arg_name == name:
param_info["description"] = doc_param.description
if doc_param.arg_name == arg_name:
arg_properties["description"] = doc_param.description

parameters[name] = param_info
args_map[arg_name] = arg_properties

function_info = {
metadata: FunctionMetadata = {
"name": func.__name__,
"description": docstring_parsed.short_description,
"parameters": {"type": "object", "properties": parameters, "required": required},
"parameters": {"type": "object", "properties": args_map, "required": required},
}

return function_info
return metadata

def get_all_function_information(self):
def get_metadata(self):
"""
Get all function information from the registry.
Get all function metadata from the registry.
"""
return [self.get_function_info(func) for func in self.functions.values()]
return [self.get_function_metadata(func) for func in self.functions.values()]

def get(self, name: str) -> Callable[..., Any]:
"""
Expand All @@ -92,7 +110,7 @@ def get(self, name: str) -> Callable[..., Any]:
return self.functions[name]


def to_json_schema_type(type_name: str) -> str:
def to_json_primitive(python_type: str) -> JsonPrimitive:
type_map = {
"str": "string",
"int": "integer",
Expand All @@ -104,7 +122,7 @@ def to_json_schema_type(type_name: str) -> str:
"List": "array",
"Optional": "any",
}
return type_map.get(type_name, "any")
return type_map.get(python_type, "any")


def docstring_parameters(**kwargs):
Expand All @@ -114,8 +132,11 @@ def dec(obj):

return dec

class Tool(Callable):
metadata: FunctionMetadata


def tool(registry: ToolRegistry, depends_on=None):
def tool(registry: ToolRegistry, depends_on=None) -> Tool:
if registry is None:
raise ValueError("The registry cannot be None")
if not isinstance(registry, ToolRegistry):
Expand All @@ -131,8 +152,8 @@ def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
if dependency_name not in registry.functions:
raise ValueError(f"Dependency function '{dependency_name}' is not registered in the registry")

func_info = registry.get_function_info(func)
func.json_info = func_info
func_info = registry.get_function_metadata(func)
func.metadata = func_info
registry.add(func) # Register the function in the passed registry
return func

Expand Down
8 changes: 4 additions & 4 deletions docs/AgentAI_Intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@
"conversation = Conversation()\n",
"conversation.add_message(\"user\", \"what is the weather like today?\")\n",
"\n",
"chat_response = chat_complete(conversation=conversation, function_registry=weather_registry, model=GPT_MODEL)\n",
"chat_response = chat_complete(conversation=conversation, tool_registry=weather_registry, model=GPT_MODEL)\n",
"message = chat_response.json()[\"choices\"][0][\"message\"]\n",
"conversation.add_message(message[\"role\"], message[\"content\"])\n",
"message"
Expand All @@ -201,7 +201,7 @@
"# Once the user provides the required information, the model can generate the function arguments\n",
"conversation.add_message(\"user\", \"I'm in Bengaluru, India\")\n",
"chat_response = chat_complete(\n",
" conversation=conversation, function_registry=weather_registry, model=GPT_MODEL, return_function_params=True\n",
" conversation=conversation, tool_registry=weather_registry, model=GPT_MODEL, return_function_params=True\n",
")\n",
"chat_response.json()\n",
"eval(chat_response.json()[\"choices\"][0][\"message\"][\"function_call\"][\"arguments\"])"
Expand Down Expand Up @@ -384,7 +384,7 @@
"sql_conversation.add_message(role=\"user\", content=\"Hi, who are the top 5 artists by number of tracks\")\n",
"assistant_message = chat_complete_execute_fn(\n",
" conversation=sql_conversation,\n",
" function_registry=db_registry,\n",
" tool_registry=db_registry,\n",
" model=GPT_MODEL,\n",
" callable_function=ask_database,\n",
")"
Expand Down Expand Up @@ -436,7 +436,7 @@
"source": [
"sql_conversation.add_message(\"user\", \"What is the name of the album with the most tracks\")\n",
"chat_response = chat_complete_execute_fn(\n",
" conversation=sql_conversation, function_registry=db_registry, model=GPT_MODEL, callable_function=ask_database\n",
" conversation=sql_conversation, tool_registry=db_registry, model=GPT_MODEL, callable_function=ask_database\n",
")"
]
},
Expand Down