Skip to content

Commit

Permalink
fix: EntitlementIterator behavior and type-hinting (#2555)
Browse files Browse the repository at this point in the history
* fix: EntitlementIterator behaviour and type-hinting

* style(pre-commit): auto fixes from pre-commit.com hooks

* simplify if's

* add changelog entry

* style(pre-commit): auto fixes from pre-commit.com hooks

* revert missclick

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Lala Sabathil <[email protected]>
  • Loading branch information
3 people authored Aug 28, 2024
1 parent 556be08 commit 2136691
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 18 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ These changes are available on the `master` branch, but have not yet been releas
- Added `Guild.fetch_role` method.
([#2528](https://github.com/Pycord-Development/pycord/pull/2528))

### Fixed

- Fixed `EntitlementIterator` behavior with `limit > 100`.
([#2555](https://github.com/Pycord-Development/pycord/pull/2555))

## [2.6.0] - 2024-07-09

### Added
Expand Down
74 changes: 56 additions & 18 deletions discord/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
from .types.audit_log import AuditLog as AuditLogPayload
from .types.guild import Guild as GuildPayload
from .types.message import Message as MessagePayload
from .types.monetization import Entitlement as EntitlementPayload
from .types.threads import Thread as ThreadPayload
from .types.user import PartialUser as PartialUserPayload
from .user import User
Expand Down Expand Up @@ -988,11 +989,21 @@ def __init__(
self.guild_id = guild_id
self.exclude_ended = exclude_ended

self._filter = None

if self.before and self.after:
self._retrieve_entitlements = self._retrieve_entitlements_before_strategy
self._filter = lambda e: int(e["id"]) > self.after.id
elif self.after:
self._retrieve_entitlements = self._retrieve_entitlements_after_strategy
else:
self._retrieve_entitlements = self._retrieve_entitlements_before_strategy

self.state = state
self.get_entitlements = state.http.list_entitlements
self.entitlements = asyncio.Queue()

async def next(self) -> BanEntry:
async def next(self) -> Entitlement:
if self.entitlements.empty():
await self.fill_entitlements()

Expand All @@ -1014,30 +1025,57 @@ async def fill_entitlements(self):
if not self._get_retrieve():
return

data = await self._retrieve_entitlements(self.retrieve)

if self._filter:
data = list(filter(self._filter, data))

if len(data) < 100:
self.limit = 0 # terminate loop

for element in data:
await self.entitlements.put(Entitlement(data=element, state=self.state))

async def _retrieve_entitlements(self, retrieve) -> list[Entitlement]:
"""Retrieve entitlements and update next parameters."""
raise NotImplementedError

async def _retrieve_entitlements_before_strategy(
self, retrieve: int
) -> list[EntitlementPayload]:
"""Retrieve entitlements using before parameter."""
before = self.before.id if self.before else None
after = self.after.id if self.after else None
data = await self.get_entitlements(
self.state.application_id,
before=before,
after=after,
limit=self.retrieve,
limit=retrieve,
user_id=self.user_id,
guild_id=self.guild_id,
sku_ids=self.sku_ids,
exclude_ended=self.exclude_ended,
)
if data:
if self.limit is not None:
self.limit -= retrieve
self.before = Object(id=int(data[-1]["id"]))
return data

if not data:
# no data, terminate
return

if self.limit:
self.limit -= self.retrieve

if len(data) < 100:
self.limit = 0 # terminate loop

self.after = Object(id=int(data[-1]["id"]))

for element in reversed(data):
await self.entitlements.put(Entitlement(data=element, state=self.state))
async def _retrieve_entitlements_after_strategy(
self, retrieve: int
) -> list[EntitlementPayload]:
"""Retrieve entitlements using after parameter."""
after = self.after.id if self.after else None
data = await self.get_entitlements(
self.state.application_id,
after=after,
limit=retrieve,
user_id=self.user_id,
guild_id=self.guild_id,
sku_ids=self.sku_ids,
exclude_ended=self.exclude_ended,
)
if data:
if self.limit is not None:
self.limit -= retrieve
self.after = Object(id=int(data[-1]["id"]))
return data

0 comments on commit 2136691

Please sign in to comment.