From e11a6cdffd71cf822313332e0698d54118accf4c Mon Sep 17 00:00:00 2001 From: Juan Pablo Vega Date: Thu, 21 Nov 2024 21:45:29 +0100 Subject: [PATCH 1/2] fix streaming ? --- agenta-cli/agenta/sdk/litellm/litellm.py | 91 +++++-------------- agenta-cli/agenta/sdk/middleware/auth.py | 22 +++-- .../03_agenta_implementationa_streaming.py | 66 ++++++++++++++ 3 files changed, 103 insertions(+), 76 deletions(-) create mode 100644 agenta-cli/tests/observability_sdk/integrations/litellm/03_agenta_implementationa_streaming.py diff --git a/agenta-cli/agenta/sdk/litellm/litellm.py b/agenta-cli/agenta/sdk/litellm/litellm.py index c8aa43096c..ccdf81aa31 100644 --- a/agenta-cli/agenta/sdk/litellm/litellm.py +++ b/agenta-cli/agenta/sdk/litellm/litellm.py @@ -63,6 +63,10 @@ def log_pre_api_call( log.error("LiteLLM callback error: span not found.") return + if not self.span.is_recording(): + log.error("Agenta SDK - litellm span not recording.") + return + self.span.set_attributes( attributes={"inputs": {"prompt": kwargs["messages"]}}, namespace="data", @@ -89,40 +93,8 @@ def log_stream_event( log.error("LiteLLM callback error: span not found.") return - try: - result = [] - for choice in response_obj.choices: - message = choice.message.__dict__ - result.append(message) - - outputs = {"completion": result} - self.span.set_attributes( - attributes={"outputs": outputs}, - namespace="data", - ) - - except Exception as e: - pass - - self.span.set_attributes( - attributes={"total": kwargs.get("response_cost")}, - namespace="metrics.unit.costs", - ) - - self.span.set_attributes( - attributes=( - { - "prompt": response_obj.usage.prompt_tokens, - "completion": response_obj.usage.completion_tokens, - "total": response_obj.usage.total_tokens, - } - ), - namespace="metrics.unit.tokens", - ) - - self.span.set_status(status="OK") - - self.span.end() + if not self.span.is_recording(): + return def log_success_event( self, @@ -131,10 +103,16 @@ def log_success_event( start_time, end_time, ): + if kwargs.get("stream"): + return + if not self.span: log.error("LiteLLM callback error: span not found.") return + if not self.span.is_recording(): + return + try: result = [] for choice in response_obj.choices: @@ -181,6 +159,9 @@ def log_failure_event( log.error("LiteLLM callback error: span not found.") return + if not self.span.is_recording(): + return + self.span.record_exception(kwargs["exception"]) self.span.set_status(status="ERROR") @@ -198,40 +179,8 @@ async def async_log_stream_event( log.error("LiteLLM callback error: span not found.") return - try: - result = [] - for choice in response_obj.choices: - message = choice.message.__dict__ - result.append(message) - - outputs = {"completion": result} - self.span.set_attributes( - attributes={"outputs": outputs}, - namespace="data", - ) - - except Exception as e: - pass - - self.span.set_attributes( - attributes={"total": kwargs.get("response_cost")}, - namespace="metrics.unit.costs", - ) - - self.span.set_attributes( - attributes=( - { - "prompt": response_obj.usage.prompt_tokens, - "completion": response_obj.usage.completion_tokens, - "total": response_obj.usage.total_tokens, - } - ), - namespace="metrics.unit.tokens", - ) - - self.span.set_status(status="OK") - - self.span.end() + if not self.span.is_recording(): + return async def async_log_success_event( self, @@ -244,6 +193,9 @@ async def async_log_success_event( log.error("LiteLLM callback error: span not found.") return + if not self.span.is_recording(): + return + try: result = [] for choice in response_obj.choices: @@ -290,6 +242,9 @@ async def async_log_failure_event( log.error("LiteLLM callback error: span not found.") return + if not self.span.is_recording(): + return + self.span.record_exception(kwargs["exception"]) self.span.set_status(status="ERROR") diff --git a/agenta-cli/agenta/sdk/middleware/auth.py b/agenta-cli/agenta/sdk/middleware/auth.py index 7a017ef220..0bda02fe30 100644 --- a/agenta-cli/agenta/sdk/middleware/auth.py +++ b/agenta-cli/agenta/sdk/middleware/auth.py @@ -21,6 +21,12 @@ 15 * 60, # 15 minutes ) +AGENTA_SDK_AUTH_CACHE = str(environ.get("AGENTA_SDK_AUTH_CACHE", True)).lower() in ( + "true", + "1", + "t", +) + AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED = str( environ.get("AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED", False) ).lower() in ("true", "1", "t") @@ -89,9 +95,11 @@ async def dispatch( sort_keys=True, ) - cached_policy = cache.get(_hash) + policy = None + if AGENTA_SDK_AUTH_CACHE: + policy = cache.get(_hash) - if not cached_policy: + if not policy: async with httpx.AsyncClient() as client: response = await client.get( f"{self.host}/api/permissions/verify", @@ -110,19 +118,17 @@ async def dispatch( cache.put(_hash, {"effect": "deny"}) return Deny() - cached_policy = { + policy = { "effect": "allow", "credentials": auth.get("credentials"), } - cache.put(_hash, cached_policy) + cache.put(_hash, policy) - if cached_policy.get("effect") == "deny": + if policy.get("effect") == "deny": return Deny() - request.state.credentials = cached_policy.get("credentials") - - print(f"credentials: {request.state.credentials}") + request.state.credentials = policy.get("credentials") return await call_next(request) diff --git a/agenta-cli/tests/observability_sdk/integrations/litellm/03_agenta_implementationa_streaming.py b/agenta-cli/tests/observability_sdk/integrations/litellm/03_agenta_implementationa_streaming.py new file mode 100644 index 0000000000..4b994d670a --- /dev/null +++ b/agenta-cli/tests/observability_sdk/integrations/litellm/03_agenta_implementationa_streaming.py @@ -0,0 +1,66 @@ +import litellm +import agenta as ag + + +ag.init() + + +@ag.instrument() +async def agenerate_completion(): + litellm.callbacks = [ag.callbacks.litellm_handler()] + + messages = [ + { + "role": "user", + "content": "Hey, how's it going? please respond in a in German.", + } + ] + + response = await litellm.acompletion( + model="gpt-4-turbo", + messages=messages, + stream=True, + ) + + chat_completion = "" + async for part in response: + print(part.choices[0].delta.content) + chat_completion += part.choices[0].delta.content or "" + + print(chat_completion) + + return chat_completion + + +@ag.instrument() +def generate_completion(): + litellm.callbacks = [ag.callbacks.litellm_handler()] + + messages = [ + { + "role": "user", + "content": "Hey, how's it going? please respond in a in Spanish.", + } + ] + + response = litellm.completion( + model="gpt-4-turbo", + messages=messages, + stream=True, + ) + + chat_completion = "" + for part in response: + print(part.choices[0].delta.content) + chat_completion += part.choices[0].delta.content or "" + + print(chat_completion) + + return chat_completion + + +if __name__ == "__main__": + import asyncio + + asyncio.run(agenerate_completion()) + # generate_completion() From d06cc9db10d2855005543ae878fa9be5a0211035 Mon Sep 17 00:00:00 2001 From: Juan Pablo Vega Date: Thu, 21 Nov 2024 21:47:48 +0100 Subject: [PATCH 2/2] revert undue changes --- agenta-cli/agenta/sdk/middleware/auth.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/agenta-cli/agenta/sdk/middleware/auth.py b/agenta-cli/agenta/sdk/middleware/auth.py index 0bda02fe30..7a017ef220 100644 --- a/agenta-cli/agenta/sdk/middleware/auth.py +++ b/agenta-cli/agenta/sdk/middleware/auth.py @@ -21,12 +21,6 @@ 15 * 60, # 15 minutes ) -AGENTA_SDK_AUTH_CACHE = str(environ.get("AGENTA_SDK_AUTH_CACHE", True)).lower() in ( - "true", - "1", - "t", -) - AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED = str( environ.get("AGENTA_UNAUTHORIZED_EXECUTION_ALLOWED", False) ).lower() in ("true", "1", "t") @@ -95,11 +89,9 @@ async def dispatch( sort_keys=True, ) - policy = None - if AGENTA_SDK_AUTH_CACHE: - policy = cache.get(_hash) + cached_policy = cache.get(_hash) - if not policy: + if not cached_policy: async with httpx.AsyncClient() as client: response = await client.get( f"{self.host}/api/permissions/verify", @@ -118,17 +110,19 @@ async def dispatch( cache.put(_hash, {"effect": "deny"}) return Deny() - policy = { + cached_policy = { "effect": "allow", "credentials": auth.get("credentials"), } - cache.put(_hash, policy) + cache.put(_hash, cached_policy) - if policy.get("effect") == "deny": + if cached_policy.get("effect") == "deny": return Deny() - request.state.credentials = policy.get("credentials") + request.state.credentials = cached_policy.get("credentials") + + print(f"credentials: {request.state.credentials}") return await call_next(request)