Skip to content

Commit

Permalink
fix stopping
Browse files Browse the repository at this point in the history
  • Loading branch information
danielgafni committed Nov 7, 2023
1 parent 2616fc3 commit 528f686
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 69 deletions.
8 changes: 6 additions & 2 deletions examples/dl_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Checkpoints(BaseModel):


class State(BaseModel):
training_stopped: bool = False
lr: float = 1e-3
checkpoints: Checkpoints = Checkpoints()
model: Model = Model()
Expand All @@ -39,11 +40,14 @@ def epoch_loop(config: State, current_epoch: int):
basicConfig(level="INFO")

state = State()
control(state)
server = control(state)

current_epoch = 0

while True:
while not state.training_stopped:
print(f"state: {state}")
epoch_loop(state, current_epoch)
current_epoch += 1
else:
print("Training stopped!")
server.stop()
21 changes: 12 additions & 9 deletions freak/freak.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from starlette.responses import JSONResponse
from uvicorn import Config

from freak.uvicorn_threaded import UvicornServer
from freak.uvicorn_threaded import ThreadedUvicorn

logger = getLogger(__name__)

Expand Down Expand Up @@ -81,6 +81,8 @@ def __init__(
self.port = port
self.uvicorn_log_level = uvicorn_log_level

self.should_stop = False

def control(self, state: T, serve: bool = True):
if not state.Config.allow_mutation:
state.Config.allow_mutation = True
Expand All @@ -94,14 +96,19 @@ def control(self, state: T, serve: bool = True):
self.serve()

def serve(self):
self.server = UvicornServer(
self.server = ThreadedUvicorn(
config=Config(app=self.app, host=self.host, port=self.port, log_level=self.uvicorn_log_level)
)
self.server.run_in_thread()
# logger.info(f"Running Freak on http://{self.host}:{self.port}")
self.server.start()
logger.info(f"Running Freak at {self.host}:{self.port}")

def stop(self):
self.server.cleanup()
logger.info("Stopping Freak Server")
self.server.stop()

@property
def running(self) -> bool:
return self.server.thread.is_alive()

def add_routes(self, app: FastAPI, state: T) -> FastAPI:
init_state = state.copy(deep=True)
Expand All @@ -110,10 +117,6 @@ def add_routes(self, app: FastAPI, state: T) -> FastAPI:

state_name = state.__repr_name__()

@router.post("/stop", description="Stop the Freak server", tags=["stop"])
async def stop_server(): # pyright: ignore
self.stop()

@router.get("/get", description=f"Get the whole {state_name}", tags=[state_name])
async def get_state() -> type(state): # pyright: ignore
return state
Expand Down
76 changes: 18 additions & 58 deletions freak/uvicorn_threaded.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,26 @@
# taken from https://github.com/encode/uvicorn/discussions/1103#discussioncomment-6187606

import asyncio
import threading
import time

import uvicorn

# this code is taken from freqtrade


def asyncio_setup() -> None: # pragma: no cover
# Set eventloop for win32 setups
# Reverts a change done in uvicorn 0.15.0 - which now sets the eventloop
# via policy.
import sys

if sys.version_info >= (3, 8) and sys.platform == "win32":
import asyncio
import selectors

selector = selectors.SelectSelector()
loop = asyncio.SelectorEventLoop(selector)
asyncio.set_event_loop(loop)


class UvicornServer(uvicorn.Server):
"""
Multithreaded server - as found in https://github.com/encode/uvicorn/issues/742

Removed install_signal_handlers() override based on changes from this commit:
https://github.com/encode/uvicorn/commit/ce2ef45a9109df8eae038c0ec323eb63d644cbc6
class ThreadedUvicorn:
def __init__(self, config: uvicorn.Config):
self.server = uvicorn.Server(config)
self.thread = threading.Thread(daemon=True, target=self.server.run)

Cannot rely on asyncio.get_event_loop() to create new event loop because of this check:
https://github.com/python/cpython/blob/4d7f11e05731f67fd2c07ec2972c6cb9861d52be/Lib/asyncio/events.py#L638
Fix by overriding run() and forcing creation of new event loop if uvloop is available
"""

def run(self, sockets=None):
import asyncio

"""
Parent implementation calls self.config.setup_event_loop(),
but we need to create uvloop event loop manually
"""
try:
import uvloop # pyright: ignore[reportMissingImports]
except ImportError: # pragma: no cover
asyncio_setup()
else:
asyncio.set_event_loop(uvloop.new_event_loop())
try:
loop = asyncio.get_running_loop()
except RuntimeError:
# When running in a thread, we'll not have an eventloop yet.
loop = asyncio.new_event_loop()

loop.run_until_complete(self.serve(sockets=sockets))

def run_in_thread(self):
self.thread = threading.Thread(target=self.run)
def start(self):
self.thread.start()
while not self.started:
time.sleep(1e-3)
asyncio.run(self.wait_for_started())

async def wait_for_started(self):
while not self.server.started:
await asyncio.sleep(0.1)

def cleanup(self):
self.should_exit = True
self.thread.join()
def stop(self):
if self.thread.is_alive():
self.server.should_exit = True
while self.thread.is_alive():
continue

0 comments on commit 528f686

Please sign in to comment.