Skip to content

Commit

Permalink
feat: Add training data publisher
Browse files Browse the repository at this point in the history
  • Loading branch information
khorshuheng committed Jan 5, 2024
1 parent 1664f3d commit 8c17b9b
Show file tree
Hide file tree
Showing 13 changed files with 853 additions and 0 deletions.
23 changes: 23 additions & 0 deletions .github/workflows/merlin.yml
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,29 @@ jobs:
path: merlin-observation-publisher.${{ needs.create-version.outputs.version }}.tar
retention-days: ${{ env.ARTIFACT_RETENTION_DAYS }}

build-batch-observation-publisher:
runs-on: ubuntu-latest
needs:
- create-version
env:
DOCKER_REGISTRY: ghcr.io
DOCKER_IMAGE_TAG: "ghcr.io/${{ github.repository }}/merlin-batch-observation-publisher:${{ needs.create-version.outputs.version }}"
steps:
- uses: actions/checkout@v2
- name: Build Observation Publisher Docker
env:
BATCH_OBSERVATION_PUBLISHER_IMAGE_TAG: ${{ env.DOCKER_IMAGE_TAG }}
run: make batch-observation-publisher
working-directory: ./python
- name: Save Batch Observation Publisher Docker
run: docker image save --output merlin-batch-observation-publisher.${{ needs.create-version.outputs.version }}.tar ${{ env.DOCKER_IMAGE_TAG }}
- name: Publish Observation Publisher Docker Artifact
uses: actions/upload-artifact@v2
with:
name: merlin-batch-observation-publisher.${{ needs.create-version.outputs.version }}.tar
path: merlin-batch-observation-publisher.${{ needs.create-version.outputs.version }}.tar
retention-days: ${{ env.ARTIFACT_RETENTION_DAYS }}

e2e-test:
runs-on: ubuntu-latest
needs:
Expand Down
20 changes: 20 additions & 0 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,23 @@ jobs:
run: |
docker image load --input merlin-observation-publisher.${{ inputs.version }}.tar
docker push ${{ env.DOCKER_IMAGE_TAG }}
publish-batch-observation-publisher:
runs-on: ubuntu-latest
env:
DOCKER_IMAGE_TAG: "ghcr.io/${{ github.repository }}/merlin-batch-observation-publisher:${{ inputs.version }}"
steps:
- name: Log in to the Container registry
uses: docker/login-action@v1
with:
registry: ${{ env.DOCKER_REGISTRY }}
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
- name: Download Batch Observation Publisher Docker Artifact
uses: actions/download-artifact@v2
with:
name: merlin-batch-observation-publisher.${{ inputs.version }}.tar
- name: Retag and Push Docker Image
run: |
docker image load --input merlin-batch-observation-publisher.${{ inputs.version }}.tar
docker push ${{ env.DOCKER_IMAGE_TAG }}
6 changes: 6 additions & 0 deletions python/Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
OBSERVATION_PUBLISHER_IMAGE_TAG ?= observation-publisher:dev
BATCH_OBSERVATION_PUBLISHER_IMAGE_TAG ?= batch-observation-publisher:dev

.PHONY: observation-publisher
observation-publisher:
@echo "Building image for observation publisher..."
@docker build -t ${OBSERVATION_PUBLISHER_IMAGE_TAG} -f observation-publisher/Dockerfile --progress plain .

.PHONY: batch-observation-publisher
batch-observation-publisher:
@echo "Building image for batch observation publisher..."
@docker build -t ${BATCH_OBSERVATION_PUBLISHER_IMAGE_TAG} -f batch-observation-publisher/Dockerfile --progress plain .
3 changes: 3 additions & 0 deletions python/batch-observation-publisher/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
.idea
venv
__pycache__
18 changes: 18 additions & 0 deletions python/batch-observation-publisher/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
FROM python:3.10-slim-buster

WORKDIR /root
ENV LANG C.UTF-8
ENV LC_ALL C.UTF-8

RUN apt-get update && apt-get install -y build-essential curl
RUN curl -sL https://ctl.flyte.org/install | bash -s -- -b /usr/local/bin v0.8.5

COPY sdk/ ./sdk
WORKDIR batch-observation-publisher
COPY batch-observation-publisher/requirements.txt requirements.txt
RUN pip install -r requirements.txt
COPY batch-observation-publisher/publisher publisher
COPY batch-observation-publisher/Makefile Makefile
ENV PYTHONPATH /root/batch-observation-publisher
ARG tag
ENV FLYTE_INTERNAL_IMAGE $tag
25 changes: 25 additions & 0 deletions python/batch-observation-publisher/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
WORKFLOW_IMAGE ?= batch-observation-publisher:dev
WORKFLOW_VERSION ?= dev
FLYTE_PROJECT ?= flyteexamples
FLYTE_DOMAIN ?= development

.PHONY: setup
setup:
@echo "Setting up environment..."
@pip install -r requirements.txt --use-pep517
@pip install -r requirements-dev.txt

.PHONY: pip-compile
pip-compile:
@echo "Compiling requirements..."
@python -m piptools compile

.PHONY: package
package:
@echo "Packaging..."
@pyflyte --pkgs publisher package -f --image ${WORKFLOW_IMAGE}

.PHONY: register
register: package
@echo "Registering..."
flytectl register files --version ${WORKFLOW_VERSION} --archive flyte-package.tgz --project ${FLYTE_PROJECT} --domain ${FLYTE_DOMAIN}
Empty file.
161 changes: 161 additions & 0 deletions python/batch-observation-publisher/publisher/ingestion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
from datetime import datetime

import pandas as pd
from arize.pandas.logger import Client
from arize.utils.types import (
Environments,
ModelTypes as ArizeModelType,
Schema as ArizeSchema,
)
from flytekit import current_context, task, Secret
from merlin.observability.inference import (
InferenceSchema,
ObservationType,
BinaryClassificationOutput,
RankingOutput,
RegressionOutput,
PredictionOutput,
)


def get_arize_model_type(prediction_output: PredictionOutput) -> ArizeModelType:
if isinstance(prediction_output, BinaryClassificationOutput):
return ArizeModelType.BINARY_CLASSIFICATION
elif isinstance(prediction_output, RegressionOutput):
return ArizeModelType.REGRESSION
elif isinstance(prediction_output, RankingOutput):
return ArizeModelType.RANKING
else:
raise ValueError(f"Unknown prediction output type: {type(prediction_output)}")


def get_prediction_attributes(prediction_output: PredictionOutput) -> dict:
if isinstance(prediction_output, BinaryClassificationOutput):
return dict(
prediction_label_column_name=prediction_output.prediction_label_column,
prediction_score_column_name=prediction_output.prediction_score_column,
)
elif isinstance(prediction_output, RegressionOutput):
return dict(
prediction_score_column_name=prediction_output.prediction_score_column,
)
elif isinstance(prediction_output, RankingOutput):
return dict(
rank_column_name=prediction_output.rank_column,
)
else:
raise ValueError(f"Unknown prediction output type: {type(prediction_output)}")


def get_ground_truth_attributes(prediction_output: PredictionOutput) -> dict:
if isinstance(prediction_output, BinaryClassificationOutput):
return dict(
actual_label_column_name=prediction_output.actual_label_column,
)
elif isinstance(prediction_output, RegressionOutput):
return dict(
actual_score_column_name=prediction_output.actual_score_column,
)
elif isinstance(prediction_output, RankingOutput):
return dict(
rank_column=prediction_output.rank_column,
)
else:
raise ValueError(f"Unknown prediction output type: {type(prediction_output)}")


def get_arize_training_schema(
inference_schema: InferenceSchema, include_prediction_and_ground_truth: bool
) -> ArizeSchema:
schema_attributes = dict(
tag_column_names=inference_schema.tag_columns,
)
if include_prediction_and_ground_truth:
schema_attributes |= dict(
prediction_id_column_name=inference_schema.prediction_id_column,
)
schema_attributes |= dict(
feature_column_names=inference_schema.feature_columns,
)
schema_attributes |= get_prediction_attributes(
inference_schema.model_prediction_output
)
schema_attributes |= get_ground_truth_attributes(
inference_schema.model_prediction_output
)
return ArizeSchema(**schema_attributes)


def add_default_prediction_and_ground_truth_columns(
inference_schema: InferenceSchema, df: pd.DataFrame
):
prediction_output = inference_schema.model_prediction_output
if isinstance(prediction_output, BinaryClassificationOutput):
df[prediction_output.prediction_score_column] = 0.0
df[
prediction_output.actual_label_column
] = prediction_output.negative_class_label
elif isinstance(prediction_output, RegressionOutput):
df[prediction_output.prediction_score_column] = 0.0
df[prediction_output.actual_score_column] = 0.0
elif isinstance(prediction_output, RankingOutput):
df[prediction_output.rank_column] = (
df.groupby(prediction_output.prediction_group_id_column).cumcount() + 1
)
df[prediction_output.relevance_score_column] = 0.0
else:
raise ValueError(f"Unknown prediction output type: {type(prediction_output)}")


ARIZE_SECRET_GROUP = "mlobs"
ARIZE_SPACE_KEY_SECRET_NAME = "arize_space_key"
ARIZE_API_KEY_SECRET_NAME = "arize_api_key"


@task(
secret_requests=[
Secret(ARIZE_SECRET_GROUP, ARIZE_SPACE_KEY_SECRET_NAME),
Secret(ARIZE_SECRET_GROUP, ARIZE_API_KEY_SECRET_NAME),
]
)
def publish_training_data_to_arize(
df: pd.DataFrame,
inference_schema: dict,
model_id: str,
model_version: str,
include_prediction_and_ground_truth: bool,
):
inference_schema = InferenceSchema.from_dict(inference_schema)
space_key = current_context().secrets.get(
ARIZE_SECRET_GROUP, ARIZE_SPACE_KEY_SECRET_NAME
)
api_key = current_context().secrets.get(
ARIZE_SECRET_GROUP, ARIZE_API_KEY_SECRET_NAME
)
client = Client(space_key=space_key, api_key=api_key)
if not include_prediction_and_ground_truth:
add_default_prediction_and_ground_truth_columns(inference_schema, df)
schema = get_arize_training_schema(
inference_schema, include_prediction_and_ground_truth
)
arize_environment = Environments.TRAINING
processed_df = inference_schema.model_prediction_output.preprocess(
df,
[
ObservationType.FEATURE,
ObservationType.PREDICTION,
ObservationType.GROUND_TRUTH,
],
)
processed_df.reset_index(drop=True, inplace=True)
version = f"{model_version}-{datetime.now().strftime('%Y%m%d')}-{current_context().execution_id.name}"
arize_model_types = get_arize_model_type(inference_schema.model_prediction_output)

client.log(
dataframe=processed_df,
environment=arize_environment,
schema=schema,
model_id=model_id,
model_type=arize_model_types,
model_version=version,
)
55 changes: 55 additions & 0 deletions python/batch-observation-publisher/publisher/sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import abc
from typing import Any

import farmhash
import pandas as pd
from flytekit import task
from merlin.observability.inference import InferenceSchema


class Sampler(abc.ABC):
@abc.abstractmethod
def sample(self, df: pd.DataFrame, sampling_rate: float) -> pd.DataFrame:
raise NotImplementedError


class RandomSampler(Sampler):
def sample(self, df: pd.DataFrame, sampling_rate: float) -> pd.DataFrame:
return df.sample(frac=sampling_rate)


def generate_farmhash_fingerprint(value: Any) -> int:
return farmhash.fingerprint64(str(value))


class ConsistentSampler(Sampler):
def __init__(self, hash_column: str):
self._hash_column = hash_column

def sample(self, df: pd.DataFrame, sampling_rate: float) -> pd.DataFrame:
filtered = df[self._hash_column].apply(
lambda x: (generate_farmhash_fingerprint(x) % 100)
<= int(100 * sampling_rate),
axis=1,
)
return df[filtered]


@task
def sample_dataframe(
df: pd.DataFrame,
sampling_rate: float,
sampling_strategy: str,
inference_schema: dict,
) -> pd.DataFrame:
inference_schema = InferenceSchema.from_dict(inference_schema)
match sampling_strategy:
case "random":
sampler = RandomSampler()
case "consistent":
sampler = ConsistentSampler(
hash_column=inference_schema.prediction_id_column
)
case _:
raise ValueError(f"Invalid sampling strategy: {sampling_strategy}")
return sampler.sample(df=df, sampling_rate=sampling_rate)
56 changes: 56 additions & 0 deletions python/batch-observation-publisher/publisher/workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import pandas as pd
from flytekit import dynamic, Resources

from publisher.ingestion import publish_training_data_to_arize
from publisher.sampling import sample_dataframe


@dynamic
def ingest_training_data(
df: pd.DataFrame,
model_id: str,
model_version: str,
inference_schema: dict,
include_prediction_and_ground_truth: bool,
sampling_rate: float,
sampling_strategy: str,
cpu: str,
memory: str,
):
"""
Ingests batch data to Arize.
:param df: batch data.
:param inference_schema: Inference schema for the given model id and version
:param model_id: Model id.
:param model_version: Model version.
:param include_prediction_and_ground_truth: Send prediction and ground truth together with the features.
If false, only send the features.
:param sampling_rate: Sampling rate (0.0 - 1.0).
:param sampling_strategy: One of ("random", "consistent"). Random sampling will randomly sample the data.
Consistent sampling is deterministic, but will round the sampling rate to the closest percentage integer.
:param memory: Memory request for the ingestion task
:param cpu: CPU request for the ingestion task
:return:
"""
sampled_dataframe = df
if sampling_rate > 0.0:
sampled_dataframe = sample_dataframe(
df=df,
sampling_rate=sampling_rate,
sampling_strategy=sampling_strategy,
inference_schema=inference_schema,
).with_overrides(
requests=Resources(cpu=cpu, mem=memory),
limits=Resources(cpu=cpu, mem=memory),
)

publish_training_data_to_arize(
df=sampled_dataframe,
inference_schema=inference_schema,
model_id=model_id,
model_version=model_version,
include_prediction_and_ground_truth=include_prediction_and_ground_truth,
).with_overrides(
requests=Resources(cpu=cpu, mem=memory), limits=Resources(cpu=cpu, mem=memory)
)
2 changes: 2 additions & 0 deletions python/batch-observation-publisher/requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
pip-tools==7.3.0
pytest==7.4.3
Loading

0 comments on commit 8c17b9b

Please sign in to comment.