Skip to content

Commit

Permalink
Merge pull request #1945 from Agenta-AI/fix/litellm-observability
Browse files Browse the repository at this point in the history
fix(sdk): AGE-480 Update litellm callbacks with new SDK
  • Loading branch information
mmabrouk authored Jul 29, 2024
2 parents d497550 + 46da762 commit 3e2c090
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 40 deletions.
89 changes: 51 additions & 38 deletions agenta-cli/agenta/sdk/tracing/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@ def log_pre_api_call(self, model, messages, kwargs):
span_kind = (
"llm" if call_type in ["completion", "acompletion"] else "embedding"
)
self._trace.start_span(

ag.tracing.start_span(
name=f"{span_kind}_call",
input={"messages": kwargs["messages"]},
spankind=span_kind,
)
self._trace.set_span_attribute(
ag.tracing.set_attributes(
{
"model_config": {
"model": kwargs.get("model"),
Expand All @@ -49,15 +50,17 @@ def log_pre_api_call(self, model, messages, kwargs):
)

def log_stream_event(self, kwargs, response_obj, start_time, end_time):
self._trace.update_span_status(span=self._trace.active_span, value="OK")
self._trace.end_span(
ag.tracing.set_status(status="OK")
ag.tracing.end_span(
outputs={
"message": kwargs.get(
"complete_streaming_response"
), # the complete streamed response (only set if `completion(..stream=True)`)
"usage": response_obj.usage.dict()
if hasattr(response_obj, "usage")
else None, # litellm calculates usage
"usage": (
response_obj.usage.dict()
if hasattr(response_obj, "usage")
else None
), # litellm calculates usage
"cost": kwargs.get(
"response_cost"
), # litellm calculates response cost
Expand All @@ -67,13 +70,15 @@ def log_stream_event(self, kwargs, response_obj, start_time, end_time):
def log_success_event(
self, kwargs, response_obj: ModelResponse, start_time, end_time
):
self._trace.update_span_status(span=self._trace.active_span, value="OK")
self._trace.end_span(
ag.tracing.set_status(status="OK")
ag.tracing.end_span(
outputs={
"message": response_obj.choices[0].message.content,
"usage": response_obj.usage.dict()
if hasattr(response_obj, "usage")
else None, # litellm calculates usage
"usage": (
response_obj.usage.dict()
if hasattr(response_obj, "usage")
else None
), # litellm calculates usage
"cost": kwargs.get(
"response_cost"
), # litellm calculates response cost
Expand All @@ -83,23 +88,25 @@ def log_success_event(
def log_failure_event(
self, kwargs, response_obj: ModelResponse, start_time, end_time
):
self._trace.update_span_status(span=self._trace.active_span, value="ERROR")
self._trace.set_span_attribute(
ag.tracing.set_status(status="ERROR")
ag.tracing.set_attributes(
{
"traceback_exception": kwargs[
"traceback_exception"
], # the traceback generated via `traceback.format_exc()`
"traceback_exception": repr(
kwargs["traceback_exception"]
), # the traceback generated via `traceback.format_exc()`
"call_end_time": kwargs[
"end_time"
], # datetime object of when call was completed
},
)
self._trace.end_span(
ag.tracing.end_span(
outputs={
"message": kwargs["exception"], # the Exception raised
"usage": response_obj.usage.dict()
if hasattr(response_obj, "usage")
else None, # litellm calculates usage
"usage": (
response_obj.usage.dict()
if hasattr(response_obj, "usage")
else None
), # litellm calculates usage
"cost": kwargs.get(
"response_cost"
), # litellm calculates response cost
Expand All @@ -109,15 +116,17 @@ def log_failure_event(
async def async_log_stream_event(
self, kwargs, response_obj, start_time, end_time
):
self._trace.update_span_status(span=self._trace.active_span, value="OK")
self._trace.end_span(
ag.tracing.set_status(status="OK")
ag.tracing.end_span(
outputs={
"message": kwargs.get(
"complete_streaming_response"
), # the complete streamed response (only set if `completion(..stream=True)`)
"usage": response_obj.usage.dict()
if hasattr(response_obj, "usage")
else None, # litellm calculates usage
"usage": (
response_obj.usage.dict()
if hasattr(response_obj, "usage")
else None
), # litellm calculates usage
"cost": kwargs.get(
"response_cost"
), # litellm calculates response cost
Expand All @@ -127,13 +136,15 @@ async def async_log_stream_event(
async def async_log_success_event(
self, kwargs, response_obj, start_time, end_time
):
self._trace.update_span_status(span=self._trace.active_span, value="OK")
self._trace.end_span(
ag.tracing.set_status(status="OK")
ag.tracing.end_span(
outputs={
"message": response_obj.choices[0].message.content,
"usage": response_obj.usage.dict()
if hasattr(response_obj, "usage")
else None, # litellm calculates usage
"usage": (
response_obj.usage.dict()
if hasattr(response_obj, "usage")
else None
), # litellm calculates usage
"cost": kwargs.get(
"response_cost"
), # litellm calculates response cost
Expand All @@ -143,8 +154,8 @@ async def async_log_success_event(
async def async_log_failure_event(
self, kwargs, response_obj, start_time, end_time
):
self._trace.update_span_status(span=self._trace.active_span, value="ERROR")
self._trace.set_span_attribute(
ag.tracing.set_status(status="ERROR")
ag.tracing.set_attributes(
{
"traceback_exception": kwargs[
"traceback_exception"
Expand All @@ -154,12 +165,14 @@ async def async_log_failure_event(
], # datetime object of when call was completed
},
)
self._trace.end_span(
ag.tracing.end_span(
outputs={
"message": kwargs["exception"], # the Exception raised
"usage": response_obj.usage.dict()
if hasattr(response_obj, "usage")
else None, # litellm calculates usage
"message": repr(kwargs["exception"]), # the Exception raised
"usage": (
response_obj.usage.dict()
if hasattr(response_obj, "usage")
else None
), # litellm calculates usage
"cost": kwargs.get(
"response_cost"
), # litellm calculates response cost
Expand Down
3 changes: 2 additions & 1 deletion agenta-cli/agenta/sdk/tracing/llm_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,8 @@ def _process_closed_spans(self) -> None:
"trace",
# mock_create_traces(
self.client.create_traces(
trace=tracing.trace_id, spans=tracing.closed_spans # type: ignore
trace=tracing.trace_id,
spans=tracing.closed_spans, # type: ignore
),
self.client,
)
Expand Down
2 changes: 1 addition & 1 deletion agenta-cli/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "agenta"
version = "0.19.7"
version = "0.19.8"
description = "The SDK for agenta is an open-source LLMOps platform."
readme = "README.md"
authors = ["Mahmoud Mabrouk <[email protected]>"]
Expand Down

0 comments on commit 3e2c090

Please sign in to comment.