Skip to content

Commit

Permalink
add ttl cache that kills subprocesses on eviction
Browse files Browse the repository at this point in the history
  • Loading branch information
matyson committed Apr 25, 2024
1 parent b2ecd3f commit 8fc70e9
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 18 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Simple api to manage tensorboard instances on demand.
- [x] setup docker
- [x] setup reverse proxy with dynamic paths to tensorboards ports
- [x] setup ssl
- [] clean up idle tensorboards
- [x] clean up idle tensorboards
- [x] security token on api routes
- [] secure tensorboard instances ?
- [] test connection w/ frontend
Expand Down
32 changes: 15 additions & 17 deletions app/main.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from typing import Annotated
import os

from fastapi import FastAPI, Depends, Body, HTTPException

from .models import CreateTensorboardInstanceRequest, TensorboardInstance
from .dependencies import verify_token, config
from .tensorboard import start_tensorboard
from .tva import create_mobius


tb_instances: dict[str, TensorboardInstance] = {}
hostname = config.hostname
ttl = 300
get, get_all, set, remove, contains = create_mobius(ttl)


app = FastAPI(root_path="/api", dependencies=[Depends(verify_token)])
Expand All @@ -22,10 +23,9 @@ def read_root():

@app.get("/tensorboard")
def get_tensorboard_instance(name: str):
tb_instance = tb_instances.get(name)
if tb_instance:
return tb_instance
return {"message": "Instance not found"}
if contains(name):
return get(name)
raise HTTPException(status_code=404, detail="Instance not found")


@app.post("/tensorboard/start")
Expand All @@ -34,32 +34,30 @@ def start_tensorboard_instance(
):
logdir = request.logdir
name = request.name
if name in tb_instances:
return {"message": "Instance already exists"}
if contains(name):
return get(name)

try:
p, port = start_tensorboard(logdir, name)
url = f"http://{hostname}/{name}/{port}"
tb_instance = TensorboardInstance(url=url, logdir=logdir, name=name, pid=p.pid)
tb_instances[name] = tb_instance
return tb_instance
instance = TensorboardInstance(url=url, logdir=logdir, name=name, pid=p.pid)
set(name, instance)
return instance
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))


@app.post("/tensorboard/kill/{name}")
def kill_tensorboard_instance(name: str):
tb_instance = tb_instances.get(name)
if tb_instance:
os.kill(tb_instance.pid, 9)
del tb_instances[name]
if contains(name):
remove(name)
return {"message": "Instance killed"}
return {"message": "Instance not found"}
raise HTTPException(status_code=404, detail="Instance not found")


@app.get("/tensorboard/instances")
def get_tensorboard_instances():
return tb_instances
return get_all()


if __name__ == "__main__":
Expand Down
65 changes: 65 additions & 0 deletions app/tva.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os
import time

from .models import TensorboardInstance


def create_mobius(time_branch: int):
variants: dict[str, TensorboardInstance] = {}
access_times: dict[str, float] = {}

def get(loki: str):
if loki in variants:
current_time = time.time()

if current_time - access_times[loki] > time_branch:
os.kill(variants[loki].pid, 9)
del variants[loki]
del access_times[loki]
raise KeyError(f"Instance {loki} has expired")

access_times[loki] = current_time
return variants[loki]
else:
raise KeyError(f"Instance {loki} not found")

def get_all():
prune()
current_time = time.time()
for loki in variants:
access_times[loki] = current_time
return variants

def set(loki: str, instance: TensorboardInstance):
current_time = time.time()
variants[loki] = instance
access_times[loki] = current_time
prune()

def remove(loki: str):
if contains(loki):
os.kill(variants[loki].pid, 9)
del variants[loki]
del access_times[loki]

def prune():
current_time = time.time()
expired = [
loki
for loki, access_time in access_times.items()
if current_time - access_time > time_branch
]

for loki in expired:
os.kill(variants[loki].pid, 9)
del variants[loki]
del access_times[loki]

def contains(loki: str):
try:
get(loki)
return True
except KeyError:
return False

return get, get_all, set, remove, contains

0 comments on commit 8fc70e9

Please sign in to comment.