diff --git a/agenta-cli/agenta/sdk/tracing/llm_tracing.py b/agenta-cli/agenta/sdk/tracing/llm_tracing.py index e087d252d1..195fd9b67e 100644 --- a/agenta-cli/agenta/sdk/tracing/llm_tracing.py +++ b/agenta-cli/agenta/sdk/tracing/llm_tracing.py @@ -73,10 +73,10 @@ def __init__( max_workers if max_workers else 4, logger=llm_logger ) self.active_span: Optional[CreateSpan] = None - self.active_trace_id: Union[str, None] = None + self.active_trace_id: Optional[str] = None self.pending_spans: List[CreateSpan] = [] self.tags: List[str] = [] - self.trace_config_cache: Dict[str, Any] = {} # used to save the trace configuration before starting the first span + self.trace_config_cache: Dict[str, Any] = {} # used to save the trace configuration before starting the first span self.span_dict: Dict[str, CreateSpan] = {} # type: ignore @property @@ -95,15 +95,13 @@ def set_span_attribute( self, attributes: Dict[str, Any] = {}, ): - assert not(self.active_span is None and self.parent_span_id is not None), "The parent_span_id is set, yet there is no active span!" - if self.parent_span_id is None and self.active_span is None: # This is the case where entrypoint wants to save the trace information but the parent span has not been initialized yet + if self.active_span is None: # This is the case where entrypoint wants to save the trace information but the parent span has not been initialized yet for key, value in attributes.items(): self.trace_config_cache[key] = value else: for key, value in attributes.items(): self.active_span.attributes[key] = value - def set_trace_tags(self, tags: List[str]): self.tags.extend(tags) @@ -140,10 +138,9 @@ def start_span( token_consumption=None, parent_span_id=None, ) - - if self.active_trace_id is None: # This is a parent span + + if self.active_trace_id is None: # This is a parent span self.active_trace_id = self._create_trace_id() - assert self.parent_span_id is None, "Creating a new trace, yet the parent_span_id is not None" span.environment = ( self.trace_config_cache.get("environment") if self.trace_config_cache is not None @@ -154,9 +151,8 @@ def start_span( if not config and self.trace_config_cache is not None else None ) - self.parent_span = span else: - self.parent_span_id = self.active_span.id + span.parent_span_id = self.active_span.id self.span_dict[span.id] = span self.active_span = span @@ -170,23 +166,22 @@ def end_span(self, outputs: Dict[str, Any]): """ Ends the active span, if it is a parent span, ends the trace too. """ - if self.active_span is not None: + if self.active_span is None: raise ValueError("There is no active span to end.") self.active_span.end_time = datetime.now(timezone.utc) - self.active_span.outputs = [outputs["message"]] + self.active_span.outputs = [outputs.get("message", "")] self.active_span.cost = outputs.get("cost", None) - self.active_span.tokens = outputs.get("usage") + self.active_span.tokens = outputs.get("usage", None) # Push span to list of recorded spans self.pending_spans.append(self.active_span) self.llm_logger.info( f"Pushed {self.active_span.spankind} span {self.active_span.id} to recorded spans." ) - if self.parent_span_id is None: + if self.active_span.parent_span_id is None: self.end_trace(parent_span=self.active_span) else: - self.active_span = self.span_dict[self.parent_span_id] - self.parent_span_id = self.active_span.parent_span_id + self.active_span = self.span_dict[self.active_span.parent_span_id] def end_trace(self, parent_span: CreateSpan): if self.api_key == "": @@ -206,13 +201,12 @@ def end_trace(self, parent_span: CreateSpan): self.client, ) self.llm_logger.info( - f"Tracing for {span.id} recorded successfully and sent for processing." + f"Tracing for {parent_span.id} recorded successfully and sent for processing." ) self._clear_pending_spans() self.active_trace_id = None - self.parent_span_id = None self.active_span = None - self.trace_config_cache = {} + self.trace_config_cache.clear() def _create_trace_id(self) -> str: """Creates a unique mongo id for the trace object.