Skip to content

Commit

Permalink
WIP: temp commit
Browse files Browse the repository at this point in the history
  • Loading branch information
dormant-user committed Sep 5, 2024
1 parent a72bf84 commit ab7bf1c
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 22 deletions.
6 changes: 6 additions & 0 deletions pyninja/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ <h4>Swap Usage</h4>
</div>

<script>
function getCookie(name) {
const value = `; ${document.cookie}`;
const parts = value.split(`; ${name}=`);
if (parts.length === 2) return parts.pop().split(';').shift();
}
const token = getCookie('session_token');
const wsProtocol = window.location.protocol === "https:" ? "wss" : "ws";
const wsHost = window.location.host;
const ws = new WebSocket(`${wsProtocol}://${wsHost}/ws/system`);
Expand Down
4 changes: 4 additions & 0 deletions pyninja/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,15 @@ class WSSettings(BaseModel):
"""

client_auth: Dict[str, str] = {}
template: FilePath = os.path.join(pathlib.Path(__file__).parent, "index.html")
cpu_interval: PositiveInt = 3
refresh_interval: PositiveInt = 5


ws_settings = WSSettings()


def get_service_manager() -> ServiceManager:
"""Get service manager filepath for the host operating system.
Expand Down
59 changes: 37 additions & 22 deletions pyninja/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import jinja2
import psutil
from fastapi import Depends, Header, Request
from fastapi import Depends, Header, Request, Response
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.routing import APIRoute, APIWebSocketRoute
from fastapi.security import (
Expand Down Expand Up @@ -340,7 +340,10 @@ async def health():
raise exceptions.APIResponse(status_code=HTTPStatus.OK, detail=HTTPStatus.OK.phrase)


def verify_monitor_creds(credentials: HTTPBasicCredentials = Depends(BASIC_AUTH)):
def verify_monitor_creds(
request: Request,
credentials: HTTPBasicCredentials = Depends(BASIC_AUTH)
):
"""Verify credentials for monitoring page.
Args:
Expand All @@ -353,7 +356,25 @@ def verify_monitor_creds(credentials: HTTPBasicCredentials = Depends(BASIC_AUTH)
credentials.password, models.env.monitor_pass
)
if username and password:
return credentials.username
LOGGER.info(
"Connection received from client-host: %s, host-header: %s, x-fwd-host: %s",
request.client.host,
request.headers.get("host"),
request.headers.get("x-forwarded-host"),
)
if user_agent := request.headers.get("user-agent"):
LOGGER.info("User agent: %s", user_agent)
client_token = {"username": credentials.username, "token": squire.keygen(), "timestamp": int(time.time())}
auth_payload = str(client_token).encode("utf-8")
expiration = int(time.time()) + 3_600 # Set to 1 hour
response = Response()
response.set_cookie(key="session_token",
value=auth_payload,
samesite="strict",
path="/",
httponly=False,
expires=expiration)
return response
raise exceptions.APIResponse(
status_code=HTTPStatus.UNAUTHORIZED,
detail="Incorrect username or password",
Expand All @@ -368,15 +389,14 @@ async def monitor():
HTMLResponse:
Returns an HTML response templated using Jinja2.
"""
ws_settings = models.WSSettings()
with open(ws_settings.template) as file:
with open(models.ws_settings.template) as file:
template_file = file.read()
template = jinja2.Template(template_file)
rendered = template.render(
HOST=models.env.ninja_host,
PORT=models.env.ninja_port,
DEFAULT_CPU_INTERVAL=ws_settings.cpu_interval,
DEFAULT_REFRESH_INTERVAL=ws_settings.refresh_interval,
DEFAULT_CPU_INTERVAL=models.ws_settings.cpu_interval,
DEFAULT_REFRESH_INTERVAL=models.ws_settings.refresh_interval,
)
return HTMLResponse(rendered)

Expand All @@ -387,31 +407,26 @@ async def websocket_endpoint(websocket: WebSocket):
Args:
websocket: Reference to the websocket object.
"""
credentials = await BASIC_AUTH(websocket)
try:
verify_monitor_creds(credentials)
except exceptions.APIResponse:
await websocket.close(code=1008)
return
await websocket.accept()
print(websocket.cookies)
print(websocket.headers.get("cookie"))
refresh_time = time.time()
ws_settings = models.WSSettings()
LOGGER.info("Websocket settings: %s", ws_settings.model_dump_json())
data = squire.system_resources(ws_settings.cpu_interval)
LOGGER.info("Intervals: {'CPU': %s, 'refresh': %s}", models.ws_settings.cpu_interval, models.ws_settings.refresh_interval)
data = squire.system_resources(models.ws_settings.cpu_interval)
while True:
if websocket.application_state == WebSocketState.CONNECTED:
try:
msg = await asyncio.wait_for(websocket.receive_text(), timeout=0.1)
if msg.startswith("refresh_interval:"):
ws_settings.refresh_interval = int(msg.split(":")[1].strip())
models.ws_settings.refresh_interval = int(msg.split(":")[1].strip())
LOGGER.info(
"Updating refresh interval to %s seconds",
ws_settings.refresh_interval,
models.ws_settings.refresh_interval,
)
elif msg.startswith("cpu_interval"):
ws_settings.cpu_interval = int(msg.split(":")[1].strip())
models.ws_settings.cpu_interval = int(msg.split(":")[1].strip())
LOGGER.info(
"Updating CPU interval to %s seconds", ws_settings.cpu_interval
"Updating CPU interval to %s seconds", models.ws_settings.cpu_interval
)
else:
LOGGER.error("Invalid WS message received: %s", msg)
Expand All @@ -420,10 +435,10 @@ async def websocket_endpoint(websocket: WebSocket):
pass
except WebSocketDisconnect:
break
if time.time() - refresh_time > ws_settings.refresh_interval:
if time.time() - refresh_time > models.ws_settings.refresh_interval:
refresh_time = time.time()
LOGGER.debug("Fetching new charts")
data = squire.system_resources(ws_settings.cpu_interval)
data = squire.system_resources(models.ws_settings.cpu_interval)
try:
await websocket.send_json(data)
await asyncio.sleep(1)
Expand Down
11 changes: 11 additions & 0 deletions pyninja/squire.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import logging
import math
import secrets
import os
import pathlib
import re
Expand Down Expand Up @@ -194,3 +195,13 @@ def load_env(**kwargs) -> EnvConfig:
file_env = {}
merged_env = {**file_env, **kwargs}
return EnvConfig(**merged_env)


def keygen() -> str:
"""Generate session token from secrets module, so that users are forced to log in when the server restarts.
Returns:
str:
Returns a URL safe 64-bit token.
"""
return secrets.token_urlsafe(64)

0 comments on commit ab7bf1c

Please sign in to comment.