Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug Fixes]: Release Cycle 63 #2302

Merged
merged 16 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions agenta-backend/agenta_backend/models/api/evaluation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pydantic import BaseModel, Field, model_validator

from agenta_backend.utils import traces
from agenta_backend.models.api.api_models import Result


Expand Down Expand Up @@ -98,6 +99,15 @@ class EvaluatorMappingInputInterface(BaseModel):
inputs: Dict[str, Any]
mapping: Dict[str, Any]

@model_validator(mode="before")
def remove_trace_prefix(cls, values: Dict) -> Dict:
mapping = values.get("mapping", {})
updated_mapping = traces.remove_trace_prefix(mapping_dict=mapping)

# Set the modified mapping back to the values
values["mapping"] = updated_mapping
return values


class EvaluatorMappingOutputInterface(BaseModel):
outputs: Dict[str, Any]
Expand Down
2 changes: 1 addition & 1 deletion agenta-backend/agenta_backend/models/shared_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class CorrectAnswer(BaseModel):
class EvaluationScenarioInput(BaseModel):
name: str
type: str
value: str
value: Any


class EvaluationScenarioOutput(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def sum_float_from_llm_app_response(
raise ValueError(f"No valid values found for {key} sum aggregation.")

total_value = sum(values)

return Result(type=key, value=total_value)
except Exception as exc:
return Result(
Expand Down
53 changes: 39 additions & 14 deletions agenta-backend/agenta_backend/services/evaluators_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
EvaluatorMappingOutputInterface,
)
from agenta_backend.utils.traces import (
remove_trace_prefix,
process_distributed_trace_into_trace_tree,
get_field_value_from_trace_tree,
)
Expand Down Expand Up @@ -934,9 +935,10 @@ async def rag_faithfulness(
)

# Get required keys for rag evaluator
question_key: Union[str, None] = settings_values.get("question_key", None)
answer_key: Union[str, None] = settings_values.get("answer_key", None)
contexts_key: Union[str, None] = settings_values.get("contexts_key", None)
mapping_keys = remove_trace_prefix(settings_values=settings_values)
question_key: Union[str, None] = mapping_keys.get("question_key", None)
answer_key: Union[str, None] = mapping_keys.get("answer_key", None)
contexts_key: Union[str, None] = mapping_keys.get("contexts_key", None)

if None in [question_key, answer_key, contexts_key]:
logging.error(
Expand All @@ -947,12 +949,23 @@ async def rag_faithfulness(
)

# Turn distributed trace into trace tree
trace = process_distributed_trace_into_trace_tree(output["trace"])
trace = {}
version = output.get("version")
if version == "3.0":
trace = output.get("tree", {})
elif version == "2.0":
trace = output.get("trace", {})

trace = process_distributed_trace_into_trace_tree(trace, version)

# Get value of required keys for rag evaluator
question_val: Any = get_field_value_from_trace_tree(trace, question_key)
answer_val: Any = get_field_value_from_trace_tree(trace, answer_key)
contexts_val: Any = get_field_value_from_trace_tree(trace, contexts_key)
question_val: Any = get_field_value_from_trace_tree(
trace, question_key, version
)
answer_val: Any = get_field_value_from_trace_tree(trace, answer_key, version)
contexts_val: Any = get_field_value_from_trace_tree(
trace, contexts_key, version
)

if None in [question_val, answer_val, contexts_val]:
logging.error(
Expand Down Expand Up @@ -1035,9 +1048,10 @@ async def rag_context_relevancy(
)

# Get required keys for rag evaluator
question_key: Union[str, None] = settings_values.get("question_key", None)
answer_key: Union[str, None] = settings_values.get("answer_key", None)
contexts_key: Union[str, None] = settings_values.get("contexts_key", None)
mapping_keys = remove_trace_prefix(settings_values=settings_values)
question_key: Union[str, None] = mapping_keys.get("question_key", None)
answer_key: Union[str, None] = mapping_keys.get("answer_key", None)
contexts_key: Union[str, None] = mapping_keys.get("contexts_key", None)

if None in [question_key, answer_key, contexts_key]:
logging.error(
Expand All @@ -1048,12 +1062,23 @@ async def rag_context_relevancy(
)

# Turn distributed trace into trace tree
trace = process_distributed_trace_into_trace_tree(output["trace"])
trace = {}
version = output.get("version")
if version == "3.0":
trace = output.get("tree", {})
elif version == "2.0":
trace = output.get("trace", {})

trace = process_distributed_trace_into_trace_tree(trace, version)

# Get value of required keys for rag evaluator
question_val: Any = get_field_value_from_trace_tree(trace, question_key)
answer_val: Any = get_field_value_from_trace_tree(trace, answer_key)
contexts_val: Any = get_field_value_from_trace_tree(trace, contexts_key)
question_val: Any = get_field_value_from_trace_tree(
trace, question_key, version
)
answer_val: Any = get_field_value_from_trace_tree(trace, answer_key, version)
contexts_val: Any = get_field_value_from_trace_tree(
trace, contexts_key, version
)

if None in [question_val, answer_val, contexts_val]:
logging.error(
Expand Down
9 changes: 3 additions & 6 deletions agenta-backend/agenta_backend/services/llm_apps_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,10 @@ def extract_result_from_response(response: dict):
value["data"] = str(value.get("data"))

if "tree" in response:
trace_tree = (
response["tree"][0]
if isinstance(response.get("tree"), list)
else {}
)
trace_tree = response.get("tree", {}).get("nodes", [])[0]

latency = (
get_nested_value(trace_tree, ["time", "span"]) * 1_000_000
get_nested_value(trace_tree, ["time", "span"]) / 1_000_000
aybruhm marked this conversation as resolved.
Show resolved Hide resolved
if trace_tree
else None
)
Expand Down
98 changes: 86 additions & 12 deletions agenta-backend/agenta_backend/utils/traces.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,59 @@
import logging
import traceback
from copy import deepcopy
from typing import Any, Dict
from collections import OrderedDict
from typing import Any, Dict, Union, Optional


logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)


def remove_trace_prefix(
mapping_dict: Optional[Dict] = None, settings_values: Optional[Dict] = None
) -> Dict:
"""
Modify the values of the mapping dictionary to remove 'trace.' prefix if it exists.

Args:
mapping_dict (Optional[Dict]): A dictionary containing the mapping values.
settings_values (Optional[Dict]): A dictionary with keys like "answer_key",
"contexts_key", "question_key" to override
specific mapping values.

Returns:
Dict: A dictionary with the 'trace.' prefix removed from any string values.

Raises:
ValueError: If neither `mapping_dict` nor `settings_values` is provided.

"""

if mapping_dict is None and settings_values is None:
raise ValueError("No mapping dictionary or settings values provided")

# Determine which dictionary to use
if settings_values:
mapping_values = {
"answer_key": settings_values.get("answer_key"),
"contexts_key": settings_values.get("contexts_key"),
"question_key": settings_values.get("question_key"),
}
elif mapping_dict:
mapping_values = mapping_dict
else:
mapping_values = {}

# Update the mapping by removing the 'trace.' prefix from string values
updated_mapping_dict = {
key: value.replace("trace.", "") if isinstance(value, str) else value
for key, value in mapping_values.items()
if value is not None
}

return updated_mapping_dict


def _make_spans_id_tree(trace):
"""
Creates spans tree (id only) from flat spans list
Expand Down Expand Up @@ -63,18 +108,31 @@ def _make_nested_nodes_tree(tree: dict):

ordered_tree = OrderedDict()

def add_node(node: dict, parent_tree: dict):
def add_node(node: Union[dict, list], parent_tree: dict):
"""
Recursively adds a node and its children to the parent tree.
"""
if isinstance(node, list):
# If node is a list, process each item as a child node
for child_node in node:
add_node(child_node, parent_tree)
return

# If the node is a dictionary, proceed with its normal structure
node_id = node["node"]["id"]
parent_tree[node_id] = OrderedDict()

# If there are child nodes, recursively add them
if "nodes" in node and node["nodes"] is not None:
for child_key, child_node in node["nodes"].items():
add_node(child_node, parent_tree[node_id])
if "nodes" in node and node["nodes"]:
child_nodes = node["nodes"]
if isinstance(child_nodes, list):
# If child nodes are a list, iterate over each one
for child_node in child_nodes:
add_node(child_node, parent_tree[node_id])
elif isinstance(child_nodes, dict):
# If child nodes are a dictionary, add them recursively
for child_key, child_node in child_nodes.items():
add_node(child_node, parent_tree[node_id])

# Process the top-level nodes
for node in tree["nodes"]:
Expand Down Expand Up @@ -116,10 +174,24 @@ def gather_nodes(nodes: list):
stack = nodes[:]
while stack:
current = stack.pop()
if isinstance(current, list):
# If current is a list, process each item as a child node
stack.extend(current) # Add each item of the list to the stack
continue # Skip the rest of the logic for this item since it's a list

node_id = current["node"]["id"]
result[node_id] = current
if "nodes" in current and current["nodes"] is not None:
stack.extend(current["nodes"].values())
# If there are child nodes, add them to the stack for further processing
child_nodes = current["nodes"]
if isinstance(child_nodes, list):
stack.extend(
child_nodes
) # If the child nodes are a list, add each to the stack
elif isinstance(child_nodes, dict):
stack.extend(
child_nodes.values()
) # If child nodes are a dict, add the values to the stack
return result

def extract_node_details(node_id: str, nodes: dict):
Expand All @@ -135,14 +207,9 @@ def extract_node_details(node_id: str, nodes: dict):
"node": node_data.get("node", {}),
"parent": node_data.get("parent", None),
"time": node_data.get("time", {}),
"status": node_data.get("status"),
"exception": node_data.get("exception"),
"data": node_data.get("data"),
"metrics": node_data.get("metrics"),
"meta": node_data.get("meta"),
"refs": node_data.get("refs"),
"links": node_data.get("links"),
"otel": node_data.get("otel"),
}

def recursive_flatten(current_nodes_id: dict, result: dict, nodes: dict):
Expand All @@ -156,12 +223,19 @@ def recursive_flatten(current_nodes_id: dict, result: dict, nodes: dict):

# Recursively process child nodes
if child_nodes:
recursive_flatten(child_nodes, result, nodes)
if isinstance(child_nodes, list):
for child_node in child_nodes:
recursive_flatten(
{child_node["node"]["id"]: child_node}, result, nodes
)
elif isinstance(child_nodes, dict):
recursive_flatten(child_nodes, result, nodes)

# Initialize the ordered dictionary and start the recursion
ordered_result = dict()
nodes = gather_nodes(nodes=tree_nodes)
recursive_flatten(current_nodes_id=nodes_id, result=ordered_result, nodes=nodes)

return list(ordered_result.values())


Expand Down
Loading