diff --git a/actionweaver/llms/azure/chat.py b/actionweaver/llms/azure/chat.py index afab68e..4c85140 100644 --- a/actionweaver/llms/azure/chat.py +++ b/actionweaver/llms/azure/chat.py @@ -16,10 +16,10 @@ from actionweaver.actions.action import Action, ActionHandlers from actionweaver.llms.azure.functions import Functions -from actionweaver.llms.azure.tokens import TokenUsageTracker from actionweaver.telemetry import traceable from actionweaver.utils import DEFAULT_ACTION_SCOPE from actionweaver.utils.stream import get_first_element_and_iterator, merge_dicts +from actionweaver.utils.tokens import TokenUsageTracker # TODO: support AsyncAzureOpenAI diff --git a/actionweaver/llms/general/tokens.py b/actionweaver/llms/general/tokens.py deleted file mode 100644 index ea7f632..0000000 --- a/actionweaver/llms/general/tokens.py +++ /dev/null @@ -1,46 +0,0 @@ -import collections -import logging -import time -from typing import Dict - - -class TokenUsageTrackerException(Exception): - pass - - -class TokenUsageTracker: - def __init__(self, budget=None, logger=None): - self.logger = logger or logging.getLogger(__name__) - self.tracker = collections.Counter() - self.budget = budget - - def clear(self): - self.tracker = collections.Counter() - return self - - def track_usage(self, usage: Dict): - self.tracker = self.tracker + collections.Counter(usage) - - self.logger.debug( - { - "message": "token usage updated", - "usage": usage, - "total_usage": dict(self.tracker), - "timestamp": time.time(), - "budget": self.budget, - }, - ) - if self.budget is not None and self.tracker["total_tokens"] > self.budget: - self.logger.error( - { - "message": "Token budget exceeded", - "usage": usage, - "total_usage": dict(self.tracker), - "budget": self.budget, - }, - exc_info=True, - ) - raise TokenUsageTrackerException( - f"Token budget exceeded. Budget: {self.budget}, Usage: {dict(self.tracker)}" - ) - return self.tracker diff --git a/actionweaver/llms/openai/functions/chat.py b/actionweaver/llms/openai/functions/chat.py index 181accb..82f28ea 100644 --- a/actionweaver/llms/openai/functions/chat.py +++ b/actionweaver/llms/openai/functions/chat.py @@ -14,9 +14,9 @@ from actionweaver.actions.action import Action, ActionHandlers from actionweaver.llms.openai.functions.functions import Functions -from actionweaver.llms.openai.functions.tokens import TokenUsageTracker from actionweaver.utils import DEFAULT_ACTION_SCOPE from actionweaver.utils.stream import get_first_element_and_iterator, merge_dicts +from actionweaver.utils.tokens import TokenUsageTracker class OpenAIChatCompletionException(Exception): @@ -27,9 +27,7 @@ class OpenAIChatCompletion: def __init__(self, model, token_usage_tracker=None, logger=None): self.model = model self.logger = logger or logging.getLogger(__name__) - self.token_usage_tracker = token_usage_tracker or TokenUsageTracker( - logger=logger - ) + self.token_usage_tracker = token_usage_tracker or TokenUsageTracker() self.client = OpenAI() print( diff --git a/actionweaver/llms/openai/functions/tokens.py b/actionweaver/llms/openai/functions/tokens.py deleted file mode 100644 index ea7f632..0000000 --- a/actionweaver/llms/openai/functions/tokens.py +++ /dev/null @@ -1,46 +0,0 @@ -import collections -import logging -import time -from typing import Dict - - -class TokenUsageTrackerException(Exception): - pass - - -class TokenUsageTracker: - def __init__(self, budget=None, logger=None): - self.logger = logger or logging.getLogger(__name__) - self.tracker = collections.Counter() - self.budget = budget - - def clear(self): - self.tracker = collections.Counter() - return self - - def track_usage(self, usage: Dict): - self.tracker = self.tracker + collections.Counter(usage) - - self.logger.debug( - { - "message": "token usage updated", - "usage": usage, - "total_usage": dict(self.tracker), - "timestamp": time.time(), - "budget": self.budget, - }, - ) - if self.budget is not None and self.tracker["total_tokens"] > self.budget: - self.logger.error( - { - "message": "Token budget exceeded", - "usage": usage, - "total_usage": dict(self.tracker), - "budget": self.budget, - }, - exc_info=True, - ) - raise TokenUsageTrackerException( - f"Token budget exceeded. Budget: {self.budget}, Usage: {dict(self.tracker)}" - ) - return self.tracker diff --git a/actionweaver/llms/openai/tools/chat.py b/actionweaver/llms/openai/tools/chat.py index de8386d..9144138 100644 --- a/actionweaver/llms/openai/tools/chat.py +++ b/actionweaver/llms/openai/tools/chat.py @@ -14,11 +14,11 @@ ) from actionweaver.actions.action import Action, ActionHandlers -from actionweaver.llms.openai.tools.tokens import TokenUsageTracker from actionweaver.llms.openai.tools.tools import Tools from actionweaver.telemetry import traceable from actionweaver.utils import DEFAULT_ACTION_SCOPE from actionweaver.utils.stream import get_first_element_and_iterator, merge_dicts +from actionweaver.utils.tokens import TokenUsageTracker class OpenAIChatCompletionException(Exception): diff --git a/actionweaver/llms/openai/tools/tokens.py b/actionweaver/llms/openai/tools/tokens.py deleted file mode 100644 index 2fdd603..0000000 --- a/actionweaver/llms/openai/tools/tokens.py +++ /dev/null @@ -1,35 +0,0 @@ -import collections -import time - -from openai.types import CompletionUsage - - -class TokenUsageTrackerException(Exception): - pass - - -class TokenUsageTracker: - def __init__(self, budget=None): - self.tracker = CompletionUsage( - completion_tokens=0, prompt_tokens=0, total_tokens=0 - ) - self.budget = budget - - def clear(self): - self.tracker = CompletionUsage( - completion_tokens=0, prompt_tokens=0, total_tokens=0 - ) - return self - - def track_usage(self, usage: CompletionUsage): - self.tracker = CompletionUsage( - completion_tokens=self.tracker.completion_tokens + usage.completion_tokens, - prompt_tokens=self.tracker.prompt_tokens + usage.prompt_tokens, - total_tokens=self.tracker.total_tokens + usage.total_tokens, - ) - - if self.budget is not None and self.tracker.total_tokens > self.budget: - raise TokenUsageTrackerException( - f"Token budget exceeded. Budget: {self.budget}, Usage: {self.tracker}" - ) - return self.tracker diff --git a/actionweaver/llms/azure/tokens.py b/actionweaver/utils/tokens.py similarity index 96% rename from actionweaver/llms/azure/tokens.py rename to actionweaver/utils/tokens.py index 063e9ba..a48b7d5 100644 --- a/actionweaver/llms/azure/tokens.py +++ b/actionweaver/utils/tokens.py @@ -1,6 +1,4 @@ import collections -import logging -import time from typing import Dict