From 8fc70e9fb149a58fe0480a5f94b0019b6e9cf4d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matheus=20Lu=C3=ADs?= Date: Thu, 25 Apr 2024 09:27:07 -0300 Subject: [PATCH] add ttl cache that kills subprocesses on eviction --- README.md | 2 +- app/main.py | 32 +++++++++++++------------- app/tva.py | 65 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 18 deletions(-) create mode 100644 app/tva.py diff --git a/README.md b/README.md index 6b15f4a..02c78ad 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Simple api to manage tensorboard instances on demand. - [x] setup docker - [x] setup reverse proxy with dynamic paths to tensorboards ports - [x] setup ssl -- [] clean up idle tensorboards +- [x] clean up idle tensorboards - [x] security token on api routes - [] secure tensorboard instances ? - [] test connection w/ frontend diff --git a/app/main.py b/app/main.py index caba272..93f827b 100644 --- a/app/main.py +++ b/app/main.py @@ -1,15 +1,16 @@ from typing import Annotated -import os from fastapi import FastAPI, Depends, Body, HTTPException from .models import CreateTensorboardInstanceRequest, TensorboardInstance from .dependencies import verify_token, config from .tensorboard import start_tensorboard +from .tva import create_mobius -tb_instances: dict[str, TensorboardInstance] = {} hostname = config.hostname +ttl = 300 +get, get_all, set, remove, contains = create_mobius(ttl) app = FastAPI(root_path="/api", dependencies=[Depends(verify_token)]) @@ -22,10 +23,9 @@ def read_root(): @app.get("/tensorboard") def get_tensorboard_instance(name: str): - tb_instance = tb_instances.get(name) - if tb_instance: - return tb_instance - return {"message": "Instance not found"} + if contains(name): + return get(name) + raise HTTPException(status_code=404, detail="Instance not found") @app.post("/tensorboard/start") @@ -34,32 +34,30 @@ def start_tensorboard_instance( ): logdir = request.logdir name = request.name - if name in tb_instances: - return {"message": "Instance already exists"} + if contains(name): + return get(name) try: p, port = start_tensorboard(logdir, name) url = f"http://{hostname}/{name}/{port}" - tb_instance = TensorboardInstance(url=url, logdir=logdir, name=name, pid=p.pid) - tb_instances[name] = tb_instance - return tb_instance + instance = TensorboardInstance(url=url, logdir=logdir, name=name, pid=p.pid) + set(name, instance) + return instance except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/tensorboard/kill/{name}") def kill_tensorboard_instance(name: str): - tb_instance = tb_instances.get(name) - if tb_instance: - os.kill(tb_instance.pid, 9) - del tb_instances[name] + if contains(name): + remove(name) return {"message": "Instance killed"} - return {"message": "Instance not found"} + raise HTTPException(status_code=404, detail="Instance not found") @app.get("/tensorboard/instances") def get_tensorboard_instances(): - return tb_instances + return get_all() if __name__ == "__main__": diff --git a/app/tva.py b/app/tva.py new file mode 100644 index 0000000..adeabe0 --- /dev/null +++ b/app/tva.py @@ -0,0 +1,65 @@ +import os +import time + +from .models import TensorboardInstance + + +def create_mobius(time_branch: int): + variants: dict[str, TensorboardInstance] = {} + access_times: dict[str, float] = {} + + def get(loki: str): + if loki in variants: + current_time = time.time() + + if current_time - access_times[loki] > time_branch: + os.kill(variants[loki].pid, 9) + del variants[loki] + del access_times[loki] + raise KeyError(f"Instance {loki} has expired") + + access_times[loki] = current_time + return variants[loki] + else: + raise KeyError(f"Instance {loki} not found") + + def get_all(): + prune() + current_time = time.time() + for loki in variants: + access_times[loki] = current_time + return variants + + def set(loki: str, instance: TensorboardInstance): + current_time = time.time() + variants[loki] = instance + access_times[loki] = current_time + prune() + + def remove(loki: str): + if contains(loki): + os.kill(variants[loki].pid, 9) + del variants[loki] + del access_times[loki] + + def prune(): + current_time = time.time() + expired = [ + loki + for loki, access_time in access_times.items() + if current_time - access_time > time_branch + ] + + for loki in expired: + os.kill(variants[loki].pid, 9) + del variants[loki] + del access_times[loki] + + def contains(loki: str): + try: + get(loki) + return True + except KeyError: + return False + + return get, get_all, set, remove, contains