Skip to content

Commit

Permalink
feat: Add support for multiple sinks (#513)
Browse files Browse the repository at this point in the history
<!--  Thanks for sending a pull request!  Here are some tips for you:

1. Run unit tests and ensure that they are passing
2. If your change introduces any API changes, make sure to update the
e2e tests
3. Make sure documentation is updated for your PR!

-->
# Description
<!-- Briefly describe the motivation for the change. Please include
illustrations where appropriate. -->
Allow the observation publisher to publish to multiple sinks. Supported
sinks are: Arize, BigQuery

# Modifications
<!-- Summarize the key code changes. -->
- configuration format has been modified to support multiple sinks
- update python requirement
- fix mypy linter errors
- Bigquery sink implementation

# Tests
<!-- Besides the existing / updated automated tests, what specific
scenarios should be tested? Consider the backward compatibility of the
changes, whether corner cases are covered, etc. Please describe the
tests and check the ones that have been completed. Eg:
- [x] Deploying new and existing standard models
- [ ] Deploying PyFunc models
-->

# Checklist
- [x] Added PR label
- [x] Added unit test, integration, and/or e2e tests
- [x] Tested locally
- [ ] Updated documentation
- [ ] Update Swagger spec if the PR introduce API changes
- [ ] Regenerated Golang and Python client if the PR introduces API
changes

# Release Notes
<!--
Does this PR introduce a user-facing change?
If no, just write "NONE" in the release-note block below.
If yes, a release note is required. Enter your extended release note in
the block below.
If the PR requires additional action from users switching to the new
release, include the string "action required".

For more information about release notes, see kubernetes' guide here:
http://git.k8s.io/community/contributors/guide/release-notes.md
-->

```release-note

```
  • Loading branch information
khorshuheng authored Jan 31, 2024
1 parent 10e5498 commit 7c835f5
Show file tree
Hide file tree
Showing 15 changed files with 581 additions and 241 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,31 @@ inference_schema:
distance: "int64"
transaction: "float64"
prediction_id_column: "prediction_id"
observability_backend:
# Supported backend types:
observation_sinks:
# Supported sink types:
# - ARIZE
type: "ARIZE"
# Required if observability_backend.type is ARIZE
arize_config:
api_key: "SECRET_API_KEY"
space_key: "SECRET_SPACE_KEY"
# - BIGQUERY
- type: "ARIZE"
config:
api_key: "SECRET_API_KEY"
space_key: "SECRET_SPACE_KEY"
- type: "BIGQUERY"
config:
# GCP project for the dataset
project: "test-project"
# GCP dataset to store the observation data on
dataset: "test-dataset"
# Number of days before the created table will expire
ttl_days: 14
observation_source:
# Supported consumer types:
# - KAFKA
type: "KAFKA"
# Required if consumer.type is KAFKA
kafka_config:
# (Optional) Number of messages to be kept in-memory before being sent to the sinks. Default: 10
buffer_capacity: 10
# (Optional) Maximum duration in seconds to keep messages in-memory before being sent to the sinks, if the capacity is not met. Default: 60
buffer_max_duration_seconds: 60
config:
topic: "test-topic"
bootstrap_servers: "localhost:9092"
group_id: "test-group"
Expand Down
26 changes: 18 additions & 8 deletions python/observation-publisher/publisher/__main__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import hydra
from merlin.observability.inference import InferenceSchema
from omegaconf import OmegaConf
from prometheus_client import start_http_server

from publisher.config import PublisherConfig
from publisher.observability_backend import new_observation_sink
from publisher.metric import MetricWriter
from publisher.observation_sink import new_observation_sink
from publisher.prediction_log_consumer import new_consumer


Expand All @@ -12,18 +14,26 @@ def start_consumer(cfg: PublisherConfig) -> None:
missing_keys: set[str] = OmegaConf.missing_keys(cfg)
if missing_keys:
raise RuntimeError(f"Got missing keys in config:\n{missing_keys}")

start_http_server(cfg.environment.prometheus_port)
MetricWriter().setup(
model_id=cfg.environment.model_id, model_version=cfg.environment.model_version
)
prediction_log_consumer = new_consumer(cfg.environment.observation_source)
inference_schema = InferenceSchema.from_dict(
OmegaConf.to_container(cfg.environment.inference_schema)
)
observation_sink = new_observation_sink(
config=cfg.environment.observability_backend,
inference_schema=inference_schema,
model_id=cfg.environment.model_id,
model_version=cfg.environment.model_version,
)
observation_sinks = [
new_observation_sink(
sink_config=sink_config,
inference_schema=inference_schema,
model_id=cfg.environment.model_id,
model_version=cfg.environment.model_version,
)
for sink_config in cfg.environment.observation_sinks
]
prediction_log_consumer.start_polling(
observation_sink=observation_sink,
observation_sinks=observation_sinks,
inference_schema=inference_schema,
)

Expand Down
46 changes: 11 additions & 35 deletions python/observation-publisher/publisher/config.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,41 @@
from dataclasses import dataclass
from enum import Enum
from typing import Optional
from typing import List

from hydra.core.config_store import ConfigStore


@dataclass
class ArizeConfig:
api_key: str
space_key: str


class ObservabilityBackendType(Enum):
class ObservationSinkType(Enum):
ARIZE = "arize"
BIGQUERY = "bigquery"


@dataclass
class ObservabilityBackend:
type: ObservabilityBackendType
arize_config: Optional[ArizeConfig] = None

def __post_init__(self):
if self.type == ObservabilityBackendType.ARIZE:
assert (
self.arize_config is not None
), "Arize config must be set for Arize observability backend"
class ObservationSinkConfig:
type: ObservationSinkType
config: dict


class ObservationSource(Enum):
KAFKA = "kafka"


@dataclass
class KafkaConsumerConfig:
topic: str
bootstrap_servers: str
group_id: str
batch_size: int = 100
poll_timeout_seconds: float = 1.0
additional_consumer_config: Optional[dict] = None


@dataclass
class ObservationSourceConfig:
type: ObservationSource
kafka_config: Optional[KafkaConsumerConfig] = None

def __post_init__(self):
if self.type == ObservationSource.KAFKA:
assert (
self.kafka_config is not None
), "Kafka config must be set for Kafka observation source"
config: dict
buffer_capacity: int = 10
buffer_max_duration_seconds: int = 60


@dataclass
class Environment:
model_id: str
model_version: str
inference_schema: dict
observability_backend: ObservabilityBackend
observation_sinks: List[ObservationSinkConfig]
observation_source: ObservationSourceConfig
prometheus_port: int = 8000


@dataclass
Expand Down
60 changes: 60 additions & 0 deletions python/observation-publisher/publisher/metric.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from pandas import Timestamp
from prometheus_client import Gauge, Counter


class MetricWriter(object):
"""
Singleton class for writing metrics to Prometheus.
"""

_instance = None

def __init__(self):
if not self._initialized:
self.model_id = None
self.model_version = ""
self.last_processed_timestamp_gauge = Gauge(
"last_processed_timestamp",
"The timestamp of the last prediction log processed by the publisher",
["model_id", "model_version"],
)
self.total_prediction_logs_processed_counter = Counter(
"total_prediction_logs_processed",
"The total number of prediction logs processed by the publisher",
)
self._initialized = True

def __new__(cls):
if not cls._instance:
cls._instance = super(MetricWriter, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance

def setup(self, model_id: str, model_version: str):
"""
Needs to be run before sending metrics, so that the singleton instance has the correct properties value.
:param model_id:
:param model_version:
:return:
"""
self.model_id = model_id
self.model_version = model_version

def update_last_processed_timestamp(self, last_processed_timestamp: Timestamp):
"""
Updates the last_processed_timestamp gauge with the given value.
:param last_processed_timestamp:
:return:
"""
self.last_processed_timestamp_gauge.labels(
model_id=self.model_id, model_version=self.model_version
).set(last_processed_timestamp.timestamp())

def increment_total_prediction_logs_processed(self, value: int):
"""
Increments the total_prediction_logs_processed counter by value.
:return:
"""
self.total_prediction_logs_processed_counter.labels(
model_id=self.model_id, model_version=self.model_version
).inc(value)
113 changes: 0 additions & 113 deletions python/observation-publisher/publisher/observability_backend.py

This file was deleted.

Loading

0 comments on commit 7c835f5

Please sign in to comment.