-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
1664f3d
commit 8c17b9b
Showing
13 changed files
with
853 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,12 @@ | ||
OBSERVATION_PUBLISHER_IMAGE_TAG ?= observation-publisher:dev | ||
BATCH_OBSERVATION_PUBLISHER_IMAGE_TAG ?= batch-observation-publisher:dev | ||
|
||
.PHONY: observation-publisher | ||
observation-publisher: | ||
@echo "Building image for observation publisher..." | ||
@docker build -t ${OBSERVATION_PUBLISHER_IMAGE_TAG} -f observation-publisher/Dockerfile --progress plain . | ||
|
||
.PHONY: batch-observation-publisher | ||
batch-observation-publisher: | ||
@echo "Building image for batch observation publisher..." | ||
@docker build -t ${BATCH_OBSERVATION_PUBLISHER_IMAGE_TAG} -f batch-observation-publisher/Dockerfile --progress plain . |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
.idea | ||
venv | ||
__pycache__ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
FROM python:3.10-slim-buster | ||
|
||
WORKDIR /root | ||
ENV LANG C.UTF-8 | ||
ENV LC_ALL C.UTF-8 | ||
|
||
RUN apt-get update && apt-get install -y build-essential curl | ||
RUN curl -sL https://ctl.flyte.org/install | bash -s -- -b /usr/local/bin v0.8.5 | ||
|
||
COPY sdk/ ./sdk | ||
WORKDIR batch-observation-publisher | ||
COPY batch-observation-publisher/requirements.txt requirements.txt | ||
RUN pip install -r requirements.txt | ||
COPY batch-observation-publisher/publisher publisher | ||
COPY batch-observation-publisher/Makefile Makefile | ||
ENV PYTHONPATH /root/batch-observation-publisher | ||
ARG tag | ||
ENV FLYTE_INTERNAL_IMAGE $tag |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
WORKFLOW_IMAGE ?= batch-observation-publisher:dev | ||
WORKFLOW_VERSION ?= dev | ||
FLYTE_PROJECT ?= flyteexamples | ||
FLYTE_DOMAIN ?= development | ||
|
||
.PHONY: setup | ||
setup: | ||
@echo "Setting up environment..." | ||
@pip install -r requirements.txt --use-pep517 | ||
@pip install -r requirements-dev.txt | ||
|
||
.PHONY: pip-compile | ||
pip-compile: | ||
@echo "Compiling requirements..." | ||
@python -m piptools compile | ||
|
||
.PHONY: package | ||
package: | ||
@echo "Packaging..." | ||
@pyflyte --pkgs publisher package -f --image ${WORKFLOW_IMAGE} | ||
|
||
.PHONY: register | ||
register: package | ||
@echo "Registering..." | ||
flytectl register files --version ${WORKFLOW_VERSION} --archive flyte-package.tgz --project ${FLYTE_PROJECT} --domain ${FLYTE_DOMAIN} |
Empty file.
161 changes: 161 additions & 0 deletions
161
python/batch-observation-publisher/publisher/ingestion.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
from datetime import datetime | ||
|
||
import pandas as pd | ||
from arize.pandas.logger import Client | ||
from arize.utils.types import ( | ||
Environments, | ||
ModelTypes as ArizeModelType, | ||
Schema as ArizeSchema, | ||
) | ||
from flytekit import current_context, task, Secret | ||
from merlin.observability.inference import ( | ||
InferenceSchema, | ||
ObservationType, | ||
BinaryClassificationOutput, | ||
RankingOutput, | ||
RegressionOutput, | ||
PredictionOutput, | ||
) | ||
|
||
|
||
def get_arize_model_type(prediction_output: PredictionOutput) -> ArizeModelType: | ||
if isinstance(prediction_output, BinaryClassificationOutput): | ||
return ArizeModelType.BINARY_CLASSIFICATION | ||
elif isinstance(prediction_output, RegressionOutput): | ||
return ArizeModelType.REGRESSION | ||
elif isinstance(prediction_output, RankingOutput): | ||
return ArizeModelType.RANKING | ||
else: | ||
raise ValueError(f"Unknown prediction output type: {type(prediction_output)}") | ||
|
||
|
||
def get_prediction_attributes(prediction_output: PredictionOutput) -> dict: | ||
if isinstance(prediction_output, BinaryClassificationOutput): | ||
return dict( | ||
prediction_label_column_name=prediction_output.prediction_label_column, | ||
prediction_score_column_name=prediction_output.prediction_score_column, | ||
) | ||
elif isinstance(prediction_output, RegressionOutput): | ||
return dict( | ||
prediction_score_column_name=prediction_output.prediction_score_column, | ||
) | ||
elif isinstance(prediction_output, RankingOutput): | ||
return dict( | ||
rank_column_name=prediction_output.rank_column, | ||
) | ||
else: | ||
raise ValueError(f"Unknown prediction output type: {type(prediction_output)}") | ||
|
||
|
||
def get_ground_truth_attributes(prediction_output: PredictionOutput) -> dict: | ||
if isinstance(prediction_output, BinaryClassificationOutput): | ||
return dict( | ||
actual_label_column_name=prediction_output.actual_label_column, | ||
) | ||
elif isinstance(prediction_output, RegressionOutput): | ||
return dict( | ||
actual_score_column_name=prediction_output.actual_score_column, | ||
) | ||
elif isinstance(prediction_output, RankingOutput): | ||
return dict( | ||
rank_column=prediction_output.rank_column, | ||
) | ||
else: | ||
raise ValueError(f"Unknown prediction output type: {type(prediction_output)}") | ||
|
||
|
||
def get_arize_training_schema( | ||
inference_schema: InferenceSchema, include_prediction_and_ground_truth: bool | ||
) -> ArizeSchema: | ||
schema_attributes = dict( | ||
tag_column_names=inference_schema.tag_columns, | ||
) | ||
if include_prediction_and_ground_truth: | ||
schema_attributes |= dict( | ||
prediction_id_column_name=inference_schema.prediction_id_column, | ||
) | ||
schema_attributes |= dict( | ||
feature_column_names=inference_schema.feature_columns, | ||
) | ||
schema_attributes |= get_prediction_attributes( | ||
inference_schema.model_prediction_output | ||
) | ||
schema_attributes |= get_ground_truth_attributes( | ||
inference_schema.model_prediction_output | ||
) | ||
return ArizeSchema(**schema_attributes) | ||
|
||
|
||
def add_default_prediction_and_ground_truth_columns( | ||
inference_schema: InferenceSchema, df: pd.DataFrame | ||
): | ||
prediction_output = inference_schema.model_prediction_output | ||
if isinstance(prediction_output, BinaryClassificationOutput): | ||
df[prediction_output.prediction_score_column] = 0.0 | ||
df[ | ||
prediction_output.actual_label_column | ||
] = prediction_output.negative_class_label | ||
elif isinstance(prediction_output, RegressionOutput): | ||
df[prediction_output.prediction_score_column] = 0.0 | ||
df[prediction_output.actual_score_column] = 0.0 | ||
elif isinstance(prediction_output, RankingOutput): | ||
df[prediction_output.rank_column] = ( | ||
df.groupby(prediction_output.prediction_group_id_column).cumcount() + 1 | ||
) | ||
df[prediction_output.relevance_score_column] = 0.0 | ||
else: | ||
raise ValueError(f"Unknown prediction output type: {type(prediction_output)}") | ||
|
||
|
||
ARIZE_SECRET_GROUP = "mlobs" | ||
ARIZE_SPACE_KEY_SECRET_NAME = "arize_space_key" | ||
ARIZE_API_KEY_SECRET_NAME = "arize_api_key" | ||
|
||
|
||
@task( | ||
secret_requests=[ | ||
Secret(ARIZE_SECRET_GROUP, ARIZE_SPACE_KEY_SECRET_NAME), | ||
Secret(ARIZE_SECRET_GROUP, ARIZE_API_KEY_SECRET_NAME), | ||
] | ||
) | ||
def publish_training_data_to_arize( | ||
df: pd.DataFrame, | ||
inference_schema: dict, | ||
model_id: str, | ||
model_version: str, | ||
include_prediction_and_ground_truth: bool, | ||
): | ||
inference_schema = InferenceSchema.from_dict(inference_schema) | ||
space_key = current_context().secrets.get( | ||
ARIZE_SECRET_GROUP, ARIZE_SPACE_KEY_SECRET_NAME | ||
) | ||
api_key = current_context().secrets.get( | ||
ARIZE_SECRET_GROUP, ARIZE_API_KEY_SECRET_NAME | ||
) | ||
client = Client(space_key=space_key, api_key=api_key) | ||
if not include_prediction_and_ground_truth: | ||
add_default_prediction_and_ground_truth_columns(inference_schema, df) | ||
schema = get_arize_training_schema( | ||
inference_schema, include_prediction_and_ground_truth | ||
) | ||
arize_environment = Environments.TRAINING | ||
processed_df = inference_schema.model_prediction_output.preprocess( | ||
df, | ||
[ | ||
ObservationType.FEATURE, | ||
ObservationType.PREDICTION, | ||
ObservationType.GROUND_TRUTH, | ||
], | ||
) | ||
processed_df.reset_index(drop=True, inplace=True) | ||
version = f"{model_version}-{datetime.now().strftime('%Y%m%d')}-{current_context().execution_id.name}" | ||
arize_model_types = get_arize_model_type(inference_schema.model_prediction_output) | ||
|
||
client.log( | ||
dataframe=processed_df, | ||
environment=arize_environment, | ||
schema=schema, | ||
model_id=model_id, | ||
model_type=arize_model_types, | ||
model_version=version, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import abc | ||
from typing import Any | ||
|
||
import farmhash | ||
import pandas as pd | ||
from flytekit import task | ||
from merlin.observability.inference import InferenceSchema | ||
|
||
|
||
class Sampler(abc.ABC): | ||
@abc.abstractmethod | ||
def sample(self, df: pd.DataFrame, sampling_rate: float) -> pd.DataFrame: | ||
raise NotImplementedError | ||
|
||
|
||
class RandomSampler(Sampler): | ||
def sample(self, df: pd.DataFrame, sampling_rate: float) -> pd.DataFrame: | ||
return df.sample(frac=sampling_rate) | ||
|
||
|
||
def generate_farmhash_fingerprint(value: Any) -> int: | ||
return farmhash.fingerprint64(str(value)) | ||
|
||
|
||
class ConsistentSampler(Sampler): | ||
def __init__(self, hash_column: str): | ||
self._hash_column = hash_column | ||
|
||
def sample(self, df: pd.DataFrame, sampling_rate: float) -> pd.DataFrame: | ||
filtered = df[self._hash_column].apply( | ||
lambda x: (generate_farmhash_fingerprint(x) % 100) | ||
<= int(100 * sampling_rate), | ||
axis=1, | ||
) | ||
return df[filtered] | ||
|
||
|
||
@task | ||
def sample_dataframe( | ||
df: pd.DataFrame, | ||
sampling_rate: float, | ||
sampling_strategy: str, | ||
inference_schema: dict, | ||
) -> pd.DataFrame: | ||
inference_schema = InferenceSchema.from_dict(inference_schema) | ||
match sampling_strategy: | ||
case "random": | ||
sampler = RandomSampler() | ||
case "consistent": | ||
sampler = ConsistentSampler( | ||
hash_column=inference_schema.prediction_id_column | ||
) | ||
case _: | ||
raise ValueError(f"Invalid sampling strategy: {sampling_strategy}") | ||
return sampler.sample(df=df, sampling_rate=sampling_rate) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import pandas as pd | ||
from flytekit import dynamic, Resources | ||
|
||
from publisher.ingestion import publish_training_data_to_arize | ||
from publisher.sampling import sample_dataframe | ||
|
||
|
||
@dynamic | ||
def ingest_training_data( | ||
df: pd.DataFrame, | ||
model_id: str, | ||
model_version: str, | ||
inference_schema: dict, | ||
include_prediction_and_ground_truth: bool, | ||
sampling_rate: float, | ||
sampling_strategy: str, | ||
cpu: str, | ||
memory: str, | ||
): | ||
""" | ||
Ingests batch data to Arize. | ||
:param df: batch data. | ||
:param inference_schema: Inference schema for the given model id and version | ||
:param model_id: Model id. | ||
:param model_version: Model version. | ||
:param include_prediction_and_ground_truth: Send prediction and ground truth together with the features. | ||
If false, only send the features. | ||
:param sampling_rate: Sampling rate (0.0 - 1.0). | ||
:param sampling_strategy: One of ("random", "consistent"). Random sampling will randomly sample the data. | ||
Consistent sampling is deterministic, but will round the sampling rate to the closest percentage integer. | ||
:param memory: Memory request for the ingestion task | ||
:param cpu: CPU request for the ingestion task | ||
:return: | ||
""" | ||
sampled_dataframe = df | ||
if sampling_rate > 0.0: | ||
sampled_dataframe = sample_dataframe( | ||
df=df, | ||
sampling_rate=sampling_rate, | ||
sampling_strategy=sampling_strategy, | ||
inference_schema=inference_schema, | ||
).with_overrides( | ||
requests=Resources(cpu=cpu, mem=memory), | ||
limits=Resources(cpu=cpu, mem=memory), | ||
) | ||
|
||
publish_training_data_to_arize( | ||
df=sampled_dataframe, | ||
inference_schema=inference_schema, | ||
model_id=model_id, | ||
model_version=model_version, | ||
include_prediction_and_ground_truth=include_prediction_and_ground_truth, | ||
).with_overrides( | ||
requests=Resources(cpu=cpu, mem=memory), limits=Resources(cpu=cpu, mem=memory) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
pip-tools==7.3.0 | ||
pytest==7.4.3 |
Oops, something went wrong.