Skip to content

Commit

Permalink
fix: ensure retrieval documents are unnested when converting from spa…
Browse files Browse the repository at this point in the history
…n to dataset example (#5523)
  • Loading branch information
axiomofjoy authored Nov 26, 2024
1 parent b0ef597 commit 2f6e22e
Show file tree
Hide file tree
Showing 5 changed files with 373 additions and 382 deletions.
117 changes: 77 additions & 40 deletions src/phoenix/server/api/helpers/dataset_helpers.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,65 @@
import json
from collections.abc import Mapping
from typing import Any, Literal, Optional, Protocol
from typing import Any, Literal, Optional

from openinference.semconv.trace import (
MessageAttributes,
OpenInferenceMimeTypeValues,
OpenInferenceSpanKindValues,
SpanAttributes,
ToolCallAttributes,
)

from phoenix.db.models import Span
from phoenix.trace.attributes import get_attribute_value


class HasSpanIO(Protocol):
"""
An interface that contains the information needed to extract dataset example
input and output values from a span.
"""

span_kind: Optional[str]
input_value: Any
input_mime_type: Optional[str]
output_value: Any
output_mime_type: Optional[str]
llm_prompt_template_variables: Any
llm_input_messages: Any
llm_output_messages: Any
retrieval_documents: Any


def get_dataset_example_input(span: HasSpanIO) -> dict[str, Any]:
def get_dataset_example_input(span: Span) -> dict[str, Any]:
"""
Extracts the input value from a span and returns it as a dictionary. Input
values from LLM spans are extracted from the input messages and prompt
template variables (if present). For other span kinds, the input is
extracted from the input value and input mime type attributes.
"""
input_value = span.input_value
input_mime_type = span.input_mime_type
if span.span_kind == OpenInferenceSpanKindValues.LLM.value:
span_kind = span.span_kind
attributes = span.attributes
input_value = get_attribute_value(attributes, INPUT_VALUE)
input_mime_type = get_attribute_value(attributes, INPUT_MIME_TYPE)
prompt_template_variables = get_attribute_value(attributes, LLM_PROMPT_TEMPLATE_VARIABLES)
input_messages = get_attribute_value(attributes, LLM_INPUT_MESSAGES)
if span_kind == LLM:
return _get_llm_span_input(
input_messages=span.llm_input_messages,
input_messages=input_messages,
input_value=input_value,
input_mime_type=input_mime_type,
prompt_template_variables=span.llm_prompt_template_variables,
prompt_template_variables=prompt_template_variables,
)
return _get_generic_io_value(io_value=input_value, mime_type=input_mime_type, kind="input")


def get_dataset_example_output(span: HasSpanIO) -> dict[str, Any]:
def get_dataset_example_output(span: Span) -> dict[str, Any]:
"""
Extracts the output value from a span and returns it as a dictionary. Output
values from LLM spans are extracted from the output messages (if present).
Output from retriever spans are extracted from the retrieval documents (if
present). For other span kinds, the output is extracted from the output
value and output mime type attributes.
"""

output_value = span.output_value
output_mime_type = span.output_mime_type
if (span_kind := span.span_kind) == OpenInferenceSpanKindValues.LLM.value:
span_kind = span.span_kind
attributes = span.attributes
output_value = get_attribute_value(attributes, OUTPUT_VALUE)
output_mime_type = get_attribute_value(attributes, OUTPUT_MIME_TYPE)
output_messages = get_attribute_value(attributes, LLM_OUTPUT_MESSAGES)
retrieval_documents = get_attribute_value(attributes, RETRIEVAL_DOCUMENTS)
if span_kind == LLM:
return _get_llm_span_output(
output_messages=span.llm_output_messages,
output_messages=output_messages,
output_value=output_value,
output_mime_type=output_mime_type,
)
if span_kind == OpenInferenceSpanKindValues.RETRIEVER.value:
return _get_retriever_span_output(
retrieval_documents=span.retrieval_documents,
retrieval_documents=retrieval_documents,
output_value=output_value,
output_mime_type=output_mime_type,
)
Expand All @@ -90,8 +82,8 @@ def _get_llm_span_input(
input["messages"] = messages
if not input:
input = _get_generic_io_value(io_value=input_value, mime_type=input_mime_type, kind="input")
if prompt_template_variables:
input = {**input, "prompt_template_variables": prompt_template_variables}
if prompt_template_variables_data := _safely_json_decode(prompt_template_variables):
input["prompt_template_variables"] = prompt_template_variables_data
return input


Expand All @@ -118,7 +110,7 @@ def _get_retriever_span_output(
Extracts the output value from a retriever span and returns it as a dictionary.
The output is extracted from the retrieval documents (if present).
"""
if retrieval_documents is not None:
if (retrieval_documents := _parse_retrieval_documents(retrieval_documents)) is not None:
return {"documents": retrieval_documents}
return _get_generic_io_value(io_value=output_value, mime_type=output_mime_type, kind="output")

Expand All @@ -130,12 +122,14 @@ def _get_generic_io_value(
Makes a best-effort attempt to extract the input or output value from a span
and returns it as a dictionary.
"""
if mime_type == OpenInferenceMimeTypeValues.JSON.value:
parsed_value = json.loads(io_value)
if isinstance(parsed_value, dict):
return parsed_value
if (
mime_type == OpenInferenceMimeTypeValues.JSON.value
and (io_value_data := _safely_json_decode(io_value)) is not None
):
if isinstance(io_value_data, dict):
return io_value_data
else:
return {kind: parsed_value}
return {kind: io_value_data}
if isinstance(io_value, str):
return {kind: io_value}
return {}
Expand Down Expand Up @@ -169,12 +163,55 @@ def _get_message(message: Mapping[str, Any]) -> dict[str, Any]:
}


def _parse_retrieval_documents(retrieval_documents: Any) -> Optional[list[dict[str, Any]]]:
"""
Safely un-nests a list of retrieval documents.
Example: [{"document": {"content": "..."}}] -> [{"content": "..."}]
"""
if not isinstance(retrieval_documents, list):
return None
docs = []
for retrieval_doc in retrieval_documents:
if not isinstance(retrieval_doc, dict) or not (doc := retrieval_doc.get("document")):
return None
docs.append(doc)
return docs


def _safely_json_decode(value: Any) -> Any:
"""
Safely decodes a JSON-encoded value.
"""
if not isinstance(value, str):
return None
try:
return json.loads(value)
except json.JSONDecodeError:
return None


# MessageAttributes
MESSAGE_CONTENT = MessageAttributes.MESSAGE_CONTENT
MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON = MessageAttributes.MESSAGE_FUNCTION_CALL_ARGUMENTS_JSON
MESSAGE_FUNCTION_CALL_NAME = MessageAttributes.MESSAGE_FUNCTION_CALL_NAME
MESSAGE_NAME = MessageAttributes.MESSAGE_NAME
MESSAGE_ROLE = MessageAttributes.MESSAGE_ROLE
MESSAGE_TOOL_CALLS = MessageAttributes.MESSAGE_TOOL_CALLS

TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
# OpenInferenceSpanKindValues
LLM = OpenInferenceSpanKindValues.LLM.value

# SpanAttributes
INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
INPUT_VALUE = SpanAttributes.INPUT_VALUE
LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
LLM_PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
RETRIEVAL_DOCUMENTS = SpanAttributes.RETRIEVAL_DOCUMENTS

# ToolCallAttributes
TOOL_CALL_FUNCTION_ARGUMENTS_JSON = ToolCallAttributes.TOOL_CALL_FUNCTION_ARGUMENTS_JSON
TOOL_CALL_FUNCTION_NAME = ToolCallAttributes.TOOL_CALL_FUNCTION_NAME
40 changes: 10 additions & 30 deletions src/phoenix/server/api/mutations/dataset_mutations.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,42 +136,18 @@ async def add_spans_to_dataset(
.returning(models.DatasetVersion.id)
)
spans = (
await session.execute(
select(
models.Span.id,
models.Span.span_kind,
models.Span.attributes,
_span_attribute(INPUT_MIME_TYPE),
_span_attribute(INPUT_VALUE),
_span_attribute(OUTPUT_MIME_TYPE),
_span_attribute(OUTPUT_VALUE),
_span_attribute(LLM_PROMPT_TEMPLATE_VARIABLES),
_span_attribute(LLM_INPUT_MESSAGES),
_span_attribute(LLM_OUTPUT_MESSAGES),
_span_attribute(RETRIEVAL_DOCUMENTS),
)
.select_from(models.Span)
.where(models.Span.id.in_(span_rowids))
)
await session.scalars(select(models.Span).where(models.Span.id.in_(span_rowids)))
).all()
if missing_span_rowids := span_rowids - {span.id for span in spans}:
raise ValueError(
f"Could not find spans with rowids: {', '.join(map(str, missing_span_rowids))}"
) # todo: implement error handling types https://github.com/Arize-ai/phoenix/issues/3221

span_annotations = (
await session.execute(
select(
models.SpanAnnotation.span_rowid,
models.SpanAnnotation.name,
models.SpanAnnotation.label,
models.SpanAnnotation.score,
models.SpanAnnotation.explanation,
models.SpanAnnotation.metadata_,
models.SpanAnnotation.annotator_kind,
await session.scalars(
select(models.SpanAnnotation).where(
models.SpanAnnotation.span_rowid.in_(span_rowids)
)
.select_from(models.SpanAnnotation)
.where(models.SpanAnnotation.span_rowid.in_(span_rowids))
)
).all()

Expand Down Expand Up @@ -214,8 +190,12 @@ async def add_spans_to_dataset(
DatasetExampleRevision.input.key: get_dataset_example_input(span),
DatasetExampleRevision.output.key: get_dataset_example_output(span),
DatasetExampleRevision.metadata_.key: {
**span.attributes,
"annotations": span_annotations_by_span[span.id],
"span_kind": span.span_kind,
**(
{"annotations": annotations}
if (annotations := span_annotations_by_span[span.id])
else {}
),
},
DatasetExampleRevision.revision_kind.key: "CREATE",
}
Expand Down
61 changes: 12 additions & 49 deletions src/phoenix/server/api/types/Span.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
import json
from collections.abc import Mapping, Sized
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import TYPE_CHECKING, Any, Optional, cast

import numpy as np
import strawberry
from openinference.semconv.trace import EmbeddingAttributes, SpanAttributes
from openinference.semconv.trace import SpanAttributes
from strawberry import ID, UNSET
from strawberry.relay import Node, NodeID
from strawberry.types import Info
Expand Down Expand Up @@ -39,18 +38,6 @@
if TYPE_CHECKING:
from phoenix.server.api.types.Project import Project

EMBEDDING_EMBEDDINGS = SpanAttributes.EMBEDDING_EMBEDDINGS
EMBEDDING_VECTOR = EmbeddingAttributes.EMBEDDING_VECTOR
INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
INPUT_VALUE = SpanAttributes.INPUT_VALUE
LLM_PROMPT_TEMPLATE_VARIABLES = SpanAttributes.LLM_PROMPT_TEMPLATE_VARIABLES
LLM_INPUT_MESSAGES = SpanAttributes.LLM_INPUT_MESSAGES
LLM_OUTPUT_MESSAGES = SpanAttributes.LLM_OUTPUT_MESSAGES
METADATA = SpanAttributes.METADATA
OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
RETRIEVAL_DOCUMENTS = SpanAttributes.RETRIEVAL_DOCUMENTS


@strawberry.enum
class SpanKind(Enum):
Expand Down Expand Up @@ -236,21 +223,7 @@ async def descendants(
description="The span's attributes translated into an example revision for a dataset",
) # type: ignore
async def as_example_revision(self, info: Info[Context, None]) -> SpanAsExampleRevision:
db_span = self.db_span
attributes = db_span.attributes
span_io = _SpanIO(
span_kind=db_span.span_kind,
input_value=get_attribute_value(attributes, INPUT_VALUE),
input_mime_type=get_attribute_value(attributes, INPUT_MIME_TYPE),
output_value=get_attribute_value(attributes, OUTPUT_VALUE),
output_mime_type=get_attribute_value(attributes, OUTPUT_MIME_TYPE),
llm_prompt_template_variables=get_attribute_value(
attributes, LLM_PROMPT_TEMPLATE_VARIABLES
),
llm_input_messages=get_attribute_value(attributes, LLM_INPUT_MESSAGES),
llm_output_messages=get_attribute_value(attributes, LLM_OUTPUT_MESSAGES),
retrieval_documents=get_attribute_value(attributes, RETRIEVAL_DOCUMENTS),
)
span = self.db_span

# Fetch annotations associated with this span
span_annotations = await self.span_annotations(info)
Expand All @@ -265,13 +238,13 @@ async def as_example_revision(self, info: Info[Context, None]) -> SpanAsExampleR
}
# Merge annotations into the metadata
metadata = {
**attributes,
"annotations": annotations,
"span_kind": span.span_kind,
**({"annotations": annotations} if annotations else {}),
}

return SpanAsExampleRevision(
input=get_dataset_example_input(span_io),
output=get_dataset_example_output(span_io),
input=get_dataset_example_input(span),
output=get_dataset_example_output(span),
metadata=metadata,
)

Expand Down Expand Up @@ -435,19 +408,9 @@ def _convert_metadata_to_string(metadata: Any) -> Optional[str]:
return str(metadata)


@dataclass
class _SpanIO:
"""
An class that contains the information needed to extract dataset example
input and output values from a span.
"""

span_kind: Optional[str]
input_value: Any
input_mime_type: Optional[str]
output_value: Any
output_mime_type: Optional[str]
llm_prompt_template_variables: Any
llm_input_messages: Any
llm_output_messages: Any
retrieval_documents: Any
INPUT_MIME_TYPE = SpanAttributes.INPUT_MIME_TYPE
INPUT_VALUE = SpanAttributes.INPUT_VALUE
METADATA = SpanAttributes.METADATA
OUTPUT_MIME_TYPE = SpanAttributes.OUTPUT_MIME_TYPE
OUTPUT_VALUE = SpanAttributes.OUTPUT_VALUE
RETRIEVAL_DOCUMENTS = SpanAttributes.RETRIEVAL_DOCUMENTS
Loading

0 comments on commit 2f6e22e

Please sign in to comment.