From 28fab357755bf85cfce07c5d696adf8c47551e4d Mon Sep 17 00:00:00 2001 From: DA-344 <108473820+DA-344@users.noreply.github.com> Date: Wed, 11 Dec 2024 16:30:57 +0100 Subject: [PATCH] chore: Update Client.close to prevent double closing and race conditions --- discord/client.py | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/discord/client.py b/discord/client.py index e6832231b6..8c64fe447b 100644 --- a/discord/client.py +++ b/discord/client.py @@ -259,7 +259,8 @@ def __init__( self._enable_debug_events: bool = options.pop("enable_debug_events", False) self._connection: ConnectionState = self._get_state(**options) self._connection.shard_count = self.shard_count - self._closed: bool = False + self._closed: asyncio.Event = asyncio.Event() + self._closing_task: asyncio.Lock = asyncio.Lock() self._ready: asyncio.Event = asyncio.Event() self._connection._get_websocket = self._get_websocket self._connection._get_client = lambda: self @@ -289,6 +290,7 @@ async def __aenter__(self) -> Client: self._connection.loop = self.loop self._ready = asyncio.Event() + self._closed = asyncio.Event() return self @@ -725,23 +727,24 @@ async def close(self) -> None: Closes the connection to Discord. """ - if self._closed: + if self.is_closed(): return - await self.http.close() - self._closed = True + async with self._closing_task: + await self.http.close() - for voice in self.voice_clients: - try: - await voice.disconnect(force=True) - except Exception: - # if an error happens during disconnects, disregard it. - pass + for voice in self.voice_clients: + try: + await voice.disconnect(force=True) + except Exception: + # if an error happens during disconnects, disregard it. + pass - if self.ws is not None and self.ws.open: - await self.ws.close(code=1000) + if self.ws is not None and self.ws.open: + await self.ws.close(code=1000) - self._ready.clear() + self._ready.clear() + self._closed.set() def clear(self) -> None: """Clears the internal state of the bot. @@ -818,14 +821,14 @@ async def runner(): if not self.is_closed(): self.loop.run_until_complete(self.close()) - _log.info("Cleaning up tasks.") - _cleanup_loop(self.loop) + _log.info("Cleaning up tasks.") + _cleanup_loop(self.loop) # properties def is_closed(self) -> bool: """Indicates if the WebSocket connection is closed.""" - return self._closed + return self._closed.is_set() @property def activity(self) -> ActivityTypes | None: