Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add public method to get session start limits #10007

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions discord/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
subscription,
)
from .types.snowflake import Snowflake, SnowflakeList
from .types.gateway import SessionStartLimit

from types import TracebackType

Expand Down Expand Up @@ -2753,13 +2754,13 @@ def get_sku_subscription(self, sku_id: Snowflake, subscription_id: Snowflake) ->

# Misc

async def get_bot_gateway(self) -> Tuple[int, str]:
async def get_bot_gateway(self) -> Tuple[int, str, SessionStartLimit]:
try:
data = await self.request(Route('GET', '/gateway/bot'))
except HTTPException as exc:
raise GatewayNotFound() from exc

return data['shards'], data['url']
return data['shards'], data['url'], data['session_start_limit']

def get_user(self, user_id: Snowflake) -> Response[user.User]:
return self.request(Route('GET', '/users/{user_id}', user_id=user_id))
58 changes: 57 additions & 1 deletion discord/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,16 @@
from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict

if TYPE_CHECKING:
from typing_extensions import Unpack
from .gateway import DiscordWebSocket
from .activity import BaseActivity
from .flags import Intents
from .types.gateway import SessionStartLimit

__all__ = (
'AutoShardedClient',
'ShardInfo',
'SessionStartLimits',
)

_log = logging.getLogger(__name__)
Expand Down Expand Up @@ -293,6 +296,32 @@ def is_ws_ratelimited(self) -> bool:
return self._parent.ws.is_ratelimited()


class SessionStartLimits:
"""A class that holds info about session start limits

Attributes
----------
total: :class:`int`
The total number of session starts the current user is allowed
remaining: :class:`int`
Remaining remaining number of session starts the current user is allowed
reset_after: :class:`int`
The number of milliseconds until the limit resets
max_concurrency: :class:`int`
The number of identify requests allowed per 5 seconds

.. versionadded:: 2.5
"""

__slots__ = ("total", "remaining", "reset_after", "max_concurrency")

def __init__(self, **kwargs: Unpack[SessionStartLimit]):
self.total: int = kwargs['total']
self.remaining: int = kwargs['remaining']
self.reset_after: int = kwargs['reset_after']
self.max_concurrency: int = kwargs['max_concurrency']


class AutoShardedClient(Client):
"""A client similar to :class:`Client` except it handles the complications
of sharding for the user into a more manageable and transparent single
Expand Down Expand Up @@ -415,6 +444,33 @@ def shards(self) -> Dict[int, ShardInfo]:
"""Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object."""
return {shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items()}

async def fetch_session_start_limits(self) -> SessionStartLimits:
"""|coro|

Get the session start limits.

This is not typically needed, and will be handled for you by default.

At the point where you are launching multiple instances
with manual shard ranges and are considered required to use large bot
sharding by Discord, this function when used along IPC and a
before_identity_hook can speed up session start.

.. versionadded:: 2.5

Returns
-------
:class:`SessionStartLimits`
A class containing the session start limits

Raises
------
GatewayNotFound
The gateway was unreachable
"""
_, _, limits = await self.http.get_bot_gateway()
return SessionStartLimits(**limits)

async def launch_shard(self, gateway: yarl.URL, shard_id: int, *, initial: bool = False) -> None:
try:
coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id)
Expand All @@ -434,7 +490,7 @@ async def launch_shards(self) -> None:

if self.shard_count is None:
self.shard_count: int
self.shard_count, gateway_url = await self.http.get_bot_gateway()
self.shard_count, gateway_url, _session_start_limit = await self.http.get_bot_gateway()
gateway = yarl.URL(gateway_url)
else:
gateway = DiscordWebSocket.DEFAULT_GATEWAY
Expand Down
10 changes: 9 additions & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1381,7 +1381,7 @@ Subscriptions
.. function:: on_subscription_delete(subscription)

Called when a subscription is deleted.

.. versionadded:: 2.5

:param subscription: The subscription that was deleted.
Expand Down Expand Up @@ -5209,6 +5209,14 @@ ShardInfo
.. autoclass:: ShardInfo()
:members:

SessionStartLimits
~~~~~~~~~~~~~~~~~~~~

.. attributetable:: SessionStartLimits

.. autoclass:: SessionStartLimits()
:members:

SKU
~~~~~~~~~~~

Expand Down
Loading