Skip to content

Commit

Permalink
fix: use infer_variance=True
Browse files Browse the repository at this point in the history
  • Loading branch information
phi-friday committed Feb 23, 2024
1 parent 673f91c commit be8c7a4
Show file tree
Hide file tree
Showing 7 changed files with 47 additions and 50 deletions.
12 changes: 6 additions & 6 deletions src/async_wrapper/convert/_async.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from __future__ import annotations

from functools import partial, wraps
from typing import Any, Callable, Coroutine, TypeVar
from typing import Any, Callable, Coroutine

from anyio import to_thread
from typing_extensions import ParamSpec
from typing_extensions import ParamSpec, TypeVar

ValueT_co = TypeVar("ValueT_co", covariant=True)
ValueT = TypeVar("ValueT", infer_variance=True)
ParamT = ParamSpec("ParamT")

__all__ = ["sync_to_async"]


def sync_to_async(
func: Callable[ParamT, ValueT_co],
) -> Callable[ParamT, Coroutine[Any, Any, ValueT_co]]:
func: Callable[ParamT, ValueT],
) -> Callable[ParamT, Coroutine[Any, Any, ValueT]]:
"""
Convert a synchronous function to an asynchronous function.
Expand Down Expand Up @@ -55,7 +55,7 @@ def sync_to_async(
"""

@wraps(func)
async def inner(*args: ParamT.args, **kwargs: ParamT.kwargs) -> ValueT_co:
async def inner(*args: ParamT.args, **kwargs: ParamT.kwargs) -> ValueT:
return await to_thread.run_sync(partial(func, *args, **kwargs))

return inner
22 changes: 10 additions & 12 deletions src/async_wrapper/convert/_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
from concurrent.futures import ThreadPoolExecutor, wait
from contextvars import ContextVar
from functools import partial, wraps
from typing import Awaitable, Callable, TypeVar
from typing import Awaitable, Callable

import anyio
from sniffio import AsyncLibraryNotFoundError, current_async_library
from typing_extensions import ParamSpec
from typing_extensions import ParamSpec, TypeVar

ValueT_co = TypeVar("ValueT_co", covariant=True)
ValueT = TypeVar("ValueT", infer_variance=True)
ParamT = ParamSpec("ParamT")

__all__ = ["async_to_sync"]
Expand All @@ -19,8 +19,8 @@


def async_to_sync(
func: Callable[ParamT, Awaitable[ValueT_co]],
) -> Callable[ParamT, ValueT_co]:
func: Callable[ParamT, Awaitable[ValueT]],
) -> Callable[ParamT, ValueT]:
"""
Convert an awaitable function to a synchronous function.
Expand Down Expand Up @@ -93,7 +93,7 @@ def async_to_sync(
sync_func = _as_sync(func)

@wraps(func)
def inner(*args: ParamT.args, **kwargs: ParamT.kwargs) -> ValueT_co:
def inner(*args: ParamT.args, **kwargs: ParamT.kwargs) -> ValueT:
backend = _get_current_backend()
use_uvloop = _check_uvloop()
with ThreadPoolExecutor(
Expand All @@ -106,21 +106,19 @@ def inner(*args: ParamT.args, **kwargs: ParamT.kwargs) -> ValueT_co:
return inner


def _as_sync(
func: Callable[ParamT, Awaitable[ValueT_co]],
) -> Callable[ParamT, ValueT_co]:
def _as_sync(func: Callable[ParamT, Awaitable[ValueT]]) -> Callable[ParamT, ValueT]:
@wraps(func)
def inner(*args: ParamT.args, **kwargs: ParamT.kwargs) -> ValueT_co:
def inner(*args: ParamT.args, **kwargs: ParamT.kwargs) -> ValueT:
return _run(func, *args, **kwargs)

return inner


def _run(
func: Callable[ParamT, Awaitable[ValueT_co]],
func: Callable[ParamT, Awaitable[ValueT]],
*args: ParamT.args,
**kwargs: ParamT.kwargs,
) -> ValueT_co:
) -> ValueT:
backend = _get_current_backend()
new_func = partial(func, *args, **kwargs)
backend_options = {}
Expand Down
6 changes: 3 additions & 3 deletions src/async_wrapper/convert/main.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations

from inspect import iscoroutinefunction
from typing import Any, Callable, Coroutine, TypeVar, overload
from typing import Any, Callable, Coroutine, overload

from typing_extensions import ParamSpec
from typing_extensions import ParamSpec, TypeVar

from ._async import sync_to_async
from ._sync import async_to_sync

ValueT = TypeVar("ValueT")
ValueT = TypeVar("ValueT", infer_variance=True)
ParamT = ParamSpec("ParamT")

__all__ = ["toggle_func", "async_to_sync", "sync_to_async"]
Expand Down
7 changes: 3 additions & 4 deletions src/async_wrapper/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
Generic,
Literal,
NoReturn,
TypeVar,
)

from anyio import WouldBlock, create_memory_object_stream, create_task_group, fail_after
from anyio.streams.memory import BrokenResourceError, ClosedResourceError, EndOfStream
from typing_extensions import override
from typing_extensions import TypeVar, override

from async_wrapper.exception import (
QueueBrokenError,
Expand All @@ -43,8 +42,8 @@

__all__ = ["Queue", "create_queue"]

ValueT = TypeVar("ValueT")
QueueT_co = TypeVar("QueueT_co", covariant=True, bound="Queue")
ValueT = TypeVar("ValueT", infer_variance=True)
QueueT = TypeVar("QueueT", infer_variance=True, bound="Queue")


class Queue(Generic[ValueT]):
Expand Down
28 changes: 13 additions & 15 deletions src/async_wrapper/task_group/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from contextlib import AsyncExitStack
from functools import partial, wraps
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Generic, TypeVar
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Coroutine, Generic

from anyio import create_task_group as _create_task_group
from anyio.abc import TaskGroup as _TaskGroup
from typing_extensions import Concatenate, ParamSpec, Self, override
from typing_extensions import Concatenate, ParamSpec, Self, TypeVar, override

from async_wrapper.task_group.value import SoonValue

Expand All @@ -15,8 +15,8 @@

from anyio.abc import CancelScope, CapacityLimiter, Lock, Semaphore

ValueT_co = TypeVar("ValueT_co", covariant=True)
OtherValueT_co = TypeVar("OtherValueT_co", covariant=True)
ValueT = TypeVar("ValueT", infer_variance=True)
OtherValueT = TypeVar("OtherValueT", infer_variance=True)
ParamT = ParamSpec("ParamT")
OtherParamT = ParamSpec("OtherParamT")

Expand Down Expand Up @@ -102,11 +102,11 @@ async def __aexit__(

def wrap(
self,
func: Callable[ParamT, Awaitable[ValueT_co]],
func: Callable[ParamT, Awaitable[ValueT]],
semaphore: Semaphore | None = None,
limiter: CapacityLimiter | None = None,
lock: Lock | None = None,
) -> SoonWrapper[ParamT, ValueT_co]:
) -> SoonWrapper[ParamT, ValueT]:
"""
Wrap a function to be used within a wrapper.
Expand All @@ -124,14 +124,14 @@ def wrap(
return SoonWrapper(func, self, semaphore=semaphore, limiter=limiter, lock=lock)


class SoonWrapper(Generic[ParamT, ValueT_co]):
class SoonWrapper(Generic[ParamT, ValueT]):
"""wrapped func using in :class:`TaskGroupWrapper`"""

__slots__ = ("func", "task_group", "semaphore", "limiter", "lock", "_wrapped")

def __init__( # noqa: PLR0913
self,
func: Callable[ParamT, Awaitable[ValueT_co]],
func: Callable[ParamT, Awaitable[ValueT]],
task_group: _TaskGroup,
semaphore: Semaphore | None = None,
limiter: CapacityLimiter | None = None,
Expand All @@ -147,26 +147,24 @@ def __init__( # noqa: PLR0913

def __call__(
self, *args: ParamT.args, **kwargs: ParamT.kwargs
) -> SoonValue[ValueT_co]:
value: SoonValue[ValueT_co] = SoonValue()
) -> SoonValue[ValueT]:
value: SoonValue[ValueT] = SoonValue()
wrapped = partial(self.wrapped, value, *args, **kwargs)
self.task_group.start_soon(wrapped)
return value

@property
def wrapped(
self,
) -> Callable[
Concatenate[SoonValue[ValueT_co], ParamT], Coroutine[Any, Any, ValueT_co]
]:
) -> Callable[Concatenate[SoonValue[ValueT], ParamT], Coroutine[Any, Any, ValueT]]:
"""wrapped func using semaphore"""
if self._wrapped is not None:
return self._wrapped

@wraps(self.func)
async def wrapped(
value: SoonValue[ValueT_co], *args: ParamT.args, **kwargs: ParamT.kwargs
) -> ValueT_co:
value: SoonValue[ValueT], *args: ParamT.args, **kwargs: ParamT.kwargs
) -> ValueT:
async with AsyncExitStack() as stack:
if self.semaphore is not None:
await stack.enter_async_context(self.semaphore)
Expand Down
12 changes: 7 additions & 5 deletions src/async_wrapper/task_group/value.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,32 @@
from __future__ import annotations

from threading import local
from typing import Generic, TypeVar
from typing import Generic

from typing_extensions import TypeVar

from async_wrapper.exception import PendingError

ValueT_co = TypeVar("ValueT_co", covariant=True)
ValueT = TypeVar("ValueT", infer_variance=True)
Pending = local()

__all__ = ["SoonValue"]


class SoonValue(Generic[ValueT_co]):
class SoonValue(Generic[ValueT]):
"""A class representing a value that will be available soon."""

__slots__ = ("_value",)

def __init__(self) -> None:
self._value: ValueT_co | local = Pending
self._value: ValueT | local = Pending

def __repr__(self) -> str:
status = "pending" if self._value is Pending else "done"
return f"<SoonValue: status={status}>"

@property
def value(self) -> ValueT_co:
def value(self) -> ValueT:
"""
Gets the soon-to-be available value.
Expand Down
10 changes: 5 additions & 5 deletions src/async_wrapper/wait.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from functools import partial
from threading import local
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable, TypeVar
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Iterable

from anyio import EndOfStream, Event, create_memory_object_stream, create_task_group
from typing_extensions import ParamSpec, Self, override
from typing_extensions import ParamSpec, Self, TypeVar, override

from async_wrapper.exception import PendingError

Expand All @@ -19,7 +19,7 @@

__all__ = ["Waiter", "Completed", "wait_for"]

ValueT_co = TypeVar("ValueT_co", covariant=True)
ValueT = TypeVar("ValueT", infer_variance=True)
ParamT = ParamSpec("ParamT")
Pending = local()

Expand Down Expand Up @@ -299,10 +299,10 @@ async def __anext__(self) -> Any:

async def wait_for(
event: Event | Iterable[Event],
func: Callable[ParamT, Awaitable[ValueT_co]],
func: Callable[ParamT, Awaitable[ValueT]],
*args: ParamT.args,
**kwargs: ParamT.kwargs,
) -> ValueT_co:
) -> ValueT:
"""
Wait for an event before executing an awaitable function.
Expand Down

0 comments on commit be8c7a4

Please sign in to comment.