Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
mikeshardmind committed Apr 5, 2024
1 parent 0872f6c commit 7693571
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 57 deletions.
16 changes: 8 additions & 8 deletions async_utils/_cpython_stuff.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@

from __future__ import annotations

from collections.abc import Callable, Sized
from collections.abc import Callable, Hashable, Sized
from typing import Any


class _HashedSeq(list[Any]):
""" This class guarantees that hash() will be called no more than once
per element. This is important because the lru_cache() will hash
the key multiple times on a cache miss.
"""This class guarantees that hash() will be called no more than once
per element. This is important because the lru_cache() will hash
the key multiple times on a cache miss.
"""

__slots__ = ('hashvalue',)
__slots__ = ("hashvalue",)

def __init__(self, tup: tuple[Any, ...], hash: Callable[[object], int]=hash): # noqa: A002
def __init__(self, tup: tuple[Any, ...], hash: Callable[[object], int] = hash): # noqa: A002
self[:] = tup
self.hashvalue: int = hash(tup)

Expand All @@ -40,8 +40,8 @@ def make_key(
kwd_mark: tuple[object] = (object(),),
fasttypes: set[type] = {int, str}, # noqa: B006
type: type[type] = type, # noqa: A002
len: Callable[[Sized], int] = len # noqa: A002
) -> _HashedSeq:
len: Callable[[Sized], int] = len, # noqa: A002
) -> Hashable:
"""Make a cache key from optionally typed positional and keyword arguments
The key is constructed in a way that is flat as possible rather than
as a nested structure that would take more memory.
Expand Down
3 changes: 3 additions & 0 deletions async_utils/keyed_locks.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

KT = TypeVar("KT", bound=Hashable)


class KeyedLocks(Generic[KT]):
"""Locks per hashable resource type
Currently implemented with a weakvalue dictionary + asyncio.Locks
Expand All @@ -32,7 +33,9 @@ class KeyedLocks(Generic[KT]):
some of the functionality of asyncio locks. May revisit later, intent here
is that if I do, everything I use like this improves at once.
"""

def __init__(self) -> None:
self._locks: WeakValueDictionary[KT, asyncio.Lock] = WeakValueDictionary()

def __getitem__(self, item: KT) -> asyncio.Lock:
return self._locks.get(item, self._locks.setdefault(item, asyncio.Lock()))
8 changes: 2 additions & 6 deletions async_utils/ratelimiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,12 @@ def __init__(self, rate_limit: int, period: float, granularity: float):
async def __aenter__(self):
# The ordering of these conditions matters to avoid an async context switch between
# confirming the ratelimit isn't exhausted and allowing the user code to continue.
while (len(self._monotonics) >= self.rate_limit) and await asyncio.sleep(
self.granularity, True
):
while (len(self._monotonics) >= self.rate_limit) and await asyncio.sleep(self.granularity, True):
now = time.monotonic()
while self._monotonics and (now - self._monotonics[0] > self.period):
self._monotonics.popleft()

self._monotonics.append(time.monotonic())

async def __aexit__(
self, exc_type: type[Exception], exc: Exception, tb: TracebackType
):
async def __aexit__(self, exc_type: type[Exception], exc: Exception, tb: TracebackType):
pass
6 changes: 1 addition & 5 deletions async_utils/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ class CancelationToken:

@total_ordering
class _Task(Generic[T]):

__slots__ = ("timestamp", "payload", "canceled", "cancel_token")

def __init__(self, timestamp: float, payload: T, /):
Expand All @@ -46,7 +45,6 @@ def __lt__(self, other: _Task[T]):


class Scheduler(Generic[T]):

__tasks: dict[CancelationToken, _Task[T]]
__tqueue: asyncio.PriorityQueue[_Task[T]]
__closed: bool
Expand All @@ -69,9 +67,7 @@ async def __aenter__(self):

return self

async def __aexit__(
self, exc_type: type[Exception], exc: Exception, tb: TracebackType
):
async def __aexit__(self, exc_type: type[Exception], exc: Exception, tb: TracebackType):
self.__closed = True

def __aiter__(self):
Expand Down
15 changes: 4 additions & 11 deletions async_utils/task_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

import asyncio
from collections.abc import Callable, Coroutine
from collections.abc import Callable, Coroutine, Hashable
from functools import partial
from typing import Any, ParamSpec, TypeVar

Expand All @@ -42,11 +42,8 @@ def taskcache(
Consider not wrapping instance methods, but what those methods call when feasible in cases where this may matter.
"""

def wrapper(
coro: Callable[P, Coroutine[Any, Any, T]]
) -> Callable[P, asyncio.Task[T]]:

internal_cache: dict[Any, asyncio.Task[T]] = {}
def wrapper(coro: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, asyncio.Task[T]]:
internal_cache: dict[Hashable, asyncio.Task[T]] = {}

def wrapped(*args: P.args, **kwargs: P.kwargs) -> asyncio.Task[T]:
key = make_key(args, kwargs)
Expand All @@ -63,11 +60,7 @@ def wrapped(*args: P.args, **kwargs: P.kwargs) -> asyncio.Task[T]:
internal_cache.pop,
key,
)
task.add_done_callback(call_after_ttl) # pyright: ignore[reportArgumentType]
# call_after_ttl is incorrectly determined to be a function taking a single argument
# with the same type as the value type of internal_case
# dict.pop *has* overloads for this, but the lack of bidirectional inference
# with functools.partial use in pyright breaks this.
task.add_done_callback(call_after_ttl)
return task

return wrapped
Expand Down
40 changes: 13 additions & 27 deletions async_utils/waterfall.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,7 @@
import asyncio
import time
from collections.abc import Callable, Coroutine, Sequence
from typing import (
Any,
Generic,
Literal,
TypeVar,
overload
)
from typing import Any, Generic, Literal, TypeVar, overload

T = TypeVar("T")

Expand All @@ -45,15 +39,14 @@ def __init__(
self.max_wait: float = max_wait
self.max_wait_finalize: int = max_wait_finalize
self.max_quantity: int = max_quantity
self.callback: Callable[
[Sequence[T]], Coroutine[Any, Any, Any]
] = async_callback
self.callback: Callable[[Sequence[T]], Coroutine[Any, Any, Any]] = async_callback
self.task: asyncio.Task[None] | None = None
self._alive: bool = False

def start(self):
if self.task is not None:
raise RuntimeError("Already Running")
msg = "Already Running"
raise RuntimeError(msg)

self._alive = True
self.task = asyncio.create_task(self._loop())
Expand All @@ -63,33 +56,30 @@ def stop(self, wait: Literal[True]) -> Coroutine[Any, Any, None]:
...

@overload
def stop(self, wait: Literal[False]):
def stop(self, wait: Literal[False]) -> None:
...

@overload
def stop(self, wait: bool = False) -> Coroutine[Any, Any, None] | None:
...

def stop(self, wait: bool = False):
def stop(self, wait: bool = False) -> Coroutine[Any, Any, None] | None:
self._alive = False
if wait:
return self.queue.join()
return self.queue.join() if wait else None

def put(self, item: T):
if not self._alive:
raise RuntimeError("Can't put something in a non-running Waterfall.")
msg = "Can't put something in a non-running Waterfall."
raise RuntimeError(msg)
self.queue.put_nowait(item)

async def _loop(self):
async def _loop(self) -> None:
try:

while self._alive:
queue_items: Sequence[T] = []
iter_start = time.monotonic()

while (
this_max_wait := (time.monotonic() - iter_start)
) < self.max_wait:
while (this_max_wait := (time.monotonic() - iter_start)) < self.max_wait:
try:
n = await asyncio.wait_for(self.queue.get(), this_max_wait)
except asyncio.TimeoutError:
Expand All @@ -113,8 +103,7 @@ async def _loop(self):
f = asyncio.create_task(self._finalize(), name="waterfall.finalizer")
await asyncio.wait_for(f, timeout=self.max_wait_finalize)

async def _finalize(self):

async def _finalize(self) -> None:
# WARNING: Do not allow an async context switch before the gather below

self._alive = False
Expand All @@ -136,10 +125,7 @@ async def _finalize(self):

pending_futures: list[asyncio.Task[Any]] = []

for chunk in (
remaining_items[p : p + self.max_quantity]
for p in range(0, num_remaining, self.max_quantity)
):
for chunk in (remaining_items[p : p + self.max_quantity] for p in range(0, num_remaining, self.max_quantity)):
fut = asyncio.create_task(self.callback(chunk))
pending_futures.append(fut)

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,5 @@ extend-ignore = [
"E501", # reccomended by ruff when using ruff format
"ISC001", # reccomended by ruff when using ruff format
"Q003", # reccomended by ruff when using ruff format
"RUF006", # Don't actually need to store a task in all cases, and I'm aware of which cases.
]

0 comments on commit 7693571

Please sign in to comment.