diff --git a/python/observation-publisher/Makefile b/python/observation-publisher/Makefile index d30f50b97..eaf3aa0b1 100644 --- a/python/observation-publisher/Makefile +++ b/python/observation-publisher/Makefile @@ -14,7 +14,12 @@ pip-compile: .PHONY: test test: @echo "Running tests..." - @python -m pytest + @python -m pytest -m "not integration" + +.PHONY: test-integration +integration-test: + @echo "Running integration tests..." + @python -m pytest -m "integration" .PHONY: run run: diff --git a/python/observation-publisher/publisher/__main__.py b/python/observation-publisher/publisher/__main__.py index b048ce147..89860f0d8 100644 --- a/python/observation-publisher/publisher/__main__.py +++ b/python/observation-publisher/publisher/__main__.py @@ -35,6 +35,7 @@ def start_consumer(cfg: PublisherConfig) -> None: prediction_log_consumer.start_polling( observation_sinks=observation_sinks, inference_schema=inference_schema, + model_version=cfg.environment.model_version, ) diff --git a/python/observation-publisher/publisher/metric.py b/python/observation-publisher/publisher/metric.py index 26d5314af..575c15028 100644 --- a/python/observation-publisher/publisher/metric.py +++ b/python/observation-publisher/publisher/metric.py @@ -1,5 +1,5 @@ from pandas import Timestamp -from prometheus_client import Gauge, Counter +from prometheus_client import Counter, Gauge class MetricWriter(object): @@ -21,6 +21,7 @@ def __init__(self): self.total_prediction_logs_processed_counter = Counter( "total_prediction_logs_processed", "The total number of prediction logs processed by the publisher", + ["model_id", "model_version"], ) self._initialized = True diff --git a/python/observation-publisher/publisher/observation_sink.py b/python/observation-publisher/publisher/observation_sink.py index f501c60e0..a09ddbb6c 100644 --- a/python/observation-publisher/publisher/observation_sink.py +++ b/python/observation-publisher/publisher/observation_sink.py @@ -1,6 +1,6 @@ import abc +import time from dataclasses import dataclass -from datetime import datetime, timedelta from typing import List, Tuple import pandas as pd @@ -10,24 +10,19 @@ from arize.utils.types import Environments from arize.utils.types import ModelTypes as ArizeModelType from dataclasses_json import dataclass_json +from google.api_core.exceptions import NotFound from google.cloud.bigquery import Client as BigQueryClient -from google.cloud.bigquery import ( - SchemaField, - Table, - TimePartitioning, - TimePartitioningType, -) -from merlin.observability.inference import ( - BinaryClassificationOutput, - InferenceSchema, - ObservationType, - RankingOutput, - RegressionOutput, - ValueType, -) +from google.cloud.bigquery import (SchemaField, Table, TimePartitioning, + TimePartitioningType) +from merlin.observability.inference import (BinaryClassificationOutput, + InferenceSchema, ObservationType, + RankingOutput, RegressionOutput, + ValueType) from publisher.config import ObservationSinkConfig, ObservationSinkType -from publisher.prediction_log_parser import PREDICTION_LOG_TIMESTAMP_COLUMN +from publisher.prediction_log_parser import (MODEL_VERSION_COLUMN, + PREDICTION_LOG_TIMESTAMP_COLUMN, + ROW_ID_COLUMN, SESSION_ID_COLUMN) class ObservationSink(abc.ABC): @@ -74,6 +69,12 @@ def __init__( model_version: str, arize_client: ArizeClient, ): + """ + :param inference_schema: Inference schema for the ingested model + :param model_id: Merlin model id + :param model_version: Merlin model version + :param arize_client: Arize Pandas Logger client + """ super().__init__(inference_schema, model_id, model_version) self._client = arize_client @@ -101,7 +102,7 @@ def _to_arize_schema(self) -> Tuple[ArizeModelType, ArizeSchema]: elif isinstance(prediction_output, RankingOutput): schema_attributes = self._common_arize_schema_attributes() | dict( rank_column_name=prediction_output.rank_column, - prediction_group_id_column_name=prediction_output.prediction_group_id_column, + prediction_group_id_column_name=SESSION_ID_COLUMN, ) model_type = ArizeModelType.RANKING else: @@ -112,6 +113,14 @@ def _to_arize_schema(self) -> Tuple[ArizeModelType, ArizeSchema]: return model_type, ArizeSchema(**schema_attributes) def write(self, df: pd.DataFrame): + df[self._inference_schema.prediction_id_column] = ( + df[SESSION_ID_COLUMN] + df[ROW_ID_COLUMN] + ) + if isinstance(self._inference_schema.model_prediction_output, RankingOutput): + df[ + self._inference_schema.model_prediction_output.prediction_group_id_column + ] = df[SESSION_ID_COLUMN] + processed_df = self._inference_schema.model_prediction_output.preprocess( df, [ObservationType.FEATURE, ObservationType.PREDICTION] ) @@ -134,17 +143,44 @@ def write(self, df: pd.DataFrame): raise e +@dataclass_json +@dataclass +class BigQueryRetryConfig: + """ + Configuration for retrying failed write attempts. Write could fail due to BigQuery + taking time to update the table schema / create new table. + Attributes: + enabled: Whether to retry failed write attempts + retry_attempts: Number of retry attempts + retry_interval_seconds: Interval between retry attempts + """ + + enabled: bool = False + retry_attempts: int = 4 + retry_interval_seconds: int = 30 + + @dataclass_json @dataclass class BigQueryConfig: + """ + Configuration for writing to BigQuery + Attributes: + project: GCP project id + dataset: BigQuery dataset name + ttl_days: Time to live for the date partition + retry: Configuration for retrying failed write attempts + """ + project: str dataset: str ttl_days: int + retry: BigQueryRetryConfig = BigQueryRetryConfig() class BigQuerySink(ObservationSink): """ - Writes prediction logs to BigQuery. If the destination table doesn't exist, it will be created based on the inference schema.. + Writes prediction logs to BigQuery. If the destination table doesn't exist, it will be created based on the inference schema. """ def __init__( @@ -152,21 +188,54 @@ def __init__( inference_schema: InferenceSchema, model_id: str, model_version: str, - project: str, - dataset: str, - ttl_days: int, + config: BigQueryConfig, ): + """ + :param inference_schema: Inference schema for the ingested model + :param model_id: Merlin model id + :param model_version: Merlin model version + :param config: Configuration to write to bigquery sink + """ super().__init__(inference_schema, model_id, model_version) self._client = BigQueryClient() self._inference_schema = inference_schema self._model_id = model_id self._model_version = model_version - self._project = project - self._dataset = dataset - table = Table(self.write_location, schema=self.schema_fields) - table.time_partitioning = TimePartitioning(type_=TimePartitioningType.DAY) - table.expires = datetime.now() + timedelta(days=ttl_days) - self._table: Table = self._client.create_table(exists_ok=True, table=table) + self._config = config + self._table = self.create_or_update_table() + + @property + def project(self) -> str: + return self._config.project + + @property + def dataset(self) -> str: + return self._config.dataset + + @property + def retry(self) -> BigQueryRetryConfig: + return self._config.retry + + def create_or_update_table(self) -> Table: + try: + original_table = self._client.get_table(self.write_location) + original_schema = original_table.schema + migrated_schema = original_schema[:] + for field in self.schema_fields: + if field not in original_schema: + migrated_schema.append(field) + if migrated_schema == original_schema: + return original_table + original_table.schema = migrated_schema + return self._client.update_table(original_table, ["schema"]) + except NotFound: + table = Table(self.write_location, schema=self.schema_fields) + table.time_partitioning = TimePartitioning( + type_=TimePartitioningType.DAY, + field=PREDICTION_LOG_TIMESTAMP_COLUMN, + expiration_ms=self._config.ttl_days * 24 * 60 * 60 * 1000, + ) + return self._client.create_table(table=table) @property def schema_fields(self) -> List[SchemaField]: @@ -179,13 +248,21 @@ def schema_fields(self) -> List[SchemaField]: schema_fields = [ SchemaField( - name=self._inference_schema.prediction_id_column, + name=SESSION_ID_COLUMN, + field_type="STRING", + ), + SchemaField( + name=ROW_ID_COLUMN, field_type="STRING", ), SchemaField( name=PREDICTION_LOG_TIMESTAMP_COLUMN, field_type="TIMESTAMP", ), + SchemaField( + name=MODEL_VERSION_COLUMN, + field_type="STRING", + ), ] for feature, feature_type in self._inference_schema.feature_types.items(): schema_fields.append( @@ -207,13 +284,35 @@ def schema_fields(self) -> List[SchemaField]: @property def write_location(self) -> str: - table_name = f"prediction_log_{self._model_id}_{self._model_version}".replace( - "-", "_" - ).replace(".", "_") - return f"{self._project}.{self._dataset}.{table_name}" + table_name = f"prediction_log_{self._model_id}".replace("-", "_").replace( + ".", "_" + ) + return f"{self.project}.{self.dataset}.{table_name}" def write(self, dataframe: pd.DataFrame): - self._client.insert_rows_from_dataframe(dataframe=dataframe, table=self._table) + for i in range(0, self.retry.retry_attempts + 1): + try: + response = self._client.insert_rows_from_dataframe( + dataframe=dataframe, table=self._table + ) + errors = [error for error_chunk in response for error in error_chunk] + if len(errors) > 0: + if not self.retry.enabled: + print("Errors when inserting rows to BigQuery") + return + else: + print( + f"Errors when inserting rows to BigQuery, retrying attempt {i}/{self.retry.retry_attempts}" + ) + time.sleep(self.retry.retry_interval_seconds) + else: + return + except NotFound as e: + print( + f"Table not found: {e}, retrying attempt {i}/{self.retry.retry_attempts}" + ) + time.sleep(self.retry.retry_interval_seconds) + print(f"Failed to write to BigQuery after {self.retry.retry_attempts} attempts") def new_observation_sink( @@ -230,9 +329,7 @@ def new_observation_sink( inference_schema=inference_schema, model_id=model_id, model_version=model_version, - project=bq_config.project, - dataset=bq_config.dataset, - ttl_days=bq_config.ttl_days, + config=bq_config, ) case ObservationSinkType.ARIZE: arize_config: ArizeConfig = ArizeConfig.from_dict(sink_config.config) # type: ignore[attr-defined] diff --git a/python/observation-publisher/publisher/prediction_log_consumer.py b/python/observation-publisher/publisher/prediction_log_consumer.py index 1d6199dde..eef099c53 100644 --- a/python/observation-publisher/publisher/prediction_log_consumer.py +++ b/python/observation-publisher/publisher/prediction_log_consumer.py @@ -14,11 +14,11 @@ from publisher.config import ObservationSource, ObservationSourceConfig from publisher.metric import MetricWriter from publisher.observation_sink import ObservationSink -from publisher.prediction_log_parser import ( - PREDICTION_LOG_TIMESTAMP_COLUMN, - PredictionLogFeatureTable, - PredictionLogResultsTable, -) +from publisher.prediction_log_parser import (MODEL_VERSION_COLUMN, + PREDICTION_LOG_TIMESTAMP_COLUMN, + ROW_ID_COLUMN, SESSION_ID_COLUMN, + PredictionLogFeatureTable, + PredictionLogResultsTable) class PredictionLogConsumer(abc.ABC): @@ -42,6 +42,7 @@ def start_polling( self, observation_sinks: List[ObservationSink], inference_schema: InferenceSchema, + model_version: str, ): try: buffered_logs = [] @@ -58,7 +59,9 @@ def start_polling( and buffered_duration < buffered_max_duration_seconds ): continue - df = log_batch_to_dataframe(buffered_logs, inference_schema) + df = log_batch_to_dataframe( + buffered_logs, inference_schema, model_version + ) most_recent_prediction_timestamp = df[ PREDICTION_LOG_TIMESTAMP_COLUMN ].max() @@ -69,7 +72,7 @@ def start_polling( len(buffered_logs) ) write_tasks = [ - Thread(target=sink.write, args=(df,)) for sink in observation_sinks + Thread(target=sink.write, args=(df.copy(),)) for sink in observation_sinks ] for task in write_tasks: task.start() @@ -160,7 +163,7 @@ def parse_message_to_prediction_log(msg: str) -> PredictionLog: def log_to_records( - log: PredictionLog, inference_schema: InferenceSchema + log: PredictionLog, inference_schema: InferenceSchema, model_version: str ) -> Tuple[List[List[np.int64 | np.float64 | np.bool_ | np.str_]], List[str]]: request_timestamp = log.request_timestamp.ToDatetime() feature_table = PredictionLogFeatureTable.from_struct( @@ -171,7 +174,9 @@ def log_to_records( ) rows = [ - feature_row + prediction_row + [log.prediction_id + row_id, request_timestamp] + feature_row + + prediction_row + + [log.prediction_id, row_id, request_timestamp, model_version] for feature_row, prediction_row, row_id in zip( feature_table.rows, prediction_results_table.rows, @@ -182,18 +187,23 @@ def log_to_records( column_names = ( feature_table.columns + prediction_results_table.columns - + [inference_schema.prediction_id_column, PREDICTION_LOG_TIMESTAMP_COLUMN] + + [ + SESSION_ID_COLUMN, + ROW_ID_COLUMN, + PREDICTION_LOG_TIMESTAMP_COLUMN, + MODEL_VERSION_COLUMN, + ] ) return rows, column_names def log_batch_to_dataframe( - logs: List[PredictionLog], inference_schema: InferenceSchema + logs: List[PredictionLog], inference_schema: InferenceSchema, model_version: str ) -> pd.DataFrame: combined_records = [] column_names: List[str] = [] for log in logs: - rows, column_names = log_to_records(log, inference_schema) + rows, column_names = log_to_records(log, inference_schema, model_version) combined_records.extend(rows) return pd.DataFrame.from_records(combined_records, columns=column_names) diff --git a/python/observation-publisher/publisher/prediction_log_parser.py b/python/observation-publisher/publisher/prediction_log_parser.py index 669da15d2..97ecd4d09 100644 --- a/python/observation-publisher/publisher/prediction_log_parser.py +++ b/python/observation-publisher/publisher/prediction_log_parser.py @@ -7,7 +7,10 @@ from merlin.observability.inference import InferenceSchema, ValueType from typing_extensions import Self +SESSION_ID_COLUMN = "session_id" +ROW_ID_COLUMN = "row_id" PREDICTION_LOG_TIMESTAMP_COLUMN = "request_timestamp" +MODEL_VERSION_COLUMN = "model_version" @dataclass diff --git a/python/observation-publisher/pyproject.toml b/python/observation-publisher/pyproject.toml index ac89db93c..b0b21c481 100644 --- a/python/observation-publisher/pyproject.toml +++ b/python/observation-publisher/pyproject.toml @@ -2,6 +2,9 @@ addopts = [ "--import-mode=importlib", ] +markers = [ + "integration: mark a test as integration test" +] [tool.mypy] exclude = "test.*" diff --git a/python/observation-publisher/requirements-dev.txt b/python/observation-publisher/requirements-dev.txt index d8df3e33f..1a610b802 100644 --- a/python/observation-publisher/requirements-dev.txt +++ b/python/observation-publisher/requirements-dev.txt @@ -4,4 +4,5 @@ types-requests==2.31.0.20231231 types-PyYAML==6.0.12.12 types-jmespath==1.0.2.7 mypy==1.7.1 -mypy-extensions==1.0.0 \ No newline at end of file +mypy-extensions==1.0.0 +db-dtypes==1.2.0 \ No newline at end of file diff --git a/python/observation-publisher/requirements.in b/python/observation-publisher/requirements.in index f42e3bd7f..d7080b2c4 100644 --- a/python/observation-publisher/requirements.in +++ b/python/observation-publisher/requirements.in @@ -1,6 +1,6 @@ confluent-kafka>=2.3.0 caraml-upi-protos>=1.0.0 -arize==7.7.* +arize>=7.7.0 hydra-core>=1.3.0 pandas>=1.0.0 google-cloud-bigquery diff --git a/python/observation-publisher/requirements.txt b/python/observation-publisher/requirements.txt index ebd0d6930..0890b3c2d 100644 --- a/python/observation-publisher/requirements.txt +++ b/python/observation-publisher/requirements.txt @@ -8,6 +8,8 @@ # via -r requirements.in alembic==1.13.0 # via mlflow +annotated-types==0.6.0 + # via pydantic antlr4-python3-runtime==4.9.3 # via # hydra-core @@ -74,7 +76,9 @@ flask==2.3.3 gitdb==4.0.11 # via gitpython gitpython==3.1.40 - # via mlflow + # via + # merlin-sdk + # mlflow google-api-core==2.15.0 # via # google-cloud-bigquery @@ -191,6 +195,10 @@ pyasn1==0.5.1 # rsa pyasn1-modules==0.3.0 # via google-auth +pydantic==2.5.3 + # via merlin-sdk +pydantic-core==2.14.6 + # via pydantic pygments==2.17.2 # via rich pyjwt==2.8.0 @@ -266,6 +274,8 @@ typing-extensions==4.9.0 # via # -r requirements.in # alembic + # pydantic + # pydantic-core # typing-inspect typing-inspect==0.9.0 # via dataclasses-json diff --git a/python/observation-publisher/tests/__init__.py b/python/observation-publisher/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/python/observation-publisher/tests/common_fixtures.py b/python/observation-publisher/tests/common_fixtures.py new file mode 100644 index 000000000..db018fb57 --- /dev/null +++ b/python/observation-publisher/tests/common_fixtures.py @@ -0,0 +1,13 @@ +import os + +import pytest + + +@pytest.fixture +def bq_project() -> str: + return os.environ.get("INTEGRATION_TEST_BQ_PROJECT") + + +@pytest.fixture +def bq_dataset() -> str: + return os.environ.get("INTEGRATION_TEST_BQ_DATASET") diff --git a/python/observation-publisher/tests/test_config.py b/python/observation-publisher/tests/test_config.py index 3a3ba60f9..199042aa2 100644 --- a/python/observation-publisher/tests/test_config.py +++ b/python/observation-publisher/tests/test_config.py @@ -4,14 +4,9 @@ from merlin.observability.inference import InferenceSchema, ValueType from omegaconf import OmegaConf -from publisher.config import ( - Environment, - ObservationSinkConfig, - ObservationSinkType, - ObservationSource, - ObservationSourceConfig, - PublisherConfig, -) +from publisher.config import (Environment, ObservationSinkConfig, + ObservationSinkType, ObservationSource, + ObservationSourceConfig, PublisherConfig) def test_config_initialization(): diff --git a/python/observation-publisher/tests/test_observation_sink.py b/python/observation-publisher/tests/test_observation_sink.py index de93b2a72..f12306a8e 100644 --- a/python/observation-publisher/tests/test_observation_sink.py +++ b/python/observation-publisher/tests/test_observation_sink.py @@ -1,18 +1,79 @@ +import dataclasses +import time from datetime import datetime from typing import Optional import pandas as pd import pyarrow as pa +import pytest from arize.pandas.logger import Client -from merlin.observability.inference import ( - BinaryClassificationOutput, - InferenceSchema, - RankingOutput, - ValueType, -) +from dateutil import tz +from google.cloud.bigquery import Client as BigQueryClient +from google.cloud.bigquery import SchemaField +from merlin.observability.inference import (BinaryClassificationOutput, + InferenceSchema, RankingOutput, + ValueType) +from pandas._testing import assert_frame_equal from requests import Response -from publisher.observation_sink import ArizeSink +from publisher.observation_sink import (ArizeSink, BigQueryConfig, + BigQueryRetryConfig, BigQuerySink) +from tests.common_fixtures import bq_dataset, bq_project + + +@pytest.fixture +def binary_classification_inference_schema() -> InferenceSchema: + return InferenceSchema( + feature_types={ + "rating": ValueType.FLOAT64, + }, + model_prediction_output=BinaryClassificationOutput( + prediction_score_column="prediction_score", + actual_label_column="actual_label", + positive_class_label="fraud", + negative_class_label="non fraud", + score_threshold=0.5, + ), + ) + + +@pytest.fixture +def binary_classification_inference_logs() -> pd.DataFrame: + request_timestamp = datetime(2024, 1, 1, 0, 0, 0).astimezone(tz.UTC) + return pd.DataFrame.from_records( + [ + [0.8, 0.4, "1234", "a", request_timestamp, "0.1.0", "non fraud"], + [0.5, 0.9, "1234", "b", request_timestamp, "0.1.0", "fraud"], + ], + columns=[ + "rating", + "prediction_score", + "session_id", + "row_id", + "request_timestamp", + "model_version", + "_prediction_label", + ], + ) + + +@pytest.fixture +def ranking_inference_logs() -> pd.DataFrame: + request_timestamp = datetime(2024, 1, 1, 0, 0, 0).astimezone(tz.UTC) + return pd.DataFrame.from_records( + [ + [5.0, 1.0, "1234", "1001", request_timestamp], + [4.0, 0.9, "1234", "1002", request_timestamp], + [3.0, 0.8, "1234", "1003", request_timestamp], + ], + columns=[ + "rating", + "rank_score", + "session_id", + "row_id", + "request_timestamp", + ], + ) class MockResponse(Response): @@ -36,43 +97,24 @@ def _post_file( ) -def test_binary_classification_model_preprocessing_for_arize(): - inference_schema = InferenceSchema( - feature_types={ - "rating": ValueType.FLOAT64, - }, - model_prediction_output=BinaryClassificationOutput( - prediction_score_column="prediction_score", - actual_label_column="actual_label", - positive_class_label="fraud", - negative_class_label="non fraud", - score_threshold=0.5, - ), - ) +def test_binary_classification_model_preprocessing_for_arize( + binary_classification_inference_schema: InferenceSchema, + binary_classification_inference_logs: pd.DataFrame, +): arize_client = MockArizeClient(api_key="test", space_key="test") arize_sink = ArizeSink( - inference_schema, + binary_classification_inference_schema, "test-model", "0.1.0", arize_client, ) - request_timestamp = datetime.now() - input_df = pd.DataFrame.from_records( - [ - [0.8, 0.4, "1234a", request_timestamp], - [0.5, 0.9, "1234b", request_timestamp], - ], - columns=[ - "rating", - "prediction_score", - "prediction_id", - "request_timestamp", - ], - ) - arize_sink.write(input_df) + arize_sink.write(binary_classification_inference_logs) -def test_ranking_model_preprocessing_for_arize(): +def test_ranking_model_preprocessing_for_arize( + binary_classification_inference_logs: pd.DataFrame, + ranking_inference_logs: pd.DataFrame, +): inference_schema = InferenceSchema( feature_types={ "rating": ValueType.FLOAT64, @@ -83,21 +125,6 @@ def test_ranking_model_preprocessing_for_arize(): relevance_score_column="relevance_score_column", ), ) - request_timestamp = datetime.now() - input_df = pd.DataFrame.from_records( - [ - [5.0, 1.0, "1234", "1001", request_timestamp], - [4.0, 0.9, "1234", "1001", request_timestamp], - [3.0, 0.8, "1234", "1001", request_timestamp], - ], - columns=[ - "rating", - "rank_score", - "prediction_id", - "order_id", - "request_timestamp", - ], - ) arize_client = MockArizeClient(api_key="test", space_key="test") arize_sink = ArizeSink( inference_schema, @@ -105,4 +132,105 @@ def test_ranking_model_preprocessing_for_arize(): "0.1.0", arize_client, ) - arize_sink.write(input_df) + arize_sink.write(ranking_inference_logs) + + +@pytest.mark.integration +def test_bigquery_sink_schema_migration( + bq_project: str, + bq_dataset: str, + binary_classification_inference_schema: InferenceSchema, + binary_classification_inference_logs: pd.DataFrame, +): + client = BigQueryClient() + client.delete_table( + f"{bq_project}.{bq_dataset}.prediction_log_test_model", not_found_ok=True + ) + bq_sink = BigQuerySink( + binary_classification_inference_schema, + "test-model", + "0.1.0", + config=BigQueryConfig( + project=bq_project, + dataset=bq_dataset, + ttl_days=14, + retry=BigQueryRetryConfig( + enabled=True, retry_attempts=3, retry_interval_seconds=10 + ), + ), + ) + bq_sink.write(binary_classification_inference_logs) + migrated_schema = dataclasses.replace(binary_classification_inference_schema) + migrated_schema.feature_types = { + "rating_v2": ValueType.FLOAT64, + } + migrated_bq_sink = BigQuerySink( + migrated_schema, + "test-model", + "0.2.0", + config=BigQueryConfig( + project=bq_project, + dataset=bq_dataset, + ttl_days=14, + retry=BigQueryRetryConfig( + enabled=True, retry_attempts=5, retry_interval_seconds=30 + ), + ), + ) + migrated_inference_logs = binary_classification_inference_logs.rename( + columns={"rating": "rating_v2"} + ) + migrated_inference_logs["model_version"] = "0.2.0" + migrated_bq_sink.write(migrated_inference_logs) + version_update_bq_sink = BigQuerySink( + migrated_schema, + "test-model", + "0.3.0", + config=BigQueryConfig( + project=bq_project, + dataset=bq_dataset, + ttl_days=14, + ), + ) + version_update_inference_logs = migrated_inference_logs.copy() + version_update_inference_logs["model_version"] = "0.3.0" + version_update_bq_sink.write(version_update_inference_logs) + + table = client.get_table(f"{bq_project}.{bq_dataset}.prediction_log_test_model") + assert table.schema == [ + SchemaField(name="session_id", field_type="STRING"), + SchemaField(name="row_id", field_type="STRING"), + SchemaField(name="request_timestamp", field_type="TIMESTAMP"), + SchemaField(name="model_version", field_type="STRING"), + SchemaField(name="rating", field_type="FLOAT"), + SchemaField(name="prediction_score", field_type="FLOAT"), + SchemaField(name="_prediction_label", field_type="STRING"), + SchemaField(name="rating_v2", field_type="FLOAT"), + ] + df = client.query( + "SELECT * FROM `{}.{}.prediction_log_test_model`".format(bq_project, bq_dataset) + ).to_dataframe() + df.reset_index(drop=True, inplace=True) + event_timestamp = datetime(2024, 1, 1, 0, 0, 0).astimezone(tz.UTC) + expected_df = pd.DataFrame.from_records( + [ + [0.8, 0.4, "1234", "a", event_timestamp, "0.1.0", "non fraud", None], + [0.5, 0.9, "1234", "b", event_timestamp, "0.1.0", "fraud", None], + [None, 0.4, "1234", "a", event_timestamp, "0.2.0", "non fraud", 0.8], + [None, 0.9, "1234", "b", event_timestamp, "0.2.0", "fraud", 0.5], + [None, 0.4, "1234", "a", event_timestamp, "0.3.0", "non fraud", 0.8], + [None, 0.9, "1234", "b", event_timestamp, "0.3.0", "fraud", 0.5], + ], + columns=[ + "rating", + "prediction_score", + "session_id", + "row_id", + "request_timestamp", + "model_version", + "_prediction_label", + "rating_v2", + ], + ) + expected_df.reset_index(drop=True, inplace=True) + assert_frame_equal(df, expected_df, check_like=True) diff --git a/python/observation-publisher/tests/test_prediction_log_consumer.py b/python/observation-publisher/tests/test_prediction_log_consumer.py index db2a71569..c41f9b60d 100644 --- a/python/observation-publisher/tests/test_prediction_log_consumer.py +++ b/python/observation-publisher/tests/test_prediction_log_consumer.py @@ -4,11 +4,8 @@ import numpy as np import pandas as pd from caraml.upi.v1.prediction_log_pb2 import PredictionLog -from merlin.observability.inference import ( - BinaryClassificationOutput, - InferenceSchema, - ValueType, -) +from merlin.observability.inference import (BinaryClassificationOutput, + InferenceSchema, ValueType) from pandas._testing import assert_frame_equal from publisher.prediction_log_consumer import log_batch_to_dataframe @@ -77,6 +74,7 @@ def test_log_to_dataframe(): "service_type", ] output_columns = ["prediction_score"] + request_timestamp = datetime(2021, 1, 1, 0, 0, 0) prediction_logs = [ new_prediction_log( prediction_id="1234", @@ -92,7 +90,7 @@ def test_log_to_dataframe(): [0.9], [0.5], ], - request_timestamp=datetime(2021, 1, 1, 0, 0, 0), + request_timestamp=request_timestamp, row_ids=["a", "b"], ), new_prediction_log( @@ -109,25 +107,29 @@ def test_log_to_dataframe(): [0.4], [0.2], ], - request_timestamp=datetime(2021, 1, 1, 0, 0, 0), + request_timestamp=request_timestamp, row_ids=["c", "d"], ), ] - prediction_logs_df = log_batch_to_dataframe(prediction_logs, inference_schema) + prediction_logs_df = log_batch_to_dataframe( + prediction_logs, inference_schema, model_version + ) expected_df = pd.DataFrame.from_records( [ - [0.8, 24, "FOOD", 0.9, "1234a", datetime(2021, 1, 1, 0, 0, 0)], - [0.5, 2, "RIDE", 0.5, "1234b", datetime(2021, 1, 1, 0, 0, 0)], - [1.0, 13, "CAR", 0.4, "5678c", datetime(2021, 1, 1, 0, 0, 0)], - [0.4, 60, "RIDE", 0.2, "5678d", datetime(2021, 1, 1, 0, 0, 0)], + [0.8, 24, "FOOD", 0.9, "1234", "a", request_timestamp, model_version], + [0.5, 2, "RIDE", 0.5, "1234", "b", request_timestamp, model_version], + [1.0, 13, "CAR", 0.4, "5678", "c", request_timestamp, model_version], + [0.4, 60, "RIDE", 0.2, "5678", "d", request_timestamp, model_version], ], columns=[ "acceptance_rate", "minutes_since_last_order", "service_type", "prediction_score", - "prediction_id", + "session_id", + "row_id", "request_timestamp", + "model_version", ], ) assert_frame_equal(prediction_logs_df, expected_df) @@ -165,21 +167,27 @@ def test_empty_column_conversion_to_dataframe(): row_ids=["a"], ), ] - prediction_logs_df = log_batch_to_dataframe(prediction_logs, inference_schema) + prediction_logs_df = log_batch_to_dataframe( + prediction_logs, inference_schema, model_version + ) expected_df = pd.DataFrame.from_records( [ [ np.NaN, 0.5, - "1234a", + "1234", + "a", datetime(2021, 1, 1, 0, 0, 0), + "0.1.0", ], ], columns=[ "acceptance_rate", "prediction_score", - "prediction_id", + "session_id", + "row_id", "request_timestamp", + "model_version", ], ) assert_frame_equal(prediction_logs_df, expected_df)