Skip to content

Commit

Permalink
add general llm support (#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
TengHu authored Dec 28, 2023
1 parent 1e6c8ff commit 325eda8
Show file tree
Hide file tree
Showing 7 changed files with 241 additions and 0 deletions.
Empty file.
64 changes: 64 additions & 0 deletions actionweaver/llms/general/action_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typing import Any, Callable, Dict, Optional, Union

ExtractorType = Callable[[str], Dict[str, Any]]


class ActionProcessor:
def __init__(
self,
tools: Optional[Dict[str, Any]] = None,
custom_extractor: Optional[ExtractorType] = None,
) -> None:
self.tools = tools or {}
self.dict = {tool.name: tool for tool in tools}
self.custom_extractor = custom_extractor

def extract_function(self, text: str) -> Union[Dict[str, Any], None]:
if self.custom_extractor:
extracted = self.custom_extractor(text)
if (
not isinstance(extracted, dict)
or "name" not in extracted
or "parameters" not in extracted
):
raise ValueError(
"Custom extractor must return a dictionary with 'name' and 'parameters' keys."
)
return extracted
else:
import json

j = json.loads(text)
return {"name": j["function"], "parameters": j["parameters"]}

def respond(self, text: str):
function = None
try:
function = self.extract_function(text)
except Exception as e:
exception_type = type(e).__name__
exception_message = str(e)
full_exception_string = f"{exception_type}: {exception_message}"
return (
None,
False,
f"Unable to extract a valid function from the input. Error encountered in extractor: {full_exception_string}",
)

if function["name"] not in self.dict:
return None, False, "Function or tool not found"

response = ""
try:
response = self.dict[function["name"]](**function["parameters"])
except Exception as e:
exception_type = type(e).__name__
exception_message = str(e)
full_exception_string = f"{exception_type}: {exception_message}"
return (
None,
False,
f"Unable to invoke valid function {function['name']}, parameters: {function['parameters']}. Error encountered: {full_exception_string}",
)

return response, True, None
Empty file.
46 changes: 46 additions & 0 deletions actionweaver/llms/general/tokens.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import collections
import logging
import time
from typing import Dict


class TokenUsageTrackerException(Exception):
pass


class TokenUsageTracker:
def __init__(self, budget=None, logger=None):
self.logger = logger or logging.getLogger(__name__)
self.tracker = collections.Counter()
self.budget = budget

def clear(self):
self.tracker = collections.Counter()
return self

def track_usage(self, usage: Dict):
self.tracker = self.tracker + collections.Counter(usage)

self.logger.debug(
{
"message": "token usage updated",
"usage": usage,
"total_usage": dict(self.tracker),
"timestamp": time.time(),
"budget": self.budget,
},
)
if self.budget is not None and self.tracker["total_tokens"] > self.budget:
self.logger.error(
{
"message": "Token budget exceeded",
"usage": usage,
"total_usage": dict(self.tracker),
"budget": self.budget,
},
exc_info=True,
)
raise TokenUsageTrackerException(
f"Token budget exceeded. Budget: {self.budget}, Usage: {dict(self.tracker)}"
)
return self.tracker
30 changes: 30 additions & 0 deletions actionweaver/llms/general/tools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# TODO: assume all actions are functions for now
from actionweaver.actions import Action


class ToolException(Exception):
pass


class Tools:
def __init__(self, tools=None) -> None:
self.tools = tools

@classmethod
def from_expr(cls, expr):
if expr is None:
return cls()
elif isinstance(expr, list):
return cls(
tools=expr,
)
else:
raise ToolException(f"Invalid orchestration expression: {expr}")

def to_arguments(self):
return "\n".join(
[
f"""{a.name}: \n description: {a.description} \n params: {a.json_schema()}"""
for a in self.tools
]
)
Empty file added tests/llms/general/__init__.py
Empty file.
101 changes: 101 additions & 0 deletions tests/llms/general/test_action_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from __future__ import annotations

import unittest
from unittest.mock import Mock, call, patch
from urllib import response

from actionweaver.actions import Action
from actionweaver.llms.general.action_processor import ActionProcessor


class TestActionProcessor(unittest.TestCase):
def test_action_processor(self):
def get_current_weather(location, unit="fahrenheit"):
"""mock method"""
import json

return json.dumps(
{"location": location, "temperature": "22", "unit": "celsius"}
)

ap = ActionProcessor(
tools=[
Action(
"GetWeather",
get_current_weather,
).build_pydantic_model_cls()
]
)

response, ok, err = ap.respond("hello")
self.assertFalse(ok)
self.assertTrue(response is None)
self.assertTrue(
"Unable to extract a valid function from the input. Error encountered in extractor"
in err,
)

response, ok, err = ap.respond(
'{\n "function": "GetWeather",\n "parameters": {\n "location": "San Francisco",\n "unit": "fahrenheit"\n }\n}'
)
self.assertTrue(ok)
self.assertTrue(
response,
{
"location": "San Francisco",
"temperature": "22",
"unit": "celsius",
},
)
self.assertTrue(err is None)

def test_action_processor_with_custom_extractor(self):
def get_current_weather(location, unit="fahrenheit"):
"""mock method"""
import json

return json.dumps(
{"location": location, "temperature": "22", "unit": "celsius"}
)

def extractor(text: str):
import json

j = json.loads(text)
return {"name": j["tool_name"], "parameters": j["tool_arguments"]}

ap = ActionProcessor(
tools=[
Action(
"GetWeather",
get_current_weather,
).build_pydantic_model_cls()
],
custom_extractor=extractor,
)

response, ok, err = ap.respond("hello")
self.assertFalse(ok)
self.assertTrue(response is None)
self.assertTrue(
"Unable to extract a valid function from the input. Error encountered in extractor"
in err,
)

response, ok, err = ap.respond(
'{\n "tool_name": "GetWeather",\n "tool_arguments": {\n "location": "San Francisco",\n "unit": "fahrenheit"\n }\n}'
)
self.assertTrue(ok)
self.assertTrue(
response,
{
"location": "San Francisco",
"temperature": "22",
"unit": "celsius",
},
)
self.assertTrue(err is None)


if __name__ == "__main__":
unittest.main()

0 comments on commit 325eda8

Please sign in to comment.