Skip to content

Commit

Permalink
Update PyFunc and support publishing log from PyFunc (#489)
Browse files Browse the repository at this point in the history
<!--  Thanks for sending a pull request!  Here are some tips for you:

1. Run unit tests and ensure that they are passing
2. If your change introduces any API changes, make sure to update the
e2e tests
3. Make sure documentation is updated for your PR!

-->

**What this PR does / why we need it**:
<!-- Explain here the context and why you're making the change. What is
the problem you're trying to solve. --->
To onboard to model observability, we need to gather features and
prediction value of the model, per current condition in PyFunc model we
can't get all the features needed for the model and the prediction that
generated by the model, since the input and output of the PyFunc model
is not necessary features and prediction value. This PR try to solve
that by adding new PyFunc model that can identify which one is features
and prediction value. Once the data is identified, then it will be
published to kafka for later processing

Modification:
* `python/sdk/merlin/pyfunc.py`
* Adding new `PyFuncV3Model` to differentiate features, and prediction
value
* Introducing `PyFuncOutput` as the single output for realtime Pyfunc
(`PyFuncModel` and `PyFuncV3Model)
* `python/pyfunc-server/pyfuncserver/config.py` - Adding configuration
for kafka publishing and sampling ratio
* `python/pyfunc-server/pyfuncserver/protocol/rest/handler.py` - Add
async publishing after get pyfunc model output
*  `python/pyfunc-server/pyfuncserver/protocol/rest/server.py` 
* Create subclass of tornado web application that will hold kafka
producer instance
* `python/pyfunc-server/pyfuncserver/publisher/publisher.py` - Adding
asyncio publisher code
* `python/pyfunc-server/pyfuncserver/sampler/sampler.py` - Random
sampling method base on the given ration
* `python/pyfunc-server/pyfuncserver/publisher/kafka.py` - Kafka
producer implementation given PyFuncOutput

**Which issue(s) this PR fixes**:
<!--
*Automatically closes linked issue when PR is merged.
Usage: `Fixes #<issue number>`, or `Fixes (paste link of issue)`.
-->

Fixes #

**Does this PR introduce a user-facing change?**:
<!--
If no, just write "NONE" in the release-note block below.
If yes, a release note is required. Enter your extended release note in
the block below.
If the PR requires additional action from users switching to the new
release, include the string "action required".

For more information about release notes, see kubernetes' guide here:
http://git.k8s.io/community/contributors/guide/release-notes.md
-->

```release-note

```

**Checklist**

- [x] Added unit test, integration, and/or e2e tests
- [x] Tested locally
- [ ] Updated documentation
- [ ] Update Swagger spec if the PR introduce API changes
- [ ] Regenerated Golang and Python client if the PR introduce API
changes
  • Loading branch information
tiopramayudi authored Nov 22, 2023
1 parent 95a6183 commit 9cc5432
Show file tree
Hide file tree
Showing 29 changed files with 904 additions and 110 deletions.
75 changes: 68 additions & 7 deletions python/pyfunc-server/pyfuncserver/config.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import json
import logging
import os
from typing import Optional

from merlin.protocol import Protocol
from dataclasses import dataclass

# Following environment variables are expected to be populated by Merlin
HTTP_PORT = ("CARAML_HTTP_PORT", 8080)
MODEL_NAME = ("CARAML_MODEL_NAME", "model")
MODEL_VERSION = ("CARAML_MODEL_VERSION", "1")
MODEL_FULL_NAME = ("CARAML_MODEL_FULL_NAME", "model-1")
PROJECT = ("CARAML_PROJECT", "project")
PROTOCOL = ("CARAML_PROTOCOL", "HTTP_JSON")

WORKERS = ("WORKERS", 1)
Expand All @@ -21,16 +24,24 @@
PUSHGATEWAY_URL = ("PUSHGATEWAY_URL", "localhost:9091")
PUSHGATEWAY_PUSH_INTERVAL_SEC = ("PUSHGATEWAY_PUSH_INTERVAL_SEC", 30)

PUBLISHER_KAFKA_TOPIC = ("PUBLISHER_KAFKA_TOPIC", "")
PUBLISHER_KAFKA_BROKERS = ("PUBLISHER_KAFKA_BROKERS", "")
PUBLISHER_KAKFA_LINGER_MS = ("PUBLISHER_KAFKA_LINGER_MS", 1000)
PUBLISHER_KAFKA_ACKS = ("PUBLISHER_KAFKA_ACKS", 0)
PUBLISHER_KAFKA_CONFIG = ("PUBLISHER_KAFKA_CONFIG", "{}")
PUBLISHER_SAMPLING_RATIO = ("PUBLISHER_SAMPLING_RATIO", 0.01)
PUBLISHER_ENABLED = ("PUBLISHER_ENABLED", "false")

@dataclass
class ModelManifest:
"""
Model Manifest
"""

def __init__(self, model_name: str, model_version: str, model_full_name: str, model_dir: str):
self.model_name = model_name
self.model_version = model_version
self.model_full_name = model_full_name
self.model_dir = model_dir
model_name: str
model_version: str
model_full_name: str
model_dir: str
project: str


class PushGateway:
Expand All @@ -39,6 +50,27 @@ def __init__(self, enabled, url, push_interval_sec):
self.enabled = enabled
self.push_interval_sec = push_interval_sec

@dataclass
class Kafka:
"""
Kafka configuration
"""
topic: str
brokers: str
linger_ms: int
acks: int
configuration: dict

@dataclass
class Publisher:
"""
Publisher configuration
"""
# sampling ratio of data that needs to be published
sampling_ratio: float
enabled: bool
kafka: Kafka


class Config:
"""
Expand All @@ -54,7 +86,8 @@ def __init__(self, model_dir: str):
model_name = os.getenv(*MODEL_NAME)
model_version = os.getenv(*MODEL_VERSION)
model_full_name = os.getenv(*MODEL_FULL_NAME)
self.model_manifest = ModelManifest(model_name, model_version, model_full_name, model_dir)
project = os.getenv(*PROJECT)
self.model_manifest = ModelManifest(model_name, model_version, model_full_name, model_dir, project)

self.workers = int(os.getenv(*WORKERS))
self.log_level = self._log_level()
Expand All @@ -68,6 +101,34 @@ def __init__(self, model_dir: str):
self.push_gateway = PushGateway(push_enabled,
push_url,
push_interval)

# Publisher
self.publisher = None
publisher_enabled = str_to_bool(os.getenv(*PUBLISHER_ENABLED))
if publisher_enabled:
sampling_ratio = float(os.getenv(*PUBLISHER_SAMPLING_RATIO))
kafka_topic = os.getenv(*PUBLISHER_KAFKA_TOPIC)
kafka_brokers = os.getenv(*PUBLISHER_KAFKA_BROKERS)
if kafka_topic == "":
raise ValueError("kafka topic must be set")
if kafka_brokers == "":
raise ValueError("kafka brokers must be set")
kafka_linger_ms = int(os.getenv(*PUBLISHER_KAKFA_LINGER_MS))
kafka_acks = int(os.getenv(*PUBLISHER_KAFKA_ACKS))
kafka_cfgs = self._kafka_config()
kafka = Kafka(
kafka_topic,
kafka_brokers,
kafka_linger_ms,
kafka_acks,
kafka_cfgs)
self.publisher = Publisher(sampling_ratio, publisher_enabled, kafka)


def _kafka_config(self):
raw_cfg = os.getenv(*PUBLISHER_KAFKA_CONFIG)
cfg = json.loads(raw_cfg)
return cfg

def _log_level(self):
log_level = os.getenv(*LOG_LEVEL)
Expand Down
6 changes: 3 additions & 3 deletions python/pyfunc-server/pyfuncserver/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import grpc
from caraml.upi.v1 import upi_pb2
from merlin.protocol import Protocol
from merlin.pyfunc import PYFUNC_EXTRA_ARGS_KEY, PYFUNC_GRPC_CONTEXT, PYFUNC_MODEL_INPUT_KEY, PYFUNC_PROTOCOL_KEY
from merlin.pyfunc import PYFUNC_EXTRA_ARGS_KEY, PYFUNC_GRPC_CONTEXT, PYFUNC_MODEL_INPUT_KEY, PYFUNC_PROTOCOL_KEY, PyFuncOutput
from mlflow import pyfunc

from pyfuncserver.config import ModelManifest
Expand Down Expand Up @@ -75,7 +75,7 @@ def _get_pyfunc_model_version(self):

return PyFuncModelVersion.LATEST

def predict(self, inputs: dict, **kwargs) -> dict:
def predict(self, inputs: dict, **kwargs) -> PyFuncOutput:
if self.pyfunc_type == PyFuncModelVersion.OLD_PYFUNC_LATEST_MLFLOW:
# for case user specified old merlin-sdk as dependency and using mlflow without version specified
return self._model._model_impl.python_model.predict(inputs, **kwargs)
Expand All @@ -92,7 +92,7 @@ def predict(self, inputs: dict, **kwargs) -> dict:
return self._model.predict(model_inputs)

def upiv1_predict(self, request: upi_pb2.PredictValuesRequest,
context: grpc.ServicerContext) -> upi_pb2.PredictValuesResponse:
context: grpc.ServicerContext) -> PyFuncOutput:
model_inputs = {
PYFUNC_PROTOCOL_KEY: Protocol.UPI_V1,
PYFUNC_MODEL_INPUT_KEY: request,
Expand Down
10 changes: 6 additions & 4 deletions python/pyfunc-server/pyfuncserver/protocol/rest/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@

from pyfuncserver.model.model import PyFuncModel


class PredictHandler(tornado.web.RequestHandler):
def initialize(self, models):
self.models = models # pylint:disable=attribute-defined-outside-init
self.publisher = self.application.publisher

def get_model(self, full_name: str):
if full_name not in self.models:
Expand Down Expand Up @@ -57,12 +57,14 @@ def post(self, full_name: str):
request = self.validate(self.request)
headers = self.get_headers(self.request)

response = model.predict(request, headers=headers)

response_json = orjson.dumps(response)
output = model.predict(request, headers=headers)
response_json = orjson.dumps(output.http_response)
self.write(response_json)
self.set_header("Content-Type", "application/json; charset=UTF-8")

if self.publisher is not None and output.contains_prediction_log():
tornado.ioloop.IOLoop.current().spawn_callback(self.publisher.publish, output)

def write_error(self, status_code: int, **kwargs: Any) -> None:
logging.error(self._reason)
self.write({"status_code": status_code, "reason": self._reason})
Expand Down
43 changes: 31 additions & 12 deletions python/pyfunc-server/pyfuncserver/protocol/rest/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from pyfuncserver.metrics.handler import MetricsHandler
from pyfuncserver.model.model import PyFuncModel
from pyfuncserver.protocol.rest.handler import HealthHandler, LivenessHandler, PredictHandler
from pyfuncserver.publisher.kafka import KafkaProducer
from pyfuncserver.publisher.publisher import Publisher
from pyfuncserver.sampler.sampler import RatioSampling


async def sig_handler(server):
Expand All @@ -22,34 +25,50 @@ async def sig_handler(server):
tornado.ioloop.IOLoop.current().stop()


class Application(tornado.web.Application):
def __init__(self, config: Config, metrics_registry: CollectorRegistry, registered_models: dict):
self.publisher = None
self.model_manifest = config.model_manifest
handlers = [
# Server Liveness API returns 200 if server is alive.
(r"/", LivenessHandler),
# Model Health API returns 200 if model is ready to serve.
(r"/v1/models/([a-zA-Z0-9_-]+)",
HealthHandler, dict(models=registered_models)),
(r"/v1/models/([a-zA-Z0-9_-]+):predict",
PredictHandler, dict(models=registered_models)),
(r"/metrics", MetricsHandler, dict(metrics_registry=metrics_registry))
]
super().__init__(handlers) # type: ignore # noqa


class HTTPServer:
def __init__(self, model: PyFuncModel, config: Config, metrics_registry: CollectorRegistry):
self.config = config
self.workers = config.workers
self.model_manifest = config.model_manifest
self.http_port = config.http_port
self.metrics_registry = metrics_registry
self.registered_models: dict = {}
self.register_model(model)

def create_application(self):
return tornado.web.Application([
# Server Liveness API returns 200 if server is alive.
(r"/", LivenessHandler),
# Model Health API returns 200 if model is ready to serve.
(r"/v1/models/([a-zA-Z0-9_-]+)",
HealthHandler, dict(models=self.registered_models)),
(r"/v1/models/([a-zA-Z0-9_-]+):predict",
PredictHandler, dict(models=self.registered_models)),
(r"/metrics", MetricsHandler, dict(metrics_registry=self.metrics_registry))
])
return Application(self.config, self.metrics_registry, self.registered_models)

def start(self):
self._http_server = tornado.httpserver.HTTPServer(
self.create_application())
application = self.create_application()
self._http_server = tornado.httpserver.HTTPServer(application)
logging.info("Listening on port %s", self.http_port)
self._http_server.bind(self.http_port)
logging.info("Will fork %d workers", self.workers)
self._http_server.start(self.workers)

# kafka producer must be initialize after fork the process
if self.config.publisher is not None:
kafka_producer = KafkaProducer(self.config.publisher, self.config.model_manifest)
sampler = RatioSampling(self.config.publisher.sampling_ratio)
application.publisher = Publisher(kafka_producer, sampler)

for signame in ('SIGINT', 'SIGTERM'):
asyncio.get_event_loop().add_signal_handler(getattr(signal, signame),
lambda: asyncio.create_task(sig_handler(self._http_server)))
Expand Down
34 changes: 28 additions & 6 deletions python/pyfunc-server/pyfuncserver/protocol/upi/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,38 @@
from concurrent import futures

import grpc
from grpc import aio
import asyncio
from typing import Optional
from caraml.upi.v1 import upi_pb2, upi_pb2_grpc
from grpc_reflection.v1alpha import reflection
from grpc_health.v1.health import HealthServicer
from grpc_health.v1 import health_pb2_grpc

from pyfuncserver.config import Config
from pyfuncserver.model.model import PyFuncModel
from pyfuncserver.publisher.publisher import Publisher
from pyfuncserver.publisher.kafka import KafkaProducer
from pyfuncserver.sampler.sampler import RatioSampling

class PredictionService(upi_pb2_grpc.UniversalPredictionServiceServicer):
def __init__(self, model: PyFuncModel):
if not model.ready:
model.load()
self._model = model
self._publisher: Optional[Publisher] = None

def set_publisher(self, publisher: Publisher):
if self._publisher is None:
self._publisher = publisher

def PredictValues(self, request, context):
return self._model.upiv1_predict(request=request, context=context)
output = self._model.upiv1_predict(request=request, context=context)
if self._publisher is not None and output.contains_prediction_log():
# need to also check whether the output contains prediction log since the pyfunc server doesn't know which model that is used
asyncio.create_task(self._publisher.publish(output))

return output.upi_response


class UPIServer:
Expand All @@ -37,18 +53,24 @@ def start(self):
worker = multiprocessing.Process(target=self._run_server)
worker.start()
workers.append(worker)

if self._config.publisher is not None:
kafka_producer = KafkaProducer(self._config.publisher, self._config.model_manifest)
sampler = RatioSampling(self._config.publisher.sampling_ratio)
publisher = Publisher(kafka_producer, sampler)
self._predict_service.set_publisher(publisher)

self._run_server()
asyncio.get_event_loop().run_until_complete(self._run_server())

def _run_server(self):
async def _run_server(self):
"""
Start a server in a subprocess.
"""
options = self._config.grpc_options
options.append(('grpc.so_reuseport', 1))

server = grpc.server(futures.ThreadPoolExecutor(max_workers=self._config.grpc_concurrency),
server = aio.server(futures.ThreadPoolExecutor(max_workers=self._config.grpc_concurrency),
options=options)
upi_pb2_grpc.add_UniversalPredictionServiceServicer_to_server(self._predict_service, server)
health_pb2_grpc.add_HealthServicer_to_server(self._health_service, server)
Expand All @@ -63,5 +85,5 @@ def _run_server(self):
logging.info(
f"Starting grpc service at port {self._config.grpc_port} with options {self._config.grpc_options}")
server.add_insecure_port(f"[::]:{self._config.grpc_port}")
server.start()
server.wait_for_termination()
await server.start()
await server.wait_for_termination()
Empty file.
25 changes: 25 additions & 0 deletions python/pyfunc-server/pyfuncserver/publisher/kafka.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import uuid

from pyfuncserver.config import Publisher as PublisherConfig, ModelManifest
from pyfuncserver.utils.converter import build_prediction_log

from confluent_kafka import Producer
from merlin.pyfunc import PyFuncOutput


class KafkaProducer(Producer):
def __init__(self, publisher_config: PublisherConfig, model_manifest: ModelManifest) -> None:
conf = {
"bootstrap.servers": publisher_config.kafka.brokers,
"acks": publisher_config.kafka.acks,
"linger.ms": publisher_config.kafka.linger_ms
}
conf.update(publisher_config.kafka.configuration)
self.producer = Producer(**conf)
self.topic = publisher_config.kafka.topic
self.model_manifest = model_manifest

def produce(self, data: PyFuncOutput):
prediction_log = build_prediction_log(pyfunc_output=data, model_manifest=self.model_manifest)
serialized_data = prediction_log.SerializeToString()
self.producer.produce(topic=self.topic, value=serialized_data)
21 changes: 21 additions & 0 deletions python/pyfunc-server/pyfuncserver/publisher/publisher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from pyfuncserver.sampler.sampler import Sampler
from merlin.pyfunc import PyFuncOutput
from abc import ABC, abstractmethod
import asyncio

class Producer(ABC):

@abstractmethod
def produce(self, data: PyFuncOutput):
pass

class Publisher:
def __init__(self, producer: Producer, sampler: Sampler) -> None:
self.producer = producer
self.sampler = sampler

async def publish(self, output: PyFuncOutput):
if not self.sampler.should_sample():
return

self.producer.produce(output)
Empty file.
Loading

0 comments on commit 9cc5432

Please sign in to comment.