Skip to content

Commit

Permalink
First pass of aio transmitter
Browse files Browse the repository at this point in the history
  • Loading branch information
jondequinor authored and sondreso committed Mar 24, 2021
1 parent 8f139f8 commit b82af7b
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 74 deletions.
56 changes: 32 additions & 24 deletions ert3/data/_record.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@
import uuid
from abc import abstractmethod
from enum import Enum, auto
from functools import partial, wraps
from pathlib import Path
from typing import Awaitable, List, Mapping, Tuple, Union

import aiofiles
from aiofiles.os import wrap
from pydantic import BaseModel, root_validator


_copy = wrap(shutil.copy)


class _DataElement(BaseModel):
class Config:
validate_all = True
Expand Down Expand Up @@ -151,7 +157,7 @@ async def transmit(
Path,
Union[List[float], Mapping[int, float], Mapping[str, float], List[bytes]],
],
) -> Awaitable[bool]:
) -> None:
pass


Expand All @@ -169,46 +175,47 @@ def set_transmitted(self, uri: Path):
super().set_transmitted()
self._uri = str(uri)

def transmit(
async def transmit(
self,
data_or_file: typing.Union[
Path, List[float], Mapping[int, float], Mapping[str, float], List[bytes]
],
mime="text/json",
):
) -> None:
if self.is_transmitted():
raise RuntimeError("Record already transmitted")
if isinstance(data_or_file, Path) or isinstance(data_or_file, str):
with open(data_or_file) as f:
record = Record(data=json.load(f))
async with aiofiles.open(str(data_or_file), mode="r") as f:
contents = await f.read()
record = Record(data=json.loads(contents))
else:
record = Record(data=data_or_file)

storage_uri = self._storage_path / self._concrete_key
with open(storage_uri, "w") as f:
async with aiofiles.open(storage_uri, mode="w") as f:
if mime == "text/json":
json.dump(record.data, f)
contents = json.dumps(record.data)
await f.write(contents)
elif mime == "application/x-python-code":
# XXX: An opaque record is a list of bytes... yes
# sonso or dan or jond: do something about this
f.write(record.data[0].decode())
await f.write(record.data[0].decode())
else:
raise ValueError(f"unsupported mime {mime}")
self.set_transmitted(storage_uri)

def load(self, mime="text/json"):
if mime != "text/json":
raise NotImplementedError("can't do {mime}, sorry")
async def load(self) -> Record:
if self._state != RecordTransmitterState.transmitted:
raise RuntimeError("cannot load untransmitted record")
with open(str(self._uri)) as f:
return Record(data=json.load(f))
async with aiofiles.open(str(self._uri)) as f:
contents = await f.read()
return Record(data=json.loads(contents))

# TODO: should use Path
def dump(self, location: str, format: str = "json"):
async def dump(self, location: str):
if self._state != RecordTransmitterState.transmitted:
raise RuntimeError("cannot dump untransmitted record")
shutil.copy(self._uri, location)
await _copy(self._uri, location)


class InMemoryRecordTransmitter(RecordTransmitter):
Expand All @@ -223,7 +230,7 @@ def set_transmitted(self, record: Record):
super().set_transmitted()
self._record = record

def transmit(
async def transmit(
self,
data_or_file: typing.Union[
Path, List[float], Mapping[int, float], Mapping[str, float], List[bytes]
Expand All @@ -233,27 +240,28 @@ def transmit(
if self.is_transmitted():
raise RuntimeError("Record already transmitted")
if isinstance(data_or_file, Path) or isinstance(data_or_file, str):
with open(data_or_file) as f:
record = Record(data=json.load(f))
async with aiofiles.open(data_or_file) as f:
contents = await f.read()
record = Record(data=json.loads(contents))
else:
record = Record(data=data_or_file)
self.set_transmitted(record=record)

def load(self, mime="text/json"):
async def load(self):
return self._record

# TODO: should use Path
def dump(self, location: str, format: str = "text/json"):
async def dump(self, location: str, format: str = "text/json"):
if format is None:
format = "text/json"
if self._state != RecordTransmitterState.transmitted:
if not self.is_transmitted():
raise RuntimeError("cannot dump untransmitted record")
with open(location, "w") as f:
async with aiofiles.open(location, mode="w") as f:
if format == "text/json":
json.dump(self._record.data, f)
await f.write(json.dumps(self._record.data))
elif format == "application/x-python-code":
# XXX: An opaque record is a list of bytes... yes
# sonso or dan or jond: do something about this
f.write(self._record.data[0].decode())
await f.write(self._record.data[0].decode())
else:
raise ValueError(f"unsupported mime {format}")
29 changes: 21 additions & 8 deletions ert3/evaluator/_evaluator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import asyncio
import os
import pathlib
import shutil
import typing
from collections import defaultdict

import aiofiles

import ert3
from ert3.config._stages_config import StagesConfig
from ert_shared.ensemble_evaluator.config import EvaluatorServerConfig
Expand Down Expand Up @@ -41,6 +44,7 @@ def _prepare_input(
storage_config = ee_config["storage"]
transmitters = defaultdict(dict)

futures = []
for input_ in step_config.input:
for iens, record in enumerate(inputs.ensemble_records[input_.record].records):
if storage_config.get("type") == "shared_disk":
Expand All @@ -52,8 +56,9 @@ def _prepare_input(
raise ValueError(
f"Unsupported transmitter type: {storage_config.get('type')}"
)
transmitter.transmit(record.data)
futures.append(transmitter.transmit(record.data))
transmitters[iens][input_.record] = transmitter
asyncio.get_event_loop().run_until_complete(asyncio.gather(*futures))
for command in step_config.transportable_commands:
for iens in range(0, ensemble_size):
if storage_config.get("type") == "shared_disk":
Expand All @@ -66,7 +71,9 @@ def _prepare_input(
f"Unsupported transmitter type: {storage_config.get('type')}"
)
with open(command.location, "rb") as f:
transmitter.transmit([f.read()], mime=command.mime)
asyncio.get_event_loop().run_until_complete(
transmitter.transmit([f.read()], mime=command.mime)
)
transmitters[iens][command.name] = transmitter
return dict(transmitters)

Expand Down Expand Up @@ -219,15 +226,22 @@ def _run(ensemble_evaluator):


def _prepare_responses(raw_responses):
data_results = []
async def _load(iens, record_key, transmitter):
record = await transmitter.load()
return (iens, record_key, record)

futures = []
for iens in sorted(raw_responses.keys(), key=int):
real_data = {}
for record, transmitter in raw_responses[iens].items():
real_data[record] = transmitter.load()
data_results.append(real_data)
futures.append(_load(iens, record, transmitter))
results = asyncio.get_event_loop().run_until_complete(asyncio.gather(*futures))

data_results = defaultdict(dict)
for res in results:
data_results[res[0]][res[1]] = res[2]

responses = {response_name: [] for response_name in data_results[0]}
for realization in data_results:
for realization in data_results.values():
assert responses.keys() == realization.keys()
for key in realization:
responses[key].append(realization[key])
Expand Down Expand Up @@ -255,7 +269,6 @@ def evaluate(

ee = EnsembleEvaluator(ensemble=ensemble, config=config, iter_=0)
result = _run(ee)

responses = _prepare_responses(result)

return responses
Expand Down
1 change: 1 addition & 0 deletions ert_shared/ensemble_evaluator/entity/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def _sort_steps(steps: typing.List["_Step"]):
graph[other.get_name()].add(step.get_name())
edged_nodes.add(step.get_name())
edged_nodes.add(other.get_name())

null_nodes = set([step.get_name() for step in steps]) - edged_nodes
[graph[node] for node in null_nodes]
ts = TopologicalSorter(graph)
Expand Down
25 changes: 18 additions & 7 deletions ert_shared/ensemble_evaluator/entity/function_step.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Dict, Optional
from prefect import Task
import prefect
Expand All @@ -13,18 +14,28 @@ def __init__(self, step, output_transmitters, ee_id, *args, **kwargs) -> None:
self._ee_id = ee_id

def _attempt_execute(self, *, func, transmitters):
kwargs = {
input_.get_name(): transmitters[input_.get_name()].load().data
for input_ in self._step.get_inputs()
}
async def _load(io_name, transmitter):
record = await transmitter.load()
return (io_name, record)

futures = []
for input_ in self._step.get_inputs():
futures.append(_load(input_.get_name(), transmitters[input_.get_name()]))
results = asyncio.get_event_loop().run_until_complete(asyncio.gather(*futures))
kwargs = {result[0]: result[1].data for result in results}
function_output = func(**kwargs)

transmitter_map = {}
async def _transmit(io_name, transmitter, data):
await transmitter.transmit(data)
return (io_name, transmitter)

futures = []
for output in self._step.get_outputs():
name = output.get_name()
transmitter = self._output_transmitters[name]
transmitter.transmit(function_output)
transmitter_map[name] = transmitter
futures.append(_transmit(name, transmitter, function_output))
results = asyncio.get_event_loop().run_until_complete(asyncio.gather(*futures))
transmitter_map = {result[0]: result[1] for result in results}
return transmitter_map

def run_job(self, job, transmitters: Dict[str, "RecordTransmitter"], client):
Expand Down
18 changes: 14 additions & 4 deletions ert_shared/ensemble_evaluator/entity/unix_step.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import os
import stat
import subprocess
Expand Down Expand Up @@ -54,11 +55,16 @@ def run_jobs(self, client, run_path):
)

def _load_and_dump_input(self, transmitters, runpath):
futures = []
for input_ in self._step.get_inputs():
# TODO: use Path
transmitters[input_.get_name()].dump(
os.path.join(runpath, input_.get_path()), input_.get_mime()
futures.append(
transmitters[input_.get_name()].dump(
os.path.join(runpath, input_.get_path())
)
)
asyncio.get_event_loop().run_until_complete(asyncio.gather(*futures))
for input_ in self._step.get_inputs():
if input_.is_executable():
path = os.path.join(runpath, input_.get_path())
st = os.stat(path)
Expand All @@ -76,6 +82,7 @@ def run(self, inputs=None):
outputs = {}
self.run_jobs(ee_client, run_path)

futures = []
for output in self._step.get_outputs():
if not os.path.exists(os.path.join(run_path, output.get_path())):
raise FileNotFoundError(
Expand All @@ -85,9 +92,12 @@ def run(self, inputs=None):
outputs[output.get_name()] = self._output_transmitters[
output.get_name()
]
outputs[output.get_name()].transmit(
os.path.join(run_path, output.get_path())
futures.append(
outputs[output.get_name()].transmit(
os.path.join(run_path, output.get_path())
)
)
asyncio.get_event_loop().run_until_complete(asyncio.gather(*futures))

ee_client.send_event(
ev_type=ids.EVTYPE_FM_STEP_SUCCESS,
Expand Down
43 changes: 24 additions & 19 deletions tests/ensemble_evaluator/test_prefect_ensemble.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
from pathlib import Path
import sys
import asyncio
import copy
import os
import os.path
import yaml
import pytest
import sys
import threading
import copy
from collections import defaultdict
from datetime import timedelta
from functools import partial
from itertools import permutations
from prefect import Flow
from prefect.run_configs import LocalRun
from collections import defaultdict
from tests.utils import SOURCE_DIR, tmp
from pathlib import Path

import ert3
import ert_shared.ensemble_evaluator.entity.ensemble as ee
import pytest
import yaml
from ert_shared.ensemble_evaluator.client import Client
from ert_shared.ensemble_evaluator.config import EvaluatorServerConfig
from ert_shared.ensemble_evaluator.entity import identifiers as ids
from ert_shared.ensemble_evaluator.entity.unix_step import UnixTask
from ert_shared.ensemble_evaluator.evaluator import EnsembleEvaluator
from ert_shared.ensemble_evaluator.prefect_ensemble import PrefectEnsemble
import ert_shared.ensemble_evaluator.entity.ensemble as ee
from ert_shared.ensemble_evaluator.entity.unix_step import UnixTask

from ert_shared.ensemble_evaluator.entity import identifiers as ids
from ert_shared.ensemble_evaluator.client import Client
import ert3

from prefect import Flow
from prefect.run_configs import LocalRun
from tests.ensemble_evaluator.conftest import _mock_ws
from tests.utils import SOURCE_DIR, tmp


def parse_config(path):
Expand All @@ -36,7 +36,7 @@ def input_transmitter(name, data, storage_path):
transmitter = ert3.data.SharedDiskRecordTransmitter(
name=name, storage_path=Path(storage_path)
)
transmitter.transmit(data)
asyncio.get_event_loop().run_until_complete(transmitter.transmit(data))
return {name: transmitter}


Expand Down Expand Up @@ -69,7 +69,9 @@ def script_transmitter(name, location, storage_path):
name=name, storage_path=Path(storage_path)
)
with open(location, "rb") as f:
transmitter.transmit([f.read()], mime="application/x-python-code")
asyncio.get_event_loop().run_until_complete(
transmitter.transmit([f.read()], mime="application/x-python-code")
)

return {name: transmitter}

Expand Down Expand Up @@ -291,7 +293,10 @@ def sum_function(values):
expected_uri = output_trans["output"]._uri
output_uri = task_result.result["output"]._uri
assert expected_uri == output_uri
transmitted_result = task_result.result["output"].load().data
transmitted_record = asyncio.get_event_loop().run_until_complete(
task_result.result["output"].load()
)
transmitted_result = transmitted_record.data
expected_result = sum_function(**test_values)
assert expected_result == transmitted_result

Expand Down
Loading

0 comments on commit b82af7b

Please sign in to comment.