From b82af7b170ff8f6ba6cc15d6c6874ee7f01da529 Mon Sep 17 00:00:00 2001 From: "Jonas G. Drange" Date: Tue, 23 Mar 2021 20:34:11 +0100 Subject: [PATCH] First pass of aio transmitter --- ert3/data/_record.py | 56 +++++++++++-------- ert3/evaluator/_evaluator.py | 29 +++++++--- .../ensemble_evaluator/entity/ensemble.py | 1 + .../entity/function_step.py | 25 ++++++--- .../ensemble_evaluator/entity/unix_step.py | 18 ++++-- .../test_prefect_ensemble.py | 43 +++++++------- tests/ert3/data/test_transmitters.py | 28 ++++++---- 7 files changed, 126 insertions(+), 74 deletions(-) diff --git a/ert3/data/_record.py b/ert3/data/_record.py index 9e9bfcd08e0..61d270f82b1 100644 --- a/ert3/data/_record.py +++ b/ert3/data/_record.py @@ -5,12 +5,18 @@ import uuid from abc import abstractmethod from enum import Enum, auto +from functools import partial, wraps from pathlib import Path from typing import Awaitable, List, Mapping, Tuple, Union +import aiofiles +from aiofiles.os import wrap from pydantic import BaseModel, root_validator +_copy = wrap(shutil.copy) + + class _DataElement(BaseModel): class Config: validate_all = True @@ -151,7 +157,7 @@ async def transmit( Path, Union[List[float], Mapping[int, float], Mapping[str, float], List[bytes]], ], - ) -> Awaitable[bool]: + ) -> None: pass @@ -169,46 +175,47 @@ def set_transmitted(self, uri: Path): super().set_transmitted() self._uri = str(uri) - def transmit( + async def transmit( self, data_or_file: typing.Union[ Path, List[float], Mapping[int, float], Mapping[str, float], List[bytes] ], mime="text/json", - ): + ) -> None: if self.is_transmitted(): raise RuntimeError("Record already transmitted") if isinstance(data_or_file, Path) or isinstance(data_or_file, str): - with open(data_or_file) as f: - record = Record(data=json.load(f)) + async with aiofiles.open(str(data_or_file), mode="r") as f: + contents = await f.read() + record = Record(data=json.loads(contents)) else: record = Record(data=data_or_file) storage_uri = self._storage_path / self._concrete_key - with open(storage_uri, "w") as f: + async with aiofiles.open(storage_uri, mode="w") as f: if mime == "text/json": - json.dump(record.data, f) + contents = json.dumps(record.data) + await f.write(contents) elif mime == "application/x-python-code": # XXX: An opaque record is a list of bytes... yes # sonso or dan or jond: do something about this - f.write(record.data[0].decode()) + await f.write(record.data[0].decode()) else: raise ValueError(f"unsupported mime {mime}") self.set_transmitted(storage_uri) - def load(self, mime="text/json"): - if mime != "text/json": - raise NotImplementedError("can't do {mime}, sorry") + async def load(self) -> Record: if self._state != RecordTransmitterState.transmitted: raise RuntimeError("cannot load untransmitted record") - with open(str(self._uri)) as f: - return Record(data=json.load(f)) + async with aiofiles.open(str(self._uri)) as f: + contents = await f.read() + return Record(data=json.loads(contents)) # TODO: should use Path - def dump(self, location: str, format: str = "json"): + async def dump(self, location: str): if self._state != RecordTransmitterState.transmitted: raise RuntimeError("cannot dump untransmitted record") - shutil.copy(self._uri, location) + await _copy(self._uri, location) class InMemoryRecordTransmitter(RecordTransmitter): @@ -223,7 +230,7 @@ def set_transmitted(self, record: Record): super().set_transmitted() self._record = record - def transmit( + async def transmit( self, data_or_file: typing.Union[ Path, List[float], Mapping[int, float], Mapping[str, float], List[bytes] @@ -233,27 +240,28 @@ def transmit( if self.is_transmitted(): raise RuntimeError("Record already transmitted") if isinstance(data_or_file, Path) or isinstance(data_or_file, str): - with open(data_or_file) as f: - record = Record(data=json.load(f)) + async with aiofiles.open(data_or_file) as f: + contents = await f.read() + record = Record(data=json.loads(contents)) else: record = Record(data=data_or_file) self.set_transmitted(record=record) - def load(self, mime="text/json"): + async def load(self): return self._record # TODO: should use Path - def dump(self, location: str, format: str = "text/json"): + async def dump(self, location: str, format: str = "text/json"): if format is None: format = "text/json" - if self._state != RecordTransmitterState.transmitted: + if not self.is_transmitted(): raise RuntimeError("cannot dump untransmitted record") - with open(location, "w") as f: + async with aiofiles.open(location, mode="w") as f: if format == "text/json": - json.dump(self._record.data, f) + await f.write(json.dumps(self._record.data)) elif format == "application/x-python-code": # XXX: An opaque record is a list of bytes... yes # sonso or dan or jond: do something about this - f.write(self._record.data[0].decode()) + await f.write(self._record.data[0].decode()) else: raise ValueError(f"unsupported mime {format}") diff --git a/ert3/evaluator/_evaluator.py b/ert3/evaluator/_evaluator.py index ca5b0da09da..15f99e86b1a 100644 --- a/ert3/evaluator/_evaluator.py +++ b/ert3/evaluator/_evaluator.py @@ -1,9 +1,12 @@ +import asyncio import os import pathlib import shutil import typing from collections import defaultdict +import aiofiles + import ert3 from ert3.config._stages_config import StagesConfig from ert_shared.ensemble_evaluator.config import EvaluatorServerConfig @@ -41,6 +44,7 @@ def _prepare_input( storage_config = ee_config["storage"] transmitters = defaultdict(dict) + futures = [] for input_ in step_config.input: for iens, record in enumerate(inputs.ensemble_records[input_.record].records): if storage_config.get("type") == "shared_disk": @@ -52,8 +56,9 @@ def _prepare_input( raise ValueError( f"Unsupported transmitter type: {storage_config.get('type')}" ) - transmitter.transmit(record.data) + futures.append(transmitter.transmit(record.data)) transmitters[iens][input_.record] = transmitter + asyncio.get_event_loop().run_until_complete(asyncio.gather(*futures)) for command in step_config.transportable_commands: for iens in range(0, ensemble_size): if storage_config.get("type") == "shared_disk": @@ -66,7 +71,9 @@ def _prepare_input( f"Unsupported transmitter type: {storage_config.get('type')}" ) with open(command.location, "rb") as f: - transmitter.transmit([f.read()], mime=command.mime) + asyncio.get_event_loop().run_until_complete( + transmitter.transmit([f.read()], mime=command.mime) + ) transmitters[iens][command.name] = transmitter return dict(transmitters) @@ -219,15 +226,22 @@ def _run(ensemble_evaluator): def _prepare_responses(raw_responses): - data_results = [] + async def _load(iens, record_key, transmitter): + record = await transmitter.load() + return (iens, record_key, record) + + futures = [] for iens in sorted(raw_responses.keys(), key=int): - real_data = {} for record, transmitter in raw_responses[iens].items(): - real_data[record] = transmitter.load() - data_results.append(real_data) + futures.append(_load(iens, record, transmitter)) + results = asyncio.get_event_loop().run_until_complete(asyncio.gather(*futures)) + + data_results = defaultdict(dict) + for res in results: + data_results[res[0]][res[1]] = res[2] responses = {response_name: [] for response_name in data_results[0]} - for realization in data_results: + for realization in data_results.values(): assert responses.keys() == realization.keys() for key in realization: responses[key].append(realization[key]) @@ -255,7 +269,6 @@ def evaluate( ee = EnsembleEvaluator(ensemble=ensemble, config=config, iter_=0) result = _run(ee) - responses = _prepare_responses(result) return responses diff --git a/ert_shared/ensemble_evaluator/entity/ensemble.py b/ert_shared/ensemble_evaluator/entity/ensemble.py index f441ad66010..6701a436264 100644 --- a/ert_shared/ensemble_evaluator/entity/ensemble.py +++ b/ert_shared/ensemble_evaluator/entity/ensemble.py @@ -35,6 +35,7 @@ def _sort_steps(steps: typing.List["_Step"]): graph[other.get_name()].add(step.get_name()) edged_nodes.add(step.get_name()) edged_nodes.add(other.get_name()) + null_nodes = set([step.get_name() for step in steps]) - edged_nodes [graph[node] for node in null_nodes] ts = TopologicalSorter(graph) diff --git a/ert_shared/ensemble_evaluator/entity/function_step.py b/ert_shared/ensemble_evaluator/entity/function_step.py index 237caa39ce1..a86b529f0a0 100644 --- a/ert_shared/ensemble_evaluator/entity/function_step.py +++ b/ert_shared/ensemble_evaluator/entity/function_step.py @@ -1,3 +1,4 @@ +import asyncio from typing import Dict, Optional from prefect import Task import prefect @@ -13,18 +14,28 @@ def __init__(self, step, output_transmitters, ee_id, *args, **kwargs) -> None: self._ee_id = ee_id def _attempt_execute(self, *, func, transmitters): - kwargs = { - input_.get_name(): transmitters[input_.get_name()].load().data - for input_ in self._step.get_inputs() - } + async def _load(io_name, transmitter): + record = await transmitter.load() + return (io_name, record) + + futures = [] + for input_ in self._step.get_inputs(): + futures.append(_load(input_.get_name(), transmitters[input_.get_name()])) + results = asyncio.get_event_loop().run_until_complete(asyncio.gather(*futures)) + kwargs = {result[0]: result[1].data for result in results} function_output = func(**kwargs) - transmitter_map = {} + async def _transmit(io_name, transmitter, data): + await transmitter.transmit(data) + return (io_name, transmitter) + + futures = [] for output in self._step.get_outputs(): name = output.get_name() transmitter = self._output_transmitters[name] - transmitter.transmit(function_output) - transmitter_map[name] = transmitter + futures.append(_transmit(name, transmitter, function_output)) + results = asyncio.get_event_loop().run_until_complete(asyncio.gather(*futures)) + transmitter_map = {result[0]: result[1] for result in results} return transmitter_map def run_job(self, job, transmitters: Dict[str, "RecordTransmitter"], client): diff --git a/ert_shared/ensemble_evaluator/entity/unix_step.py b/ert_shared/ensemble_evaluator/entity/unix_step.py index fb11ea351f1..02ea6a3d5fb 100644 --- a/ert_shared/ensemble_evaluator/entity/unix_step.py +++ b/ert_shared/ensemble_evaluator/entity/unix_step.py @@ -1,3 +1,4 @@ +import asyncio import os import stat import subprocess @@ -54,11 +55,16 @@ def run_jobs(self, client, run_path): ) def _load_and_dump_input(self, transmitters, runpath): + futures = [] for input_ in self._step.get_inputs(): # TODO: use Path - transmitters[input_.get_name()].dump( - os.path.join(runpath, input_.get_path()), input_.get_mime() + futures.append( + transmitters[input_.get_name()].dump( + os.path.join(runpath, input_.get_path()) + ) ) + asyncio.get_event_loop().run_until_complete(asyncio.gather(*futures)) + for input_ in self._step.get_inputs(): if input_.is_executable(): path = os.path.join(runpath, input_.get_path()) st = os.stat(path) @@ -76,6 +82,7 @@ def run(self, inputs=None): outputs = {} self.run_jobs(ee_client, run_path) + futures = [] for output in self._step.get_outputs(): if not os.path.exists(os.path.join(run_path, output.get_path())): raise FileNotFoundError( @@ -85,9 +92,12 @@ def run(self, inputs=None): outputs[output.get_name()] = self._output_transmitters[ output.get_name() ] - outputs[output.get_name()].transmit( - os.path.join(run_path, output.get_path()) + futures.append( + outputs[output.get_name()].transmit( + os.path.join(run_path, output.get_path()) + ) ) + asyncio.get_event_loop().run_until_complete(asyncio.gather(*futures)) ee_client.send_event( ev_type=ids.EVTYPE_FM_STEP_SUCCESS, diff --git a/tests/ensemble_evaluator/test_prefect_ensemble.py b/tests/ensemble_evaluator/test_prefect_ensemble.py index bc63a5ba45e..39ebfd61a28 100644 --- a/tests/ensemble_evaluator/test_prefect_ensemble.py +++ b/tests/ensemble_evaluator/test_prefect_ensemble.py @@ -1,29 +1,29 @@ -from pathlib import Path -import sys +import asyncio +import copy import os import os.path -import yaml -import pytest +import sys import threading -import copy +from collections import defaultdict from datetime import timedelta from functools import partial from itertools import permutations -from prefect import Flow -from prefect.run_configs import LocalRun -from collections import defaultdict -from tests.utils import SOURCE_DIR, tmp +from pathlib import Path + +import ert3 +import ert_shared.ensemble_evaluator.entity.ensemble as ee +import pytest +import yaml +from ert_shared.ensemble_evaluator.client import Client from ert_shared.ensemble_evaluator.config import EvaluatorServerConfig +from ert_shared.ensemble_evaluator.entity import identifiers as ids +from ert_shared.ensemble_evaluator.entity.unix_step import UnixTask from ert_shared.ensemble_evaluator.evaluator import EnsembleEvaluator from ert_shared.ensemble_evaluator.prefect_ensemble import PrefectEnsemble -import ert_shared.ensemble_evaluator.entity.ensemble as ee -from ert_shared.ensemble_evaluator.entity.unix_step import UnixTask - -from ert_shared.ensemble_evaluator.entity import identifiers as ids -from ert_shared.ensemble_evaluator.client import Client -import ert3 - +from prefect import Flow +from prefect.run_configs import LocalRun from tests.ensemble_evaluator.conftest import _mock_ws +from tests.utils import SOURCE_DIR, tmp def parse_config(path): @@ -36,7 +36,7 @@ def input_transmitter(name, data, storage_path): transmitter = ert3.data.SharedDiskRecordTransmitter( name=name, storage_path=Path(storage_path) ) - transmitter.transmit(data) + asyncio.get_event_loop().run_until_complete(transmitter.transmit(data)) return {name: transmitter} @@ -69,7 +69,9 @@ def script_transmitter(name, location, storage_path): name=name, storage_path=Path(storage_path) ) with open(location, "rb") as f: - transmitter.transmit([f.read()], mime="application/x-python-code") + asyncio.get_event_loop().run_until_complete( + transmitter.transmit([f.read()], mime="application/x-python-code") + ) return {name: transmitter} @@ -291,7 +293,10 @@ def sum_function(values): expected_uri = output_trans["output"]._uri output_uri = task_result.result["output"]._uri assert expected_uri == output_uri - transmitted_result = task_result.result["output"].load().data + transmitted_record = asyncio.get_event_loop().run_until_complete( + task_result.result["output"].load() + ) + transmitted_result = transmitted_record.data expected_result = sum_function(**test_values) assert expected_result == transmitted_result diff --git a/tests/ert3/data/test_transmitters.py b/tests/ert3/data/test_transmitters.py index c2b586d5aa9..17bbfccefa6 100644 --- a/tests/ert3/data/test_transmitters.py +++ b/tests/ert3/data/test_transmitters.py @@ -64,9 +64,10 @@ def in_memory_factory(name: str) -> InMemoryRecordTransmitter: ) +@pytest.mark.asyncio @simple_records @factory_params -def test_simple_record_transmit( +async def test_simple_record_transmit( record_transmitter_factory_context: ContextManager[ Callable[[str], RecordTransmitter] ], @@ -75,15 +76,16 @@ def test_simple_record_transmit( ): with record_transmitter_factory_context() as record_transmitter_factory: transmitter = record_transmitter_factory(name="some_name") - transmitter.transmit(data_in) + await transmitter.transmit(data_in) assert transmitter.is_transmitted() with pytest.raises(RuntimeError, match="Record already transmitted"): - transmitter.transmit(data_or_file=[1, 2, 3]) + await transmitter.transmit(data_or_file=[1, 2, 3]) +@pytest.mark.asyncio @simple_records @factory_params -def test_simple_record_transmit_and_load( +async def test_simple_record_transmit_and_load( record_transmitter_factory_context: ContextManager[ Callable[[str], RecordTransmitter] ], @@ -92,16 +94,17 @@ def test_simple_record_transmit_and_load( ): with record_transmitter_factory_context() as record_transmitter_factory: transmitter = record_transmitter_factory(name="some_name") - transmitter.transmit(data_in) + await transmitter.transmit(data_in) - record = transmitter.load() + record = await transmitter.load() assert record.data == expected_data +@pytest.mark.asyncio @simple_records @factory_params @tmpdir(None) -def test_simple_record_transmit_and_dump( +async def test_simple_record_transmit_and_dump( record_transmitter_factory_context: ContextManager[ Callable[[str], RecordTransmitter] ], @@ -110,17 +113,18 @@ def test_simple_record_transmit_and_dump( ): with record_transmitter_factory_context() as record_transmitter_factory: transmitter = record_transmitter_factory(name="some_name") - transmitter.transmit(data_in) + await transmitter.transmit(data_in) - transmitter.dump("record.json") + await transmitter.dump("record.json") with open("record.json") as f: expected_data = json.loads(json.dumps(expected_data)) assert expected_data == json.load(f) +@pytest.mark.asyncio @simple_records @factory_params -def test_simple_record_transmit_pickle_and_load( +async def test_simple_record_transmit_pickle_and_load( record_transmitter_factory_context: ContextManager[ Callable[[str], RecordTransmitter] ], @@ -130,8 +134,8 @@ def test_simple_record_transmit_pickle_and_load( with record_transmitter_factory_context() as record_transmitter_factory: transmitter = record_transmitter_factory(name="some_name") transmitter = pickle.loads(cloudpickle.dumps(transmitter)) - transmitter.transmit(data_in) + await transmitter.transmit(data_in) transmitter = pickle.loads(cloudpickle.dumps(transmitter)) - record = transmitter.load() + record = await transmitter.load() assert record.data == expected_data