Skip to content

Commit

Permalink
feat: prepare and pool processes (#87)
Browse files Browse the repository at this point in the history
Closes #85

### Summary of Changes

- Use a process pool to keep started processes waiting
- The max. amount of pipeline processes is now set to `4`.
- Reuse started processes. This should be correct, as the same pipeline
process cannot be used by multiple pipelines at the same time. As the
`metapath` is reset to remove the custom generated Safe-DS pipeline
code, only global library imports (and settings) should remain. If this
is a concern, `maxtasksperchild` can be set to `1`, in which case
pipeline processes are not reused.
- Reuse shared memory location for saving placeholders, if the
memoization infrastructure has added such a location to the object being
saved

---------

Co-authored-by: megalinter-bot <129584137+megalinter-bot@users.noreply.github.com>
Co-authored-by: Lars Reimann <mail@larsreimann.com>
  • Loading branch information
3 people committed Apr 21, 2024
1 parent 50d831f commit e5e7011
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 43 deletions.
45 changes: 39 additions & 6 deletions src/safeds_runner/server/_pipeline_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@

import asyncio
import json
import linecache
import logging
import multiprocessing
import os
import queue
import runpy
import threading
import typing
from concurrent.futures import ProcessPoolExecutor
from functools import cached_property
from multiprocessing.managers import SyncManager
from pathlib import Path
Expand All @@ -17,6 +19,13 @@
import stack_data

from ._memoization_map import MemoizationMap
from ._memoization_utils import (
ExplicitIdentityWrapper,
ExplicitIdentityWrapperLazy,
_has_explicit_identity_memory,
_is_deterministically_hashable,
_is_not_primitive,
)
from ._messages import (
Message,
MessageDataProgram,
Expand Down Expand Up @@ -53,6 +62,10 @@ def _multiprocessing_manager(self) -> SyncManager:
def _messages_queue(self) -> queue.Queue[Message]:
return self._multiprocessing_manager.Queue()

@cached_property
def _process_pool(self) -> ProcessPoolExecutor:
return ProcessPoolExecutor(max_workers=4, mp_context=multiprocessing.get_context("spawn"))

@cached_property
def _messages_queue_thread(self) -> threading.Thread:
return threading.Thread(target=self._handle_queue_messages, daemon=True, args=(asyncio.get_event_loop(),))
Expand All @@ -75,6 +88,8 @@ def startup(self) -> None:
_mq = self._messages_queue # Initialize it here before starting a thread to avoid potential race condition
if not self._messages_queue_thread.is_alive():
self._messages_queue_thread.start()
# Ensure that pool is started
_pool = self._process_pool

def _handle_queue_messages(self, event_loop: asyncio.AbstractEventLoop) -> None:
"""
Expand Down Expand Up @@ -144,7 +159,7 @@ def execute_pipeline(
self._placeholder_map[execution_id],
self._memoization_map,
)
process.execute()
process.execute(self._process_pool)

def get_placeholder(self, execution_id: str, placeholder_name: str) -> tuple[str | None, Any]:
"""
Expand All @@ -167,6 +182,8 @@ def get_placeholder(self, execution_id: str, placeholder_name: str) -> tuple[str
if placeholder_name not in self._placeholder_map[execution_id]:
return None, None
value = self._placeholder_map[execution_id][placeholder_name]
if isinstance(value, ExplicitIdentityWrapper | ExplicitIdentityWrapperLazy):
value = value.value
return _get_placeholder_type(value), value

def shutdown(self) -> None:
Expand All @@ -176,6 +193,7 @@ def shutdown(self) -> None:
This should only be called if this PipelineManager is not intended to be reused again.
"""
self._multiprocessing_manager.shutdown()
self._process_pool.shutdown(wait=True, cancel_futures=True)


class PipelineProcess:
Expand Down Expand Up @@ -210,7 +228,6 @@ def __init__(
self._messages_queue = messages_queue
self._placeholder_map = placeholder_map
self._memoization_map = memoization_map
self._process = multiprocessing.Process(target=self._execute, daemon=True)

def _send_message(self, message_type: str, value: dict[Any, Any] | str) -> None:
self._messages_queue.put(Message(message_type, self._id, value))
Expand All @@ -236,8 +253,16 @@ def save_placeholder(self, placeholder_name: str, value: Any) -> None:
import torch

value = Image(value._image_tensor, torch.device("cpu"))
self._placeholder_map[placeholder_name] = value
placeholder_type = _get_placeholder_type(value)
if _is_deterministically_hashable(value) and _has_explicit_identity_memory(value):
value = ExplicitIdentityWrapperLazy.existing(value)
elif (
not _is_deterministically_hashable(value)
and _is_not_primitive(value)
and _has_explicit_identity_memory(value)
):
value = ExplicitIdentityWrapper.existing(value)
self._placeholder_map[placeholder_name] = value
self._send_message(
message_type_placeholder_type,
create_placeholder_description(placeholder_name, placeholder_type),
Expand Down Expand Up @@ -284,15 +309,23 @@ def _execute(self) -> None:
except BaseException as error: # noqa: BLE001
self._send_exception(error)
finally:
linecache.clearcache()
pipeline_finder.detach()

def execute(self) -> None:
def _catch_subprocess_error(self, error: BaseException) -> None:
# This is a callback to log an unexpected failure, executing this is never expected
logging.exception("Pipeline process unexpectedly failed", exc_info=error) # pragma: no cover

def execute(self, pool: ProcessPoolExecutor) -> None:
"""
Execute this pipeline in a newly created process.
Execute this pipeline in a process from the provided process pool.
Results, progress and errors are communicated back to the main process.
"""
self._process.start()
future = pool.submit(self._execute)
exception = future.exception()
if exception is not None:
self._catch_subprocess_error(exception) # pragma: no cover


# Pipeline process object visible in child process
Expand Down
73 changes: 36 additions & 37 deletions tests/safeds_runner/server/test_websocket_mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import json
import logging
import multiprocessing
import os
import sys
import time
import typing
Expand Down Expand Up @@ -142,7 +141,8 @@
)
@pytest.mark.asyncio()
async def test_should_fail_message_validation_ws(websocket_message: str) -> None:
test_client = SafeDsServer().app.test_client()
sds_server = SafeDsServer()
test_client = sds_server.app.test_client()
async with test_client.websocket("/WSMain") as test_websocket:
await test_websocket.send(websocket_message)
disconnected = False
Expand All @@ -151,6 +151,7 @@ async def test_should_fail_message_validation_ws(websocket_message: str) -> None
except WebsocketDisconnectError as _disconnect:
disconnected = True
assert disconnected
sds_server.app_pipeline_manager.shutdown()


@pytest.mark.parametrize(
Expand Down Expand Up @@ -352,13 +353,6 @@ def test_should_fail_message_validation_reason_placeholder_query(
assert invalid_message == exception_message


@pytest.mark.skipif(
sys.platform.startswith("win") and os.getenv("COVERAGE_RCFILE") is not None,
reason=(
"skipping multiprocessing tests on windows if coverage is enabled, as pytest "
"causes Manager to hang, when using multiprocessing coverage"
),
)
@pytest.mark.parametrize(
argnames="message,expected_response_runtime_error",
argvalues=[
Expand Down Expand Up @@ -388,7 +382,8 @@ async def test_should_execute_pipeline_return_exception(
message: str,
expected_response_runtime_error: Message,
) -> None:
test_client = SafeDsServer().app.test_client()
sds_server = SafeDsServer()
test_client = sds_server.app.test_client()
async with test_client.websocket("/WSMain") as test_websocket:
await test_websocket.send(message)
received_message = await test_websocket.receive()
Expand All @@ -404,15 +399,9 @@ async def test_should_execute_pipeline_return_exception(
assert isinstance(frame["file"], str)
assert "line" in frame
assert isinstance(frame["line"], int)
sds_server.app_pipeline_manager.shutdown()


@pytest.mark.skipif(
sys.platform.startswith("win") and os.getenv("COVERAGE_RCFILE") is not None,
reason=(
"skipping multiprocessing tests on windows if coverage is enabled, as pytest "
"causes Manager to hang, when using multiprocessing coverage"
),
)
@pytest.mark.parametrize(
argnames="initial_messages,initial_execution_message_wait,appended_messages,expected_responses",
argvalues=[
Expand All @@ -426,11 +415,15 @@ async def test_should_execute_pipeline_return_exception(
"code": {
"": {
"gen_test_a": (
"import safeds_runner\nimport base64\nfrom safeds.data.image.containers import Image\n\ndef pipe():\n\tvalue1 ="
"import safeds_runner\nimport base64\nfrom safeds.data.image.containers import Image\nfrom safeds.data.tabular.containers import Table\nimport safeds_runner\nfrom safeds_runner.server._json_encoder import SafeDsEncoder\n\ndef pipe():\n\tvalue1 ="
" 1\n\tsafeds_runner.save_placeholder('value1',"
" value1)\n\tsafeds_runner.save_placeholder('obj',"
" object())\n\tsafeds_runner.save_placeholder('image',"
" Image.from_bytes(base64.b64decode('iVBORw0KGgoAAAANSUhEUgAAAAQAAAAECAYAAACp8Z5+AAAAD0lEQVQIW2NkQAOMpAsAAADuAAVDMQ2mAAAAAElFTkSuQmCC')))\n"
" Image.from_bytes(base64.b64decode('iVBORw0KGgoAAAANSUhEUgAAAAQAAAAECAYAAACp8Z5+AAAAD0lEQVQIW2NkQAOMpAsAAADuAAVDMQ2mAAAAAElFTkSuQmCC')))\n\t"
"table = safeds_runner.memoized_static_call(\"safeds.data.tabular.containers.Table.from_dict\", Table.from_dict, [{'a': [1, 2], 'b': [3, 4]}], [])\n\t"
"safeds_runner.save_placeholder('table',table)\n\t"
'object_mem = safeds_runner.memoized_static_call("random.object.call", SafeDsEncoder, [], [])\n\t'
"safeds_runner.save_placeholder('object_mem',object_mem)\n"
),
"gen_test_a_pipe": (
"from gen_test_a import pipe\n\nif __name__ == '__main__':\n\tpipe()"
Expand All @@ -442,10 +435,12 @@ async def test_should_execute_pipeline_return_exception(
},
),
],
4,
6,
[
# Query Placeholder
json.dumps({"type": "placeholder_query", "id": "abcdefg", "data": {"name": "value1", "window": {}}}),
# Query Placeholder (memoized type)
json.dumps({"type": "placeholder_query", "id": "abcdefg", "data": {"name": "table", "window": {}}}),
# Query not displayable Placeholder
json.dumps({"type": "placeholder_query", "id": "abcdefg", "data": {"name": "obj", "window": {}}}),
# Query invalid placeholder
Expand All @@ -456,6 +451,12 @@ async def test_should_execute_pipeline_return_exception(
Message(message_type_placeholder_type, "abcdefg", create_placeholder_description("value1", "Int")),
Message(message_type_placeholder_type, "abcdefg", create_placeholder_description("obj", "object")),
Message(message_type_placeholder_type, "abcdefg", create_placeholder_description("image", "Image")),
Message(message_type_placeholder_type, "abcdefg", create_placeholder_description("table", "Table")),
Message(
message_type_placeholder_type,
"abcdefg",
create_placeholder_description("object_mem", "SafeDsEncoder"),
),
# Validate Progress Information
Message(message_type_runtime_progress, "abcdefg", create_runtime_progress_done()),
# Query Result Valid
Expand All @@ -464,6 +465,12 @@ async def test_should_execute_pipeline_return_exception(
"abcdefg",
create_placeholder_value(MessageQueryInformation("value1"), "Int", 1),
),
# Query Result Valid (memoized)
Message(
message_type_placeholder_value,
"abcdefg",
create_placeholder_value(MessageQueryInformation("table"), "Table", {"a": [1, 2], "b": [3, 4]}),
),
# Query Result not displayable
Message(
message_type_placeholder_value,
Expand All @@ -489,7 +496,8 @@ async def test_should_execute_pipeline_return_valid_placeholder(
expected_responses: list[Message],
) -> None:
# Initial execution
test_client = SafeDsServer().app.test_client()
sds_server = SafeDsServer()
test_client = sds_server.app.test_client()
async with test_client.websocket("/WSMain") as test_websocket:
for message in initial_messages:
await test_websocket.send(message)
Expand All @@ -506,15 +514,9 @@ async def test_should_execute_pipeline_return_valid_placeholder(
received_message = await test_websocket.receive()
next_message = Message.from_dict(json.loads(received_message))
assert next_message == expected_responses.pop(0)
sds_server.app_pipeline_manager.shutdown()


@pytest.mark.skipif(
sys.platform.startswith("win") and os.getenv("COVERAGE_RCFILE") is not None,
reason=(
"skipping multiprocessing tests on windows if coverage is enabled, as pytest "
"causes Manager to hang, when using multiprocessing coverage"
),
)
@pytest.mark.parametrize(
argnames="messages,expected_response",
argvalues=[
Expand Down Expand Up @@ -576,22 +578,17 @@ async def test_should_execute_pipeline_return_valid_placeholder(
)
@pytest.mark.asyncio()
async def test_should_successfully_execute_simple_flow(messages: list[str], expected_response: Message) -> None:
test_client = SafeDsServer().app.test_client()
sds_server = SafeDsServer()
test_client = sds_server.app.test_client()
async with test_client.websocket("/WSMain") as test_websocket:
for message in messages:
await test_websocket.send(message)
received_message = await test_websocket.receive()
query_result_invalid = Message.from_dict(json.loads(received_message))
assert query_result_invalid == expected_response
sds_server.app_pipeline_manager.shutdown()


@pytest.mark.skipif(
sys.platform.startswith("win") and os.getenv("COVERAGE_RCFILE") is not None,
reason=(
"skipping multiprocessing tests on windows if coverage is enabled, as pytest "
"causes Manager to hang, when using multiprocessing coverage"
),
)
@pytest.mark.parametrize(
argnames="messages",
argvalues=[
Expand All @@ -613,10 +610,12 @@ def helper_should_shut_itself_down_run_in_subprocess(sub_messages: list[str]) ->


async def helper_should_shut_itself_down_run_in_subprocess_async(sub_messages: list[str]) -> None:
test_client = SafeDsServer().app.test_client()
sds_server = SafeDsServer()
test_client = sds_server.app.test_client()
async with test_client.websocket("/WSMain") as test_websocket:
for message in sub_messages:
await test_websocket.send(message)
sds_server.app_pipeline_manager.shutdown()


@pytest.mark.timeout(45)
Expand Down

0 comments on commit e5e7011

Please sign in to comment.