Skip to content

Commit

Permalink
Add public method to get session start limits
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeshardmind committed Nov 10, 2024
1 parent af75985 commit a503acc
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 4 deletions.
5 changes: 3 additions & 2 deletions discord/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
subscription,
)
from .types.snowflake import Snowflake, SnowflakeList
from .types.gateway import SessionStartLimit

from types import TracebackType

Expand Down Expand Up @@ -2753,13 +2754,13 @@ def get_sku_subscription(self, sku_id: Snowflake, subscription_id: Snowflake) ->

# Misc

async def get_bot_gateway(self) -> Tuple[int, str]:
async def get_bot_gateway(self) -> Tuple[int, str, SessionStartLimit]:
try:
data = await self.request(Route('GET', '/gateway/bot'))
except HTTPException as exc:
raise GatewayNotFound() from exc

return data['shards'], data['url']
return data['shards'], data['url'], data['session_start_limit']

def get_user(self, user_id: Snowflake) -> Response[user.User]:
return self.request(Route('GET', '/users/{user_id}', user_id=user_id))
58 changes: 57 additions & 1 deletion discord/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,16 @@
from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict

if TYPE_CHECKING:
from typing_extensions import Unpack
from .gateway import DiscordWebSocket
from .activity import BaseActivity
from .flags import Intents
from .types.gateway import SessionStartLimit

__all__ = (
'AutoShardedClient',
'ShardInfo',
'SessionStartLimits',
)

_log = logging.getLogger(__name__)
Expand Down Expand Up @@ -293,6 +296,32 @@ def is_ws_ratelimited(self) -> bool:
return self._parent.ws.is_ratelimited()


class SessionStartLimits:
"""A class that holds info about session start limits
Attributes
----------
total: :class:`int`
The total number of session starts the current user is allowed
remaining: :class:`int`
Remaining remaining number of session starts the current user is allowed
reset_after: :class:`int`
The number of milliseconds until the limit resets
max_concurrency: :class:`int`
The number of identify requests allowed per 5 seconds
.. versionadded:: 2.5
"""

__slots__ = ("total", "remaining", "reset_after", "max_concurrency")

def __init__(self, **kwargs: Unpack[SessionStartLimit]):
self.total: int = kwargs['total']
self.remaining: int = kwargs['remaining']
self.reset_after: int = kwargs['reset_after']
self.max_concurrency: int = kwargs['max_concurrency']


class AutoShardedClient(Client):
"""A client similar to :class:`Client` except it handles the complications
of sharding for the user into a more manageable and transparent single
Expand Down Expand Up @@ -415,6 +444,33 @@ def shards(self) -> Dict[int, ShardInfo]:
"""Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object."""
return {shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items()}

async def fetch_session_start_limits(self) -> SessionStartLimits:
"""|coro|
Get the session start limits.
This is not typically needed, and will be handled for you by default.
At the point where you are launching multiple instances
with manual shard ranges and are considered required to use large bot
sharding by Discord, this function when used along IPC and a
before_identity_hook can speed up session start.
.. versionadded:: 2.5
Returns
-------
:class:`SessionStartLimits`
A class containing the session start limits
Raises
------
GatewayNotFound
The gateway was unreachable
"""
_, _, limits = await self.http.get_bot_gateway()
return SessionStartLimits(**limits)

async def launch_shard(self, gateway: yarl.URL, shard_id: int, *, initial: bool = False) -> None:
try:
coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id)
Expand All @@ -434,7 +490,7 @@ async def launch_shards(self) -> None:

if self.shard_count is None:
self.shard_count: int
self.shard_count, gateway_url = await self.http.get_bot_gateway()
self.shard_count, gateway_url, _session_start_limit = await self.http.get_bot_gateway()
gateway = yarl.URL(gateway_url)
else:
gateway = DiscordWebSocket.DEFAULT_GATEWAY
Expand Down
10 changes: 9 additions & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1381,7 +1381,7 @@ Subscriptions
.. function:: on_subscription_delete(subscription)

Called when a subscription is deleted.

.. versionadded:: 2.5

:param subscription: The subscription that was deleted.
Expand Down Expand Up @@ -5209,6 +5209,14 @@ ShardInfo
.. autoclass:: ShardInfo()
:members:

SessionStartLimits
~~~~~~~~~~~~~~~~~~~~

.. attributetable:: SessionStartLimits

.. autoclass:: SessionStartLimits()
:members:

SKU
~~~~~~~~~~~

Expand Down

0 comments on commit a503acc

Please sign in to comment.