Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Update Client.run to have a better async I/O usage #2645

Open
wants to merge 18 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ These changes are available on the `master` branch, but have not yet been releas
apps. ([#2650](https://github.com/Pycord-Development/pycord/pull/2650))
- Fixed type annotations of cached properties.
([#2635](https://github.com/Pycord-Development/pycord/issues/2635))
- Fixed Async I/O errors that could be raised when using `Client.run`.
([#2645](https://github.com/Pycord-Development/pycord/pull/2645))

### Changed

Expand Down
123 changes: 66 additions & 57 deletions discord/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

import asyncio
import logging
import signal
import sys
import traceback
from types import TracebackType
Expand Down Expand Up @@ -122,6 +121,12 @@ class Client:

A number of options can be passed to the :class:`Client`.

.. container:: operations

.. describe:: async with x

Asynchronously initializes the client.

Parameters
-----------
max_messages: Optional[:class:`int`]
Expand Down Expand Up @@ -221,14 +226,12 @@ class Client:
def __init__(
self,
*,
loop: asyncio.AbstractEventLoop | None = None,
loop: asyncio.AbstractEventLoop = MISSING,
**options: Any,
):
# self.ws is set in the connect method
self.ws: DiscordWebSocket = None # type: ignore
self.loop: asyncio.AbstractEventLoop = (
asyncio.get_event_loop() if loop is None else loop
)
self.loop: asyncio.AbstractEventLoop = loop
self._listeners: dict[str, list[tuple[asyncio.Future, Callable[..., bool]]]] = (
{}
)
Expand Down Expand Up @@ -256,7 +259,8 @@ def __init__(
self._enable_debug_events: bool = options.pop("enable_debug_events", False)
self._connection: ConnectionState = self._get_state(**options)
self._connection.shard_count = self.shard_count
self._closed: bool = False
self._closed: asyncio.Event = asyncio.Event()
self._closing_task: asyncio.Lock = asyncio.Lock()
self._ready: asyncio.Event = asyncio.Event()
self._connection._get_websocket = self._get_websocket
self._connection._get_client = lambda: self
Expand All @@ -270,12 +274,23 @@ def __init__(
self._tasks = set()

async def __aenter__(self) -> Client:
loop = asyncio.get_running_loop()
self.loop = loop
self.http.loop = loop
self._connection.loop = loop
if self.loop is MISSING:
try:
self.loop = asyncio.get_running_loop()
except RuntimeError:
# No event loop was found, this should not happen
# because entering on this context manager means a
# loop is already active, but we need to handle it
# anyways just to prevent future errors.

# Maybe handle different system event loop policies?
self.loop = asyncio.new_event_loop()

self.http.loop = self.loop
self._connection.loop = self.loop

self._ready = asyncio.Event()
self._closed = asyncio.Event()

return self

Expand Down Expand Up @@ -712,23 +727,24 @@ async def close(self) -> None:

Closes the connection to Discord.
"""
if self._closed:
if self.is_closed():
return

await self.http.close()
self._closed = True
async with self._closing_task:
await self.http.close()

for voice in self.voice_clients:
try:
await voice.disconnect(force=True)
except Exception:
# if an error happens during disconnects, disregard it.
pass
for voice in self.voice_clients:
try:
await voice.disconnect(force=True)
except Exception:
# if an error happens during disconnects, disregard it.
pass

if self.ws is not None and self.ws.open:
await self.ws.close(code=1000)
if self.ws is not None and self.ws.open:
await self.ws.close(code=1000)

self._ready.clear()
self._ready.clear()
self._closed.set()

def clear(self) -> None:
"""Clears the internal state of the bot.
Expand All @@ -752,10 +768,11 @@ async def start(self, token: str, *, reconnect: bool = True) -> None:
TypeError
An unexpected keyword argument was received.
"""
# Update the loop to get the running one in case the one set is MISSING
await self.login(token)
await self.connect(reconnect=reconnect)

def run(self, *args: Any, **kwargs: Any) -> None:
def run(self, token: str, *, reconnect: bool = True) -> None:
"""A blocking call that abstracts away the event loop
initialisation from you.

Expand All @@ -766,60 +783,52 @@ def run(self, *args: Any, **kwargs: Any) -> None:
Roughly Equivalent to: ::

try:
loop.run_until_complete(start(*args, **kwargs))
asyncio.run(start(token))
except KeyboardInterrupt:
loop.run_until_complete(close())
# cancel all tasks lingering
finally:
loop.close()
return

Parameters
----------
token: :class:`str`
The authentication token. Do not prefix this token with
anything as the library will do it for you.
reconnect: :class:`bool`
If we should attempt reconnecting to the gateway, either due to internet
failure or a specific failure on Discord's part. Certain
disconnects that lead to bad state will not be handled (such as
invalid sharding payloads or bad tokens).

.. warning::

This function must be the last function to call due to the fact that it
is blocking. That means that registration of events or anything being
called after this function call will not execute until it returns.
"""
loop = self.loop

try:
loop.add_signal_handler(signal.SIGINT, loop.stop)
loop.add_signal_handler(signal.SIGTERM, loop.stop)
except (NotImplementedError, RuntimeError):
pass

async def runner():
try:
await self.start(*args, **kwargs)
finally:
if not self.is_closed():
await self.close()
async with self:
await self.start(token=token, reconnect=reconnect)

def stop_loop_on_completion(f):
loop.stop()
run = asyncio.run

if self.loop is not MISSING:
run = self.loop.run_until_complete

future = asyncio.ensure_future(runner(), loop=loop)
future.add_done_callback(stop_loop_on_completion)
try:
loop.run_forever()
except KeyboardInterrupt:
_log.info("Received signal to terminate bot and event loop.")
run(runner())
finally:
future.remove_done_callback(stop_loop_on_completion)
_log.info("Cleaning up tasks.")
_cleanup_loop(loop)
# Ensure the bot is closed
if not self.is_closed():
self.loop.run_until_complete(self.close())

if not future.cancelled():
try:
return future.result()
except KeyboardInterrupt:
# I am unsure why this gets raised here but suppress it anyway
return None
_log.info("Cleaning up tasks.")
_cleanup_loop(self.loop)

# properties

def is_closed(self) -> bool:
"""Indicates if the WebSocket connection is closed."""
return self._closed
return self._closed.is_set()

@property
def activity(self) -> ActivityTypes | None:
Expand Down
Loading