From b183593733d46b0e1aa9ff926798c90893f96564 Mon Sep 17 00:00:00 2001 From: Khor Shu Heng <32997938+khorshuheng@users.noreply.github.com> Date: Tue, 2 Jan 2024 10:53:51 +0800 Subject: [PATCH] feat: update observation publisher to use newer sdk (#509) **What this PR does / why we need it**: This is a follow up to https://github.com/caraml-dev/merlin/pull/504 , which update the observation publisher to use the latest Merlin SDK. Also introduce the use of piptools to pin the dependencies. Another important changes is using dict instead of the InferenceSchema dataclass. This is because OmegaConf (and by extension, Hydra) does not have a way to instantiate the correct subclass to an abstract class field, so we have to perform our own deserialization instead. **Which issue(s) this PR fixes**: Fixes # **Does this PR introduce a user-facing change?**: NONE ```release-note NONE ``` **Checklist** - [x] Added unit test, integration, and/or e2e tests - [x] Tested locally - [x] Updated documentation - [ ] Update Swagger spec if the PR introduce API changes - [ ] Regenerated Golang and Python client if the PR introduce API changes --- .github/workflows/merlin.yml | 48 +++ .github/workflows/release.yml | 20 ++ python/Makefile | 4 +- python/observation-publisher/Dockerfile | 12 +- python/observation-publisher/Makefile | 8 +- python/observation-publisher/README.md | 10 + .../conf/environment/example-override.yaml | 53 +--- .../publisher/__main__.py | 12 +- .../observation-publisher/publisher/config.py | 23 +- .../publisher/observability_backend.py | 122 ++++---- .../publisher/prediction_log_consumer.py | 21 +- .../publisher/prediction_log_parser.py | 4 +- .../requirements-dev.txt | 3 +- python/observation-publisher/requirements.in | 6 + python/observation-publisher/requirements.txt | 279 +++++++++++++++++- .../tests/test_config.py | 38 ++- .../tests/test_observability_backend.py | 128 +++++--- .../tests/test_prediction_log_consumer.py | 51 ++-- 18 files changed, 621 insertions(+), 221 deletions(-) create mode 100644 python/observation-publisher/requirements.in diff --git a/.github/workflows/merlin.yml b/.github/workflows/merlin.yml index 40c22c5f3..5e118094f 100644 --- a/.github/workflows/merlin.yml +++ b/.github/workflows/merlin.yml @@ -186,6 +186,29 @@ jobs: POSTGRES_PASSWORD: ${{ secrets.DB_PASSWORD }} run: make it-test-api-ci + + test-observation-publisher: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: actions/setup-python@v4 + id: setup-python + with: + python-version: '3.10' + - uses: actions/cache@v3 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip-${{ steps.setup-python-outputs.python-version }}- + - name: Install dependencies + working-directory: ./python/observation-publisher + run: | + make setup + - name: Unit test observation publisher + working-directory: ./python/observation-publisher + run: make test + build-ui: runs-on: ubuntu-latest steps: @@ -355,6 +378,30 @@ jobs: path: merlin-logger.${{ needs.create-version.outputs.version }}.tar retention-days: ${{ env.ARTIFACT_RETENTION_DAYS }} + build-observation-publisher: + runs-on: ubuntu-latest + needs: + - create-version + - test-observation-publisher + env: + DOCKER_REGISTRY: ghcr.io + DOCKER_IMAGE_TAG: "ghcr.io/${{ github.repository }}/merlin-observation-publisher:${{ needs.create-version.outputs.version }}" + steps: + - uses: actions/checkout@v2 + - name: Build Observation Publisher Docker + env: + OBSERVATION_PUBLISHER_IMAGE_TAG: ${{ env.DOCKER_IMAGE_TAG }} + run: make observation-publisher + working-directory: ./python + - name: Save Observation Publisher Docker + run: docker image save --output merlin-observation-publisher.${{ needs.create-version.outputs.version }}.tar ${{ env.DOCKER_IMAGE_TAG }} + - name: Publish Observation Publisher Docker Artifact + uses: actions/upload-artifact@v2 + with: + name: merlin-observation-publisher.${{ needs.create-version.outputs.version }}.tar + path: merlin-observation-publisher.${{ needs.create-version.outputs.version }}.tar + retention-days: ${{ env.ARTIFACT_RETENTION_DAYS }} + e2e-test: runs-on: ubuntu-latest needs: @@ -443,6 +490,7 @@ jobs: - build-api - build-batch-predictor-base - build-pyfunc-server-base + - build-observation-publisher - test-python-sdk - e2e-test with: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 8b0408a96..60d31146a 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -156,3 +156,23 @@ jobs: run: | docker image load --input merlin-pyfunc-base.${{ inputs.version }}.tar docker push ${{ env.DOCKER_IMAGE_TAG }} + + publish-observation-publisher: + runs-on: ubuntu-latest + env: + DOCKER_IMAGE_TAG: "ghcr.io/${{ github.repository }}/merlin-observation-publisher:${{ inputs.version }}" + steps: + - name: Log in to the Container registry + uses: docker/login-action@v1 + with: + registry: ${{ env.DOCKER_REGISTRY }} + username: ${{ github.actor }} + password: ${{ secrets.GITHUB_TOKEN }} + - name: Download Observation Publisher Docker Artifact + uses: actions/download-artifact@v2 + with: + name: merlin-observation-publisher.${{ inputs.version }}.tar + - name: Retag and Push Docker Image + run: | + docker image load --input merlin-observation-publisher.${{ inputs.version }}.tar + docker push ${{ env.DOCKER_IMAGE_TAG }} diff --git a/python/Makefile b/python/Makefile index aa443b7d6..09d88ccc5 100644 --- a/python/Makefile +++ b/python/Makefile @@ -1,6 +1,6 @@ -IMAGE_TAG=dev +OBSERVATION_PUBLISHER_IMAGE_TAG ?= observation-publisher:dev .PHONY: observation-publisher observation-publisher: @echo "Building image for observation publisher..." - @docker build -t observation-publisher:${IMAGE_TAG} -f observation-publisher/Dockerfile . + @docker build -t ${OBSERVATION_PUBLISHER_IMAGE_TAG} -f observation-publisher/Dockerfile --progress plain . diff --git a/python/observation-publisher/Dockerfile b/python/observation-publisher/Dockerfile index d4e2a96ad..0724fc312 100644 --- a/python/observation-publisher/Dockerfile +++ b/python/observation-publisher/Dockerfile @@ -1,14 +1,12 @@ FROM python:3.10 WORKDIR /mlobs - +COPY sdk ./sdk +WORKDIR /mlobs/observation-publisher COPY observation-publisher/requirements.txt . -COPY sdk/ ./sdk -ENV SDK_PATH=/mlobs/sdk RUN pip install -r requirements.txt RUN rm requirements.txt -RUN rm -rf sdk -COPY observation-publisher/conf/ ./conf -COPY observation-publisher/publisher/ ./publisher - +WORKDIR /mlobs +COPY observation-publisher ./observation-publisher +WORKDIR /mlobs/observation-publisher ENTRYPOINT ["python", "-m", "publisher"] \ No newline at end of file diff --git a/python/observation-publisher/Makefile b/python/observation-publisher/Makefile index 4f107d7d0..d30f50b97 100644 --- a/python/observation-publisher/Makefile +++ b/python/observation-publisher/Makefile @@ -4,6 +4,12 @@ ENVIRONMENT_CONFIG = "example-override" setup: @echo "Setting up environment..." @pip install -r requirements.txt --use-pep517 + @pip install -r requirements-dev.txt + +.PHONY: pip-compile +pip-compile: + @echo "Compiling requirements..." + @python -m piptools compile .PHONY: test test: @@ -13,4 +19,4 @@ test: .PHONY: run run: @echo "Running observation publisher..." - @python -m observation_publisher +environment=${ENVIRONMENT_CONFIG} \ No newline at end of file + @python -m publisher +environment=${ENVIRONMENT_CONFIG} diff --git a/python/observation-publisher/README.md b/python/observation-publisher/README.md index a3368b9ba..2a6698afb 100644 --- a/python/observation-publisher/README.md +++ b/python/observation-publisher/README.md @@ -16,6 +16,16 @@ make run ## Development +### Setup +```bash +make setup +``` + +### Updating requirements.txt +Make changes on requirements.in, then execute +```bash +make pip-compile +``` ### Run test ```bash diff --git a/python/observation-publisher/conf/environment/example-override.yaml b/python/observation-publisher/conf/environment/example-override.yaml index 03f59e2b2..276352df2 100644 --- a/python/observation-publisher/conf/environment/example-override.yaml +++ b/python/observation-publisher/conf/environment/example-override.yaml @@ -1,53 +1,20 @@ model_id: "test-model" model_version: "0.1.0" inference_schema: - # Supported model types: - # - BINARY_CLASSIFICATION - # - MULTICLASS_CLASSIFICATION - # - REGRESSION - # - RANKING - # The prediction output schema that is corresponded to the model - # type must be provided. + # Inference schema associated with the model id and version. For full documentation on the support configuration, + # refer to the Merlin SDK. Example below is for a binary classification model. - # Example for binary classification - type: "BINARY_CLASSIFICATION" - binary_classification: - # The classification label STRING of this event. - prediction_label_column: "label" - # Optional: The likelihood of the event (Probability between 0 to 1.0). + model_prediction_output: + output_class: "BinaryClassificationOutput" prediction_score_column: "score" + actual_label_column: "actual_label" + positive_class_label: "positive" + negative_class_label: "negative" + score_threshold: 0.5 -# # Example for multiclass classification -# type: "MULTICLASS_CLASSIFICATION" -# multiclass_classification: -# # The classification label STRING of this event. -# prediction_label_column: "label" -# # Optional: The likelihood of the event (Probability between 0 to 1.0). -# prediction_score_column: "score" - -# # Example for regression -# type: "REGRESSION" -# regression: -# # FLOAT64 value for the prediction value. -# prediction_score_column: "score" - -# # Example for ranking -# type: "RANKING" -# ranking: -# # A group of predictions within which items are ranked.. -# prediction_group_id_column: "prediction_group" -# # INT64 value for the rank of the prediction within the group. -# rank_column: "rank" - - # Column name to data types mapping for feature columns. Supported types are: - # - INT64 - # - FLOAT64 - # - STRING feature_types: - distance: "INT64" - transaction: "FLOAT64" - # Optional: Column name to be used for prediction id. - # If not provided, it's assumed to be prediction_id. + distance: "int64" + transaction: "float64" prediction_id_column: "prediction_id" observability_backend: # Supported backend types: diff --git a/python/observation-publisher/publisher/__main__.py b/python/observation-publisher/publisher/__main__.py index 26074136b..6454173bb 100644 --- a/python/observation-publisher/publisher/__main__.py +++ b/python/observation-publisher/publisher/__main__.py @@ -1,4 +1,5 @@ import hydra +from merlin.observability.inference import InferenceSchema from omegaconf import OmegaConf from publisher.config import PublisherConfig @@ -12,11 +13,18 @@ def start_consumer(cfg: PublisherConfig) -> None: if missing_keys: raise RuntimeError(f"Got missing keys in config:\n{missing_keys}") prediction_log_consumer = new_consumer(cfg.environment.observation_source) + inference_schema = InferenceSchema.from_dict( + OmegaConf.to_container(cfg.environment.inference_schema) + ) observation_sink = new_observation_sink( - cfg.environment.observability_backend, cfg.environment.model + config=cfg.environment.observability_backend, + inference_schema=inference_schema, + model_id=cfg.environment.model_id, + model_version=cfg.environment.model_version, ) prediction_log_consumer.start_polling( - observation_sink=observation_sink, model_spec=cfg.environment.model + observation_sink=observation_sink, + inference_schema=inference_schema, ) diff --git a/python/observation-publisher/publisher/config.py b/python/observation-publisher/publisher/config.py index a2e8e3d0c..f55f10da3 100644 --- a/python/observation-publisher/publisher/config.py +++ b/python/observation-publisher/publisher/config.py @@ -1,9 +1,8 @@ from dataclasses import dataclass -from enum import Enum, unique +from enum import Enum from typing import Optional from hydra.core.config_store import ConfigStore -from merlin.observability.inference import InferenceSchema @dataclass @@ -12,9 +11,8 @@ class ArizeConfig: space_key: str -@unique class ObservabilityBackendType(Enum): - ARIZE = 1 + ARIZE = "arize" @dataclass @@ -22,10 +20,15 @@ class ObservabilityBackend: type: ObservabilityBackendType arize_config: Optional[ArizeConfig] = None + def __post_init__(self): + if self.type == ObservabilityBackendType.ARIZE: + assert ( + self.arize_config is not None + ), "Arize config must be set for Arize observability backend" + -@unique class ObservationSource(Enum): - KAFKA = 1 + KAFKA = "kafka" @dataclass @@ -43,12 +46,18 @@ class ObservationSourceConfig: type: ObservationSource kafka_config: Optional[KafkaConsumerConfig] = None + def __post_init__(self): + if self.type == ObservationSource.KAFKA: + assert ( + self.kafka_config is not None + ), "Kafka config must be set for Kafka observation source" + @dataclass class Environment: model_id: str model_version: str - inference_schema: InferenceSchema + inference_schema: dict observability_backend: ObservabilityBackend observation_source: ObservationSourceConfig diff --git a/python/observation-publisher/publisher/observability_backend.py b/python/observation-publisher/publisher/observability_backend.py index f160210da..71463fe53 100644 --- a/python/observation-publisher/publisher/observability_backend.py +++ b/python/observation-publisher/publisher/observability_backend.py @@ -1,15 +1,21 @@ import abc -from typing import List +from typing import Tuple import pandas as pd from arize.pandas.logger import Client from arize.pandas.logger import Schema as ArizeSchema +from arize.pandas.validation.errors import ValidationFailure from arize.utils.types import Environments from arize.utils.types import ModelTypes as ArizeModelType -from merlin.observability.inference import InferenceSchema, InferenceType +from merlin.observability.inference import ( + InferenceSchema, + RegressionOutput, + BinaryClassificationOutput, + RankingOutput, + ObservationType, +) -from publisher.config import (ArizeConfig, ObservabilityBackend, - ObservabilityBackendType) +from publisher.config import ObservabilityBackend, ObservabilityBackendType from publisher.prediction_log_parser import PREDICTION_LOG_TIMESTAMP_COLUMN @@ -19,77 +25,74 @@ def write(self, dataframe: pd.DataFrame): raise NotImplementedError -def map_to_arize_schema(inference_schema: InferenceSchema) -> List[ArizeSchema]: - # One log will be published per model schema - match inference_schema.type: - case InferenceType.BINARY_CLASSIFICATION: - return [ - ArizeSchema( - feature_column_names=inference_schema.feature_columns, - prediction_label_column_name=inference_schema.binary_classification.prediction_label_column, - prediction_score_column_name=inference_schema.binary_classification.prediction_score_column, - prediction_id_column_name=inference_schema.prediction_id_column, - timestamp_column_name=PREDICTION_LOG_TIMESTAMP_COLUMN, - ) - ] - case InferenceType.MULTICLASS_CLASSIFICATION: - return [ - ArizeSchema( - feature_column_names=inference_schema.feature_columns, - prediction_label_column_name=prediction_label_column, - prediction_score_column_name=prediction_score_column, - prediction_id_column_name=inference_schema.prediction_id_column, - timestamp_column_name=PREDICTION_LOG_TIMESTAMP_COLUMN, - ) - for prediction_label_column, prediction_score_column in zip( - inference_schema.multiclass_classification.prediction_label_columns, - inference_schema.multiclass_classification.prediction_score_columns, - ) - ] - case InferenceType.REGRESSION: - return [ - ArizeSchema( - feature_column_names=inference_schema.feature_columns, - prediction_score_column_name=inference_schema.regression.prediction_score_column, - prediction_id_column_name=inference_schema.prediction_id_column, - timestamp_column_name=PREDICTION_LOG_TIMESTAMP_COLUMN, - ) - ] - case InferenceType.RANKING: - return [ - ArizeSchema( - feature_column_names=inference_schema.feature_columns, - rank_column_name=inference_schema.ranking.rank_column, - prediction_group_id_column_name=inference_schema.ranking.prediction_group_id_column, - prediction_id_column_name=inference_schema.prediction_id_column, - timestamp_column_name=PREDICTION_LOG_TIMESTAMP_COLUMN, - ) - ] - - class ArizeSink(ObservationSink): def __init__( self, - config: ArizeConfig, + arize_client: Client, inference_schema: InferenceSchema, model_id: str, model_version: str, ): - self._client = Client(space_key=config.space_key, api_key=config.api_key) + self._client = arize_client self._model_id = model_id self._model_version = model_version self._inference_schema = inference_schema + def common_arize_schema_attributes(self) -> dict: + return dict( + feature_column_names=self._inference_schema.feature_columns, + prediction_id_column_name=self._inference_schema.prediction_id_column, + timestamp_column_name=PREDICTION_LOG_TIMESTAMP_COLUMN, + tag_column_names=self._inference_schema.tag_columns, + ) + + def to_arize_schema(self) -> Tuple[ArizeModelType, ArizeSchema]: + prediction_output = self._inference_schema.model_prediction_output + if isinstance(prediction_output, BinaryClassificationOutput): + schema_attributes = self.common_arize_schema_attributes() | dict( + prediction_label_column_name=prediction_output.prediction_label_column, + prediction_score_column_name=prediction_output.prediction_score_column, + ) + model_type = ArizeModelType.BINARY_CLASSIFICATION + elif isinstance(prediction_output, RegressionOutput): + schema_attributes = self.common_arize_schema_attributes() | dict( + prediction_score_column_name=prediction_output.prediction_score_column, + ) + model_type = ArizeModelType.REGRESSION + elif isinstance(prediction_output, RankingOutput): + schema_attributes = self.common_arize_schema_attributes() | dict( + rank_column_name=prediction_output.rank_column, + prediction_group_id_column_name=prediction_output.prediction_group_id_column, + ) + model_type = ArizeModelType.RANKING + else: + raise ValueError( + f"Unknown prediction output type: {type(prediction_output)}" + ) + + return model_type, ArizeSchema(**schema_attributes) + def write(self, df: pd.DataFrame): - for schema in map_to_arize_schema(self._inference_schema): + processed_df = self._inference_schema.model_prediction_output.preprocess( + df, [ObservationType.FEATURE, ObservationType.PREDICTION] + ) + model_type, arize_schema = self.to_arize_schema() + try: self._client.log( - dataframe=df, + dataframe=processed_df, environment=Environments.PRODUCTION, - schema=schema, + schema=arize_schema, model_id=self._model_id, - model_type=ArizeModelType(self._inference_schema.type.name), + model_type=model_type, model_version=self._model_version, ) + except ValidationFailure as e: + error_mesage = "\n".join([err.error_message() for err in e.errors]) + print(f"Failed to log to Arize: {error_mesage}") + raise e + except Exception as e: + print(f"Failed to log to Arize: {e}") + raise e def new_observation_sink( @@ -99,8 +102,9 @@ def new_observation_sink( model_version: str, ) -> ObservationSink: if config.type == ObservabilityBackendType.ARIZE: + client = Client(space_key=config.arize_config.space_key, api_key=config.arize_config.api_key) return ArizeSink( - config=config.arize_config, + arize_client = client, inference_schema=inference_schema, model_id=model_id, model_version=model_version, diff --git a/python/observation-publisher/publisher/prediction_log_consumer.py b/python/observation-publisher/publisher/prediction_log_consumer.py index 21cffc461..c2e8f21f4 100644 --- a/python/observation-publisher/publisher/prediction_log_consumer.py +++ b/python/observation-publisher/publisher/prediction_log_consumer.py @@ -7,12 +7,17 @@ from confluent_kafka import Consumer, KafkaException from merlin.observability.inference import InferenceSchema -from publisher.config import (KafkaConsumerConfig, ObservationSource, - ObservationSourceConfig) +from publisher.config import ( + KafkaConsumerConfig, + ObservationSource, + ObservationSourceConfig, +) from publisher.observability_backend import ObservationSink -from publisher.prediction_log_parser import (PREDICTION_LOG_TIMESTAMP_COLUMN, - parse_struct_to_feature_table, - parse_struct_to_result_table) +from publisher.prediction_log_parser import ( + PREDICTION_LOG_TIMESTAMP_COLUMN, + parse_struct_to_feature_table, + parse_struct_to_result_table, +) class PredictionLogConsumer(abc.ABC): @@ -37,7 +42,7 @@ def start_polling( if len(logs) == 0: continue df = log_batch_to_dataframe(logs, inference_schema) - observation_sink.write(dataframe=df) + observation_sink.write(df) self.commit() finally: self.close() @@ -48,6 +53,7 @@ def __init__(self, config: KafkaConsumerConfig): consumer_config = { "bootstrap.servers": config.bootstrap_servers, "group.id": config.group_id, + "enable.auto.commit": False, } if config.additional_consumer_config is not None: @@ -61,13 +67,12 @@ def __init__(self, config: KafkaConsumerConfig): def poll_new_logs(self) -> List[PredictionLog]: messages = self._consumer.consume(self._batch_size, timeout=self._poll_timeout) errors = [msg.error() for msg in messages if msg.error() is not None] - if len(errors) > 0: print(f"Last encountered error: {errors[-1]}") raise KafkaException(errors[-1]) return [ - parse_message_to_prediction_log(msg.value().decode("utf-8")) + parse_message_to_prediction_log(msg.value()) for msg in messages if (msg is not None and msg.error() is None) ] diff --git a/python/observation-publisher/publisher/prediction_log_parser.py b/python/observation-publisher/publisher/prediction_log_parser.py index 3a9f67815..13024de66 100644 --- a/python/observation-publisher/publisher/prediction_log_parser.py +++ b/python/observation-publisher/publisher/prediction_log_parser.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union import numpy as np from google.protobuf.internal.well_known_types import ListValue, Struct @@ -63,7 +63,7 @@ def parse_struct_to_result_table( table_struct: Struct, inference_schema: InferenceSchema ) -> PredictionLogResultsTable: columns = [c for c in table_struct["columns"]] - column_types = inference_schema.prediction_column_types + column_types = inference_schema.model_prediction_output.prediction_types() return PredictionLogResultsTable( columns=columns, rows=[ diff --git a/python/observation-publisher/requirements-dev.txt b/python/observation-publisher/requirements-dev.txt index 55b033e90..1c1df9505 100644 --- a/python/observation-publisher/requirements-dev.txt +++ b/python/observation-publisher/requirements-dev.txt @@ -1 +1,2 @@ -pytest \ No newline at end of file +pip-tools==7.3.0 +pytest==7.4.3 \ No newline at end of file diff --git a/python/observation-publisher/requirements.in b/python/observation-publisher/requirements.in new file mode 100644 index 000000000..2f5ec3225 --- /dev/null +++ b/python/observation-publisher/requirements.in @@ -0,0 +1,6 @@ +confluent-kafka>=2.3.0 +caraml-upi-protos>=1.0.0 +arize==7.7.* +hydra-core>=1.3.0 +pandas>=1.0.0 +-e file:../sdk \ No newline at end of file diff --git a/python/observation-publisher/requirements.txt b/python/observation-publisher/requirements.txt index a76377f92..2062495ee 100644 --- a/python/observation-publisher/requirements.txt +++ b/python/observation-publisher/requirements.txt @@ -1,7 +1,274 @@ -confluent-kafka>=2.3.0 -caraml-upi-protos>=1.0.0 -arize==7.6.* -hydra-core>=1.3.0 -pandas>=1.0.0 +# +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: +# +# pip-compile +# +-e file:../sdk + # via -r requirements.in +alembic==1.13.0 + # via mlflow +antlr4-python3-runtime==4.9.3 + # via + # hydra-core + # omegaconf +arize==7.7.2 + # via -r requirements.in +arrow==1.3.0 + # via cookiecutter +binaryornot==0.4.4 + # via cookiecutter +blinker==1.7.0 + # via flask +boto3==1.33.11 + # via merlin-sdk +botocore==1.33.11 + # via + # boto3 + # s3transfer +cachetools==5.3.2 + # via google-auth +caraml-auth-google==0.0.0.post7 + # via merlin-sdk +caraml-upi-protos==1.0.0 + # via + # -r requirements.in + # merlin-sdk +certifi==2023.11.17 + # via + # merlin-sdk + # requests +chardet==5.2.0 + # via binaryornot +charset-normalizer==3.3.2 + # via requests +click==8.1.3 + # via + # cookiecutter + # databricks-cli + # flask + # merlin-sdk + # mlflow +cloudpickle==2.0.0 + # via + # merlin-sdk + # mlflow +confluent-kafka==2.3.0 + # via -r requirements.in +cookiecutter==2.5.0 + # via merlin-sdk +databricks-cli==0.18.0 + # via mlflow +dataclasses-json==0.6.3 + # via merlin-sdk +docker==6.1.3 + # via + # merlin-sdk + # mlflow +entrypoints==0.4 + # via mlflow +flask==2.3.3 + # via + # mlflow + # prometheus-flask-exporter +gitdb==4.0.11 + # via gitpython +gitpython==3.1.40 + # via mlflow +google-api-core==2.15.0 + # via + # google-cloud-core + # google-cloud-storage +google-auth==2.25.2 + # via + # caraml-auth-google + # google-api-core + # google-cloud-core + # google-cloud-storage +google-cloud-core==2.4.1 + # via google-cloud-storage +google-cloud-storage==2.13.0 + # via merlin-sdk +google-crc32c==1.5.0 + # via + # google-cloud-storage + # google-resumable-media +google-resumable-media==2.6.0 + # via google-cloud-storage +googleapis-common-protos==1.62.0 + # via + # arize + # caraml-upi-protos + # google-api-core +grpcio==1.60.0 + # via grpcio-tools +grpcio-tools==1.60.0 + # via caraml-upi-protos +gunicorn==20.1.0 + # via mlflow +hydra-core==1.3.2 + # via -r requirements.in +idna==3.6 + # via requests +importlib-metadata==5.2.0 + # via mlflow +itsdangerous==2.1.2 + # via flask +jinja2==3.1.2 + # via + # cookiecutter + # flask +jmespath==1.0.1 + # via + # boto3 + # botocore +mako==1.3.0 + # via alembic +markdown-it-py==3.0.0 + # via rich +markupsafe==2.1.3 + # via + # jinja2 + # mako + # werkzeug +marshmallow==3.20.1 + # via dataclasses-json +mdurl==0.1.2 + # via markdown-it-py +mlflow==1.30.1 + # via merlin-sdk +mypy-extensions==1.0.0 + # via typing-inspect +numpy==1.23.5 + # via + # merlin-sdk + # mlflow + # pandas + # pyarrow + # scipy +oauthlib==3.2.2 + # via databricks-cli +omegaconf==2.3.0 + # via hydra-core +packaging==21.3 + # via + # docker + # hydra-core + # marshmallow + # mlflow +pandas==1.5.3 + # via + # -r requirements.in + # arize + # mlflow +prometheus-client==0.19.0 + # via prometheus-flask-exporter +prometheus-flask-exporter==0.23.0 + # via mlflow +protobuf==4.25.1 + # via + # arize + # google-api-core + # googleapis-common-protos + # grpcio-tools + # merlin-sdk + # mlflow +pyarrow==14.0.1 + # via arize +pyasn1==0.5.1 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.3.0 + # via google-auth +pygments==2.17.2 + # via rich +pyjwt==2.8.0 + # via databricks-cli +pyparsing==3.1.1 + # via packaging +pyprind==2.11.3 + # via merlin-sdk +python-dateutil==2.8.2 + # via + # arrow + # botocore + # merlin-sdk + # pandas +python-slugify==8.0.1 + # via cookiecutter +pytz==2022.7.1 + # via + # mlflow + # pandas +pyyaml==6.0.1 + # via + # cookiecutter + # merlin-sdk + # mlflow + # omegaconf +querystring-parser==1.2.4 + # via mlflow +requests==2.31.0 + # via + # cookiecutter + # databricks-cli + # docker + # google-api-core + # google-cloud-storage + # mlflow + # requests-futures +requests-futures==1.0.0 + # via arize +rich==13.7.0 + # via cookiecutter +rsa==4.9 + # via google-auth +s3transfer==0.8.2 + # via boto3 +scipy==1.11.4 + # via mlflow +six==1.16.0 + # via + # databricks-cli + # merlin-sdk + # python-dateutil + # querystring-parser +smmap==5.0.1 + # via gitdb +sqlalchemy==1.4.50 + # via + # alembic + # mlflow +sqlparse==0.4.4 + # via mlflow +tabulate==0.9.0 + # via databricks-cli +text-unidecode==1.3 + # via python-slugify +tqdm==4.66.1 + # via arize +types-python-dateutil==2.8.19.14 + # via arrow +typing-extensions==4.9.0 + # via + # alembic + # typing-inspect +typing-inspect==0.9.0 + # via dataclasses-json +urllib3==2.0.7 + # via + # botocore + # databricks-cli + # docker + # merlin-sdk + # requests +websocket-client==1.7.0 + # via docker +werkzeug==3.0.1 + # via flask +zipp==3.17.0 + # via importlib-metadata -file:${SDK_PATH} +# The following packages are considered to be unsafe in a requirements file: +# setuptools diff --git a/python/observation-publisher/tests/test_config.py b/python/observation-publisher/tests/test_config.py index e8cf2d1e7..6483c7cde 100644 --- a/python/observation-publisher/tests/test_config.py +++ b/python/observation-publisher/tests/test_config.py @@ -1,14 +1,22 @@ import dataclasses from hydra import compose, initialize -from merlin.observability.inference import (BinaryClassificationOutput, - InferenceSchema, InferenceType, - ValueType) +from merlin.observability.inference import ( + InferenceSchema, + ValueType, +) +from omegaconf import OmegaConf -from publisher.config import (ArizeConfig, Environment, KafkaConsumerConfig, - ObservabilityBackend, ObservabilityBackendType, - ObservationSource, ObservationSourceConfig, - PublisherConfig) +from publisher.config import ( + ArizeConfig, + Environment, + KafkaConsumerConfig, + ObservabilityBackend, + ObservabilityBackendType, + ObservationSource, + ObservationSourceConfig, + PublisherConfig, +) def test_config_initialization(): @@ -18,15 +26,18 @@ def test_config_initialization(): environment=Environment( model_id="test-model", model_version="0.1.0", - inference_schema=InferenceSchema( - type=InferenceType.BINARY_CLASSIFICATION, + inference_schema=dict( feature_types={ "distance": ValueType.INT64, "transaction": ValueType.FLOAT64, }, - binary_classification=BinaryClassificationOutput( - prediction_label_column="label", + model_prediction_output=dict( + output_class="BinaryClassificationOutput", prediction_score_column="score", + actual_label_column="actual_label", + positive_class_label="positive", + negative_class_label="negative", + score_threshold=0.5, ), ), observability_backend=ObservabilityBackend( @@ -50,7 +61,10 @@ def test_config_initialization(): ), ) ) - assert cfg.environment.inference_schema == dataclasses.asdict( + parsed_schema: InferenceSchema = InferenceSchema.from_dict( + OmegaConf.to_container(cfg.environment.inference_schema) + ) + assert parsed_schema == InferenceSchema.from_dict( expected_cfg.environment.inference_schema ) assert cfg.environment.observability_backend == dataclasses.asdict( diff --git a/python/observation-publisher/tests/test_observability_backend.py b/python/observation-publisher/tests/test_observability_backend.py index f711d2c35..2a5437440 100644 --- a/python/observation-publisher/tests/test_observability_backend.py +++ b/python/observation-publisher/tests/test_observability_backend.py @@ -1,70 +1,106 @@ from datetime import datetime +from typing import Optional import pandas as pd import pyarrow as pa -from arize.pandas.validation.validator import Validator -from arize.utils.types import Environments -from arize.utils.types import ModelTypes as ArizeModelType -from merlin.observability.inference import (BinaryClassificationOutput, - InferenceSchema, InferenceType, - ValueType) +from arize.pandas.logger import Client +from merlin.observability.inference import ( + BinaryClassificationOutput, + InferenceSchema, + ValueType, + RankingOutput, +) +from requests import Response -from publisher.observability_backend import map_to_arize_schema +from publisher.observability_backend import ArizeSink -def test_arize_schema_mapping(): +class MockResponse(Response): + def __init__(self, df, reason, status_code): + super().__init__() + self.df = df + self.reason = reason + self.status_code = status_code + + +class MockArizeClient(Client): + def _post_file( + self, + path: str, + schema: bytes, + sync: Optional[bool], + timeout: Optional[float] = None, + ) -> Response: + return MockResponse(pa.ipc.open_stream(pa.OSFile(path)).read_pandas(), "Success", 200) + + +def test_binary_classification_model_preprocessing_for_arize(): inference_schema = InferenceSchema( - type=InferenceType.BINARY_CLASSIFICATION, feature_types={ - "acceptance_rate": ValueType.FLOAT64, - "minutes_since_last_order": ValueType.INT64, - "service_type": ValueType.STRING, - "prediction_score": ValueType.FLOAT64, - "prediction_label": ValueType.STRING, + "rating": ValueType.FLOAT64, }, - binary_classification=BinaryClassificationOutput( - prediction_label_column="prediction_label", + model_prediction_output=BinaryClassificationOutput( prediction_score_column="prediction_score", + actual_label_column="actual_label", + positive_class_label="fraud", + negative_class_label="non fraud", + score_threshold=0.5, ), ) - arize_schemas = map_to_arize_schema(inference_schema) - assert len(arize_schemas) == 1 - input_dataframe = pd.DataFrame.from_records( + arize_client = MockArizeClient(api_key="test", space_key="test") + arize_sink = ArizeSink( + arize_client, + inference_schema, + "test-model", + "0.1.0", + ) + request_timestamp = datetime.now() + input_df = pd.DataFrame.from_records( [ - [0.8, 24, "FOOD", 0.9, "non fraud", "1234a", datetime(2021, 1, 1, 0, 0, 0)], - [0.5, 2, "RIDE", 0.5, "fraud", "1234b", datetime(2021, 1, 1, 0, 0, 0)], - [1.0, 13, "CAR", 0.4, "non fraud", "5678c", datetime(2021, 1, 1, 0, 0, 0)], - [0.4, 60, "RIDE", 0.2, "non fraud", "5678d", datetime(2021, 1, 1, 0, 0, 0)], + [0.8, 0.4, "1234a", request_timestamp], + [0.5, 0.9, "1234b", request_timestamp], ], columns=[ - "acceptance_rate", - "minutes_since_last_order", - "service_type", + "rating", "prediction_score", - "prediction_label", "prediction_id", "request_timestamp", ], ) - errors = Validator.validate_required_checks( - dataframe=input_dataframe, - model_id="test_model", - schema=arize_schemas[0], - model_version="0.1.0", + arize_sink.write(input_df) + + +def test_ranking_model_preprocessing_for_arize(): + inference_schema = InferenceSchema( + feature_types={ + "rating": ValueType.FLOAT64, + }, + model_prediction_output=RankingOutput( + rank_score_column="rank_score", + prediction_group_id_column="order_id", + relevance_score_column="relevance_score_column", + ), ) - assert len(errors) == 0 - errors = Validator.validate_params( - dataframe=input_dataframe, - model_id="test_model", - model_type=ArizeModelType.BINARY_CLASSIFICATION, - environment=Environments.PRODUCTION, - schema=arize_schemas[0], - model_version="0.1.0", + request_timestamp = datetime.now() + input_df = pd.DataFrame.from_records( + [ + [5.0, 1.0, "1234", "1001", request_timestamp], + [4.0, 0.9, "1234", "1001", request_timestamp], + [3.0, 0.8, "1234", "1001", request_timestamp], + ], + columns=[ + "rating", + "rank_score", + "prediction_id", + "order_id", + "request_timestamp", + ], ) - assert len(errors) == 0 - Validator.validate_types( - model_type=ArizeModelType.BINARY_CLASSIFICATION, - schema=arize_schemas[0], - pyarrow_schema=pa.Schema.from_pandas(input_dataframe), + arize_client = MockArizeClient(api_key="test", space_key="test") + arize_sink = ArizeSink( + arize_client, + inference_schema, + "test-model", + "0.1.0", ) - assert len(errors) == 0 + arize_sink.write(input_df) diff --git a/python/observation-publisher/tests/test_prediction_log_consumer.py b/python/observation-publisher/tests/test_prediction_log_consumer.py index af0695a0b..db2a71569 100644 --- a/python/observation-publisher/tests/test_prediction_log_consumer.py +++ b/python/observation-publisher/tests/test_prediction_log_consumer.py @@ -4,9 +4,11 @@ import numpy as np import pandas as pd from caraml.upi.v1.prediction_log_pb2 import PredictionLog -from merlin.observability.inference import (BinaryClassificationOutput, - InferenceSchema, InferenceType, - ValueType) +from merlin.observability.inference import ( + BinaryClassificationOutput, + InferenceSchema, + ValueType, +) from pandas._testing import assert_frame_equal from publisher.prediction_log_consumer import log_batch_to_dataframe @@ -56,17 +58,17 @@ def test_log_to_dataframe(): model_id = "test_model" model_version = "0.1.0" inference_schema = InferenceSchema( - type=InferenceType.BINARY_CLASSIFICATION, feature_types={ "acceptance_rate": ValueType.FLOAT64, "minutes_since_last_order": ValueType.INT64, "service_type": ValueType.STRING, - "prediction_score": ValueType.FLOAT64, - "prediction_label": ValueType.STRING, }, - binary_classification=BinaryClassificationOutput( - prediction_label_column="prediction_label", + model_prediction_output=BinaryClassificationOutput( prediction_score_column="prediction_score", + actual_label_column="actual_label", + positive_class_label="fraud", + negative_class_label="non fraud", + score_threshold=0.5, ), ) input_columns = [ @@ -74,7 +76,7 @@ def test_log_to_dataframe(): "minutes_since_last_order", "service_type", ] - output_columns = ["prediction_score", "prediction_label"] + output_columns = ["prediction_score"] prediction_logs = [ new_prediction_log( prediction_id="1234", @@ -87,8 +89,8 @@ def test_log_to_dataframe(): ], output_columns=output_columns, output_data=[ - [0.9, "non fraud"], - [0.5, "fraud"], + [0.9], + [0.5], ], request_timestamp=datetime(2021, 1, 1, 0, 0, 0), row_ids=["a", "b"], @@ -104,8 +106,8 @@ def test_log_to_dataframe(): ], output_columns=output_columns, output_data=[ - [0.4, "non fraud"], - [0.2, "non fraud"], + [0.4], + [0.2], ], request_timestamp=datetime(2021, 1, 1, 0, 0, 0), row_ids=["c", "d"], @@ -114,17 +116,16 @@ def test_log_to_dataframe(): prediction_logs_df = log_batch_to_dataframe(prediction_logs, inference_schema) expected_df = pd.DataFrame.from_records( [ - [0.8, 24, "FOOD", 0.9, "non fraud", "1234a", datetime(2021, 1, 1, 0, 0, 0)], - [0.5, 2, "RIDE", 0.5, "fraud", "1234b", datetime(2021, 1, 1, 0, 0, 0)], - [1.0, 13, "CAR", 0.4, "non fraud", "5678c", datetime(2021, 1, 1, 0, 0, 0)], - [0.4, 60, "RIDE", 0.2, "non fraud", "5678d", datetime(2021, 1, 1, 0, 0, 0)], + [0.8, 24, "FOOD", 0.9, "1234a", datetime(2021, 1, 1, 0, 0, 0)], + [0.5, 2, "RIDE", 0.5, "1234b", datetime(2021, 1, 1, 0, 0, 0)], + [1.0, 13, "CAR", 0.4, "5678c", datetime(2021, 1, 1, 0, 0, 0)], + [0.4, 60, "RIDE", 0.2, "5678d", datetime(2021, 1, 1, 0, 0, 0)], ], columns=[ "acceptance_rate", "minutes_since_last_order", "service_type", "prediction_score", - "prediction_label", "prediction_id", "request_timestamp", ], @@ -136,13 +137,15 @@ def test_empty_column_conversion_to_dataframe(): model_id = "test_model" model_version = "0.1.0" inference_schema = InferenceSchema( - type=InferenceType.BINARY_CLASSIFICATION, feature_types={ "acceptance_rate": ValueType.FLOAT64, }, - binary_classification=BinaryClassificationOutput( - prediction_label_column="prediction_label", + model_prediction_output=BinaryClassificationOutput( prediction_score_column="prediction_score", + actual_label_column="actual_label", + positive_class_label="fraud", + negative_class_label="non fraud", + score_threshold=0.5, ), ) prediction_logs = [ @@ -154,9 +157,9 @@ def test_empty_column_conversion_to_dataframe(): input_data=[ [None], ], - output_columns=["prediction_label", "prediction_score"], + output_columns=["prediction_score"], output_data=[ - ["ACCEPTED", 0.5], + [0.5], ], request_timestamp=datetime(2021, 1, 1, 0, 0, 0), row_ids=["a"], @@ -167,7 +170,6 @@ def test_empty_column_conversion_to_dataframe(): [ [ np.NaN, - "ACCEPTED", 0.5, "1234a", datetime(2021, 1, 1, 0, 0, 0), @@ -175,7 +177,6 @@ def test_empty_column_conversion_to_dataframe(): ], columns=[ "acceptance_rate", - "prediction_label", "prediction_score", "prediction_id", "request_timestamp",