Skip to content

Commit

Permalink
make actual score column mandatory for regression model
Browse files Browse the repository at this point in the history
  • Loading branch information
khorshuheng committed Feb 28, 2024
1 parent dfb7462 commit f0fc4b1
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 33 deletions.
36 changes: 14 additions & 22 deletions api/client/model_regression_output.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions python/observation-publisher/publisher/observation_sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from merlin.observability.inference import (BinaryClassificationOutput,
InferenceSchema,
RankingOutput, RegressionOutput,
ValueType)
ValueType, add_prediction_id_column)

from publisher.config import ObservationSinkConfig, ObservationSinkType
from publisher.prediction_log_parser import (PREDICTION_LOG_MODEL_VERSION_COLUMN,
Expand Down Expand Up @@ -61,6 +61,8 @@ class ArizeSink(ObservationSink):
Writes prediction logs to Arize AI.
"""

ARIZE_PREDICTION_ID_COLUMN = "_prediction_id"

def __init__(
self,
inference_schema: InferenceSchema,
Expand All @@ -80,7 +82,7 @@ def __init__(
def _common_arize_schema_attributes(self) -> dict:
return dict(
feature_column_names=self._inference_schema.feature_columns,
prediction_id_column_name=self._inference_schema.prediction_id_column,
prediction_id_column_name=ArizeSink.ARIZE_PREDICTION_ID_COLUMN,
timestamp_column_name=PREDICTION_LOG_TIMESTAMP_COLUMN,
)

Expand Down Expand Up @@ -112,6 +114,7 @@ def _to_arize_schema(self) -> Tuple[ArizeModelType, ArizeSchema]:

def write(self, df: pd.DataFrame):
model_type, arize_schema = self._to_arize_schema()
df = add_prediction_id_column(df, self._inference_schema.session_id_column, self._inference_schema.row_id_column, ArizeSink.ARIZE_PREDICTION_ID_COLUMN)
try:
self._client.log(
dataframe=df,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,10 +118,10 @@ def test_log_to_dataframe():
)
expected_df = pd.DataFrame.from_records(
[
[0.8, 24, "FOOD", 0.9, "fraud", "1234", "a", "1234_a", request_timestamp, model_version],
[0.5, 2, "RIDE", 0.5, "fraud", "1234", "b", "1234_b", request_timestamp, model_version],
[1.0, 13, "CAR", 0.4, "non fraud", "5678", "c", "5678_c", request_timestamp, model_version],
[0.4, 60, "RIDE", 0.2, "non fraud", "5678", "d", "5678_d", request_timestamp, model_version],
[0.8, 24, "FOOD", 0.9, "fraud", "1234", "a", request_timestamp, model_version],
[0.5, 2, "RIDE", 0.5, "fraud", "1234", "b", request_timestamp, model_version],
[1.0, 13, "CAR", 0.4, "non fraud", "5678", "c", request_timestamp, model_version],
[0.4, 60, "RIDE", 0.2, "non fraud", "5678", "d", request_timestamp, model_version],
],
columns=[
"acceptance_rate",
Expand All @@ -131,7 +131,6 @@ def test_log_to_dataframe():
"_prediction_label",
"order_id",
"driver_id",
"prediction_id",
"request_timestamp",
"model_version",
],
Expand Down Expand Up @@ -182,7 +181,6 @@ def test_empty_column_conversion_to_dataframe():
"fraud",
"1234",
"a",
"1234_a",
datetime(2021, 1, 1, 0, 0, 0),
"0.1.0",
],
Expand All @@ -193,7 +191,6 @@ def test_empty_column_conversion_to_dataframe():
"_prediction_label",
"session_id",
"row_id",
"prediction_id",
"request_timestamp",
"model_version",
],
Expand Down
4 changes: 2 additions & 2 deletions python/sdk/client/models/regression_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import json


from typing import Any, ClassVar, Dict, List, Optional
from typing import Any, ClassVar, Dict, List
from pydantic import BaseModel, StrictStr
from client.models.model_prediction_output_class import ModelPredictionOutputClass
try:
Expand All @@ -31,7 +31,7 @@ class RegressionOutput(BaseModel):
RegressionOutput
""" # noqa: E501
prediction_score_column: StrictStr
actual_score_column: Optional[StrictStr] = None
actual_score_column: StrictStr
output_class: ModelPredictionOutputClass
__properties: ClassVar[List[str]] = ["prediction_score_column", "actual_score_column", "output_class"]

Expand Down
1 change: 1 addition & 0 deletions swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1456,6 +1456,7 @@ components:
type: object
required:
- prediction_score_column
- actual_score_column
- output_class
properties:
prediction_score_column:
Expand Down

0 comments on commit f0fc4b1

Please sign in to comment.