Skip to content

Commit

Permalink
Add basic Prometheus instrumentation for workers (#111)
Browse files Browse the repository at this point in the history
* Add basic Prometheus instrumentation for workers

* Fix tests

---------

Co-authored-by: Christian Stefanescu <[email protected]>
Co-authored-by: Christian Stefanescu <[email protected]>
  • Loading branch information
3 people authored Oct 13, 2023
1 parent 0b6e034 commit a294c28
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 0 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ sqlalchemy==2.0.4
structlog==23.2.0
colorama==0.4.6
pika==1.3.2
prometheus-client==0.17.1
4 changes: 4 additions & 0 deletions servicelayer/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,7 @@
SENTRY_DSN = env.get("SENTRY_DSN")
SENTRY_ENVIRONMENT = env.get("SENTRY_ENVIRONMENT", "")
SENTRY_RELEASE = env.get("SENTRY_RELEASE", "")

# Instrumentation
PROMETHEUS_ENABLED = env.to_bool("PROMETHEUS_ENABLED", False)
PROMETHEUS_PORT = env.to_int("PROMETHEUS_PORT", 9090)
67 changes: 67 additions & 0 deletions servicelayer/worker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import signal
import logging
from timeit import default_timer
import sys
from threading import Thread
from banal import ensure_list
from abc import ABC, abstractmethod

from prometheus_client import start_http_server, Counter, Histogram

from servicelayer import settings
from servicelayer.jobs import Stage
from servicelayer.cache import get_redis
Expand All @@ -19,6 +22,47 @@
INTERVAL = 2
TASK_FETCH_RETRY = 60 / INTERVAL

TASK_STARTED = Counter(
"task_started_total",
"Number of tasks that a worker started processing",
["stage"],
)

TASK_SUCCEEDED = Counter(
"task_succeeded_total",
"Number of successfully processed tasks",
["stage", "retries"],
)

TASK_FAILED = Counter(
"task_failed_total",
"Number of failed tasks",
["stage", "retries", "failed_permanently"],
)

TASK_DURATION = Histogram(
"task_duration_seconds",
"Task duration in seconds",
["stage"],
# The bucket sizes are a rough guess right now, we might want to adjust
# them later based on observed runtimes
buckets=[
0.25,
0.5,
1,
5,
15,
30,
60,
60 * 15,
60 * 30,
60 * 60,
60 * 60 * 2,
60 * 60 * 6,
60 * 60 * 24,
],
)


class Worker(ABC):
"""Workers of all microservices, unite!"""
Expand Down Expand Up @@ -51,8 +95,15 @@ def _handle_signal(self, signal, frame):
sys.exit(self.exit_code)

def handle_safe(self, task):
retries = unpack_int(task.context.get("retries"))

try:
TASK_STARTED.labels(task.stage.stage).inc()
start_time = default_timer()
self.handle(task)
duration = max(0, default_timer() - start_time)
TASK_DURATION.labels(task.stage.stage).observe(duration)
TASK_SUCCEEDED.labels(task.stage.stage, retries).inc()
except SystemExit as exc:
self.exit_code = exc.code
self.retry(task)
Expand All @@ -72,19 +123,34 @@ def init_internal(self):
self.exit_code = 0
self.boot()

def run_prometheus_server(self):
if not settings.PROMETHEUS_ENABLED:
return

def run_server():
port = settings.PROMETHEUS_PORT
log.info(f"Running Prometheus metrics server on port {port}")
start_http_server(port)

thread = Thread(target=run_server)
thread.start()
thread.join()

def retry(self, task):
retries = unpack_int(task.context.get("retries"))
if retries < settings.WORKER_RETRY:
retry_count = retries + 1
log.warning(
f"Queueing failed task for retry #{retry_count}/{settings.WORKER_RETRY}..." # noqa
)
TASK_FAILED.labels(task.stage.stage, retries, False).inc()
task.context["retries"] = retry_count
task.stage.queue(task.payload, task.context)
else:
log.warning(
f"Failed task, exhausted retry count of {settings.WORKER_RETRY}"
)
TASK_FAILED.labels(task.stage.stage, retries, True).inc()

def process(self, blocking=True, interval=INTERVAL):
retries = 0
Expand Down Expand Up @@ -119,6 +185,7 @@ def run(self, blocking=True, interval=INTERVAL):
signal.signal(signal.SIGINT, self._handle_signal)
signal.signal(signal.SIGTERM, self._handle_signal)
self.init_internal()
self.run_prometheus_server()

def process():
return self.process(blocking=blocking, interval=interval)
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"structlog >= 20.2.0, < 24.0.0",
"colorama >= 0.4.4, < 1.0.0",
"pika >= 1.3.1, < 2.0.0",
"prometheus-client >= 0.17.1, < 0.18.0",
],
extras_require={
"amazon": ["boto3 >= 1.11.9, <2.0.0"],
Expand Down
103 changes: 103 additions & 0 deletions tests/test_worker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import pytest
from prometheus_client import REGISTRY
from prometheus_client.metrics import MetricWrapperBase

from servicelayer.cache import get_fakeredis
from servicelayer.jobs import Job
Expand All @@ -17,11 +19,112 @@ def handle(self, task):
self.test_done += 1


class FailingWorker(worker.Worker):
def handle(self, task):
raise Exception("Woops")


class NoOpWorker(worker.Worker):
def handle(self, task):
pass


class PrometheusTests:
def setup_method(self, method):
# This relies on internal implementation details of the client to reset
# previously collected metrics before every test execution. Unfortunately,
# there is no clean way of achieving the same thing that doesn't add a lot
# of complexity to the test and application code.
collectors = REGISTRY._collector_to_names.keys()
for collector in collectors:
if isinstance(collector, MetricWrapperBase):
collector._metrics.clear()
collector._metric_init()

def test_prometheus_succeeded(self):
conn = get_fakeredis()
worker = CountingWorker(conn=conn, stages=["ingest"])
job = Job.create(conn, "test")
stage = job.get_stage("ingest")
stage.queue({}, {})
worker.sync()

labels = {"stage": "ingest"}
success_labels = {"stage": "ingest", "retries": "0"}

started = REGISTRY.get_sample_value("task_started_total", labels)
succeeded = REGISTRY.get_sample_value("task_succeeded_total", success_labels)

# Under the hood, histogram metrics create multiple time series tracking
# the number and sum of observations, as well as individual histogram buckets.
duration_sum = REGISTRY.get_sample_value("task_duration_seconds_sum", labels)
duration_count = REGISTRY.get_sample_value(
"task_duration_seconds_count",
labels,
)

assert started == 1
assert succeeded == 1
assert duration_sum > 0
assert duration_count == 1

def test_prometheus_failed(self):
conn = get_fakeredis()
worker = FailingWorker(conn=conn, stages=["ingest"])
job = Job.create(conn, "test")
stage = job.get_stage("ingest")
stage.queue({}, {})
labels = {"stage": "ingest"}

worker.sync()

assert REGISTRY.get_sample_value("task_started_total", labels) == 1
assert REGISTRY.get_sample_value(
"task_failed_total",
{
"stage": "ingest",
"retries": "0",
"failed_permanently": "False",
},
)

worker.sync()

assert REGISTRY.get_sample_value("task_started_total", labels) == 2
assert REGISTRY.get_sample_value(
"task_failed_total",
{
"stage": "ingest",
"retries": "1",
"failed_permanently": "False",
},
)

worker.sync()

assert REGISTRY.get_sample_value("task_started_total", labels) == 3
assert REGISTRY.get_sample_value(
"task_failed_total",
{
"stage": "ingest",
"retries": "2",
"failed_permanently": "False",
},
)

worker.sync()

assert REGISTRY.get_sample_value("task_started_total", labels) == 4
assert REGISTRY.get_sample_value(
"task_failed_total",
{
"stage": "ingest",
"retries": "3",
"failed_permanently": "True",
},
)


def test_run():
conn = get_fakeredis()
operation = "lala"
Expand Down

0 comments on commit a294c28

Please sign in to comment.