From 80cf9113effa827a9880b15e3312d7dfbb4305b6 Mon Sep 17 00:00:00 2001 From: Khor Shu Heng <32997938+khorshuheng@users.noreply.github.com> Date: Wed, 22 Nov 2023 11:30:34 +0800 Subject: [PATCH] feat: add data struct for inference schema (#494) **What this PR does / why we need it**: This PR provides the data types required to support ML observability within Merlin. The new data types are expected to be used for https://github.com/caraml-dev/merlin/pull/488 , flyte workflows, and also in the future when we store an optional inference schema within Merlin. dataclasses_json dependency is introduced here as it is required if we want to pass dataclass as an input to Flyte. **Which issue(s) this PR fixes**: Fixes # **Does this PR introduce a user-facing change?**: ```release-note NONE ``` **Checklist** - [ ] Added unit test, integration, and/or e2e tests - [ ] Tested locally - [ ] Updated documentation - [ ] Update Swagger spec if the PR introduce API changes - [ ] Regenerated Golang and Python client if the PR introduce API changes --- python/sdk/merlin/observability/__init__.py | 0 python/sdk/merlin/observability/inference.py | 111 +++++++++++++++++++ python/sdk/setup.py | 1 + 3 files changed, 112 insertions(+) create mode 100644 python/sdk/merlin/observability/__init__.py create mode 100644 python/sdk/merlin/observability/inference.py diff --git a/python/sdk/merlin/observability/__init__.py b/python/sdk/merlin/observability/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/sdk/merlin/observability/inference.py b/python/sdk/merlin/observability/inference.py new file mode 100644 index 000000000..63878ef83 --- /dev/null +++ b/python/sdk/merlin/observability/inference.py @@ -0,0 +1,111 @@ +from dataclasses import dataclass +from enum import unique, Enum +from typing import Dict, Optional, List + +from dataclasses_json import dataclass_json + + +@unique +class ValueType(Enum): + FLOAT64 = 1 + INT64 = 2 + BOOLEAN = 3 + STRING = 4 + + +@dataclass_json +@dataclass +class RegressionOutput: + prediction_score_column: str + + @property + def column_types(self) -> Dict[str, ValueType]: + return {self.prediction_score_column: ValueType.FLOAT64} + + +@dataclass_json +@dataclass +class BinaryClassificationOutput: + prediction_label_column: str + prediction_score_column: Optional[str] = None + + @property + def column_types(self) -> Dict[str, ValueType]: + column_types_mapping = {self.prediction_label_column: ValueType.STRING} + if self.prediction_score_column is not None: + column_types_mapping[self.prediction_score_column] = ValueType.FLOAT64 + return column_types_mapping + + +@dataclass_json +@dataclass +class MulticlassClassificationOutput: + prediction_label_columns: List[str] + prediction_score_columns: Optional[List[str]] = None + + @property + def column_types(self) -> Dict[str, ValueType]: + column_types_mapping = { + label_column: ValueType.STRING + for label_column in self.prediction_label_columns + } + if self.prediction_score_columns is not None: + for column_name in self.prediction_score_columns: + column_types_mapping[column_name] = ValueType.FLOAT64 + return column_types_mapping + + +@dataclass_json +@dataclass +class RankingOutput: + rank_column: str + prediction_group_id_column: str + + @property + def column_types(self) -> Dict[str, ValueType]: + return { + self.rank_column: ValueType.INT64, + self.prediction_group_id_column: ValueType.STRING, + } + + +@unique +class InferenceType(Enum): + BINARY_CLASSIFICATION = 1 + MULTICLASS_CLASSIFICATION = 2 + REGRESSION = 3 + RANKING = 4 + + +@dataclass_json +@dataclass +class InferenceSchema: + feature_types: Dict[str, ValueType] + type: InferenceType + binary_classification: Optional[BinaryClassificationOutput] = None + multiclass_classification: Optional[MulticlassClassificationOutput] = None + regression: Optional[RegressionOutput] = None + ranking: Optional[RankingOutput] = None + prediction_id_column: Optional[str] = "prediction_id" + tag_columns: Optional[List[str]] = None + + @property + def feature_columns(self) -> List[str]: + return list(self.feature_types.keys()) + + @property + def prediction_column_types(self) -> Dict[str, ValueType]: + if self.type == InferenceType.BINARY_CLASSIFICATION: + assert self.binary_classification is not None + return self.binary_classification.column_types + elif self.type == InferenceType.MULTICLASS_CLASSIFICATION: + assert self.multiclass_classification is not None + return self.multiclass_classification.column_types + elif self.type == InferenceType.REGRESSION: + assert self.regression is not None + return self.regression.column_types + elif self.type == InferenceType.RANKING: + assert self.ranking is not None + return self.ranking.column_types + else: + raise ValueError(f"Unknown prediction type: {self.type}") diff --git a/python/sdk/setup.py b/python/sdk/setup.py index 2c77aa823..aead2e0d1 100644 --- a/python/sdk/setup.py +++ b/python/sdk/setup.py @@ -28,6 +28,7 @@ "Click>=7.0,<8.1.4", "cloudpickle==2.0.0", # used by mlflow "cookiecutter>=1.7.2", + "dataclasses-json>=0.5.2", # allow Flyte version 1.2.0 or above to import Merlin SDK "docker>=4.2.1", "google-cloud-storage>=1.19.0", "protobuf>=3.12.0,<5.0.0", # Determined by the mlflow dependency