diff --git a/agenta-backend/agenta_backend/core/observability/utils.py b/agenta-backend/agenta_backend/core/observability/utils.py index 9446368191..8ce48cba49 100644 --- a/agenta-backend/agenta_backend/core/observability/utils.py +++ b/agenta-backend/agenta_backend/core/observability/utils.py @@ -388,12 +388,15 @@ def calculate_costs(span_idx: Dict[str, SpanDTO]): and span.meta and span.metrics ): + model = span.meta.get("response.model") + prompt_tokens = span.metrics.get("unit.tokens.prompt", 0.0) + completion_tokens = span.metrics.get("unit.tokens.completion", 0.0) + try: costs = cost_calculator.cost_per_token( - model=span.meta.get("response.model"), - prompt_tokens=span.metrics.get("unit.tokens.prompt", 0.0), - completion_tokens=span.metrics.get("unit.tokens.completion", 0.0), - call_type=span.node.type.name.lower(), + model=model, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, ) if not costs: @@ -406,5 +409,8 @@ def calculate_costs(span_idx: Dict[str, SpanDTO]): span.metrics["unit.costs.completion"] = completion_cost span.metrics["unit.costs.total"] = total_cost - except: # pylint: disable=W0702:bare-except - pass + except: # pylint: disable=bare-except + print("Failed to calculate costs:") + print( + f"model={model}, prompt_tokens={prompt_tokens}, completion_tokens={completion_tokens}" + ) diff --git a/agenta-cli/agenta/sdk/tracing/inline.py b/agenta-cli/agenta/sdk/tracing/inline.py index 8c4165ea6f..011405e328 100644 --- a/agenta-cli/agenta/sdk/tracing/inline.py +++ b/agenta-cli/agenta/sdk/tracing/inline.py @@ -1229,13 +1229,15 @@ def calculate_costs(span_idx: Dict[str, SpanDTO]): and span.meta and span.metrics ): + model = span.meta.get("response.model") + prompt_tokens = span.metrics.get("unit.tokens.prompt", 0.0) + completion_tokens = span.metrics.get("unit.tokens.completion", 0.0) + try: costs = cost_calculator.cost_per_token( - model=span.meta.get("response.model"), - prompt_tokens=span.metrics.get("unit.tokens.prompt", 0.0), - completion_tokens=span.metrics.get("unit.tokens.completion", 0.0), - call_type=span.node.type.name.lower(), - response_time_ms=span.time.span // 1_000, + model=model, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, ) if not costs: @@ -1248,5 +1250,8 @@ def calculate_costs(span_idx: Dict[str, SpanDTO]): span.metrics["unit.costs.completion"] = completion_cost span.metrics["unit.costs.total"] = total_cost - except: - pass + except: # pylint: disable=bare-except + print("Failed to calculate costs:") + print( + f"model={model}, prompt_tokens={prompt_tokens}, completion_tokens={completion_tokens}" + )