Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: reuse the same bigquery table for multiple model versions for Arize BQ Sink #531

Merged
merged 1 commit into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion python/observation-publisher/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@ pip-compile:
.PHONY: test
test:
@echo "Running tests..."
@python -m pytest
@python -m pytest -m "not integration"

.PHONY: test-integration
integration-test:
@echo "Running integration tests..."
@python -m pytest -m "integration"

.PHONY: run
run:
Expand Down
1 change: 1 addition & 0 deletions python/observation-publisher/publisher/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def start_consumer(cfg: PublisherConfig) -> None:
prediction_log_consumer.start_polling(
observation_sinks=observation_sinks,
inference_schema=inference_schema,
model_version=cfg.environment.model_version,
)


Expand Down
3 changes: 2 additions & 1 deletion python/observation-publisher/publisher/metric.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pandas import Timestamp
from prometheus_client import Gauge, Counter
from prometheus_client import Counter, Gauge


class MetricWriter(object):
Expand All @@ -21,6 +21,7 @@ def __init__(self):
self.total_prediction_logs_processed_counter = Counter(
"total_prediction_logs_processed",
"The total number of prediction logs processed by the publisher",
["model_id", "model_version"],
)
self._initialized = True

Expand Down
169 changes: 133 additions & 36 deletions python/observation-publisher/publisher/observation_sink.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
import time
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import List, Tuple

import pandas as pd
Expand All @@ -10,24 +10,19 @@
from arize.utils.types import Environments
from arize.utils.types import ModelTypes as ArizeModelType
from dataclasses_json import dataclass_json
from google.api_core.exceptions import NotFound
from google.cloud.bigquery import Client as BigQueryClient
from google.cloud.bigquery import (
SchemaField,
Table,
TimePartitioning,
TimePartitioningType,
)
from merlin.observability.inference import (
BinaryClassificationOutput,
InferenceSchema,
ObservationType,
RankingOutput,
RegressionOutput,
ValueType,
)
from google.cloud.bigquery import (SchemaField, Table, TimePartitioning,
TimePartitioningType)
from merlin.observability.inference import (BinaryClassificationOutput,
InferenceSchema, ObservationType,
RankingOutput, RegressionOutput,
ValueType)

from publisher.config import ObservationSinkConfig, ObservationSinkType
from publisher.prediction_log_parser import PREDICTION_LOG_TIMESTAMP_COLUMN
from publisher.prediction_log_parser import (MODEL_VERSION_COLUMN,
PREDICTION_LOG_TIMESTAMP_COLUMN,
ROW_ID_COLUMN, SESSION_ID_COLUMN)


class ObservationSink(abc.ABC):
Expand Down Expand Up @@ -74,6 +69,12 @@ def __init__(
model_version: str,
arize_client: ArizeClient,
):
"""
:param inference_schema: Inference schema for the ingested model
:param model_id: Merlin model id
:param model_version: Merlin model version
:param arize_client: Arize Pandas Logger client
"""
super().__init__(inference_schema, model_id, model_version)
self._client = arize_client

Expand Down Expand Up @@ -101,7 +102,7 @@ def _to_arize_schema(self) -> Tuple[ArizeModelType, ArizeSchema]:
elif isinstance(prediction_output, RankingOutput):
schema_attributes = self._common_arize_schema_attributes() | dict(
rank_column_name=prediction_output.rank_column,
prediction_group_id_column_name=prediction_output.prediction_group_id_column,
prediction_group_id_column_name=SESSION_ID_COLUMN,
)
model_type = ArizeModelType.RANKING
else:
Expand All @@ -112,6 +113,14 @@ def _to_arize_schema(self) -> Tuple[ArizeModelType, ArizeSchema]:
return model_type, ArizeSchema(**schema_attributes)

def write(self, df: pd.DataFrame):
df[self._inference_schema.prediction_id_column] = (
df[SESSION_ID_COLUMN] + df[ROW_ID_COLUMN]
)
if isinstance(self._inference_schema.model_prediction_output, RankingOutput):
df[
self._inference_schema.model_prediction_output.prediction_group_id_column
] = df[SESSION_ID_COLUMN]

processed_df = self._inference_schema.model_prediction_output.preprocess(
df, [ObservationType.FEATURE, ObservationType.PREDICTION]
)
Expand All @@ -134,39 +143,99 @@ def write(self, df: pd.DataFrame):
raise e


@dataclass_json
@dataclass
class BigQueryRetryConfig:
"""
Configuration for retrying failed write attempts. Write could fail due to BigQuery
taking time to update the table schema / create new table.
Attributes:
enabled: Whether to retry failed write attempts
retry_attempts: Number of retry attempts
retry_interval_seconds: Interval between retry attempts
"""

enabled: bool = False
retry_attempts: int = 4
retry_interval_seconds: int = 30


@dataclass_json
@dataclass
class BigQueryConfig:
"""
Configuration for writing to BigQuery
Attributes:
project: GCP project id
dataset: BigQuery dataset name
ttl_days: Time to live for the date partition
retry: Configuration for retrying failed write attempts
"""

project: str
dataset: str
ttl_days: int
retry: BigQueryRetryConfig = BigQueryRetryConfig()


class BigQuerySink(ObservationSink):
"""
Writes prediction logs to BigQuery. If the destination table doesn't exist, it will be created based on the inference schema..
Writes prediction logs to BigQuery. If the destination table doesn't exist, it will be created based on the inference schema.
"""

def __init__(
self,
inference_schema: InferenceSchema,
model_id: str,
model_version: str,
project: str,
dataset: str,
ttl_days: int,
config: BigQueryConfig,
):
"""
:param inference_schema: Inference schema for the ingested model
:param model_id: Merlin model id
:param model_version: Merlin model version
:param config: Configuration to write to bigquery sink
"""
super().__init__(inference_schema, model_id, model_version)
self._client = BigQueryClient()
self._inference_schema = inference_schema
self._model_id = model_id
self._model_version = model_version
self._project = project
self._dataset = dataset
table = Table(self.write_location, schema=self.schema_fields)
table.time_partitioning = TimePartitioning(type_=TimePartitioningType.DAY)
table.expires = datetime.now() + timedelta(days=ttl_days)
self._table: Table = self._client.create_table(exists_ok=True, table=table)
self._config = config
self._table = self.create_or_update_table()

@property
def project(self) -> str:
return self._config.project

@property
def dataset(self) -> str:
return self._config.dataset

@property
def retry(self) -> BigQueryRetryConfig:
return self._config.retry

def create_or_update_table(self) -> Table:
try:
original_table = self._client.get_table(self.write_location)
original_schema = original_table.schema
migrated_schema = original_schema[:]
for field in self.schema_fields:
if field not in original_schema:
migrated_schema.append(field)
if migrated_schema == original_schema:
return original_table
original_table.schema = migrated_schema
return self._client.update_table(original_table, ["schema"])
except NotFound:
table = Table(self.write_location, schema=self.schema_fields)
table.time_partitioning = TimePartitioning(
type_=TimePartitioningType.DAY,
field=PREDICTION_LOG_TIMESTAMP_COLUMN,
expiration_ms=self._config.ttl_days * 24 * 60 * 60 * 1000,
)
return self._client.create_table(table=table)

@property
def schema_fields(self) -> List[SchemaField]:
Expand All @@ -179,13 +248,21 @@ def schema_fields(self) -> List[SchemaField]:

schema_fields = [
SchemaField(
name=self._inference_schema.prediction_id_column,
name=SESSION_ID_COLUMN,
field_type="STRING",
),
SchemaField(
name=ROW_ID_COLUMN,
field_type="STRING",
),
SchemaField(
name=PREDICTION_LOG_TIMESTAMP_COLUMN,
field_type="TIMESTAMP",
),
SchemaField(
name=MODEL_VERSION_COLUMN,
field_type="STRING",
),
]
for feature, feature_type in self._inference_schema.feature_types.items():
schema_fields.append(
Expand All @@ -207,13 +284,35 @@ def schema_fields(self) -> List[SchemaField]:

@property
def write_location(self) -> str:
table_name = f"prediction_log_{self._model_id}_{self._model_version}".replace(
"-", "_"
).replace(".", "_")
return f"{self._project}.{self._dataset}.{table_name}"
table_name = f"prediction_log_{self._model_id}".replace("-", "_").replace(
khorshuheng marked this conversation as resolved.
Show resolved Hide resolved
".", "_"
)
return f"{self.project}.{self.dataset}.{table_name}"

def write(self, dataframe: pd.DataFrame):
self._client.insert_rows_from_dataframe(dataframe=dataframe, table=self._table)
for i in range(0, self.retry.retry_attempts + 1):
try:
response = self._client.insert_rows_from_dataframe(
dataframe=dataframe, table=self._table
)
errors = [error for error_chunk in response for error in error_chunk]
if len(errors) > 0:
if not self.retry.enabled:
print("Errors when inserting rows to BigQuery")
return
else:
print(
f"Errors when inserting rows to BigQuery, retrying attempt {i}/{self.retry.retry_attempts}"
)
time.sleep(self.retry.retry_interval_seconds)
else:
return
except NotFound as e:
print(
f"Table not found: {e}, retrying attempt {i}/{self.retry.retry_attempts}"
)
time.sleep(self.retry.retry_interval_seconds)
tiopramayudi marked this conversation as resolved.
Show resolved Hide resolved
print(f"Failed to write to BigQuery after {self.retry.retry_attempts} attempts")


def new_observation_sink(
Expand All @@ -230,9 +329,7 @@ def new_observation_sink(
inference_schema=inference_schema,
model_id=model_id,
model_version=model_version,
project=bq_config.project,
dataset=bq_config.dataset,
ttl_days=bq_config.ttl_days,
config=bq_config,
)
case ObservationSinkType.ARIZE:
arize_config: ArizeConfig = ArizeConfig.from_dict(sink_config.config) # type: ignore[attr-defined]
Expand Down
34 changes: 22 additions & 12 deletions python/observation-publisher/publisher/prediction_log_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from publisher.config import ObservationSource, ObservationSourceConfig
from publisher.metric import MetricWriter
from publisher.observation_sink import ObservationSink
from publisher.prediction_log_parser import (
PREDICTION_LOG_TIMESTAMP_COLUMN,
PredictionLogFeatureTable,
PredictionLogResultsTable,
)
from publisher.prediction_log_parser import (MODEL_VERSION_COLUMN,
PREDICTION_LOG_TIMESTAMP_COLUMN,
ROW_ID_COLUMN, SESSION_ID_COLUMN,
PredictionLogFeatureTable,
PredictionLogResultsTable)


class PredictionLogConsumer(abc.ABC):
Expand All @@ -42,6 +42,7 @@ def start_polling(
self,
observation_sinks: List[ObservationSink],
inference_schema: InferenceSchema,
model_version: str,
):
try:
buffered_logs = []
Expand All @@ -58,7 +59,9 @@ def start_polling(
and buffered_duration < buffered_max_duration_seconds
):
continue
df = log_batch_to_dataframe(buffered_logs, inference_schema)
df = log_batch_to_dataframe(
buffered_logs, inference_schema, model_version
)
most_recent_prediction_timestamp = df[
PREDICTION_LOG_TIMESTAMP_COLUMN
].max()
Expand All @@ -69,7 +72,7 @@ def start_polling(
len(buffered_logs)
)
write_tasks = [
Thread(target=sink.write, args=(df,)) for sink in observation_sinks
Thread(target=sink.write, args=(df.copy(),)) for sink in observation_sinks
]
for task in write_tasks:
task.start()
Expand Down Expand Up @@ -160,7 +163,7 @@ def parse_message_to_prediction_log(msg: str) -> PredictionLog:


def log_to_records(
log: PredictionLog, inference_schema: InferenceSchema
log: PredictionLog, inference_schema: InferenceSchema, model_version: str
) -> Tuple[List[List[np.int64 | np.float64 | np.bool_ | np.str_]], List[str]]:
request_timestamp = log.request_timestamp.ToDatetime()
feature_table = PredictionLogFeatureTable.from_struct(
Expand All @@ -171,7 +174,9 @@ def log_to_records(
)

rows = [
feature_row + prediction_row + [log.prediction_id + row_id, request_timestamp]
feature_row
+ prediction_row
+ [log.prediction_id, row_id, request_timestamp, model_version]
for feature_row, prediction_row, row_id in zip(
feature_table.rows,
prediction_results_table.rows,
Expand All @@ -182,18 +187,23 @@ def log_to_records(
column_names = (
feature_table.columns
+ prediction_results_table.columns
+ [inference_schema.prediction_id_column, PREDICTION_LOG_TIMESTAMP_COLUMN]
+ [
SESSION_ID_COLUMN,
ROW_ID_COLUMN,
PREDICTION_LOG_TIMESTAMP_COLUMN,
MODEL_VERSION_COLUMN,
]
)

return rows, column_names


def log_batch_to_dataframe(
logs: List[PredictionLog], inference_schema: InferenceSchema
logs: List[PredictionLog], inference_schema: InferenceSchema, model_version: str
) -> pd.DataFrame:
combined_records = []
column_names: List[str] = []
for log in logs:
rows, column_names = log_to_records(log, inference_schema)
rows, column_names = log_to_records(log, inference_schema, model_version)
combined_records.extend(rows)
return pd.DataFrame.from_records(combined_records, columns=column_names)
Loading
Loading