Skip to content

Commit

Permalink
feat(main): Add metrics endpoint for prometheus
Browse files Browse the repository at this point in the history
Add basic metrics support for monitoring system, which will be
extended further later.

Signed-off-by: Denys Fedoryshchenko <[email protected]>
  • Loading branch information
nuclearcat committed Oct 17, 2024
1 parent 8431443 commit 0323eae
Showing 1 changed file with 87 additions and 3 deletions.
90 changes: 87 additions & 3 deletions api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
import re
from typing import List, Union, Optional
import threading
from fastapi import (
Depends,
FastAPI,
Expand Down Expand Up @@ -57,6 +58,54 @@
# models etc.
API_VERSIONS = ['v0']


class Metrics():
'''
Class to store and update various metrics
'''
def __init__(self):
'''
Initialize metrics dictionary and lock
'''
self.metrics = {}
self.metrics['http_requests_total'] = 0
self.lock = threading.Lock()

# Various internal metrics
def update(self):
'''
Update metrics (reserved for future use)
'''

def add(self, key, value):
'''
Add a value to a metric
'''
with self.lock:
if key not in self.metrics:
self.metrics[key] = 0
self.metrics[key] += value

def get(self, key):
'''
Get the value of a metric
'''
self.update()
with self.lock:
return self.metrics.get(key, 0)

def all(self):
'''
Get all the metrics
'''
self.update()
with self.lock:
return self.metrics


metrics = Metrics()


app = FastAPI()
db = Database(service=(os.getenv('MONGO_SERVICE') or 'mongodb://db:27017'))
auth = Authentication(token_url="user/login")
Expand Down Expand Up @@ -113,6 +162,7 @@ async def invalid_id_exception_handler(
@app.get('/')
async def root():
"""Root endpoint handler"""
metrics.add('http_requests_total', 1)
return {"message": "KernelCI API"}

# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -153,6 +203,7 @@ async def register(request: Request, user: UserCreate,
This handler will convert them to `UserGroup` objects and
insert user object to database.
"""
metrics.add('http_requests_total', 1)
existing_user = await db.find_one(User, username=user.username)
if existing_user:
raise HTTPException(
Expand Down Expand Up @@ -226,6 +277,7 @@ async def update_me(request: Request, user: UserUpdate,
Custom user update router handler will only allow users to update
its own profile.
"""
metrics.add('http_requests_total', 1)
if user.username and user.username != current_user.username:
existing_user = await db.find_one(User, username=user.username)
if existing_user:
Expand Down Expand Up @@ -254,6 +306,7 @@ async def update_me(request: Request, user: UserUpdate,
async def update_user(user_id: str, request: Request, user: UserUpdate,
current_user: User = Depends(get_current_superuser)):
"""Router to allow admin users to update other user account"""
metrics.add('http_requests_total', 1)
user_from_id = await db.find_by_id(User, user_id)
if not user_from_id:
raise HTTPException(
Expand Down Expand Up @@ -323,6 +376,7 @@ async def get_users(request: Request,
current_user: User = Depends(get_current_user)):
"""Get all the users if no request parameters have passed.
Get the matching users otherwise."""
metrics.add('http_requests_total', 1)
query_params = dict(request.query_params)
# Drop pagination parameters from query as they're already in arguments
for pg_key in ['limit', 'offset']:
Expand All @@ -338,6 +392,7 @@ async def update_password(request: Request,
credentials: OAuth2PasswordRequestForm = Depends(),
new_password: str = Form(None)):
"""Update user password"""
metrics.add('http_requests_total', 1)
user = await user_manager.authenticate(credentials)
if user is None or not user.is_active:
raise HTTPException(
Expand All @@ -360,6 +415,7 @@ async def post_user_group(
group: UserGroup,
current_user: User = Depends(get_current_superuser)):
"""Create new user group"""
metrics.add('http_requests_total', 1)
try:
obj = await db.create(group)
except DuplicateKeyError as error:
Expand All @@ -377,6 +433,7 @@ async def post_user_group(
async def get_user_groups(request: Request):
"""Get all the user groups if no request parameters have passed.
Get all the matching user groups otherwise."""
metrics.add('http_requests_total', 1)
query_params = dict(request.query_params)

# Drop pagination parameters from query as they're already in arguments
Expand All @@ -393,13 +450,15 @@ async def get_user_groups(request: Request):
response_model_by_alias=False)
async def get_group(group_id: str):
"""Get user group information from the provided group id"""
metrics.add('http_requests_total', 1)
return await db.find_by_id(UserGroup, group_id)


@app.delete('/group/{group_id}', response_model=PageModel)
async def delete_group(group_id: str,
current_user: User = Depends(get_current_superuser)):
"""Delete user group matching the provided group id"""
metrics.add('http_requests_total', 1)
group_from_id = await db.find_by_id(UserGroup, group_id)
if not group_from_id:
raise HTTPException(
Expand Down Expand Up @@ -455,6 +514,7 @@ async def translate_null_query_params(query_params: dict):
response_model_by_alias=False)
async def get_node(node_id: str):
"""Get node information from the provided node id"""
metrics.add('http_requests_total', 1)
try:
return await db.find_by_id(Node, node_id)
except KeyError as error:
Expand Down Expand Up @@ -483,6 +543,7 @@ def serialize_paginated_data(model, data: list):
async def get_nodes(request: Request):
"""Get all the nodes if no request parameters have passed.
Get all the matching nodes otherwise, within the pagination limit."""
metrics.add('http_requests_total', 1)
query_params = dict(request.query_params)

# Drop pagination parameters from query as they're already in arguments
Expand Down Expand Up @@ -513,6 +574,7 @@ async def get_nodes(request: Request):
async def get_nodes_count(request: Request):
"""Get the count of all the nodes if no request parameters have passed.
Get the count of all the matching nodes otherwise."""
metrics.add('http_requests_total', 1)
query_params = dict(request.query_params)

query_params = await translate_null_query_params(query_params)
Expand Down Expand Up @@ -555,6 +617,7 @@ async def post_node(node: Node,
authorization: str | None = Header(default=None),
current_user: User = Depends(get_current_user)):
"""Create a new node"""
metrics.add('http_requests_total', 1)
# [TODO] Remove translation below once we can use it in the pipeline
node = _translate_version_fields(node)

Expand Down Expand Up @@ -589,6 +652,7 @@ async def post_node(node: Node,
async def put_node(node_id: str, node: Node,
user: str = Depends(authorize_user)):
"""Update an already added node"""
metrics.add('http_requests_total', 1)
node.id = ObjectId(node_id)
node_from_id = await db.find_by_id(Node, node_id)
if not node_from_id:
Expand Down Expand Up @@ -647,6 +711,7 @@ async def put_nodes(
authorization: str | None = Header(default=None),
user: str = Depends(authorize_user)):
"""Add a hierarchy of nodes to an existing root node"""
metrics.add('http_requests_total', 1)
nodes.node.id = ObjectId(node_id)
# Retrieve the root node from the DB and submitter
node_from_id = await db.find_by_id(Node, node_id)
Expand Down Expand Up @@ -675,6 +740,7 @@ async def put_nodes(
async def subscribe(channel: str, user: User = Depends(get_current_user),
promisc: Optional[bool] = Query(None)):
"""Subscribe handler for Pub/Sub channel"""
metrics.add('http_requests_total', 1)
options = {}
if promisc:
options['promiscuous'] = promisc
Expand All @@ -684,6 +750,7 @@ async def subscribe(channel: str, user: User = Depends(get_current_user),
@app.post('/unsubscribe/{sub_id}')
async def unsubscribe(sub_id: int, user: User = Depends(get_current_user)):
"""Unsubscribe handler for Pub/Sub channel"""
metrics.add('http_requests_total', 1)
try:
await pubsub.unsubscribe(sub_id, user.username)
except KeyError as error:
Expand All @@ -701,6 +768,7 @@ async def unsubscribe(sub_id: int, user: User = Depends(get_current_user)):
@app.get('/listen/{sub_id}')
async def listen(sub_id: int, user: User = Depends(get_current_user)):
"""Listen messages from a subscribed Pub/Sub channel"""
metrics.add('http_requests_total', 1)
try:
return await pubsub.listen(sub_id, user.username)
except KeyError as error:
Expand All @@ -719,6 +787,7 @@ async def listen(sub_id: int, user: User = Depends(get_current_user)):
async def publish(event: PublishEvent, channel: str,
user: User = Depends(get_current_user)):
"""Publish an event on the provided Pub/Sub channel"""
metrics.add('http_requests_total', 1)
event_dict = PublishEvent.dict(event)
# 1 - Extract data and attributes from the event
# 2 - Add the owner as an extra attribute
Expand All @@ -739,6 +808,7 @@ async def publish(event: PublishEvent, channel: str,
async def push(raw: dict, list_name: str,
user: User = Depends(get_current_user)):
"""Push a message on the provided list"""
metrics.add('http_requests_total', 1)
attributes = dict(raw)
data = attributes.pop('data')
await pubsub.push_cloudevent(list_name, data, attributes)
Expand All @@ -747,20 +817,22 @@ async def push(raw: dict, list_name: str,
@app.get('/pop/{list_name}')
async def pop(list_name: str, user: User = Depends(get_current_user)):
"""Pop a message from a given list"""
metrics.add('http_requests_total', 1)
return await pubsub.pop(list_name)


@app.get('/stats/subscriptions', response_model=List[SubscriptionStats])
async def stats(user: User = Depends(get_current_superuser)):
"""Get details of all existing subscriptions"""
metrics.add('http_requests_total', 1)
return await pubsub.subscription_stats()


@app.get('/viewer')
async def viewer():
"""Serve simple HTML page to view the API /static/viewer.html
Set various no-cache tag we might update it often"""

metrics.add('http_requests_total', 1)
root_dir = os.path.dirname(os.path.abspath(__file__))
viewer_path = os.path.join(root_dir, 'templates', 'viewer.html')
with open(viewer_path, 'r', encoding='utf-8') as file:
Expand All @@ -778,7 +850,7 @@ async def viewer():
async def dashboard():
"""Serve simple HTML page to view the API dashboard.html
Set various no-cache tag we might update it often"""

metrics.add('http_requests_total', 1)
root_dir = os.path.dirname(os.path.abspath(__file__))
dashboard_path = os.path.join(root_dir, 'templates', 'dashboard.html')
with open(dashboard_path, 'r', encoding='utf-8') as file:
Expand All @@ -795,7 +867,7 @@ async def dashboard():
@app.get('/manage')
async def manage():
"""Serve simple HTML page to submit custom nodes"""

metrics.add('http_requests_total', 1)
root_dir = os.path.dirname(os.path.abspath(__file__))
manage_path = os.path.join(root_dir, 'templates', 'manage.html')
with open(manage_path, 'r', encoding='utf-8') as file:
Expand All @@ -812,6 +884,7 @@ async def manage():
@app.get('/icons/{icon_name}')
async def icons(icon_name: str):
"""Serve icons from /static/icons"""
metrics.add('http_requests_total', 1)
root_dir = os.path.dirname(os.path.abspath(__file__))
if not re.match(r'^[A-Za-z0-9_.-]+\.png$', icon_name):
raise HTTPException(
Expand All @@ -821,6 +894,17 @@ async def icons(icon_name: str):
icon_path = os.path.join(root_dir, 'templates', icon_name)
return FileResponse(icon_path)

@app.get('/metrics')
async def get_metrics():
"""Get metrics"""
metrics.add('http_requests_total', 1)
# return metrics as plaintext in prometheus format
all_metrics = metrics.all()
response = ''
for key, value in all_metrics.items():
response += f'{key}{{instance="api"}} {value}\n'
return PlainTextResponse(response)


versioned_app = VersionedFastAPI(
app,
Expand Down

0 comments on commit 0323eae

Please sign in to comment.