Skip to content

Commit

Permalink
fix: ruff, pyright errors
Browse files Browse the repository at this point in the history
  • Loading branch information
phi-friday committed Jul 30, 2024
1 parent cd6af21 commit dbc56a5
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 20 deletions.
19 changes: 13 additions & 6 deletions src/timeout_executor/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from itertools import chain
from pathlib import Path
from types import FunctionType
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Generic, Iterable, overload
from typing import TYPE_CHECKING, Any, Callable, Generic, overload
from uuid import UUID, uuid4

import anyio
Expand All @@ -25,6 +25,8 @@
from timeout_executor.types import Callback, CallbackArgs, ExecutorArgs, ProcessCallback

if TYPE_CHECKING:
from collections.abc import Coroutine, Iterable

from timeout_executor.main import TimeoutExecutor

__all__ = ["apply_func", "delay_func"]
Expand Down Expand Up @@ -75,7 +77,7 @@ def _dump_args(
) -> bytes:
input_args = (self._func, args, kwargs, output_file)
logger.debug("%r before dump input args", self)
input_args_as_bytes = cloudpickle.dumps(input_args) # pyright: ignore[reportUnknownMemberType]
input_args_as_bytes = cloudpickle.dumps(input_args)
logger.debug(
"%r after dump input args :: size: %d", self, len(input_args_as_bytes)
)
Expand All @@ -86,8 +88,8 @@ def _create_process(
) -> subprocess.Popen[str]:
command = self._command(stacklevel=stacklevel + 1)
logger.debug("%r before create new process", self, stacklevel=stacklevel)
process = subprocess.Popen(
command, # noqa: S603
process = subprocess.Popen( # noqa: S603
command,
env={TIMEOUT_EXECUTOR_INPUT_FILE: input_file.as_posix()},
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
Expand Down Expand Up @@ -152,6 +154,7 @@ async def delay(self, *args: P.args, **kwargs: P.kwargs) -> AsyncResult[P, T]:

return self._init_process(input_file, output_file)

@override
def __repr__(self) -> str:
return f"<{type(self).__name__}: {self._func_name}>"

Expand Down Expand Up @@ -200,8 +203,10 @@ def apply_func(
"""run function with deadline
Args:
timeout: deadline
timeout_or_executor: deadline
func: func(sync or async)
*args: func args
**kwargs: func kwargs
Returns:
async result container
Expand Down Expand Up @@ -242,8 +247,10 @@ async def delay_func(
"""run function with deadline
Args:
timeout: deadline
timeout_or_executor: deadline
func: func(sync or async)
*args: func args
**kwargs: func kwargs
Returns:
async result container
Expand Down
11 changes: 10 additions & 1 deletion src/timeout_executor/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@

from collections import deque
from contextlib import suppress
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Generic, Iterable, overload
from typing import TYPE_CHECKING, Any, Callable, Generic, overload

from typing_extensions import ParamSpec, Self, TypeVar, override

from timeout_executor.executor import apply_func, delay_func
from timeout_executor.types import Callback, ProcessCallback

if TYPE_CHECKING:
from collections.abc import Coroutine, Iterable

from timeout_executor.result import AsyncResult

__all__ = ["TimeoutExecutor"]
Expand Down Expand Up @@ -49,6 +51,8 @@ def apply(
Args:
func: func(sync or async)
*args: func args
**kwargs: func kwargs
Returns:
async result container
Expand All @@ -73,6 +77,8 @@ async def delay(
Args:
func: func(sync or async)
*args: func args
**kwargs: func kwargs
Returns:
async result container
Expand All @@ -99,12 +105,15 @@ async def apply_async(
Args:
func: func(sync or async)
*args: func args
**kwargs: func kwargs
Returns:
async result container
"""
return await self.delay(func, *args, **kwargs)

@override
def __repr__(self) -> str:
return f"<{type(self).__name__}, timeout: {self.timeout:.2f}s>"

Expand Down
4 changes: 3 additions & 1 deletion src/timeout_executor/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import subprocess
from functools import partial
from typing import TYPE_CHECKING, Any, Generic, Iterable
from typing import TYPE_CHECKING, Any, Generic

import anyio
import cloudpickle
Expand All @@ -14,6 +14,7 @@
from timeout_executor.types import Callback, ProcessCallback

if TYPE_CHECKING:
from collections.abc import Iterable
from pathlib import Path

from timeout_executor.terminate import Terminator
Expand Down Expand Up @@ -125,6 +126,7 @@ async def _load_output(self) -> T:
await self._output.parent.rmdir()
return await self._load_output()

@override
def __repr__(self) -> str:
return f"<{type(self).__name__}: {self._func_name}>"

Expand Down
14 changes: 7 additions & 7 deletions src/timeout_executor/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@

import cloudpickle
from tblib.pickling_support import (
pickle_exception, # pyright: ignore[reportUnknownVariableType]
pickle_traceback, # pyright: ignore[reportUnknownVariableType]
unpickle_exception, # pyright: ignore[reportUnknownVariableType]
unpickle_traceback, # pyright: ignore[reportUnknownVariableType]
pickle_exception,
pickle_traceback,
unpickle_exception,
unpickle_traceback,
)

__all__ = ["dumps_error", "loads_error", "serialize_error", "deserialize_error"]
Expand All @@ -34,7 +34,7 @@ class SerializedError:


def serialize_traceback(traceback: TracebackType) -> tuple[Any, ...]:
return pickle_traceback(traceback) # pyright: ignore[reportUnknownVariableType]
return pickle_traceback(traceback)


def serialize_error(error: BaseException) -> SerializedError:
Expand Down Expand Up @@ -84,15 +84,15 @@ def deserialize_error(error: SerializedError) -> BaseException:
traceback = unpickle_traceback(*value)
result.insert(index + salt, traceback)

return unpickle_exception(*arg_exception, *exception) # pyright: ignore[reportUnknownVariableType]
return unpickle_exception(*arg_exception, *exception)


def dumps_error(error: BaseException | SerializedError) -> bytes:
"""serialize exception as bytes"""
if not isinstance(error, SerializedError):
error = serialize_error(error)

return cloudpickle.dumps(error) # pyright: ignore[reportUnknownMemberType]
return cloudpickle.dumps(error)


def loads_error(error: bytes | SerializedError) -> BaseException:
Expand Down
7 changes: 5 additions & 2 deletions src/timeout_executor/subprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from inspect import iscoroutinefunction
from os import environ
from pathlib import Path
from typing import Any, Callable, Coroutine
from typing import TYPE_CHECKING, Any, Callable

import anyio
import cloudpickle
Expand All @@ -14,6 +14,9 @@
from timeout_executor.const import TIMEOUT_EXECUTOR_INPUT_FILE
from timeout_executor.serde import dumps_error

if TYPE_CHECKING:
from collections.abc import Coroutine

__all__ = []

P = ParamSpec("P")
Expand All @@ -37,7 +40,7 @@ def run_in_subprocess() -> None:
def dumps_value(value: Any) -> bytes:
if isinstance(value, BaseException):
return dumps_error(value)
return cloudpickle.dumps(value) # pyright: ignore[reportUnknownMemberType]
return cloudpickle.dumps(value)


def output_to_file(
Expand Down
6 changes: 5 additions & 1 deletion src/timeout_executor/terminate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
from collections import deque
from contextlib import suppress
from itertools import chain
from typing import Any, Callable, Generic, Iterable
from typing import TYPE_CHECKING, Any, Callable, Generic

from psutil import pid_exists
from typing_extensions import ParamSpec, Self, TypeVar, override

from timeout_executor.logging import logger
from timeout_executor.types import Callback, CallbackArgs, ExecutorArgs, ProcessCallback

if TYPE_CHECKING:
from collections.abc import Iterable

__all__ = []

P = ParamSpec("P")
Expand Down Expand Up @@ -142,6 +145,7 @@ def close(self, name: str | None = None) -> None:
if text:
sys.stderr.write(text)

@override
def __repr__(self) -> str:
return f"<{type(self).__name__}: {self.func_name}>"

Expand Down
5 changes: 3 additions & 2 deletions src/timeout_executor/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@
from abc import ABC, abstractmethod
from collections import deque
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Generic, Iterable
from typing import TYPE_CHECKING, Any, Callable, Generic

from typing_extensions import ParamSpec, TypeVar

from timeout_executor.logging import logger

if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup # type: ignore
from exceptiongroup import ExceptionGroup

if TYPE_CHECKING:
import subprocess
from collections.abc import Iterable
from pathlib import Path

import anyio
Expand Down

0 comments on commit dbc56a5

Please sign in to comment.