From 36d311f62223e8e4b4091a0baed3080f0ff9da53 Mon Sep 17 00:00:00 2001 From: Abram Date: Tue, 26 Nov 2024 15:24:16 +0100 Subject: [PATCH] feat (backend): create remove_trace_prefix utility function --- agenta-backend/agenta_backend/utils/traces.py | 47 ++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/agenta-backend/agenta_backend/utils/traces.py b/agenta-backend/agenta_backend/utils/traces.py index 897d313375..807466550b 100644 --- a/agenta-backend/agenta_backend/utils/traces.py +++ b/agenta-backend/agenta_backend/utils/traces.py @@ -1,14 +1,59 @@ import logging import traceback from copy import deepcopy -from typing import Any, Dict, Union from collections import OrderedDict +from typing import Any, Dict, Union, Optional logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) +def remove_trace_prefix( + mapping_dict: Optional[Dict] = None, settings_values: Optional[Dict] = None +) -> Dict: + """ + Modify the values of the mapping dictionary to remove 'trace.' prefix if it exists. + + Args: + mapping_dict (Optional[Dict]): A dictionary containing the mapping values. + settings_values (Optional[Dict]): A dictionary with keys like "answer_key", + "contexts_key", "question_key" to override + specific mapping values. + + Returns: + Dict: A dictionary with the 'trace.' prefix removed from any string values. + + Raises: + ValueError: If neither `mapping_dict` nor `settings_values` is provided. + + """ + + if mapping_dict is None and settings_values is None: + raise ValueError("No mapping dictionary or settings values provided") + + # Determine which dictionary to use + if settings_values: + mapping_values = { + "answer_key": settings_values.get("answer_key"), + "contexts_key": settings_values.get("contexts_key"), + "question_key": settings_values.get("question_key"), + } + elif mapping_dict: + mapping_values = mapping_dict + else: + mapping_values = {} + + # Update the mapping by removing the 'trace.' prefix from string values + updated_mapping_dict = { + key: value.replace("trace.", "") if isinstance(value, str) else value + for key, value in mapping_values.items() + if value is not None + } + + return updated_mapping_dict + + def _make_spans_id_tree(trace): """ Creates spans tree (id only) from flat spans list