Skip to content

Commit

Permalink
Improve auto-adding of save_result (#623, #401, #583, #391)
Browse files Browse the repository at this point in the history
- Check whole process graph for pre-existing `save_result` nodes, not just final node
- Disallow ambiguity of combining explicit `save_result` and download/create_job with format
  • Loading branch information
soxofaan committed Sep 24, 2024
1 parent 26bef79 commit 5aa28fb
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 55 deletions.
20 changes: 19 additions & 1 deletion openeo/internal/graph_building.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import sys
from contextlib import nullcontext
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Iterator, Optional, Tuple, Union

from openeo.api.process import Parameter
from openeo.internal.process_graph_visitor import (
Expand Down Expand Up @@ -225,6 +225,24 @@ def from_flat_graph(flat_graph: dict, parameters: Optional[dict] = None) -> PGNo
return PGNodeGraphUnflattener.unflatten(flat_graph=flat_graph, parameters=parameters)


def walk_nodes(self) -> Iterator[PGNode]:
"""Walk this node and all it's parents"""
# TODO: option to do deep walk (walk through child graphs too)?
yield self

def walk(x) -> Iterator[PGNode]:
if isinstance(x, PGNode):
yield from x.walk_nodes()
elif isinstance(x, dict):
for v in x.values():
yield from walk(v)
elif isinstance(x, (list, tuple)):
for v in x:
yield from walk(v)

yield from walk(self.arguments)


def as_flat_graph(x: Union[dict, FlatGraphableMixin, Path, Any]) -> Dict[str, dict]:
"""
Convert given object to a internal flat dict graph representation.
Expand Down
43 changes: 21 additions & 22 deletions openeo/rest/datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -2095,10 +2095,16 @@ def save_result(
}
)

def _get_save_result_nodes(self) -> List[PGNode]:
return [n for n in self.result_node().walk_nodes() if n.process_id == "save_result"]

def _ensure_save_result(
self,
*,
format: Optional[str] = None,
options: Optional[dict] = None,
weak_format: Optional[str] = None,
method: str,
) -> DataCube:
"""
Make sure there is a (final) `save_result` node in the process graph.
Expand All @@ -2110,25 +2116,19 @@ def _ensure_save_result(
:return:
"""
# TODO #401 Unify with VectorCube._ensure_save_result and move to generic data cube parent class (not only for raster cubes, but also vector cubes)
result_node = self.result_node()
if result_node.process_id == "save_result":
# There is already a `save_result` node:
# check if it is consistent with given format/options (if any)
args = result_node.arguments
if format is not None and format.lower() != args["format"].lower():
raise ValueError(
f"Existing `save_result` node with different format {args['format']!r} != {format!r}"
)
if options is not None and options != args["options"]:
raise ValueError(
f"Existing `save_result` node with different options {args['options']!r} != {options!r}"
)
cube = self
else:
save_result_nodes = self._get_save_result_nodes()

cube = self
if not save_result_nodes:
# No `save_result` node yet: automatically add it.
cube = self.save_result(
format=format or self._DEFAULT_RASTER_FORMAT, options=options
cube = cube.save_result(format=format or weak_format or self._DEFAULT_RASTER_FORMAT, options=options)
elif format or options:
raise OpenEoClientException(
f"{method} with explicit output {'format' if format else 'options'} {format or options!r},"
f" but the process graph already has `save_result` node(s)"
f" which is ambiguous and should not be combined."
)

return cube

def download(
Expand All @@ -2152,10 +2152,9 @@ def download(
(overruling the connection's ``auto_validate`` setting).
:return: None if the result is stored to disk, or a bytes object returned by the backend.
"""
if format is None and outputfile:
# TODO #401/#449 don't guess/override format if there is already a save_result with format?
format = guess_format(outputfile)
cube = self._ensure_save_result(format=format, options=options)
# TODO #401/#449 don't guess/override format if there is already a save_result with format?
weak_format = guess_format(outputfile) if outputfile else None
cube = self._ensure_save_result(format=format, options=options, weak_format=weak_format, method="Download")
return self._connection.download(cube.flat_graph(), outputfile, validate=validate)

def validate(self) -> List[dict]:
Expand Down Expand Up @@ -2321,7 +2320,7 @@ def create_job(
# TODO: add option to also automatically start the job?
# TODO: avoid using all kwargs as format_options
# TODO: centralize `create_job` for `DataCube`, `VectorCube`, `MlModel`, ...
cube = self._ensure_save_result(format=out_format, options=format_options or None)
cube = self._ensure_save_result(format=out_format, options=format_options or None, method="Creating job")
return self._connection.create_job(
process_graph=cube.flat_graph(),
title=title,
Expand Down
33 changes: 33 additions & 0 deletions tests/internal/test_graphbuilding.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,3 +379,36 @@ def test_parameter_substitution_undefined(self):
}
with pytest.raises(ProcessGraphVisitException, match="No substitution value for parameter 'increment'"):
_ = PGNodeGraphUnflattener.unflatten(flat_graph, parameters={"other": 100})


def test_walk_nodes_basic():
node = PGNode("foo")
walk = node.walk_nodes()
assert next(walk) is node
with pytest.raises(StopIteration):
next(walk)


def test_walk_nodes_args():
data = PGNode("load")
geometry = PGNode("vector")
node = PGNode("foo", data=data, geometry=geometry)

walk = node.walk_nodes()
assert next(walk) is node
rest = list(walk)
assert rest == [data, geometry] or rest == [geometry, data]


def test_walk_nodes_nested():
node = PGNode(
"foo",
cubes=[PGNode("load1"), PGNode("load2")],
size={
"x": PGNode("add", x=PGNode("five"), y=3),
"y": PGNode("max"),
},
)
walk = list(node.walk_nodes())
assert all(isinstance(n, PGNode) for n in walk)
assert set(n.process_id for n in walk) == {"load1", "max", "foo", "load2", "add", "five"}
56 changes: 41 additions & 15 deletions tests/rest/datacube/test_datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import contextlib
import pathlib
import re
from datetime import date, datetime
from unittest import mock

Expand Down Expand Up @@ -710,21 +711,25 @@ def test_create_job_out_format(
@pytest.mark.parametrize(
["save_result_format", "execute_format", "expected"],
[
("GTiff", "GTiff", "GTiff"),
(None, None, "GTiff"),
(None, "GTiff", "GTiff"),
("GTiff", None, "GTiff"),
("NetCDF", "NetCDF", "NetCDF"),
(None, "NetCDF", "NetCDF"),
("NetCDF", None, "NetCDF"),
],
)
def test_create_job_existing_save_result(
def test_save_result_and_create_job_at_most_one_with_format(
self,
s2cube,
get_create_job_pg,
save_result_format,
execute_format,
expected,
):
cube = s2cube.save_result(format=save_result_format)
cube = s2cube
if save_result_format:
cube = cube.save_result(format=save_result_format)

cube.create_job(out_format=execute_format)
pg = get_create_job_pg()
assert set(pg.keys()) == {"loadcollection1", "saveresult1"}
Expand All @@ -740,13 +745,21 @@ def test_create_job_existing_save_result(

@pytest.mark.parametrize(
["save_result_format", "execute_format"],
[("NetCDF", "GTiff"), ("GTiff", "NetCDF")],
[
("NetCDF", "NetCDF"),
("GTiff", "NetCDF"),
],
)
def test_create_job_existing_save_result_incompatible(
self, s2cube, save_result_format, execute_format
):
def test_save_result_and_create_job_both_with_format(self, s2cube, save_result_format, execute_format):
cube = s2cube.save_result(format=save_result_format)
with pytest.raises(ValueError):
with pytest.raises(
OpenEoClientException,
match=re.escape(
"Creating job with explicit output format 'NetCDF',"
" but the process graph already has `save_result` node(s)"
" which is ambiguous and should not be combined."
),
):
cube.create_job(out_format=execute_format)

def test_execute_batch_defaults(self, s2cube, get_create_job_pg, recwarn, caplog):
Expand Down Expand Up @@ -808,21 +821,24 @@ def test_execute_batch_out_format_from_output_file(
@pytest.mark.parametrize(
["save_result_format", "execute_format", "expected"],
[
("GTiff", "GTiff", "GTiff"),
(None, None, "GTiff"),
(None, "GTiff", "GTiff"),
("GTiff", None, "GTiff"),
("NetCDF", "NetCDF", "NetCDF"),
(None, "NetCDF", "NetCDF"),
("NetCDF", None, "NetCDF"),
],
)
def test_execute_batch_existing_save_result(
def test_save_result_and_execute_batch_at_most_one_with_format(
self,
s2cube,
get_create_job_pg,
save_result_format,
execute_format,
expected,
):
cube = s2cube.save_result(format=save_result_format)
cube = s2cube
if save_result_format:
cube = cube.save_result(format=save_result_format)
cube.execute_batch(out_format=execute_format)
pg = get_create_job_pg()
assert set(pg.keys()) == {"loadcollection1", "saveresult1"}
Expand All @@ -838,13 +854,23 @@ def test_execute_batch_existing_save_result(

@pytest.mark.parametrize(
["save_result_format", "execute_format"],
[("NetCDF", "GTiff"), ("GTiff", "NetCDF")],
[
("NetCDF", "NetCDF"),
("GTiff", "NetCDF"),
],
)
def test_execute_batch_existing_save_result_incompatible(
self, s2cube, save_result_format, execute_format
):
cube = s2cube.save_result(format=save_result_format)
with pytest.raises(ValueError):
with pytest.raises(
OpenEoClientException,
match=re.escape(
"Creating job with explicit output format 'NetCDF',"
" but the process graph already has `save_result` node(s)"
" which is ambiguous and should not be combined."
),
):
cube.execute_batch(out_format=execute_format)

def test_save_result_format_options_vs_create_job(elf, s2cube, get_create_job_pg):
Expand Down
33 changes: 16 additions & 17 deletions tests/rest/datacube/test_datacube100.py
Original file line number Diff line number Diff line change
Expand Up @@ -3257,50 +3257,49 @@ def test_apply_append_math_keep_context(con100):
({}, "result.nc", {}, b"this is netCDF data"),
({"format": "GTiff"}, "result.tiff", {}, b"this is GTiff data"),
({"format": "GTiff"}, "result.tif", {}, b"this is GTiff data"),
(
{"format": "GTiff"},
"result.nc",
{},
ValueError(
"Existing `save_result` node with different format 'GTiff' != 'netCDF'"
),
),
({"format": "GTiff"}, "result.nc", {}, b"this is GTiff data"),
({}, "result.tiff", {"format": "GTiff"}, b"this is GTiff data"),
({}, "result.nc", {"format": "netCDF"}, b"this is netCDF data"),
({}, "result.meh", {"format": "netCDF"}, b"this is netCDF data"),
(
{"format": "GTiff"},
"result.tiff",
{"format": "GTiff"},
b"this is GTiff data",
OpenEoClientException(
"Download with explicit output format 'GTiff', but the process graph already has `save_result` node(s) which is ambiguous and should not be combined."
),
),
(
{"format": "netCDF"},
"result.tiff",
{"format": "NETCDF"},
b"this is netCDF data",
OpenEoClientException(
"Download with explicit output format 'NETCDF', but the process graph already has `save_result` node(s) which is ambiguous and should not be combined."
),
),
(
{"format": "netCDF"},
"result.json",
{"format": "JSON"},
ValueError(
"Existing `save_result` node with different format 'netCDF' != 'JSON'"
OpenEoClientException(
"Download with explicit output format 'JSON', but the process graph already has `save_result` node(s) which is ambiguous and should not be combined."
),
),
({"options": {}}, "result.tiff", {}, b"this is GTiff data"),
(
{"options": {"quality": "low"}},
"result.tiff",
{"options": {"quality": "low"}},
b"this is GTiff data",
OpenEoClientException(
"Download with explicit output options {'quality': 'low'}, but the process graph already has `save_result` node(s) which is ambiguous and should not be combined."
),
),
(
{"options": {"colormap": "jet"}},
"result.tiff",
{"options": {"quality": "low"}},
ValueError(
"Existing `save_result` node with different options {'colormap': 'jet'} != {'quality': 'low'}"
OpenEoClientException(
"Download with explicit output options {'quality': 'low'}, but the process graph already has `save_result` node(s) which is ambiguous and should not be combined."
),
),
],
Expand Down Expand Up @@ -3328,8 +3327,8 @@ def post_result(request, context):
cube = cube.save_result(**save_result_kwargs)

path = tmp_path / download_filename
if isinstance(expected, ValueError):
with pytest.raises(ValueError, match=str(expected)):
if isinstance(expected, Exception):
with pytest.raises(type(expected), match=re.escape(str(expected))):
cube.download(str(path), **download_kwargs)
assert post_result_mock.call_count == 0
else:
Expand Down

0 comments on commit 5aa28fb

Please sign in to comment.