Skip to content

Commit

Permalink
feat: reuse the same bigquery table for multiple model versions for A…
Browse files Browse the repository at this point in the history
…rize BQ Sink (#531)

<!--  Thanks for sending a pull request!  Here are some tips for you:

1. Run unit tests and ensure that they are passing
2. If your change introduces any API changes, make sure to update the
e2e tests
3. Make sure documentation is updated for your PR!

-->
# Description
<!-- Briefly describe the motivation for the change. Please include
illustrations where appropriate. -->
The current sink create a new bigquery table per model version. This
makes it harder to implement Arize ground truth ingestion, because the
ground truth provided by the users are typically model version agnostic.

# Modifications
<!-- Summarize the key code changes. -->
- A single table will be used per model id, rather than model version
- A new column, model_version, is added.
- session id and row id are used in favor of prediction id as the
concept of prediction id in Arize differs from Merlin

# Tests
<!-- Besides the existing / updated automated tests, what specific
scenarios should be tested? Consider the backward compatibility of the
changes, whether corner cases are covered, etc. Please describe the
tests and check the ones that have been completed. Eg:
- [x] Deploying new and existing standard models
- [ ] Deploying PyFunc models
-->

# Checklist
- [ ] Added PR label
- [ ] Added unit test, integration, and/or e2e tests
- [ ] Tested locally
- [ ] Updated documentation
- [ ] Update Swagger spec if the PR introduce API changes
- [ ] Regenerated Golang and Python client if the PR introduces API
changes

# Release Notes
<!--
Does this PR introduce a user-facing change?
If no, just write "NONE" in the release-note block below.
If yes, a release note is required. Enter your extended release note in
the block below.
If the PR requires additional action from users switching to the new
release, include the string "action required".

For more information about release notes, see kubernetes' guide here:
http://git.k8s.io/community/contributors/guide/release-notes.md
-->

```release-note

```
  • Loading branch information
khorshuheng authored and leonlnj committed Feb 20, 2024
1 parent 98b0dd6 commit 205a6c7
Show file tree
Hide file tree
Showing 15 changed files with 404 additions and 129 deletions.
7 changes: 6 additions & 1 deletion python/observation-publisher/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ pip-compile:
.PHONY: test
test:
@echo "Running tests..."
@python -m pytest
@python -m pytest -m "not integration"

.PHONY: test-integration
integration-test:
@echo "Running integration tests..."
@python -m pytest -m "integration"

.PHONY: run
run:
Expand Down
1 change: 1 addition & 0 deletions python/observation-publisher/publisher/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def start_consumer(cfg: PublisherConfig) -> None:
prediction_log_consumer.start_polling(
observation_sinks=observation_sinks,
inference_schema=inference_schema,
model_version=cfg.environment.model_version,
)


Expand Down
3 changes: 2 additions & 1 deletion python/observation-publisher/publisher/metric.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pandas import Timestamp
from prometheus_client import Gauge, Counter
from prometheus_client import Counter, Gauge


class MetricWriter(object):
Expand All @@ -21,6 +21,7 @@ def __init__(self):
self.total_prediction_logs_processed_counter = Counter(
"total_prediction_logs_processed",
"The total number of prediction logs processed by the publisher",
["model_id", "model_version"],
)
self._initialized = True

Expand Down
169 changes: 133 additions & 36 deletions python/observation-publisher/publisher/observation_sink.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
import time
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import List, Tuple

import pandas as pd
Expand All @@ -10,24 +10,19 @@
from arize.utils.types import Environments
from arize.utils.types import ModelTypes as ArizeModelType
from dataclasses_json import dataclass_json
from google.api_core.exceptions import NotFound
from google.cloud.bigquery import Client as BigQueryClient
from google.cloud.bigquery import (
SchemaField,
Table,
TimePartitioning,
TimePartitioningType,
)
from merlin.observability.inference import (
BinaryClassificationOutput,
InferenceSchema,
ObservationType,
RankingOutput,
RegressionOutput,
ValueType,
)
from google.cloud.bigquery import (SchemaField, Table, TimePartitioning,
TimePartitioningType)
from merlin.observability.inference import (BinaryClassificationOutput,
InferenceSchema, ObservationType,
RankingOutput, RegressionOutput,
ValueType)

from publisher.config import ObservationSinkConfig, ObservationSinkType
from publisher.prediction_log_parser import PREDICTION_LOG_TIMESTAMP_COLUMN
from publisher.prediction_log_parser import (MODEL_VERSION_COLUMN,
PREDICTION_LOG_TIMESTAMP_COLUMN,
ROW_ID_COLUMN, SESSION_ID_COLUMN)


class ObservationSink(abc.ABC):
Expand Down Expand Up @@ -74,6 +69,12 @@ def __init__(
model_version: str,
arize_client: ArizeClient,
):
"""
:param inference_schema: Inference schema for the ingested model
:param model_id: Merlin model id
:param model_version: Merlin model version
:param arize_client: Arize Pandas Logger client
"""
super().__init__(inference_schema, model_id, model_version)
self._client = arize_client

Expand Down Expand Up @@ -101,7 +102,7 @@ def _to_arize_schema(self) -> Tuple[ArizeModelType, ArizeSchema]:
elif isinstance(prediction_output, RankingOutput):
schema_attributes = self._common_arize_schema_attributes() | dict(
rank_column_name=prediction_output.rank_column,
prediction_group_id_column_name=prediction_output.prediction_group_id_column,
prediction_group_id_column_name=SESSION_ID_COLUMN,
)
model_type = ArizeModelType.RANKING
else:
Expand All @@ -112,6 +113,14 @@ def _to_arize_schema(self) -> Tuple[ArizeModelType, ArizeSchema]:
return model_type, ArizeSchema(**schema_attributes)

def write(self, df: pd.DataFrame):
df[self._inference_schema.prediction_id_column] = (
df[SESSION_ID_COLUMN] + df[ROW_ID_COLUMN]
)
if isinstance(self._inference_schema.model_prediction_output, RankingOutput):
df[
self._inference_schema.model_prediction_output.prediction_group_id_column
] = df[SESSION_ID_COLUMN]

processed_df = self._inference_schema.model_prediction_output.preprocess(
df, [ObservationType.FEATURE, ObservationType.PREDICTION]
)
Expand All @@ -134,39 +143,99 @@ def write(self, df: pd.DataFrame):
raise e


@dataclass_json
@dataclass
class BigQueryRetryConfig:
"""
Configuration for retrying failed write attempts. Write could fail due to BigQuery
taking time to update the table schema / create new table.
Attributes:
enabled: Whether to retry failed write attempts
retry_attempts: Number of retry attempts
retry_interval_seconds: Interval between retry attempts
"""

enabled: bool = False
retry_attempts: int = 4
retry_interval_seconds: int = 30


@dataclass_json
@dataclass
class BigQueryConfig:
"""
Configuration for writing to BigQuery
Attributes:
project: GCP project id
dataset: BigQuery dataset name
ttl_days: Time to live for the date partition
retry: Configuration for retrying failed write attempts
"""

project: str
dataset: str
ttl_days: int
retry: BigQueryRetryConfig = BigQueryRetryConfig()


class BigQuerySink(ObservationSink):
"""
Writes prediction logs to BigQuery. If the destination table doesn't exist, it will be created based on the inference schema..
Writes prediction logs to BigQuery. If the destination table doesn't exist, it will be created based on the inference schema.
"""

def __init__(
self,
inference_schema: InferenceSchema,
model_id: str,
model_version: str,
project: str,
dataset: str,
ttl_days: int,
config: BigQueryConfig,
):
"""
:param inference_schema: Inference schema for the ingested model
:param model_id: Merlin model id
:param model_version: Merlin model version
:param config: Configuration to write to bigquery sink
"""
super().__init__(inference_schema, model_id, model_version)
self._client = BigQueryClient()
self._inference_schema = inference_schema
self._model_id = model_id
self._model_version = model_version
self._project = project
self._dataset = dataset
table = Table(self.write_location, schema=self.schema_fields)
table.time_partitioning = TimePartitioning(type_=TimePartitioningType.DAY)
table.expires = datetime.now() + timedelta(days=ttl_days)
self._table: Table = self._client.create_table(exists_ok=True, table=table)
self._config = config
self._table = self.create_or_update_table()

@property
def project(self) -> str:
return self._config.project

@property
def dataset(self) -> str:
return self._config.dataset

@property
def retry(self) -> BigQueryRetryConfig:
return self._config.retry

def create_or_update_table(self) -> Table:
try:
original_table = self._client.get_table(self.write_location)
original_schema = original_table.schema
migrated_schema = original_schema[:]
for field in self.schema_fields:
if field not in original_schema:
migrated_schema.append(field)
if migrated_schema == original_schema:
return original_table
original_table.schema = migrated_schema
return self._client.update_table(original_table, ["schema"])
except NotFound:
table = Table(self.write_location, schema=self.schema_fields)
table.time_partitioning = TimePartitioning(
type_=TimePartitioningType.DAY,
field=PREDICTION_LOG_TIMESTAMP_COLUMN,
expiration_ms=self._config.ttl_days * 24 * 60 * 60 * 1000,
)
return self._client.create_table(table=table)

@property
def schema_fields(self) -> List[SchemaField]:
Expand All @@ -179,13 +248,21 @@ def schema_fields(self) -> List[SchemaField]:

schema_fields = [
SchemaField(
name=self._inference_schema.prediction_id_column,
name=SESSION_ID_COLUMN,
field_type="STRING",
),
SchemaField(
name=ROW_ID_COLUMN,
field_type="STRING",
),
SchemaField(
name=PREDICTION_LOG_TIMESTAMP_COLUMN,
field_type="TIMESTAMP",
),
SchemaField(
name=MODEL_VERSION_COLUMN,
field_type="STRING",
),
]
for feature, feature_type in self._inference_schema.feature_types.items():
schema_fields.append(
Expand All @@ -207,13 +284,35 @@ def schema_fields(self) -> List[SchemaField]:

@property
def write_location(self) -> str:
table_name = f"prediction_log_{self._model_id}_{self._model_version}".replace(
"-", "_"
).replace(".", "_")
return f"{self._project}.{self._dataset}.{table_name}"
table_name = f"prediction_log_{self._model_id}".replace("-", "_").replace(
".", "_"
)
return f"{self.project}.{self.dataset}.{table_name}"

def write(self, dataframe: pd.DataFrame):
self._client.insert_rows_from_dataframe(dataframe=dataframe, table=self._table)
for i in range(0, self.retry.retry_attempts + 1):
try:
response = self._client.insert_rows_from_dataframe(
dataframe=dataframe, table=self._table
)
errors = [error for error_chunk in response for error in error_chunk]
if len(errors) > 0:
if not self.retry.enabled:
print("Errors when inserting rows to BigQuery")
return
else:
print(
f"Errors when inserting rows to BigQuery, retrying attempt {i}/{self.retry.retry_attempts}"
)
time.sleep(self.retry.retry_interval_seconds)
else:
return
except NotFound as e:
print(
f"Table not found: {e}, retrying attempt {i}/{self.retry.retry_attempts}"
)
time.sleep(self.retry.retry_interval_seconds)
print(f"Failed to write to BigQuery after {self.retry.retry_attempts} attempts")


def new_observation_sink(
Expand All @@ -230,9 +329,7 @@ def new_observation_sink(
inference_schema=inference_schema,
model_id=model_id,
model_version=model_version,
project=bq_config.project,
dataset=bq_config.dataset,
ttl_days=bq_config.ttl_days,
config=bq_config,
)
case ObservationSinkType.ARIZE:
arize_config: ArizeConfig = ArizeConfig.from_dict(sink_config.config) # type: ignore[attr-defined]
Expand Down
34 changes: 22 additions & 12 deletions python/observation-publisher/publisher/prediction_log_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from publisher.config import ObservationSource, ObservationSourceConfig
from publisher.metric import MetricWriter
from publisher.observation_sink import ObservationSink
from publisher.prediction_log_parser import (
PREDICTION_LOG_TIMESTAMP_COLUMN,
PredictionLogFeatureTable,
PredictionLogResultsTable,
)
from publisher.prediction_log_parser import (MODEL_VERSION_COLUMN,
PREDICTION_LOG_TIMESTAMP_COLUMN,
ROW_ID_COLUMN, SESSION_ID_COLUMN,
PredictionLogFeatureTable,
PredictionLogResultsTable)


class PredictionLogConsumer(abc.ABC):
Expand All @@ -42,6 +42,7 @@ def start_polling(
self,
observation_sinks: List[ObservationSink],
inference_schema: InferenceSchema,
model_version: str,
):
try:
buffered_logs = []
Expand All @@ -58,7 +59,9 @@ def start_polling(
and buffered_duration < buffered_max_duration_seconds
):
continue
df = log_batch_to_dataframe(buffered_logs, inference_schema)
df = log_batch_to_dataframe(
buffered_logs, inference_schema, model_version
)
most_recent_prediction_timestamp = df[
PREDICTION_LOG_TIMESTAMP_COLUMN
].max()
Expand All @@ -69,7 +72,7 @@ def start_polling(
len(buffered_logs)
)
write_tasks = [
Thread(target=sink.write, args=(df,)) for sink in observation_sinks
Thread(target=sink.write, args=(df.copy(),)) for sink in observation_sinks
]
for task in write_tasks:
task.start()
Expand Down Expand Up @@ -160,7 +163,7 @@ def parse_message_to_prediction_log(msg: str) -> PredictionLog:


def log_to_records(
log: PredictionLog, inference_schema: InferenceSchema
log: PredictionLog, inference_schema: InferenceSchema, model_version: str
) -> Tuple[List[List[np.int64 | np.float64 | np.bool_ | np.str_]], List[str]]:
request_timestamp = log.request_timestamp.ToDatetime()
feature_table = PredictionLogFeatureTable.from_struct(
Expand All @@ -171,7 +174,9 @@ def log_to_records(
)

rows = [
feature_row + prediction_row + [log.prediction_id + row_id, request_timestamp]
feature_row
+ prediction_row
+ [log.prediction_id, row_id, request_timestamp, model_version]
for feature_row, prediction_row, row_id in zip(
feature_table.rows,
prediction_results_table.rows,
Expand All @@ -182,18 +187,23 @@ def log_to_records(
column_names = (
feature_table.columns
+ prediction_results_table.columns
+ [inference_schema.prediction_id_column, PREDICTION_LOG_TIMESTAMP_COLUMN]
+ [
SESSION_ID_COLUMN,
ROW_ID_COLUMN,
PREDICTION_LOG_TIMESTAMP_COLUMN,
MODEL_VERSION_COLUMN,
]
)

return rows, column_names


def log_batch_to_dataframe(
logs: List[PredictionLog], inference_schema: InferenceSchema
logs: List[PredictionLog], inference_schema: InferenceSchema, model_version: str
) -> pd.DataFrame:
combined_records = []
column_names: List[str] = []
for log in logs:
rows, column_names = log_to_records(log, inference_schema)
rows, column_names = log_to_records(log, inference_schema, model_version)
combined_records.extend(rows)
return pd.DataFrame.from_records(combined_records, columns=column_names)
Loading

0 comments on commit 205a6c7

Please sign in to comment.