Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Worker ID functionality #261

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
3 changes: 2 additions & 1 deletion loky/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@
from .reusable_executor import get_reusable_executor
from .cloudpickle_wrapper import wrap_non_picklable_objects
from .process_executor import BrokenProcessPool, ProcessPoolExecutor
from .worker_id import get_worker_id


__all__ = ["get_reusable_executor", "cpu_count", "wait", "as_completed",
"Future", "Executor", "ProcessPoolExecutor",
"BrokenProcessPool", "CancelledError", "TimeoutError",
"FIRST_COMPLETED", "FIRST_EXCEPTION", "ALL_COMPLETED",
"wrap_non_picklable_objects", "set_loky_pickler"]
"wrap_non_picklable_objects", "set_loky_pickler", "get_worker_id"]


__version__ = '3.0.0.dev0'
35 changes: 31 additions & 4 deletions loky/process_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def _sendback_result(result_queue, work_id, result=None, exception=None):

def _process_worker(call_queue, result_queue, initializer, initargs,
processes_management_lock, timeout, worker_exit_lock,
current_depth):
current_depth, worker_id):
"""Evaluates calls from call_queue and places the results in result_queue.

This worker is run in a separate process.
Expand Down Expand Up @@ -398,6 +398,9 @@ def _process_worker(call_queue, result_queue, initializer, initargs,
_last_memory_leak_check = None
pid = os.getpid()

# set the worker_id environment variable
os.environ["LOKY_WORKER_ID"] = str(worker_id)

mp.util.debug('Worker started with timeout=%s' % timeout)
while True:
try:
Expand Down Expand Up @@ -521,6 +524,9 @@ def weakref_cb(_,
# A list of the ctx.Process instances used as workers.
self.processes = executor._processes

# A dict mapping worker pids to worker IDs
self.process_worker_ids = executor._process_worker_ids

# A ctx.Queue that will be filled with _CallItems derived from
# _WorkItems for processing by the process workers.
self.call_queue = executor._call_queue
Expand Down Expand Up @@ -668,6 +674,7 @@ def process_result_item(self, result_item):
# itself: we should not mark the executor as broken.
with self.processes_management_lock:
p = self.processes.pop(result_item, None)
self.process_worker_ids.pop(result_item, None)

# p can be None is the executor is concurrently shutting down.
if p is not None:
Expand Down Expand Up @@ -760,7 +767,10 @@ def kill_workers(self):
# terminates descendant workers of the children in case there is some
# nested parallelism.
while self.processes:
_, p = self.processes.popitem()
pid = list(self.processes.keys())[0]
pmla marked this conversation as resolved.
Show resolved Hide resolved
pid, p = self.processes.popitem()
self.process_worker_ids.pop(pid, None)

mp.util.debug('terminate process {}'.format(p.name))
try:
recursive_terminate(p)
Expand Down Expand Up @@ -983,8 +993,10 @@ def __init__(self, max_workers=None, job_reducers=None,
# Map of pids to processes
self._processes = {}

# Map of pids to process worker IDs
self._process_worker_ids = {}

# Internal variables of the ProcessPoolExecutor
self._processes = {}
self._queue_count = 0
self._pending_work_items = {}
self._running_work_items = []
Expand Down Expand Up @@ -1069,13 +1081,27 @@ def weakref_cb(
process_pool_executor_at_exit = threading._register_atexit(
_python_exit)

def _get_available_worker_id(self):
if _CURRENT_DEPTH > 0:
return -1

used_ids = set(self._process_worker_ids.values())
available_ids = set(range(self._max_workers)) - used_ids
if len(available_ids):
return available_ids.pop()
else:
return -1

def _adjust_process_count(self):
for _ in range(len(self._processes), self._max_workers):
worker_exit_lock = self._context.BoundedSemaphore(1)
worker_id = self._get_available_worker_id()
args = (self._call_queue, self._result_queue, self._initializer,
self._initargs, self._processes_management_lock,
self._timeout, worker_exit_lock, _CURRENT_DEPTH + 1)
self._timeout, worker_exit_lock, _CURRENT_DEPTH + 1,
worker_id)
worker_exit_lock.acquire()

try:
# Try to spawn the process with some environment variable to
# overwrite but it only works with the loky context for now.
Expand All @@ -1086,6 +1112,7 @@ def _adjust_process_count(self):
p._worker_exit_lock = worker_exit_lock
p.start()
self._processes[p.pid] = p
self._process_worker_ids[p.pid] = worker_id
mp.util.debug('Adjust process count : {}'.format(self._processes))

def _ensure_executor_running(self):
Expand Down
15 changes: 15 additions & 0 deletions loky/worker_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import os


def get_worker_id():
"""Get the worker ID of the current process. For a `ReusableExectutor`
with `max_workers=n`, the worker ID is in the range [0..n). This is suited
for reuse of persistent objects such as GPU IDs. This function only works
at the first level of parallelization (i.e. not for nested
parallelization). Resizing the `ReusableExectutor` will result in
unpredictable return values. Returns -1 on failure.
"""
pmla marked this conversation as resolved.
Show resolved Hide resolved
wid = os.environ.get('LOKY_WORKER_ID', None)
pmla marked this conversation as resolved.
Show resolved Hide resolved
if wid is None:
return -1
return int(wid)
36 changes: 36 additions & 0 deletions tests/test_worker_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import time
import pytest
import numpy as np
from collections import defaultdict
from loky import get_reusable_executor, get_worker_id


def random_sleep(args):
k, max_duration = args
rng = np.random.RandomState(seed=k)
duration = rng.uniform(0, max_duration)
t0 = time.time()
time.sleep(duration)
t1 = time.time()
wid = get_worker_id()
return (wid, t0, t1)


@pytest.mark.parametrize("max_duration,timeout,kmax", [(0.05, 2, 100),
(1, 0.01, 4)])
def test_worker_ids(max_duration, timeout, kmax):
"""Test that worker IDs are always unique, with re-use over time"""
num_workers = 4
executor = get_reusable_executor(max_workers=num_workers, timeout=2)
pmla marked this conversation as resolved.
Show resolved Hide resolved
results = executor.map(random_sleep, [(k, max_duration)
for k in range(kmax)])

all_intervals = defaultdict(list)
for wid, t0, t1 in results:
assert wid in set(range(num_workers))
all_intervals[wid].append((t0, t1))
pmla marked this conversation as resolved.
Show resolved Hide resolved

for intervals in all_intervals.values():
intervals = sorted(intervals)
for i in range(len(intervals) - 1):
assert intervals[i + 1][0] >= intervals[i][1]