Skip to content

Commit

Permalink
refactor (backend): fix parameters required for rag evaluators & late…
Browse files Browse the repository at this point in the history
…ncy / cost computation
  • Loading branch information
aybruhm committed Nov 26, 2024
1 parent e08da07 commit 3e93e11
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 12 deletions.
10 changes: 10 additions & 0 deletions agenta-backend/agenta_backend/models/api/evaluation_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from pydantic import BaseModel, Field, model_validator

from agenta_backend.utils import traces
from agenta_backend.models.api.api_models import Result


Expand Down Expand Up @@ -98,6 +99,15 @@ class EvaluatorMappingInputInterface(BaseModel):
inputs: Dict[str, Any]
mapping: Dict[str, Any]

@model_validator(mode="before")
def remove_trace_prefix(cls, values: Dict) -> Dict:
mapping = values.get("mapping", {})
updated_mapping = traces.remove_trace_prefix(mapping_dict=mapping)

# Set the modified mapping back to the values
values["mapping"] = updated_mapping
return values


class EvaluatorMappingOutputInterface(BaseModel):
outputs: Dict[str, Any]
Expand Down
15 changes: 9 additions & 6 deletions agenta-backend/agenta_backend/services/evaluators_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
EvaluatorMappingOutputInterface,
)
from agenta_backend.utils.traces import (
remove_trace_prefix,
process_distributed_trace_into_trace_tree,
get_field_value_from_trace_tree,
)
Expand Down Expand Up @@ -934,9 +935,10 @@ async def rag_faithfulness(
)

# Get required keys for rag evaluator
question_key: Union[str, None] = settings_values.get("question_key", None)
answer_key: Union[str, None] = settings_values.get("answer_key", None)
contexts_key: Union[str, None] = settings_values.get("contexts_key", None)
mapping_keys = remove_trace_prefix(settings_values=settings_values)
question_key: Union[str, None] = mapping_keys.get("question_key", None)
answer_key: Union[str, None] = mapping_keys.get("answer_key", None)
contexts_key: Union[str, None] = mapping_keys.get("contexts_key", None)

if None in [question_key, answer_key, contexts_key]:
logging.error(
Expand Down Expand Up @@ -1046,9 +1048,10 @@ async def rag_context_relevancy(
)

# Get required keys for rag evaluator
question_key: Union[str, None] = settings_values.get("question_key", None)
answer_key: Union[str, None] = settings_values.get("answer_key", None)
contexts_key: Union[str, None] = settings_values.get("contexts_key", None)
mapping_keys = remove_trace_prefix(settings_values=settings_values)
question_key: Union[str, None] = mapping_keys.get("question_key", None)
answer_key: Union[str, None] = mapping_keys.get("answer_key", None)
contexts_key: Union[str, None] = mapping_keys.get("contexts_key", None)

if None in [question_key, answer_key, contexts_key]:
logging.error(
Expand Down
11 changes: 5 additions & 6 deletions agenta-backend/agenta_backend/services/llm_apps_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,10 @@ def extract_result_from_response(response: dict):
value["data"] = str(value.get("data"))

if "tree" in response:
trace_tree = (
response["tree"][0]
if isinstance(response.get("tree"), list)
else {}
)
trace_tree = response.get("tree", {}).get("nodes", [])[0]

latency = (
get_nested_value(trace_tree, ["time", "span"]) * 1_000_000
get_nested_value(trace_tree, ["time", "span"]) / 1_000_000
if trace_tree
else None
)
Expand Down Expand Up @@ -108,6 +105,8 @@ def extract_result_from_response(response: dict):
value = {"error": f"Unexpected error: {e}"}
kind = "error"

print("Cost: ", cost)
print("Latency: ", latency)
return value, kind, cost, latency


Expand Down

0 comments on commit 3e93e11

Please sign in to comment.