From dcc054d0136594fba594f17e2a9e3a8ecdd813bb Mon Sep 17 00:00:00 2001 From: Juan Pablo Vega Date: Mon, 9 Dec 2024 13:10:26 +0100 Subject: [PATCH] Handle parallel calls to LiteLLM... --- agenta-cli/agenta/sdk/litellm/litellm.py | 136 ++++++++++++++------ agenta-cli/agenta/sdk/tracing/inline.py | 5 +- agenta-cli/agenta/sdk/tracing/processors.py | 2 - 3 files changed, 99 insertions(+), 44 deletions(-) diff --git a/agenta-cli/agenta/sdk/litellm/litellm.py b/agenta-cli/agenta/sdk/litellm/litellm.py index fba79d5c2c..d6e2e57c14 100644 --- a/agenta-cli/agenta/sdk/litellm/litellm.py +++ b/agenta-cli/agenta/sdk/litellm/litellm.py @@ -1,3 +1,4 @@ +from typing import Dict from opentelemetry.trace import SpanKind import agenta as ag @@ -34,7 +35,7 @@ class LitellmHandler(LitellmCustomLogger): def __init__(self): super().__init__() - self.span = None + self.span: Dict[str, CustomSpan] = dict() def log_pre_api_call( self, @@ -42,6 +43,12 @@ def log_pre_api_call( messages, kwargs, ): + litellm_call_id = kwargs.get("litellm_call_id") + + if not litellm_call_id: + log.warning("Agenta SDK - litellm tracing failed") + return + type = ( # pylint: disable=redefined-builtin "chat" if kwargs.get("call_type") in ["completion", "acompletion"] @@ -50,29 +57,31 @@ def log_pre_api_call( kind = SpanKind.CLIENT - self.span = CustomSpan( + self.span[litellm_call_id] = CustomSpan( ag.tracer.start_span(name=f"litellm_{kind.name.lower()}", kind=kind) ) - self.span.set_attributes( - attributes={"node": type}, - namespace="type", - ) + span = self.span[litellm_call_id] - if not self.span: + if not span: log.warning("Agenta SDK - litellm tracing failed") return - if not self.span.is_recording(): + if not span.is_recording(): log.error("Agenta SDK - litellm span not recording.") return - self.span.set_attributes( + span.set_attributes( + attributes={"node": type}, + namespace="type", + ) + + span.set_attributes( attributes={"inputs": {"prompt": kwargs["messages"]}}, namespace="data", ) - self.span.set_attributes( + span.set_attributes( attributes={ "configuration": { "model": kwargs.get("model"), @@ -89,11 +98,19 @@ def log_stream_event( start_time, end_time, ): - if not self.span: + litellm_call_id = kwargs.get("litellm_call_id") + + if not litellm_call_id: + log.warning("Agenta SDK - litellm tracing failed") + return + + span = self.span[litellm_call_id] + + if not span: log.warning("Agenta SDK - litellm tracing failed") return - if not self.span.is_recording(): + if not span.is_recording(): return def log_success_event( @@ -106,11 +123,19 @@ def log_success_event( if kwargs.get("stream"): return - if not self.span: + litellm_call_id = kwargs.get("litellm_call_id") + + if not litellm_call_id: log.warning("Agenta SDK - litellm tracing failed") return - if not self.span.is_recording(): + span = self.span[litellm_call_id] + + if not span: + log.warning("Agenta SDK - litellm tracing failed") + return + + if not span.is_recording(): return try: @@ -120,7 +145,7 @@ def log_success_event( result.append(message) outputs = {"completion": result} - self.span.set_attributes( + span.set_attributes( attributes={"outputs": outputs}, namespace="data", ) @@ -128,12 +153,12 @@ def log_success_event( except Exception as e: pass - self.span.set_attributes( + span.set_attributes( attributes={"total": kwargs.get("response_cost")}, namespace="metrics.unit.costs", ) - self.span.set_attributes( + span.set_attributes( attributes=( { "prompt": response_obj.usage.prompt_tokens, @@ -144,9 +169,9 @@ def log_success_event( namespace="metrics.unit.tokens", ) - self.span.set_status(status="OK") + span.set_status(status="OK") - self.span.end() + span.end() def log_failure_event( self, @@ -155,18 +180,26 @@ def log_failure_event( start_time, end_time, ): - if not self.span: + litellm_call_id = kwargs.get("litellm_call_id") + + if not litellm_call_id: log.warning("Agenta SDK - litellm tracing failed") return - if not self.span.is_recording(): + span = self.span[litellm_call_id] + + if not span: + log.warning("Agenta SDK - litellm tracing failed") + return + + if not span.is_recording(): return - self.span.record_exception(kwargs["exception"]) + span.record_exception(kwargs["exception"]) - self.span.set_status(status="ERROR") + span.set_status(status="ERROR") - self.span.end() + span.end() async def async_log_stream_event( self, @@ -175,11 +208,22 @@ async def async_log_stream_event( start_time, end_time, ): - if not self.span: + if kwargs.get("stream"): + return + + litellm_call_id = kwargs.get("litellm_call_id") + + if not litellm_call_id: + log.warning("Agenta SDK - litellm tracing failed") + return + + span = self.span[litellm_call_id] + + if not span: log.warning("Agenta SDK - litellm tracing failed") return - if not self.span.is_recording(): + if not span.is_recording(): return async def async_log_success_event( @@ -189,11 +233,19 @@ async def async_log_success_event( start_time, end_time, ): - if not self.span: + litellm_call_id = kwargs.get("litellm_call_id") + + if not litellm_call_id: + log.warning("Agenta SDK - litellm tracing failed") + return + + span = self.span[litellm_call_id] + + if not span: log.warning("Agenta SDK - litellm tracing failed") return - if not self.span.is_recording(): + if not span.is_recording(): return try: @@ -203,7 +255,7 @@ async def async_log_success_event( result.append(message) outputs = {"completion": result} - self.span.set_attributes( + span.set_attributes( attributes={"outputs": outputs}, namespace="data", ) @@ -211,12 +263,12 @@ async def async_log_success_event( except Exception as e: pass - self.span.set_attributes( + span.set_attributes( attributes={"total": kwargs.get("response_cost")}, namespace="metrics.unit.costs", ) - self.span.set_attributes( + span.set_attributes( attributes=( { "prompt": response_obj.usage.prompt_tokens, @@ -227,9 +279,9 @@ async def async_log_success_event( namespace="metrics.unit.tokens", ) - self.span.set_status(status="OK") + span.set_status(status="OK") - self.span.end() + span.end() async def async_log_failure_event( self, @@ -238,17 +290,25 @@ async def async_log_failure_event( start_time, end_time, ): - if not self.span: + litellm_call_id = kwargs.get("litellm_call_id") + + if not litellm_call_id: + log.warning("Agenta SDK - litellm tracing failed") + return + + span = self.span[litellm_call_id] + + if not span: log.warning("Agenta SDK - litellm tracing failed") return - if not self.span.is_recording(): + if not span.is_recording(): return - self.span.record_exception(kwargs["exception"]) + span.record_exception(kwargs["exception"]) - self.span.set_status(status="ERROR") + span.set_status(status="ERROR") - self.span.end() + span.end() return LitellmHandler() diff --git a/agenta-cli/agenta/sdk/tracing/inline.py b/agenta-cli/agenta/sdk/tracing/inline.py index d8309db99c..6905ad5cf0 100644 --- a/agenta-cli/agenta/sdk/tracing/inline.py +++ b/agenta-cli/agenta/sdk/tracing/inline.py @@ -1143,7 +1143,4 @@ def calculate_costs(span_idx: Dict[str, SpanDTO]): span.metrics["unit.costs.total"] = total_cost except: # pylint: disable=bare-except - print("Failed to calculate costs:") - print( - f"model={model}, prompt_tokens={prompt_tokens}, completion_tokens={completion_tokens}" - ) + pass diff --git a/agenta-cli/agenta/sdk/tracing/processors.py b/agenta-cli/agenta/sdk/tracing/processors.py index 3700f329d9..b5d04d8085 100644 --- a/agenta-cli/agenta/sdk/tracing/processors.py +++ b/agenta-cli/agenta/sdk/tracing/processors.py @@ -43,8 +43,6 @@ def on_start( span: Span, parent_context: Optional[Context] = None, ) -> None: - # ADD LINKS FROM CONTEXT, HERE - for key in self.references.keys(): span.set_attribute(f"ag.refs.{key}", self.references[key])