Skip to content

Commit

Permalink
Merge pull request #6 from RB387/dev
Browse files Browse the repository at this point in the history
Bug fix consumers disconnect
  • Loading branch information
RB387 authored Oct 9, 2021
2 parents 9b1c534 + 4e47940 commit bee5ef8
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 3 deletions.
4 changes: 2 additions & 2 deletions lib/core/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ def __init__(
def run(self):
self._setup_routes()

self._bot.listen('on_connect')(self._injector.connect)
self._bot.listen('on_disconnect')(self._injector.disconnect)
self._bot.on_startup.append(self._injector.connect)
self._bot.on_shutdown.append(self._injector.disconnect)

self._logger.info('Running bot...')
self._bot.run()
Expand Down
33 changes: 32 additions & 1 deletion lib/core/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def __init__(self, *args, logger: Logger, disconnect_timeout: float = 5.0, **kwa
self._disconnect_timeout = disconnect_timeout
self._consumers: List[asyncio.Task] = []

self.on_startup: List[Callable] = []
self.on_shutdown: List[Callable] = []

@classmethod
def __from_config__(cls, config: Dict[str, Dict], **clients) -> ClientProtocol:
kwargs: Dict[str, Any] = config.get(cls.CONFIG_NAME, {})
Expand All @@ -44,14 +47,23 @@ async def __disconnect__(self):
for consumer in self._consumers:
consumer.cancel()

await asyncio.wait(self._consumers, timeout=self._disconnect_timeout)
if self._consumers:
await asyncio.wait(self._consumers, timeout=self._disconnect_timeout)

def run(self, token: Optional[str] = None, **kwargs):
if not token:
token = os.environ['DISCORD_TOKEN']

super().run(token, **kwargs)

async def start(self, *args, **kwargs):
await self._do_startup()

try:
await super().start(*args, **kwargs)
finally:
await self._do_shutdown()

async def queue_task(self, queue_name: str, task: Callable[[], Awaitable]) -> int:
if queue_name not in self._queues:
queue = Queue()
Expand All @@ -74,6 +86,23 @@ def clear_queue(self, queue_name: str):
if q:
q.clear_queue()

async def _do_startup(self):
for handler in self.on_startup:
if asyncio.iscoroutinefunction(handler):
await handler()
else:
handler()

async def _do_shutdown(self):
for handler in self.on_shutdown:
try:
if asyncio.iscoroutinefunction(handler):
await handler()
else:
handler()
except Exception:
self._logger.exception('Failed to execute on shutdown handler')

async def _consume_queue(self, q: Queue):
async def _consume():
while self.loop.is_running():
Expand All @@ -86,5 +115,7 @@ async def _consume():

q.task_done()

self._logger.info('Finished consuming')

consumer = asyncio.create_task(_consume())
self._consumers.append(consumer)

0 comments on commit bee5ef8

Please sign in to comment.