diff --git a/openeo/internal/graph_building.py b/openeo/internal/graph_building.py index d92a496b7..1f1c6211a 100644 --- a/openeo/internal/graph_building.py +++ b/openeo/internal/graph_building.py @@ -14,7 +14,7 @@ import sys from contextlib import nullcontext from pathlib import Path -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Iterator, Optional, Tuple, Union from openeo.api.process import Parameter from openeo.internal.process_graph_visitor import ( @@ -225,6 +225,24 @@ def from_flat_graph(flat_graph: dict, parameters: Optional[dict] = None) -> PGNo return PGNodeGraphUnflattener.unflatten(flat_graph=flat_graph, parameters=parameters) + def walk_nodes(self) -> Iterator[PGNode]: + """Walk this node and all it's parents""" + # TODO: option to do deep walk (walk through child graphs too)? + yield self + + def walk(x) -> Iterator[PGNode]: + if isinstance(x, PGNode): + yield from x.walk_nodes() + elif isinstance(x, dict): + for v in x.values(): + yield from walk(v) + elif isinstance(x, (list, tuple)): + for v in x: + yield from walk(v) + + yield from walk(self.arguments) + + def as_flat_graph(x: Union[dict, FlatGraphableMixin, Path, Any]) -> Dict[str, dict]: """ Convert given object to a internal flat dict graph representation. diff --git a/openeo/rest/datacube.py b/openeo/rest/datacube.py index d90db5c7e..da19e6e84 100644 --- a/openeo/rest/datacube.py +++ b/openeo/rest/datacube.py @@ -2097,8 +2097,11 @@ def save_result( def _ensure_save_result( self, + *, format: Optional[str] = None, options: Optional[dict] = None, + weak_format: Optional[str] = None, + method: str, ) -> DataCube: """ Make sure there is a (final) `save_result` node in the process graph. @@ -2110,25 +2113,19 @@ def _ensure_save_result( :return: """ # TODO #401 Unify with VectorCube._ensure_save_result and move to generic data cube parent class (not only for raster cubes, but also vector cubes) - result_node = self.result_node() - if result_node.process_id == "save_result": - # There is already a `save_result` node: - # check if it is consistent with given format/options (if any) - args = result_node.arguments - if format is not None and format.lower() != args["format"].lower(): - raise ValueError( - f"Existing `save_result` node with different format {args['format']!r} != {format!r}" - ) - if options is not None and options != args["options"]: - raise ValueError( - f"Existing `save_result` node with different options {args['options']!r} != {options!r}" - ) - cube = self - else: + save_result_nodes = [n for n in self.result_node().walk_nodes() if n.process_id == "save_result"] + + cube = self + if not save_result_nodes: # No `save_result` node yet: automatically add it. - cube = self.save_result( - format=format or self._DEFAULT_RASTER_FORMAT, options=options + cube = cube.save_result(format=format or weak_format or self._DEFAULT_RASTER_FORMAT, options=options) + elif format or options: + raise OpenEoClientException( + f"{method} with explicit output {'format' if format else 'options'} {format or options!r}," + f" but the process graph already has `save_result` node(s)" + f" which is ambiguous and should not be combined." ) + return cube def download( @@ -2152,10 +2149,8 @@ def download( (overruling the connection's ``auto_validate`` setting). :return: None if the result is stored to disk, or a bytes object returned by the backend. """ - if format is None and outputfile: - # TODO #401/#449 don't guess/override format if there is already a save_result with format? - format = guess_format(outputfile) - cube = self._ensure_save_result(format=format, options=options) + weak_format = guess_format(outputfile) if outputfile else None + cube = self._ensure_save_result(format=format, options=options, weak_format=weak_format, method="Download") return self._connection.download(cube.flat_graph(), outputfile, validate=validate) def validate(self) -> List[dict]: @@ -2321,7 +2316,7 @@ def create_job( # TODO: add option to also automatically start the job? # TODO: avoid using all kwargs as format_options # TODO: centralize `create_job` for `DataCube`, `VectorCube`, `MlModel`, ... - cube = self._ensure_save_result(format=out_format, options=format_options or None) + cube = self._ensure_save_result(format=out_format, options=format_options or None, method="Creating job") return self._connection.create_job( process_graph=cube.flat_graph(), title=title, diff --git a/tests/internal/test_graphbuilding.py b/tests/internal/test_graphbuilding.py index f9d8d0acd..1bdf56015 100644 --- a/tests/internal/test_graphbuilding.py +++ b/tests/internal/test_graphbuilding.py @@ -379,3 +379,36 @@ def test_parameter_substitution_undefined(self): } with pytest.raises(ProcessGraphVisitException, match="No substitution value for parameter 'increment'"): _ = PGNodeGraphUnflattener.unflatten(flat_graph, parameters={"other": 100}) + + +def test_walk_nodes_basic(): + node = PGNode("foo") + walk = node.walk_nodes() + assert next(walk) is node + with pytest.raises(StopIteration): + next(walk) + + +def test_walk_nodes_args(): + data = PGNode("load") + geometry = PGNode("vector") + node = PGNode("foo", data=data, geometry=geometry) + + walk = node.walk_nodes() + assert next(walk) is node + rest = list(walk) + assert rest == [data, geometry] or rest == [geometry, data] + + +def test_walk_nodes_nested(): + node = PGNode( + "foo", + cubes=[PGNode("load1"), PGNode("load2")], + size={ + "x": PGNode("add", x=PGNode("five"), y=3), + "y": PGNode("max"), + }, + ) + walk = list(node.walk_nodes()) + assert all(isinstance(n, PGNode) for n in walk) + assert set(n.process_id for n in walk) == {"load1", "max", "foo", "load2", "add", "five"} diff --git a/tests/rest/datacube/test_datacube.py b/tests/rest/datacube/test_datacube.py index 33b32dd73..167180ce8 100644 --- a/tests/rest/datacube/test_datacube.py +++ b/tests/rest/datacube/test_datacube.py @@ -7,6 +7,7 @@ import contextlib import pathlib +import re from datetime import date, datetime from unittest import mock @@ -710,13 +711,14 @@ def test_create_job_out_format( @pytest.mark.parametrize( ["save_result_format", "execute_format", "expected"], [ - ("GTiff", "GTiff", "GTiff"), + (None, None, "GTiff"), + (None, "GTiff", "GTiff"), ("GTiff", None, "GTiff"), - ("NetCDF", "NetCDF", "NetCDF"), + (None, "NetCDF", "NetCDF"), ("NetCDF", None, "NetCDF"), ], ) - def test_create_job_existing_save_result( + def test_save_result_and_create_job_at_most_one_with_format( self, s2cube, get_create_job_pg, @@ -724,7 +726,10 @@ def test_create_job_existing_save_result( execute_format, expected, ): - cube = s2cube.save_result(format=save_result_format) + cube = s2cube + if save_result_format: + cube = cube.save_result(format=save_result_format) + cube.create_job(out_format=execute_format) pg = get_create_job_pg() assert set(pg.keys()) == {"loadcollection1", "saveresult1"} @@ -740,13 +745,21 @@ def test_create_job_existing_save_result( @pytest.mark.parametrize( ["save_result_format", "execute_format"], - [("NetCDF", "GTiff"), ("GTiff", "NetCDF")], + [ + ("NetCDF", "NetCDF"), + ("GTiff", "NetCDF"), + ], ) - def test_create_job_existing_save_result_incompatible( - self, s2cube, save_result_format, execute_format - ): + def test_save_result_and_create_job_both_with_format(self, s2cube, save_result_format, execute_format): cube = s2cube.save_result(format=save_result_format) - with pytest.raises(ValueError): + with pytest.raises( + OpenEoClientException, + match=re.escape( + "Creating job with explicit output format 'NetCDF'," + " but the process graph already has `save_result` node(s)" + " which is ambiguous and should not be combined." + ), + ): cube.create_job(out_format=execute_format) def test_execute_batch_defaults(self, s2cube, get_create_job_pg, recwarn, caplog): @@ -808,13 +821,14 @@ def test_execute_batch_out_format_from_output_file( @pytest.mark.parametrize( ["save_result_format", "execute_format", "expected"], [ - ("GTiff", "GTiff", "GTiff"), + (None, None, "GTiff"), + (None, "GTiff", "GTiff"), ("GTiff", None, "GTiff"), - ("NetCDF", "NetCDF", "NetCDF"), + (None, "NetCDF", "NetCDF"), ("NetCDF", None, "NetCDF"), ], ) - def test_execute_batch_existing_save_result( + def test_save_result_and_execute_batch_at_most_one_with_format( self, s2cube, get_create_job_pg, @@ -822,7 +836,9 @@ def test_execute_batch_existing_save_result( execute_format, expected, ): - cube = s2cube.save_result(format=save_result_format) + cube = s2cube + if save_result_format: + cube = cube.save_result(format=save_result_format) cube.execute_batch(out_format=execute_format) pg = get_create_job_pg() assert set(pg.keys()) == {"loadcollection1", "saveresult1"} @@ -838,13 +854,23 @@ def test_execute_batch_existing_save_result( @pytest.mark.parametrize( ["save_result_format", "execute_format"], - [("NetCDF", "GTiff"), ("GTiff", "NetCDF")], + [ + ("NetCDF", "NetCDF"), + ("GTiff", "NetCDF"), + ], ) def test_execute_batch_existing_save_result_incompatible( self, s2cube, save_result_format, execute_format ): cube = s2cube.save_result(format=save_result_format) - with pytest.raises(ValueError): + with pytest.raises( + OpenEoClientException, + match=re.escape( + "Creating job with explicit output format 'NetCDF'," + " but the process graph already has `save_result` node(s)" + " which is ambiguous and should not be combined." + ), + ): cube.execute_batch(out_format=execute_format) def test_save_result_format_options_vs_create_job(elf, s2cube, get_create_job_pg): diff --git a/tests/rest/datacube/test_datacube100.py b/tests/rest/datacube/test_datacube100.py index c3eabe797..72683b83c 100644 --- a/tests/rest/datacube/test_datacube100.py +++ b/tests/rest/datacube/test_datacube100.py @@ -3257,14 +3257,7 @@ def test_apply_append_math_keep_context(con100): ({}, "result.nc", {}, b"this is netCDF data"), ({"format": "GTiff"}, "result.tiff", {}, b"this is GTiff data"), ({"format": "GTiff"}, "result.tif", {}, b"this is GTiff data"), - ( - {"format": "GTiff"}, - "result.nc", - {}, - ValueError( - "Existing `save_result` node with different format 'GTiff' != 'netCDF'" - ), - ), + ({"format": "GTiff"}, "result.nc", {}, b"this is GTiff data"), ({}, "result.tiff", {"format": "GTiff"}, b"this is GTiff data"), ({}, "result.nc", {"format": "netCDF"}, b"this is netCDF data"), ({}, "result.meh", {"format": "netCDF"}, b"this is netCDF data"), @@ -3272,20 +3265,24 @@ def test_apply_append_math_keep_context(con100): {"format": "GTiff"}, "result.tiff", {"format": "GTiff"}, - b"this is GTiff data", + OpenEoClientException( + "Download with explicit output format 'GTiff', but the process graph already has `save_result` node(s) which is ambiguous and should not be combined." + ), ), ( {"format": "netCDF"}, "result.tiff", {"format": "NETCDF"}, - b"this is netCDF data", + OpenEoClientException( + "Download with explicit output format 'NETCDF', but the process graph already has `save_result` node(s) which is ambiguous and should not be combined." + ), ), ( {"format": "netCDF"}, "result.json", {"format": "JSON"}, - ValueError( - "Existing `save_result` node with different format 'netCDF' != 'JSON'" + OpenEoClientException( + "Download with explicit output format 'JSON', but the process graph already has `save_result` node(s) which is ambiguous and should not be combined." ), ), ({"options": {}}, "result.tiff", {}, b"this is GTiff data"), @@ -3293,14 +3290,16 @@ def test_apply_append_math_keep_context(con100): {"options": {"quality": "low"}}, "result.tiff", {"options": {"quality": "low"}}, - b"this is GTiff data", + OpenEoClientException( + "Download with explicit output options {'quality': 'low'}, but the process graph already has `save_result` node(s) which is ambiguous and should not be combined." + ), ), ( {"options": {"colormap": "jet"}}, "result.tiff", {"options": {"quality": "low"}}, - ValueError( - "Existing `save_result` node with different options {'colormap': 'jet'} != {'quality': 'low'}" + OpenEoClientException( + "Download with explicit output options {'quality': 'low'}, but the process graph already has `save_result` node(s) which is ambiguous and should not be combined." ), ), ], @@ -3328,8 +3327,8 @@ def post_result(request, context): cube = cube.save_result(**save_result_kwargs) path = tmp_path / download_filename - if isinstance(expected, ValueError): - with pytest.raises(ValueError, match=str(expected)): + if isinstance(expected, Exception): + with pytest.raises(type(expected), match=re.escape(str(expected))): cube.download(str(path), **download_kwargs) assert post_result_mock.call_count == 0 else: