diff --git a/lib/core/app.py b/lib/core/app.py index 179e5df..d513855 100644 --- a/lib/core/app.py +++ b/lib/core/app.py @@ -21,8 +21,8 @@ def __init__( def run(self): self._setup_routes() - self._bot.listen('on_connect')(self._injector.connect) - self._bot.listen('on_disconnect')(self._injector.disconnect) + self._bot.on_startup.append(self._injector.connect) + self._bot.on_shutdown.append(self._injector.disconnect) self._logger.info('Running bot...') self._bot.run() diff --git a/lib/core/bot.py b/lib/core/bot.py index ef376e2..ae8d351 100644 --- a/lib/core/bot.py +++ b/lib/core/bot.py @@ -35,6 +35,9 @@ def __init__(self, *args, logger: Logger, disconnect_timeout: float = 5.0, **kwa self._disconnect_timeout = disconnect_timeout self._consumers: List[asyncio.Task] = [] + self.on_startup: List[Callable] = [] + self.on_shutdown: List[Callable] = [] + @classmethod def __from_config__(cls, config: Dict[str, Dict], **clients) -> ClientProtocol: kwargs: Dict[str, Any] = config.get(cls.CONFIG_NAME, {}) @@ -44,7 +47,8 @@ async def __disconnect__(self): for consumer in self._consumers: consumer.cancel() - await asyncio.wait(self._consumers, timeout=self._disconnect_timeout) + if self._consumers: + await asyncio.wait(self._consumers, timeout=self._disconnect_timeout) def run(self, token: Optional[str] = None, **kwargs): if not token: @@ -52,6 +56,14 @@ def run(self, token: Optional[str] = None, **kwargs): super().run(token, **kwargs) + async def start(self, *args, **kwargs): + await self._do_startup() + + try: + await super().start(*args, **kwargs) + finally: + await self._do_shutdown() + async def queue_task(self, queue_name: str, task: Callable[[], Awaitable]) -> int: if queue_name not in self._queues: queue = Queue() @@ -74,6 +86,23 @@ def clear_queue(self, queue_name: str): if q: q.clear_queue() + async def _do_startup(self): + for handler in self.on_startup: + if asyncio.iscoroutinefunction(handler): + await handler() + else: + handler() + + async def _do_shutdown(self): + for handler in self.on_shutdown: + try: + if asyncio.iscoroutinefunction(handler): + await handler() + else: + handler() + except Exception: + self._logger.exception('Failed to execute on shutdown handler') + async def _consume_queue(self, q: Queue): async def _consume(): while self.loop.is_running(): @@ -86,5 +115,7 @@ async def _consume(): q.task_done() + self._logger.info('Finished consuming') + consumer = asyncio.create_task(_consume()) self._consumers.append(consumer)