Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mmabrouk committed May 29, 2024
1 parent 5749f0f commit bf561e1
Showing 1 changed file with 13 additions and 19 deletions.
32 changes: 13 additions & 19 deletions agenta-cli/agenta/sdk/tracing/llm_tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ def __init__(
max_workers if max_workers else 4, logger=llm_logger
)
self.active_span: Optional[CreateSpan] = None
self.active_trace_id: Union[str, None] = None
self.active_trace_id: Optional[str] = None
self.pending_spans: List[CreateSpan] = []
self.tags: List[str] = []
self.trace_config_cache: Dict[str, Any] = {} # used to save the trace configuration before starting the first span
self.trace_config_cache: Dict[str, Any] = {} # used to save the trace configuration before starting the first span
self.span_dict: Dict[str, CreateSpan] = {} # type: ignore

@property
Expand All @@ -95,15 +95,13 @@ def set_span_attribute(
self,
attributes: Dict[str, Any] = {},
):
assert not(self.active_span is None and self.parent_span_id is not None), "The parent_span_id is set, yet there is no active span!"
if self.parent_span_id is None and self.active_span is None: # This is the case where entrypoint wants to save the trace information but the parent span has not been initialized yet
if self.active_span is None: # This is the case where entrypoint wants to save the trace information but the parent span has not been initialized yet
for key, value in attributes.items():
self.trace_config_cache[key] = value
else:
for key, value in attributes.items():
self.active_span.attributes[key] = value


def set_trace_tags(self, tags: List[str]):
self.tags.extend(tags)

Expand Down Expand Up @@ -140,10 +138,9 @@ def start_span(
token_consumption=None,
parent_span_id=None,
)
if self.active_trace_id is None: # This is a parent span

if self.active_trace_id is None: # This is a parent span
self.active_trace_id = self._create_trace_id()
assert self.parent_span_id is None, "Creating a new trace, yet the parent_span_id is not None"
span.environment = (
self.trace_config_cache.get("environment")
if self.trace_config_cache is not None
Expand All @@ -154,9 +151,8 @@ def start_span(
if not config and self.trace_config_cache is not None
else None
)
self.parent_span = span
else:
self.parent_span_id = self.active_span.id
span.parent_span_id = self.active_span.id
self.span_dict[span.id] = span
self.active_span = span

Expand All @@ -170,23 +166,22 @@ def end_span(self, outputs: Dict[str, Any]):
"""
Ends the active span, if it is a parent span, ends the trace too.
"""
if self.active_span is not None:
if self.active_span is None:
raise ValueError("There is no active span to end.")
self.active_span.end_time = datetime.now(timezone.utc)
self.active_span.outputs = [outputs["message"]]
self.active_span.outputs = [outputs.get("message", "")]
self.active_span.cost = outputs.get("cost", None)
self.active_span.tokens = outputs.get("usage")
self.active_span.tokens = outputs.get("usage", None)

# Push span to list of recorded spans
self.pending_spans.append(self.active_span)
self.llm_logger.info(
f"Pushed {self.active_span.spankind} span {self.active_span.id} to recorded spans."
)
if self.parent_span_id is None:
if self.active_span.parent_span_id is None:
self.end_trace(parent_span=self.active_span)
else:
self.active_span = self.span_dict[self.parent_span_id]
self.parent_span_id = self.active_span.parent_span_id
self.active_span = self.span_dict[self.active_span.parent_span_id]

def end_trace(self, parent_span: CreateSpan):
if self.api_key == "":
Expand All @@ -206,13 +201,12 @@ def end_trace(self, parent_span: CreateSpan):
self.client,
)
self.llm_logger.info(
f"Tracing for {span.id} recorded successfully and sent for processing."
f"Tracing for {parent_span.id} recorded successfully and sent for processing."
)
self._clear_pending_spans()
self.active_trace_id = None
self.parent_span_id = None
self.active_span = None
self.trace_config_cache = {}
self.trace_config_cache.clear()

def _create_trace_id(self) -> str:
"""Creates a unique mongo id for the trace object.
Expand Down

0 comments on commit bf561e1

Please sign in to comment.