generated from datalad/datalad-extension-template
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
d455d63
commit 25268a4
Showing
2 changed files
with
229 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,213 @@ | ||
from __future__ import annotations | ||
|
||
import subprocess | ||
from collections import deque | ||
from collections.abc import Generator | ||
from contextlib import contextmanager | ||
from pathlib import Path | ||
from queue import Queue | ||
from subprocess import DEVNULL | ||
from typing import ( | ||
Any, | ||
IO, | ||
) | ||
|
||
from datalad_next.runners import ( | ||
GeneratorMixIn, | ||
Protocol, | ||
ThreadedRunner, | ||
) | ||
|
||
|
||
class _ProtocolShell: | ||
def __init__(self, | ||
base_class: type[Protocol], | ||
base_kwargs: dict, | ||
introduced_timeout: bool, | ||
terminate_time: int | None, | ||
kill_time: int | None, | ||
armed: bool | ||
) -> None: | ||
|
||
self.base = base_class(**base_kwargs) | ||
self.introduced_timeout = introduced_timeout | ||
self.terminate_time = terminate_time | ||
self.kill_time = ( | ||
((terminate_time or 0) + kill_time) | ||
if kill_time is not None | ||
else kill_time | ||
) | ||
self.process: subprocess.Popen | None = None | ||
self.armed = armed | ||
self.kill_counter = 0 | ||
|
||
def arm(self) -> None: | ||
self.kill_counter = 0 | ||
self.armed = True | ||
|
||
def __getattr__(self, item: Any): | ||
""" Forward instance attribute access to the base object """ | ||
try: | ||
return self.__getattribute__(item) | ||
except AttributeError: | ||
return self.base.__getattribute__(item) | ||
|
||
def connection_made(self, process: subprocess.Popen) -> None: | ||
self.process = process | ||
self.base.connection_made(process) | ||
|
||
def timeout(self, fd: int | None) -> bool: | ||
if self.armed: | ||
self.kill_counter += 1 | ||
if self.kill_time and self.kill_counter >= self.kill_time: | ||
self.process.kill() | ||
self.kill_time = None | ||
if self.terminate_time and self.kill_counter > self.terminate_time: | ||
self.process.terminate() | ||
self.terminate_time = None | ||
if self.introduced_timeout: | ||
return False | ||
return self.base.timeout(fd) | ||
|
||
|
||
class _GeneratorProtocolShell(_ProtocolShell, GeneratorMixIn): | ||
def __init__(self, | ||
base_class: type[Protocol], | ||
base_kwargs: dict, | ||
introduced_timeout: bool, | ||
terminate_time: int | None, | ||
kill_time: int | None, | ||
armed: bool, | ||
) -> None: | ||
|
||
GeneratorMixIn.__init__(self) | ||
_ProtocolShell.__init__( | ||
self, | ||
base_class, | ||
base_kwargs, | ||
introduced_timeout, | ||
terminate_time, | ||
kill_time, | ||
armed, | ||
) | ||
|
||
@property | ||
def result_queue(self) -> deque: | ||
return self.base.result_queue | ||
|
||
|
||
|
||
@contextmanager | ||
def run( | ||
cmd: list, | ||
protocol_class = type[Protocol], | ||
*, | ||
cwd: Path | None = None, | ||
input: int | IO | bytes | Queue[bytes | None] | None = None, | ||
timeout: float | None = None, | ||
terminate_time: int | None = None, | ||
kill_time: int | None = None, | ||
) -> dict | Generator: | ||
|
||
introduces_timeout = False | ||
if timeout is None: | ||
introduces_timeout = True | ||
timeout = 1.0 | ||
|
||
runner_protocol_class, armed = ( | ||
(_GeneratorProtocolShell, False) | ||
if issubclass(protocol_class, GeneratorMixIn) | ||
else (_ProtocolShell, True) | ||
) | ||
|
||
# This is a little bit ugly, implement class-attribute forwarding instead | ||
runner_protocol_class.proc_out = protocol_class.proc_out | ||
runner_protocol_class.proc_err = protocol_class.proc_err | ||
|
||
runner = ThreadedRunner( | ||
cmd=cmd, | ||
protocol_class=runner_protocol_class, | ||
stdin=DEVNULL if input is None else input, | ||
protocol_kwargs=dict( | ||
base_class=protocol_class, | ||
base_kwargs=dict(), | ||
introduced_timeout=introduces_timeout, | ||
terminate_time=terminate_time, | ||
kill_time=kill_time, | ||
armed=armed, | ||
), | ||
timeout=timeout, | ||
exception_on_error=False, | ||
cwd=cwd, | ||
) | ||
result = runner.run() | ||
if isinstance(result, dict): | ||
try: | ||
yield result | ||
finally: | ||
pass | ||
else: | ||
try: | ||
yield result | ||
finally: | ||
runner.protocol.arm() | ||
tuple(result) | ||
|
||
|
||
x = ''' | ||
with run(cmd=['find', '/home/cristian/datalad/longnow-podcasts'], | ||
protocol_class=StdOutCaptureGeneratorProtocol, | ||
terminate_time=10, | ||
kill_time=5) as r: | ||
for line in r: | ||
print(line) | ||
print(r.return_code) | ||
''' | ||
|
||
from datalad_next.runners import StdOutCaptureGeneratorProtocol | ||
|
||
with run(cmd=['sleep', '100'], | ||
protocol_class=StdOutCaptureGeneratorProtocol, | ||
terminate_time=3, | ||
kill_time=3) as r: | ||
pass | ||
|
||
print(r.return_code) | ||
|
||
|
||
|
||
with run(cmd=['sleep', '100'], | ||
protocol_class=StdOutCaptureGeneratorProtocol, | ||
terminate_time=3, | ||
kill_time=3) as r: | ||
pass | ||
|
||
print(r.return_code) | ||
|
||
|
||
|
||
py_prog = ''' | ||
import sys | ||
import time | ||
i = 0 | ||
while True: | ||
try: | ||
print(i, flush=True) | ||
i += 1 | ||
time.sleep(1) | ||
except BaseException as e: | ||
pass | ||
''' | ||
|
||
import sys | ||
|
||
with run(cmd=[sys.executable, '-c', py_prog], | ||
protocol_class=StdOutCaptureGeneratorProtocol, | ||
terminate_time=3, | ||
kill_time=3) as r: | ||
print(next(r)) | ||
print(next(r)) | ||
|
||
print(r.return_code) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters