Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hotfix] Failing templates due to LiteLLM callback issues #2355

Merged
merged 2 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 98 additions & 38 deletions agenta-cli/agenta/sdk/litellm/litellm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Dict
from opentelemetry.trace import SpanKind

import agenta as ag
Expand Down Expand Up @@ -34,14 +35,20 @@ class LitellmHandler(LitellmCustomLogger):
def __init__(self):
super().__init__()

self.span = None
self.span: Dict[str, CustomSpan] = dict()

def log_pre_api_call(
self,
model,
messages,
kwargs,
):
litellm_call_id = kwargs.get("litellm_call_id")

if not litellm_call_id:
log.warning("Agenta SDK - litellm tracing failed")
return

type = ( # pylint: disable=redefined-builtin
"chat"
if kwargs.get("call_type") in ["completion", "acompletion"]
Expand All @@ -50,29 +57,31 @@ def log_pre_api_call(

kind = SpanKind.CLIENT

self.span = CustomSpan(
self.span[litellm_call_id] = CustomSpan(
ag.tracer.start_span(name=f"litellm_{kind.name.lower()}", kind=kind)
)

self.span.set_attributes(
attributes={"node": type},
namespace="type",
)
span = self.span[litellm_call_id]

if not self.span:
if not span:
log.warning("Agenta SDK - litellm tracing failed")
return

if not self.span.is_recording():
if not span.is_recording():
log.error("Agenta SDK - litellm span not recording.")
return

self.span.set_attributes(
span.set_attributes(
attributes={"node": type},
namespace="type",
)

span.set_attributes(
attributes={"inputs": {"prompt": kwargs["messages"]}},
namespace="data",
)

self.span.set_attributes(
span.set_attributes(
attributes={
"configuration": {
"model": kwargs.get("model"),
Expand All @@ -89,11 +98,19 @@ def log_stream_event(
start_time,
end_time,
):
if not self.span:
litellm_call_id = kwargs.get("litellm_call_id")

if not litellm_call_id:
log.warning("Agenta SDK - litellm tracing failed")
return

span = self.span[litellm_call_id]

if not span:
log.warning("Agenta SDK - litellm tracing failed")
return

if not self.span.is_recording():
if not span.is_recording():
return

def log_success_event(
Expand All @@ -106,11 +123,19 @@ def log_success_event(
if kwargs.get("stream"):
return

if not self.span:
litellm_call_id = kwargs.get("litellm_call_id")

if not litellm_call_id:
log.warning("Agenta SDK - litellm tracing failed")
return

if not self.span.is_recording():
span = self.span[litellm_call_id]

if not span:
log.warning("Agenta SDK - litellm tracing failed")
return

if not span.is_recording():
return

try:
Expand All @@ -120,20 +145,20 @@ def log_success_event(
result.append(message)

outputs = {"completion": result}
self.span.set_attributes(
span.set_attributes(
attributes={"outputs": outputs},
namespace="data",
)

except Exception as e:
pass

self.span.set_attributes(
span.set_attributes(
attributes={"total": kwargs.get("response_cost")},
namespace="metrics.unit.costs",
)

self.span.set_attributes(
span.set_attributes(
attributes=(
{
"prompt": response_obj.usage.prompt_tokens,
Expand All @@ -144,9 +169,9 @@ def log_success_event(
namespace="metrics.unit.tokens",
)

self.span.set_status(status="OK")
span.set_status(status="OK")

self.span.end()
span.end()

def log_failure_event(
self,
Expand All @@ -155,18 +180,26 @@ def log_failure_event(
start_time,
end_time,
):
if not self.span:
litellm_call_id = kwargs.get("litellm_call_id")

if not litellm_call_id:
log.warning("Agenta SDK - litellm tracing failed")
return

if not self.span.is_recording():
span = self.span[litellm_call_id]

if not span:
log.warning("Agenta SDK - litellm tracing failed")
return

if not span.is_recording():
return

self.span.record_exception(kwargs["exception"])
span.record_exception(kwargs["exception"])

self.span.set_status(status="ERROR")
span.set_status(status="ERROR")

self.span.end()
span.end()

async def async_log_stream_event(
self,
Expand All @@ -175,11 +208,22 @@ async def async_log_stream_event(
start_time,
end_time,
):
if not self.span:
if kwargs.get("stream"):
return

litellm_call_id = kwargs.get("litellm_call_id")

if not litellm_call_id:
log.warning("Agenta SDK - litellm tracing failed")
return

span = self.span[litellm_call_id]

if not span:
log.warning("Agenta SDK - litellm tracing failed")
return

if not self.span.is_recording():
if not span.is_recording():
return

async def async_log_success_event(
Expand All @@ -189,11 +233,19 @@ async def async_log_success_event(
start_time,
end_time,
):
if not self.span:
litellm_call_id = kwargs.get("litellm_call_id")

if not litellm_call_id:
log.warning("Agenta SDK - litellm tracing failed")
return

span = self.span[litellm_call_id]

if not span:
log.warning("Agenta SDK - litellm tracing failed")
return

if not self.span.is_recording():
if not span.is_recording():
return

try:
Expand All @@ -203,20 +255,20 @@ async def async_log_success_event(
result.append(message)

outputs = {"completion": result}
self.span.set_attributes(
span.set_attributes(
attributes={"outputs": outputs},
namespace="data",
)

except Exception as e:
pass

self.span.set_attributes(
span.set_attributes(
attributes={"total": kwargs.get("response_cost")},
namespace="metrics.unit.costs",
)

self.span.set_attributes(
span.set_attributes(
attributes=(
{
"prompt": response_obj.usage.prompt_tokens,
Expand All @@ -227,9 +279,9 @@ async def async_log_success_event(
namespace="metrics.unit.tokens",
)

self.span.set_status(status="OK")
span.set_status(status="OK")

self.span.end()
span.end()

async def async_log_failure_event(
self,
Expand All @@ -238,17 +290,25 @@ async def async_log_failure_event(
start_time,
end_time,
):
if not self.span:
litellm_call_id = kwargs.get("litellm_call_id")

if not litellm_call_id:
log.warning("Agenta SDK - litellm tracing failed")
return

span = self.span[litellm_call_id]

if not span:
log.warning("Agenta SDK - litellm tracing failed")
return

if not self.span.is_recording():
if not span.is_recording():
return

self.span.record_exception(kwargs["exception"])
span.record_exception(kwargs["exception"])

self.span.set_status(status="ERROR")
span.set_status(status="ERROR")

self.span.end()
span.end()

return LitellmHandler()
5 changes: 1 addition & 4 deletions agenta-cli/agenta/sdk/tracing/inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,7 +1143,4 @@ def calculate_costs(span_idx: Dict[str, SpanDTO]):
span.metrics["unit.costs.total"] = total_cost

except: # pylint: disable=bare-except
print("Failed to calculate costs:")
print(
f"model={model}, prompt_tokens={prompt_tokens}, completion_tokens={completion_tokens}"
)
pass
2 changes: 0 additions & 2 deletions agenta-cli/agenta/sdk/tracing/processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ def on_start(
span: Span,
parent_context: Optional[Context] = None,
) -> None:
# ADD LINKS FROM CONTEXT, HERE

for key in self.references.keys():
span.set_attribute(f"ag.refs.{key}", self.references[key])

Expand Down
2 changes: 1 addition & 1 deletion agenta-web/src/components/Playground/Views/TestView.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ const App: React.FC<TestViewProps> = ({
`❌ ${getErrorMessage(e?.response?.data?.error || e?.response?.data, e)}`,
index,
)
if (e.response.status === 401) {
if (e?.response?.status === 401) {
setIsLLMProviderMissingModalOpen(true)
}
} else {
Expand Down
Loading