Skip to content

Commit

Permalink
Handle parallel calls to LiteLLM...
Browse files Browse the repository at this point in the history
  • Loading branch information
jp-agenta committed Dec 9, 2024
1 parent 2341fc5 commit dcc054d
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 44 deletions.
136 changes: 98 additions & 38 deletions agenta-cli/agenta/sdk/litellm/litellm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Dict
from opentelemetry.trace import SpanKind

import agenta as ag
Expand Down Expand Up @@ -34,14 +35,20 @@ class LitellmHandler(LitellmCustomLogger):
def __init__(self):
super().__init__()

self.span = None
self.span: Dict[str, CustomSpan] = dict()

def log_pre_api_call(
self,
model,
messages,
kwargs,
):
litellm_call_id = kwargs.get("litellm_call_id")

if not litellm_call_id:
log.warning("Agenta SDK - litellm tracing failed")
return

type = ( # pylint: disable=redefined-builtin
"chat"
if kwargs.get("call_type") in ["completion", "acompletion"]
Expand All @@ -50,29 +57,31 @@ def log_pre_api_call(

kind = SpanKind.CLIENT

self.span = CustomSpan(
self.span[litellm_call_id] = CustomSpan(
ag.tracer.start_span(name=f"litellm_{kind.name.lower()}", kind=kind)
)

self.span.set_attributes(
attributes={"node": type},
namespace="type",
)
span = self.span[litellm_call_id]

if not self.span:
if not span:
log.warning("Agenta SDK - litellm tracing failed")
return

if not self.span.is_recording():
if not span.is_recording():
log.error("Agenta SDK - litellm span not recording.")
return

self.span.set_attributes(
span.set_attributes(
attributes={"node": type},
namespace="type",
)

span.set_attributes(
attributes={"inputs": {"prompt": kwargs["messages"]}},
namespace="data",
)

self.span.set_attributes(
span.set_attributes(
attributes={
"configuration": {
"model": kwargs.get("model"),
Expand All @@ -89,11 +98,19 @@ def log_stream_event(
start_time,
end_time,
):
if not self.span:
litellm_call_id = kwargs.get("litellm_call_id")

if not litellm_call_id:
log.warning("Agenta SDK - litellm tracing failed")
return

span = self.span[litellm_call_id]

if not span:
log.warning("Agenta SDK - litellm tracing failed")
return

if not self.span.is_recording():
if not span.is_recording():
return

def log_success_event(
Expand All @@ -106,11 +123,19 @@ def log_success_event(
if kwargs.get("stream"):
return

if not self.span:
litellm_call_id = kwargs.get("litellm_call_id")

if not litellm_call_id:
log.warning("Agenta SDK - litellm tracing failed")
return

if not self.span.is_recording():
span = self.span[litellm_call_id]

if not span:
log.warning("Agenta SDK - litellm tracing failed")
return

if not span.is_recording():
return

try:
Expand All @@ -120,20 +145,20 @@ def log_success_event(
result.append(message)

outputs = {"completion": result}
self.span.set_attributes(
span.set_attributes(
attributes={"outputs": outputs},
namespace="data",
)

except Exception as e:
pass

self.span.set_attributes(
span.set_attributes(
attributes={"total": kwargs.get("response_cost")},
namespace="metrics.unit.costs",
)

self.span.set_attributes(
span.set_attributes(
attributes=(
{
"prompt": response_obj.usage.prompt_tokens,
Expand All @@ -144,9 +169,9 @@ def log_success_event(
namespace="metrics.unit.tokens",
)

self.span.set_status(status="OK")
span.set_status(status="OK")

self.span.end()
span.end()

def log_failure_event(
self,
Expand All @@ -155,18 +180,26 @@ def log_failure_event(
start_time,
end_time,
):
if not self.span:
litellm_call_id = kwargs.get("litellm_call_id")

if not litellm_call_id:
log.warning("Agenta SDK - litellm tracing failed")
return

if not self.span.is_recording():
span = self.span[litellm_call_id]

if not span:
log.warning("Agenta SDK - litellm tracing failed")
return

if not span.is_recording():
return

self.span.record_exception(kwargs["exception"])
span.record_exception(kwargs["exception"])

self.span.set_status(status="ERROR")
span.set_status(status="ERROR")

self.span.end()
span.end()

async def async_log_stream_event(
self,
Expand All @@ -175,11 +208,22 @@ async def async_log_stream_event(
start_time,
end_time,
):
if not self.span:
if kwargs.get("stream"):
return

litellm_call_id = kwargs.get("litellm_call_id")

if not litellm_call_id:
log.warning("Agenta SDK - litellm tracing failed")
return

span = self.span[litellm_call_id]

if not span:
log.warning("Agenta SDK - litellm tracing failed")
return

if not self.span.is_recording():
if not span.is_recording():
return

async def async_log_success_event(
Expand All @@ -189,11 +233,19 @@ async def async_log_success_event(
start_time,
end_time,
):
if not self.span:
litellm_call_id = kwargs.get("litellm_call_id")

if not litellm_call_id:
log.warning("Agenta SDK - litellm tracing failed")
return

span = self.span[litellm_call_id]

if not span:
log.warning("Agenta SDK - litellm tracing failed")
return

if not self.span.is_recording():
if not span.is_recording():
return

try:
Expand All @@ -203,20 +255,20 @@ async def async_log_success_event(
result.append(message)

outputs = {"completion": result}
self.span.set_attributes(
span.set_attributes(
attributes={"outputs": outputs},
namespace="data",
)

except Exception as e:
pass

self.span.set_attributes(
span.set_attributes(
attributes={"total": kwargs.get("response_cost")},
namespace="metrics.unit.costs",
)

self.span.set_attributes(
span.set_attributes(
attributes=(
{
"prompt": response_obj.usage.prompt_tokens,
Expand All @@ -227,9 +279,9 @@ async def async_log_success_event(
namespace="metrics.unit.tokens",
)

self.span.set_status(status="OK")
span.set_status(status="OK")

self.span.end()
span.end()

async def async_log_failure_event(
self,
Expand All @@ -238,17 +290,25 @@ async def async_log_failure_event(
start_time,
end_time,
):
if not self.span:
litellm_call_id = kwargs.get("litellm_call_id")

if not litellm_call_id:
log.warning("Agenta SDK - litellm tracing failed")
return

span = self.span[litellm_call_id]

if not span:
log.warning("Agenta SDK - litellm tracing failed")
return

if not self.span.is_recording():
if not span.is_recording():
return

self.span.record_exception(kwargs["exception"])
span.record_exception(kwargs["exception"])

self.span.set_status(status="ERROR")
span.set_status(status="ERROR")

self.span.end()
span.end()

return LitellmHandler()
5 changes: 1 addition & 4 deletions agenta-cli/agenta/sdk/tracing/inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,4 @@ def calculate_costs(span_idx: Dict[str, SpanDTO]):
span.metrics["unit.costs.total"] = total_cost

except: # pylint: disable=bare-except
print("Failed to calculate costs:")
print(
f"model={model}, prompt_tokens={prompt_tokens}, completion_tokens={completion_tokens}"
)
pass
2 changes: 0 additions & 2 deletions agenta-cli/agenta/sdk/tracing/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ def on_start(
span: Span,
parent_context: Optional[Context] = None,
) -> None:
# ADD LINKS FROM CONTEXT, HERE

for key in self.references.keys():
span.set_attribute(f"ag.refs.{key}", self.references[key])

Expand Down

0 comments on commit dcc054d

Please sign in to comment.