Skip to content

Commit

Permalink
Add basic Prometheus instrumentation for workers
Browse files Browse the repository at this point in the history
  • Loading branch information
tillprochaska committed Sep 20, 2023
1 parent b657e3b commit ca2d21d
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 0 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ sqlalchemy==2.0.4
structlog==23.1.0
colorama==0.4.6
pika==1.3.2
prometheus-client==0.17.1
4 changes: 4 additions & 0 deletions servicelayer/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,7 @@
SENTRY_DSN = env.get("SENTRY_DSN")
SENTRY_ENVIRONMENT = env.get("SENTRY_ENVIRONMENT", "")
SENTRY_RELEASE = env.get("SENTRY_RELEASE", "")

# Instrumentation
PROMETHEUS_ENABLED = env.to_bool("PROMETHEUS_ENABLED", False)
PROMETHEUS_PORT = env.to_int("PROMETHEUS_PORT", 9090)
68 changes: 68 additions & 0 deletions servicelayer/worker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import signal
import logging
from timeit import default_timer
import sys
from threading import Thread
from banal import ensure_list
from abc import ABC, abstractmethod

from prometheus_client import start_http_server, Counter, Histogram

from servicelayer import settings
from servicelayer.jobs import Stage
from servicelayer.cache import get_redis
Expand All @@ -19,6 +22,47 @@
INTERVAL = 2
TASK_FETCH_RETRY = 60 / INTERVAL

TASK_STARTED = Counter(
"task_started_total",
"Number of tasks that a worker started processing",
["stage"],
)

TASK_SUCCEEDED = Counter(
"task_succeeded_total",
"Number of successfully processed tasks",
["stage", "retries"],
)

TASK_FAILED = Counter(
"task_failed_total",
"Number of failed tasks",
["stage", "retries", "failed_permanently"],
)

TASK_DURATION = Histogram(
"task_duration_seconds",
"Task duration in seconds",
["stage"],
# The bucket sizes are a rough guess right now, we might want to adjust
# them later based on observed runtimes
buckets=[
0.25,
0.5,
1,
5,
15,
30,
60,
60 * 15,
60 * 30,
60 * 60,
60 * 60 * 2,
60 * 60 * 6,
60 * 60 * 24,
],
)


class Worker(ABC):
"""Workers of all microservices, unite!"""
Expand Down Expand Up @@ -51,8 +95,15 @@ def _handle_signal(self, signal, frame):
sys.exit(self.exit_code)

def handle_safe(self, task):
retries = unpack_int(task.context.get("retries"))

try:
TASK_STARTED.labels(task.stage.stage).inc()
start_time = default_timer()
self.handle(task)
duration = max(0, default_timer() - start_time)
TASK_DURATION.labels(task.stage.stage).observe(duration)
TASK_SUCCEEDED.labels(task.stage.stage, retries).inc()
except SystemExit as exc:
self.exit_code = exc.code
self.retry(task)
Expand All @@ -72,12 +123,28 @@ def init_internal(self):
self.exit_code = 0
self.boot()

def run_prometheus_server(self):
if not settings.PROMETHEUS_ENABLED:
return

def run_server():
port = settings.PROMETHEUS_PORT
log.info(f"Running Prometheus metrics server on port {port}")
start_http_server(port)

thread = Thread(target=run_server)
thread.start()
thread.join()

def retry(self, task):
retries = unpack_int(task.context.get("retries"))
if retries < settings.WORKER_RETRY:
log.warning("Queue failed task for re-try...")
TASK_FAILED.labels(task.stage.stage, retries, False).inc()
task.context["retries"] = retries + 1
task.stage.queue(task.payload, task.context)
else:
TASK_FAILED.labels(task.stage.stage, retries, True).inc()

def process(self, blocking=True, interval=INTERVAL):
retries = 0
Expand Down Expand Up @@ -112,6 +179,7 @@ def run(self, blocking=True, interval=INTERVAL):
signal.signal(signal.SIGINT, self._handle_signal)
signal.signal(signal.SIGTERM, self._handle_signal)
self.init_internal()
self.run_prometheus_server()

def process():
return self.process(blocking=blocking, interval=interval)
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"structlog >= 20.2.0, < 24.0.0",
"colorama >= 0.4.4, < 1.0.0",
"pika >= 1.3.1, < 2.0.0",
"prometheus-client >= 0.17.1, < 0.18.0",
],
extras_require={
"amazon": ["boto3 >= 1.11.9, <2.0.0"],
Expand Down
101 changes: 101 additions & 0 deletions tests/test_worker.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from unittest import TestCase
import pytest
from prometheus_client import REGISTRY
from prometheus_client.metrics import MetricWrapperBase

from servicelayer.cache import get_fakeredis
from servicelayer.jobs import Job
Expand All @@ -17,7 +19,23 @@ def handle(self, task):
self.test_done += 1


class FailingWorker(worker.Worker):
def handle(self, task):
raise Exception("Woops")


class WorkerTest(TestCase):
def setup_method(self, _):
# This relies on internal implementation details of the client to reset
# previously collected metrics before every test execution. Unfortunately,
# there is no clean way of achieving the same thing that doesn't add a lot
# of complexity to the test and application code.
collectors = REGISTRY._collector_to_names.keys()
for collector in collectors:
if isinstance(collector, MetricWrapperBase):
collector._metrics.clear()
collector._metric_init()

def test_run(self):
conn = get_fakeredis()
operation = "lala"
Expand Down Expand Up @@ -53,3 +71,86 @@ def test_run(self):
assert exc.code == 5, exc.code
with pytest.raises(SystemExit) as exc: # noqa
worker._handle_signal(5, None)

def test_prometheus_succeeded(self):
conn = get_fakeredis()
worker = CountingWorker(conn=conn, stages=["ingest"])
job = Job.create(conn, "test")
stage = job.get_stage("ingest")
stage.queue({}, {})
worker.sync()

labels = {"stage": "ingest"}
success_labels = {"stage": "ingest", "retries": "0"}

started = REGISTRY.get_sample_value("task_started_total", labels)
succeeded = REGISTRY.get_sample_value("task_succeeded_total", success_labels)

# Under the hood, histogram metrics create multiple time series tracking
# the number and sum of observations, as well as individual histogram buckets.
duration_sum = REGISTRY.get_sample_value("task_duration_seconds_sum", labels)
duration_count = REGISTRY.get_sample_value(
"task_duration_seconds_count",
labels,
)

assert started == 1
assert succeeded == 1
assert duration_sum > 0
assert duration_count == 1

def test_prometheus_failed(self):
conn = get_fakeredis()
worker = FailingWorker(conn=conn, stages=["ingest"])
job = Job.create(conn, "test")
stage = job.get_stage("ingest")
stage.queue({}, {})
labels = {"stage": "ingest"}

worker.sync()

assert REGISTRY.get_sample_value("task_started_total", labels) == 1
assert REGISTRY.get_sample_value(
"task_failed_total",
{
"stage": "ingest",
"retries": "0",
"failed_permanently": "False",
},
)

worker.sync()

assert REGISTRY.get_sample_value("task_started_total", labels) == 2
assert REGISTRY.get_sample_value(
"task_failed_total",
{
"stage": "ingest",
"retries": "1",
"failed_permanently": "False",
},
)

worker.sync()

assert REGISTRY.get_sample_value("task_started_total", labels) == 3
assert REGISTRY.get_sample_value(
"task_failed_total",
{
"stage": "ingest",
"retries": "2",
"failed_permanently": "False",
},
)

worker.sync()

assert REGISTRY.get_sample_value("task_started_total", labels) == 4
assert REGISTRY.get_sample_value(
"task_failed_total",
{
"stage": "ingest",
"retries": "3",
"failed_permanently": "True",
},
)

0 comments on commit ca2d21d

Please sign in to comment.