Skip to content

Commit

Permalink
GH-338 allow garbage collection of non-daemon processes
Browse files Browse the repository at this point in the history
  • Loading branch information
spyoungtech committed Aug 17, 2024
1 parent 46bce5c commit 62e5940
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 45 deletions.
77 changes: 53 additions & 24 deletions ahk/_async/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@
else:
from typing import TypeAlias, TypeGuard

if sys.version_info < (3, 11):
from typing_extensions import Self
else:
from typing import Self

T_AsyncFuture = TypeVar('T_AsyncFuture') # unasync: remove
T_SyncFuture = TypeVar('T_SyncFuture')
Expand Down Expand Up @@ -110,6 +114,9 @@ def async_assert_send_nonblocking_type_correct(
class Communicable(Protocol):
runargs: List[str]

async def start(self, atexit_cleanup: bool = True) -> None: ...
def astart(self, *args: Any, **kwargs: Any) -> None: ... # unasync: remove

def communicate(self, input_bytes: Optional[bytes], timeout: Optional[int] = None) -> Tuple[bytes, bytes]: ...

async def acommunicate( # unasync: remove
Expand All @@ -119,6 +126,8 @@ async def acommunicate( # unasync: remove
@property
def returncode(self) -> Optional[int]: ...

def kill(self) -> None: ...


class AsyncAHKProcess:
def __init__(self, runargs: List[str]):
Expand All @@ -130,9 +139,12 @@ def returncode(self) -> Optional[int]:
assert self._proc is not None
return self._proc.returncode

async def start(self) -> None:
def astart(self, *args: Any, **kwargs: Any) -> None: ... # unasync: remove

async def start(self, atexit_cleanup: bool = True) -> None:
self._proc = await async_create_process(self.runargs)
atexit.register(kill, self._proc)
if atexit_cleanup:
atexit.register(kill, self._proc)
return None

async def adrain_stdin(self) -> None: # unasync: remove
Expand Down Expand Up @@ -183,6 +195,17 @@ def communicate(self, input_bytes: Optional[bytes] = None, timeout: Optional[int
assert isinstance(self._proc, subprocess.Popen)
return self._proc.communicate(input=input_bytes, timeout=timeout)

async def __aenter__(self) -> Self:
await self.start(atexit_cleanup=False)
return self

async def __aexit__(self, exc_type: Any, exc: Any, tb: Any) -> Literal[False]:
try:
self.kill()
except Exception:
pass
return False


async def async_create_process(runargs: List[str]) -> asyncio.subprocess.Process: # unasync: remove
return await asyncio.subprocess.create_subprocess_exec(
Expand Down Expand Up @@ -635,7 +658,8 @@ async def start(self) -> None:
assert self._proc is None, 'cannot start a process twice'
with warnings.catch_warnings(record=True) as caught_warnings:
async with self.lock:
self._proc = await self._create_process()
self._proc = self._create_process()
await self._proc.start()
if caught_warnings:
for warning in caught_warnings:
warnings.warn(warning.message, warning.category, stacklevel=2)
Expand All @@ -659,9 +683,7 @@ def lock(self) -> Any:
return self._a_execution_lock # unasync: remove
return self._execution_lock

async def _create_process(
self, template: Optional[jinja2.Template] = None, **template_kwargs: Any
) -> AsyncAHKProcess:
def _create_process(self, template: Optional[jinja2.Template] = None, **template_kwargs: Any) -> AsyncAHKProcess:
if template is None:
if template_kwargs:
raise ValueError('template kwargs were specified, but no template was provided')
Expand All @@ -684,15 +706,13 @@ async def _create_process(
atexit.register(try_remove, tempscript.name)
runargs = [self._executable_path, '/CP65001', '/ErrorStdOut', daemon_script]
proc = AsyncAHKProcess(runargs=runargs)
await proc.start()
return proc

async def _send_nonblocking(
self, request: RequestMessage, engine: Optional[AsyncAHK[Any]] = None
) -> Union[None, Tuple[int, int], int, str, bool, AsyncWindow, List[AsyncWindow], List[AsyncControl]]:
msg = request.format()
proc = await self._create_process()
try:
async with self._create_process() as proc:
proc.write(msg)
await proc.adrain_stdin()
tom = await proc.readline()
Expand All @@ -715,11 +735,6 @@ async def _send_nonblocking(
part = await proc.readline()
content_buffer.write(part)
content = content_buffer.getvalue()[:-1]
finally:
try:
proc.kill()
except: # noqa
pass
response = ResponseMessage.from_bytes(content, engine=engine)
return response.unpack() # type: ignore

Expand Down Expand Up @@ -781,11 +796,17 @@ async def _async_run_nonblocking( # unasync: remove
loop = asyncio.get_running_loop()

async def f() -> str:
stdout, stderr = await proc.acommunicate(script_bytes, timeout)
try:
await proc.start(atexit_cleanup=False)
stdout, stderr = await proc.acommunicate(script_bytes, timeout)
finally:
try:
proc.kill()
except Exception:
pass
if proc.returncode != 0:
assert proc.returncode is not None
raise subprocess.CalledProcessError(proc.returncode, proc.runargs, stdout, stderr)

return stdout.decode('utf-8')

task = loop.create_task(f())
Expand All @@ -797,15 +818,23 @@ def _sync_run_nonblocking(
script_bytes: Optional[bytes],
timeout: Optional[int] = None,
) -> FutureResult[str]:
pool = ThreadPoolExecutor(max_workers=1)
raise RuntimeError('This method can only be called from the sync API') # unasync: remove

def f() -> str:
stdout, stderr = proc.communicate(script_bytes, timeout)
try:
proc.astart(atexit_cleanup=False)
stdout, stderr = proc.communicate(script_bytes, timeout)
finally:
try:
proc.kill()
except Exception:
pass
if proc.returncode != 0:
assert proc.returncode is not None
raise subprocess.CalledProcessError(proc.returncode, proc.runargs, stdout, stderr)
return stdout.decode('utf-8')

pool = ThreadPoolExecutor(max_workers=1)
fut = pool.submit(f)
pool.shutdown(wait=False)
return FutureResult(fut)
Expand All @@ -830,13 +859,13 @@ async def run_script(
script_bytes = bytes(script_text_or_path, 'utf-8')
runargs = [self._executable_path, '/CP65001', '/ErrorStdOut', '*']
proc = AsyncAHKProcess(runargs)
await proc.start()
if blocking:
stdout, stderr = await proc.acommunicate(script_bytes, timeout=timeout)
if proc.returncode != 0:
assert proc.returncode is not None
raise subprocess.CalledProcessError(proc.returncode, proc.runargs, stdout, stderr)
return stdout.decode('utf-8')
async with proc:
stdout, stderr = await proc.acommunicate(script_bytes, timeout=timeout)
if proc.returncode != 0:
assert proc.returncode is not None
raise subprocess.CalledProcessError(proc.returncode, proc.runargs, stdout, stderr)
return stdout.decode('utf-8')
else:
return await self._async_run_nonblocking(proc, script_bytes, timeout=timeout)

Expand Down
62 changes: 41 additions & 21 deletions ahk/_sync/transport.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,10 @@
else:
from typing import TypeAlias, TypeGuard

if sys.version_info < (3, 11):
from typing_extensions import Self
else:
from typing import Self

T_SyncFuture = TypeVar('T_SyncFuture')

Expand Down Expand Up @@ -102,12 +106,16 @@ def async_assert_send_nonblocking_type_correct(
class Communicable(Protocol):
runargs: List[str]

def start(self, atexit_cleanup: bool = True) -> None: ...

def communicate(self, input_bytes: Optional[bytes], timeout: Optional[int] = None) -> Tuple[bytes, bytes]: ...


@property
def returncode(self) -> Optional[int]: ...

def kill(self) -> None: ...


class SyncAHKProcess:
def __init__(self, runargs: List[str]):
Expand All @@ -119,9 +127,11 @@ def returncode(self) -> Optional[int]:
assert self._proc is not None
return self._proc.returncode

def start(self) -> None:

def start(self, atexit_cleanup: bool = True) -> None:
self._proc = sync_create_process(self.runargs)
atexit.register(kill, self._proc)
if atexit_cleanup:
atexit.register(kill, self._proc)
return None


Expand Down Expand Up @@ -160,6 +170,17 @@ def communicate(self, input_bytes: Optional[bytes] = None, timeout: Optional[int
assert isinstance(self._proc, subprocess.Popen)
return self._proc.communicate(input=input_bytes, timeout=timeout)

def __enter__(self) -> Self:
self.start(atexit_cleanup=False)
return self

def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> Literal[False]:
try:
self.kill()
except Exception:
pass
return False




Expand Down Expand Up @@ -601,6 +622,7 @@ def start(self) -> None:
with warnings.catch_warnings(record=True) as caught_warnings:
with self.lock:
self._proc = self._create_process()
self._proc.start()
if caught_warnings:
for warning in caught_warnings:
warnings.warn(warning.message, warning.category, stacklevel=2)
Expand All @@ -623,9 +645,7 @@ def _render_script(self, template: Optional[jinja2.Template] = None, **kwargs: A
def lock(self) -> Any:
return self._execution_lock

def _create_process(
self, template: Optional[jinja2.Template] = None, **template_kwargs: Any
) -> SyncAHKProcess:
def _create_process(self, template: Optional[jinja2.Template] = None, **template_kwargs: Any) -> SyncAHKProcess:
if template is None:
if template_kwargs:
raise ValueError('template kwargs were specified, but no template was provided')
Expand All @@ -648,15 +668,13 @@ def _create_process(
atexit.register(try_remove, tempscript.name)
runargs = [self._executable_path, '/CP65001', '/ErrorStdOut', daemon_script]
proc = SyncAHKProcess(runargs=runargs)
proc.start()
return proc

def _send_nonblocking(
self, request: RequestMessage, engine: Optional[AHK[Any]] = None
) -> Union[None, Tuple[int, int], int, str, bool, Window, List[Window], List[Control]]:
msg = request.format()
proc = self._create_process()
try:
with self._create_process() as proc:
proc.write(msg)
proc.drain_stdin()
tom = proc.readline()
Expand All @@ -679,11 +697,6 @@ def _send_nonblocking(
part = proc.readline()
content_buffer.write(part)
content = content_buffer.getvalue()[:-1]
finally:
try:
proc.kill()
except: # noqa
pass
response = ResponseMessage.from_bytes(content, engine=engine)
return response.unpack() # type: ignore

Expand Down Expand Up @@ -738,15 +751,22 @@ def _sync_run_nonblocking(
script_bytes: Optional[bytes],
timeout: Optional[int] = None,
) -> FutureResult[str]:
pool = ThreadPoolExecutor(max_workers=1)

def f() -> str:
stdout, stderr = proc.communicate(script_bytes, timeout)
try:
proc.start(atexit_cleanup=False)
stdout, stderr = proc.communicate(script_bytes, timeout)
finally:
try:
proc.kill()
except Exception:
pass
if proc.returncode != 0:
assert proc.returncode is not None
raise subprocess.CalledProcessError(proc.returncode, proc.runargs, stdout, stderr)
return stdout.decode('utf-8')

pool = ThreadPoolExecutor(max_workers=1)
fut = pool.submit(f)
pool.shutdown(wait=False)
return FutureResult(fut)
Expand All @@ -771,13 +791,13 @@ def run_script(
script_bytes = bytes(script_text_or_path, 'utf-8')
runargs = [self._executable_path, '/CP65001', '/ErrorStdOut', '*']
proc = SyncAHKProcess(runargs)
proc.start()
if blocking:
stdout, stderr = proc.communicate(script_bytes, timeout=timeout)
if proc.returncode != 0:
assert proc.returncode is not None
raise subprocess.CalledProcessError(proc.returncode, proc.runargs, stdout, stderr)
return stdout.decode('utf-8')
with proc:
stdout, stderr = proc.communicate(script_bytes, timeout=timeout)
if proc.returncode != 0:
assert proc.returncode is not None
raise subprocess.CalledProcessError(proc.returncode, proc.runargs, stdout, stderr)
return stdout.decode('utf-8')
else:
return self._sync_run_nonblocking(proc, script_bytes, timeout=timeout)

Expand Down
1 change: 1 addition & 0 deletions buildunasync.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
'AsyncFutureResult': 'FutureResult',
'_async_run_nonblocking': '_sync_run_nonblocking',
'acommunicate': 'communicate',
'astart': 'start',
# "__aenter__": "__aenter__",
},
),
Expand Down

0 comments on commit 62e5940

Please sign in to comment.