Skip to content

Commit

Permalink
feat: add ground truth columns to prediction output
Browse files Browse the repository at this point in the history
  • Loading branch information
khorshuheng committed Nov 27, 2023
1 parent 3340e2f commit 3731e01
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,41 @@ inference_schema:
# Example for binary classification
type: "BINARY_CLASSIFICATION"
binary_classification:
# The classification label STRING of this event.
prediction_label_column: "label"
# Optional: The likelihood of the event (Probability between 0 to 1.0).
# The likelihood of the event (Probability between 0 to 1.0).
prediction_score_column: "score"
# The actual event (either 0.0 or 1.0)
actual_score_column: "actual_score"
# (Optional) The classification label STRING of this event. Can be derived from prediction_score.
prediction_label_column: "label"
# (Optional) The actual classification label STRING of this event. Can be derived from actual_score.
actual_label_column: "actual_label"

# # Example for multiclass classification
# type: "MULTICLASS_CLASSIFICATION"
# multiclass_classification:
# # The classification label STRING of this event.
# prediction_label_column: "label"
# # Optional: The likelihood of the event (Probability between 0 to 1.0).
# prediction_score_column: "score"
# # The likelihood of the event (Probability between 0 to 1.0), for all possible classes.
# prediction_score_columns:
# - "score_class_1"
# - "score_class_2"
# # The actual event (either 0.0 or 1.0), for all possible classes. Order must correspond to that
# # specified on prediction_score_columns.
# actual_score_columns:
# - "actual_score_class_1"
# - "actual_score_class_2"
# # (Optional) The classification label STRING of the actual event.
# actual_label_column: "actual_label"
# # (Optional) The class names. Order must correspond to that specified on prediction_score_columns.
# prediction_label_columns:
# - "label_class_1"
# - "label_class_2"

# # Example for regression
# type: "REGRESSION"
# regression:
# # FLOAT64 value for the prediction value.
# prediction_score_column: "score"
# # FLOAT64 value for the actual value.
# actual_score_column: "actual_score"

# # Example for ranking
# type: "RANKING"
Expand All @@ -38,6 +55,10 @@ inference_schema:
# prediction_group_id_column: "prediction_group"
# # INT64 value for the rank of the prediction within the group.
# rank_column: "rank"
# # Ground truth representing the relevance of the prediction, in FLOAT64.
# relevance_score: "relevance_score"
# # (Optional) Ground truth label representing the relevance of the prediction, in STRING.
# relevance_label: "relevance_label"

# Column name to data types mapping for feature columns. Supported types are:
# - INT64
Expand Down
2 changes: 2 additions & 0 deletions python/observation-publisher/tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ def test_config_initialization():
binary_classification=BinaryClassificationOutput(
prediction_label_column="label",
prediction_score_column="score",
actual_label_column="actual_label",
actual_score_column="actual_score",
),
),
observability_backend=ObservabilityBackend(
Expand Down
30 changes: 27 additions & 3 deletions python/sdk/merlin/observability/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,41 +17,59 @@ class ValueType(Enum):
@dataclass
class RegressionOutput:
prediction_score_column: str
actual_score_column: str

@property
def column_types(self) -> Dict[str, ValueType]:
return {self.prediction_score_column: ValueType.FLOAT64}
return {
self.prediction_score_column: ValueType.FLOAT64,
self.actual_score_column: ValueType.FLOAT64
}


@dataclass_json
@dataclass
class BinaryClassificationOutput:
prediction_score_column: str
actual_score_column: str
prediction_label_column: Optional[str] = None
actual_label_column: Optional[str] = None

@property
def column_types(self) -> Dict[str, ValueType]:
column_types_mapping = {self.prediction_score_column: ValueType.FLOAT64}
column_types_mapping = {
self.prediction_score_column: ValueType.FLOAT64,
self.actual_score_column: ValueType.FLOAT64
}
if self.prediction_label_column is not None:
column_types_mapping[self.prediction_label_column] = ValueType.STRING
if self.actual_label_column is not None:
column_types_mapping[self.actual_label_column] = ValueType.STRING
return column_types_mapping


@dataclass_json
@dataclass
class MulticlassClassificationOutput:
prediction_score_columns: List[str]
actual_score_columns: List[str]
prediction_label_columns: Optional[List[str]] = None
actual_label_column: Optional[str] = None

@property
def column_types(self) -> Dict[str, ValueType]:
column_types_mapping = {
label_column: ValueType.FLOAT64
for label_column in self.prediction_score_columns
}
for column_name in self.actual_score_columns:
column_types_mapping[column_name] = ValueType.FLOAT64

if self.prediction_label_columns is not None:
for column_name in self.prediction_label_columns:
column_types_mapping[column_name] = ValueType.STRING
if self.actual_label_column is not None:
column_types_mapping[self.actual_label_column] = ValueType.STRING
return column_types_mapping


Expand All @@ -60,13 +78,19 @@ def column_types(self) -> Dict[str, ValueType]:
class RankingOutput:
rank_column: str
prediction_group_id_column: str
relevance_score_column: str
relevance_label_column: Optional[str] = None

@property
def column_types(self) -> Dict[str, ValueType]:
return {
column_types_mapping = {
self.rank_column: ValueType.INT64,
self.prediction_group_id_column: ValueType.STRING,
self.relevance_score_column: ValueType.FLOAT64,
}
if self.relevance_label_column is not None:
column_types_mapping[self.relevance_label_column] = ValueType.STRING
return column_types_mapping


@unique
Expand Down

0 comments on commit 3731e01

Please sign in to comment.