Skip to content

Commit

Permalink
Merge pull request #2289 from Agenta-AI/fix/litellm-streaming-warnings
Browse files Browse the repository at this point in the history
[Fix] litellm streaming warnings
  • Loading branch information
mmabrouk authored Nov 22, 2024
2 parents 34ea66b + d06cc9d commit ad20ab9
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 68 deletions.
91 changes: 23 additions & 68 deletions agenta-cli/agenta/sdk/litellm/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ def log_pre_api_call(
log.error("LiteLLM callback error: span not found.")
return

if not self.span.is_recording():
log.error("Agenta SDK - litellm span not recording.")
return

self.span.set_attributes(
attributes={"inputs": {"prompt": kwargs["messages"]}},
namespace="data",
Expand All @@ -89,40 +93,8 @@ def log_stream_event(
log.error("LiteLLM callback error: span not found.")
return

try:
result = []
for choice in response_obj.choices:
message = choice.message.__dict__
result.append(message)

outputs = {"completion": result}
self.span.set_attributes(
attributes={"outputs": outputs},
namespace="data",
)

except Exception as e:
pass

self.span.set_attributes(
attributes={"total": kwargs.get("response_cost")},
namespace="metrics.unit.costs",
)

self.span.set_attributes(
attributes=(
{
"prompt": response_obj.usage.prompt_tokens,
"completion": response_obj.usage.completion_tokens,
"total": response_obj.usage.total_tokens,
}
),
namespace="metrics.unit.tokens",
)

self.span.set_status(status="OK")

self.span.end()
if not self.span.is_recording():
return

def log_success_event(
self,
Expand All @@ -131,10 +103,16 @@ def log_success_event(
start_time,
end_time,
):
if kwargs.get("stream"):
return

if not self.span:
log.error("LiteLLM callback error: span not found.")
return

if not self.span.is_recording():
return

try:
result = []
for choice in response_obj.choices:
Expand Down Expand Up @@ -181,6 +159,9 @@ def log_failure_event(
log.error("LiteLLM callback error: span not found.")
return

if not self.span.is_recording():
return

self.span.record_exception(kwargs["exception"])

self.span.set_status(status="ERROR")
Expand All @@ -198,40 +179,8 @@ async def async_log_stream_event(
log.error("LiteLLM callback error: span not found.")
return

try:
result = []
for choice in response_obj.choices:
message = choice.message.__dict__
result.append(message)

outputs = {"completion": result}
self.span.set_attributes(
attributes={"outputs": outputs},
namespace="data",
)

except Exception as e:
pass

self.span.set_attributes(
attributes={"total": kwargs.get("response_cost")},
namespace="metrics.unit.costs",
)

self.span.set_attributes(
attributes=(
{
"prompt": response_obj.usage.prompt_tokens,
"completion": response_obj.usage.completion_tokens,
"total": response_obj.usage.total_tokens,
}
),
namespace="metrics.unit.tokens",
)

self.span.set_status(status="OK")

self.span.end()
if not self.span.is_recording():
return

async def async_log_success_event(
self,
Expand All @@ -244,6 +193,9 @@ async def async_log_success_event(
log.error("LiteLLM callback error: span not found.")
return

if not self.span.is_recording():
return

try:
result = []
for choice in response_obj.choices:
Expand Down Expand Up @@ -290,6 +242,9 @@ async def async_log_failure_event(
log.error("LiteLLM callback error: span not found.")
return

if not self.span.is_recording():
return

self.span.record_exception(kwargs["exception"])

self.span.set_status(status="ERROR")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import litellm
import agenta as ag


ag.init()


@ag.instrument()
async def agenerate_completion():
litellm.callbacks = [ag.callbacks.litellm_handler()]

messages = [
{
"role": "user",
"content": "Hey, how's it going? please respond in a in German.",
}
]

response = await litellm.acompletion(
model="gpt-4-turbo",
messages=messages,
stream=True,
)

chat_completion = ""
async for part in response:
print(part.choices[0].delta.content)
chat_completion += part.choices[0].delta.content or ""

print(chat_completion)

return chat_completion


@ag.instrument()
def generate_completion():
litellm.callbacks = [ag.callbacks.litellm_handler()]

messages = [
{
"role": "user",
"content": "Hey, how's it going? please respond in a in Spanish.",
}
]

response = litellm.completion(
model="gpt-4-turbo",
messages=messages,
stream=True,
)

chat_completion = ""
for part in response:
print(part.choices[0].delta.content)
chat_completion += part.choices[0].delta.content or ""

print(chat_completion)

return chat_completion


if __name__ == "__main__":
import asyncio

asyncio.run(agenerate_completion())
# generate_completion()

0 comments on commit ad20ab9

Please sign in to comment.