From ccf3a10bcdbebaa9564d33665aa79af87f47ddf6 Mon Sep 17 00:00:00 2001 From: Abram Date: Fri, 12 Apr 2024 21:12:29 +0100 Subject: [PATCH] Minor refactor - include loggers in each tracing step and update entrypoint to make use of active_trace --- agenta-cli/agenta/sdk/agenta_decorator.py | 4 +- agenta-cli/agenta/sdk/tracing/llm_tracing.py | 40 +++++++++++++------- 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/agenta-cli/agenta/sdk/agenta_decorator.py b/agenta-cli/agenta/sdk/agenta_decorator.py index 75a78e5a95..57ded3c2da 100644 --- a/agenta-cli/agenta/sdk/agenta_decorator.py +++ b/agenta-cli/agenta/sdk/agenta_decorator.py @@ -96,7 +96,7 @@ async def wrapper(*args, **kwargs) -> Any: # End trace recording tracing.end_recording( outputs=llm_result.dict(), - span=tracing.active_span, + span=tracing.active_trace, environment="playground", # type: ignore #NOTE: wrapper is only called in playground ) return llm_result @@ -129,7 +129,7 @@ async def wrapper_deployed(*args, **kwargs) -> Any: # End trace recording tracing.end_recording( outputs=llm_result.dict(), - span=tracing.active_span, + span=tracing.active_trace, environment="playground", # type: ignore #NOTE: wrapper is only called in playground ) return llm_result diff --git a/agenta-cli/agenta/sdk/tracing/llm_tracing.py b/agenta-cli/agenta/sdk/tracing/llm_tracing.py index 2ff40cf1b1..7f64981f93 100644 --- a/agenta-cli/agenta/sdk/tracing/llm_tracing.py +++ b/agenta-cli/agenta/sdk/tracing/llm_tracing.py @@ -1,6 +1,6 @@ # Stdlib Imports from datetime import datetime, timezone -from typing import Optional, Dict, Any, List +from typing import Optional, Dict, Any, List, Union # Own Imports from agenta.sdk.tracing.logger import llm_logger @@ -48,7 +48,8 @@ def __init__( max_workers if max_workers else 4, logger=llm_logger ) self.active_span = CreateSpan - self.active_trace = None + self.active_trace = CreateSpan + self.recording_trace_id: Union[str, None] = None self.recorded_spans: List[CreateSpan] = [] self.tags: List[str] = [] self.span_dict: Dict[str, CreateSpan] = {} # type: ignore @@ -70,7 +71,7 @@ def set_span_attribute( ): span = self.span_dict[self.active_span.id] # type: ignore for key, value in attributes.items(): - self.set_attribute(span.attributes, key, value, parent_key) # type: ignore + self.set_attribute(span.attributes, key, value, parent_key) # type: ignore def set_attribute( self, @@ -107,9 +108,10 @@ def start_parent_span( status=SpanStatusCode.UNSET.value, start_time=datetime.now(timezone.utc), ) - self.active_span = span - self.active_trace = trace_id - self.parent_span_id = span_id + self.active_trace = span + self.recording_trace_id = trace_id + self.parent_span_id = span.id + self.llm_logger.info(f"Recorded active_trace and parent_span_id: {span.id}") def start_span( self, @@ -134,7 +136,8 @@ def start_span( self.active_span = span self.span_dict[span.id] = span - self.parent_span_id = span_id + self.parent_span_id = span.id + self.llm_logger.info(f"Recorded active_span and parent_span_id: {span.id}") return span def update_span_status(self, span: CreateSpan, value: str): @@ -153,17 +156,28 @@ def end_span(self, outputs: Dict[str, Any], span: CreateSpan, **kwargs): # Push span to list of recorded spans self.recorded_spans.append(updated_span) - self.llm_logger.info(f"Pushed {updated_span.spankind} span to recorded spans.") + self.llm_logger.info(f"Pushed {updated_span.spankind} span {updated_span.id} to recorded spans.") def end_recording(self, outputs: Dict[str, Any], span: CreateSpan, **kwargs): - self.end_span(outputs=outputs, span=span, **kwargs) - + updated_span = CreateSpan( + **span.dict(), + end_time=datetime.now(timezone.utc), + outputs=[outputs["message"]], + cost=outputs.get("cost", None), + environment=kwargs.get("environment"), + tokens=outputs.get("usage"), + ) + self.recorded_spans.append(updated_span) + self.llm_logger.info( + f"Pushed workflow span {updated_span.id} to recorded spans." + ) self.llm_logger.info(f"Preparing to send recorded spans for processing.") - self.llm_logger.info(f"Recorded spans: ", self.recorded_spans) self.tasks_manager.add_task( - self.active_trace, + self.active_trace.id, "trace", - self.client.create_traces(trace=self.active_trace, spans=self.recorded_spans), + self.client.create_traces( + trace=self.recording_trace_id, spans=self.recorded_spans + ), self.client, ) self.llm_logger.info(