diff --git a/examples/dl_example.py b/examples/dl_example.py index 1267824..d89e9ab 100644 --- a/examples/dl_example.py +++ b/examples/dl_example.py @@ -22,6 +22,7 @@ class Checkpoints(BaseModel): class State(BaseModel): + training_stopped: bool = False lr: float = 1e-3 checkpoints: Checkpoints = Checkpoints() model: Model = Model() @@ -39,11 +40,14 @@ def epoch_loop(config: State, current_epoch: int): basicConfig(level="INFO") state = State() - control(state) + server = control(state) current_epoch = 0 - while True: + while not state.training_stopped: print(f"state: {state}") epoch_loop(state, current_epoch) current_epoch += 1 + else: + print("Training stopped!") + server.stop() diff --git a/freak/freak.py b/freak/freak.py index d21b0d5..c3a3697 100644 --- a/freak/freak.py +++ b/freak/freak.py @@ -8,7 +8,7 @@ from starlette.responses import JSONResponse from uvicorn import Config -from freak.uvicorn_threaded import UvicornServer +from freak.uvicorn_threaded import ThreadedUvicorn logger = getLogger(__name__) @@ -81,6 +81,8 @@ def __init__( self.port = port self.uvicorn_log_level = uvicorn_log_level + self.should_stop = False + def control(self, state: T, serve: bool = True): if not state.Config.allow_mutation: state.Config.allow_mutation = True @@ -94,14 +96,19 @@ def control(self, state: T, serve: bool = True): self.serve() def serve(self): - self.server = UvicornServer( + self.server = ThreadedUvicorn( config=Config(app=self.app, host=self.host, port=self.port, log_level=self.uvicorn_log_level) ) - self.server.run_in_thread() - # logger.info(f"Running Freak on http://{self.host}:{self.port}") + self.server.start() + logger.info(f"Running Freak at {self.host}:{self.port}") def stop(self): - self.server.cleanup() + logger.info("Stopping Freak Server") + self.server.stop() + + @property + def running(self) -> bool: + return self.server.thread.is_alive() def add_routes(self, app: FastAPI, state: T) -> FastAPI: init_state = state.copy(deep=True) @@ -110,10 +117,6 @@ def add_routes(self, app: FastAPI, state: T) -> FastAPI: state_name = state.__repr_name__() - @router.post("/stop", description="Stop the Freak server", tags=["stop"]) - async def stop_server(): # pyright: ignore - self.stop() - @router.get("/get", description=f"Get the whole {state_name}", tags=[state_name]) async def get_state() -> type(state): # pyright: ignore return state diff --git a/freak/uvicorn_threaded.py b/freak/uvicorn_threaded.py index 5bf7ddf..cdf7619 100644 --- a/freak/uvicorn_threaded.py +++ b/freak/uvicorn_threaded.py @@ -1,66 +1,26 @@ +# taken from https://github.com/encode/uvicorn/discussions/1103#discussioncomment-6187606 + +import asyncio import threading -import time import uvicorn -# this code is taken from freqtrade - - -def asyncio_setup() -> None: # pragma: no cover - # Set eventloop for win32 setups - # Reverts a change done in uvicorn 0.15.0 - which now sets the eventloop - # via policy. - import sys - - if sys.version_info >= (3, 8) and sys.platform == "win32": - import asyncio - import selectors - - selector = selectors.SelectSelector() - loop = asyncio.SelectorEventLoop(selector) - asyncio.set_event_loop(loop) - - -class UvicornServer(uvicorn.Server): - """ - Multithreaded server - as found in https://github.com/encode/uvicorn/issues/742 - Removed install_signal_handlers() override based on changes from this commit: - https://github.com/encode/uvicorn/commit/ce2ef45a9109df8eae038c0ec323eb63d644cbc6 +class ThreadedUvicorn: + def __init__(self, config: uvicorn.Config): + self.server = uvicorn.Server(config) + self.thread = threading.Thread(daemon=True, target=self.server.run) - Cannot rely on asyncio.get_event_loop() to create new event loop because of this check: - https://github.com/python/cpython/blob/4d7f11e05731f67fd2c07ec2972c6cb9861d52be/Lib/asyncio/events.py#L638 - - Fix by overriding run() and forcing creation of new event loop if uvloop is available - """ - - def run(self, sockets=None): - import asyncio - - """ - Parent implementation calls self.config.setup_event_loop(), - but we need to create uvloop event loop manually - """ - try: - import uvloop # pyright: ignore[reportMissingImports] - except ImportError: # pragma: no cover - asyncio_setup() - else: - asyncio.set_event_loop(uvloop.new_event_loop()) - try: - loop = asyncio.get_running_loop() - except RuntimeError: - # When running in a thread, we'll not have an eventloop yet. - loop = asyncio.new_event_loop() - - loop.run_until_complete(self.serve(sockets=sockets)) - - def run_in_thread(self): - self.thread = threading.Thread(target=self.run) + def start(self): self.thread.start() - while not self.started: - time.sleep(1e-3) + asyncio.run(self.wait_for_started()) + + async def wait_for_started(self): + while not self.server.started: + await asyncio.sleep(0.1) - def cleanup(self): - self.should_exit = True - self.thread.join() + def stop(self): + if self.thread.is_alive(): + self.server.should_exit = True + while self.thread.is_alive(): + continue