Skip to content

Commit

Permalink
PR #200 finetune VectorCube.apply_dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Aug 1, 2023
1 parent 30a422e commit 3c2f0e9
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 33 deletions.
70 changes: 37 additions & 33 deletions openeo_driver/datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@

from openeo.metadata import CollectionMetadata
from openeo.util import ensure_dir, str_truncate
import openeo.udf
from openeo_driver.datastructs import SarBackscatterArgs, ResolutionMergeArgs, StacAsset
from openeo_driver.errors import FeatureUnsupportedException, InternalException
from openeo_driver.util.geometry import GeometryBufferer, validate_geojson_coordinates
from openeo_driver.util.ioformats import IOFORMATS
from openeo_driver.util.pgparsing import SingleRunUDFProcessGraph
from openeo_driver.util.utm import area_in_square_meters
from openeo_driver.utils import EvalEnv
from openeogeotrellis.backend import SingleNodeUDFProcessGraphVisitor

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -248,38 +249,6 @@ def with_cube(self, cube: xarray.DataArray, flatten_prefix: str = FLATTEN_PREFIX
geometries=self._geometries, cube=cube, flatten_prefix=flatten_prefix
)

def apply_dimension(
self,
process: dict,
*,
dimension: str,
target_dimension: Optional[str] = None,
context: Optional[dict] = None,
env: EvalEnv,
) -> "DriverVectorCube":
if dimension == "bands" and target_dimension == None and len(process) == 1 and next(iter(process.values())).get('process_id') == 'run_udf':
visitor = SingleNodeUDFProcessGraphVisitor().accept_process_graph(process)
udf = visitor.udf_args.get('udf', None)

from openeo.udf import FeatureCollection, UdfData
collection = FeatureCollection(id='VectorCollection', data=self._as_geopandas_df())
data = UdfData(
proj={"EPSG": self._geometries.crs.to_epsg()}, feature_collection_list=[collection], user_context=context
)

log.info(f"[run_udf] Running UDF {str_truncate(udf, width=256)!r} on {data!r}")
result_data = env.backend_implementation.processing.run_udf(udf, data)
log.info(f"[run_udf] UDF resulted in {result_data!r}")

if isinstance(result_data, UdfData):
if(result_data.get_feature_collection_list() is not None and len(result_data.get_feature_collection_list()) == 1):
return DriverVectorCube(geometries=result_data.get_feature_collection_list()[0].data)

raise ValueError(f"Could not handle UDF result: {result_data}")

else:
raise FeatureUnsupportedException()

@classmethod
def from_fiona(
cls,
Expand Down Expand Up @@ -537,6 +506,41 @@ def buffer_points(self, distance: float = 10) -> "DriverVectorCube":
]
)

def apply_dimension(
self,
process: dict,
*,
dimension: str,
target_dimension: Optional[str] = None,
context: Optional[dict] = None,
env: EvalEnv,
) -> "DriverVectorCube":
single_run_udf = SingleRunUDFProcessGraph.parse_or_none(process)

if single_run_udf:
# Process with single "run_udf" node
if self._cube is None and dimension == self.DIM_GEOMETRIES and target_dimension is Non:
# TODO: this is non-standard special case: vector cube with only geometries, but no "cube" data
feature_collection = openeo.udf.FeatureCollection(id="_", data=self._as_geopandas_df())
udf_data = openeo.udf.UdfData(
proj={"EPSG": self._geometries.crs.to_epsg()},
feature_collection_list=[feature_collection],
user_context=context,
)
log.info(f"[run_udf] Running UDF {str_truncate(single_run_udf.udf, width=256)!r} on {udf_data!r}")
result_data = env.backend_implementation.processing.run_udf(udf=single_run_udf.udf, data=udf_data)
log.info(f"[run_udf] UDF resulted in {result_data!r}")

if isinstance(result_data, openeo.udf.UdfData):
result_features = result_data.get_feature_collection_list()
if result_features and len(result_features) == 1:
return DriverVectorCube(geometries=result_features[0].data)
raise ValueError(f"Could not handle UDF result: {result_data}")

else:
raise FeatureUnsupportedException()



class DriverMlModel:
"""Base class for driver-side 'ml-model' data structures"""
Expand Down
45 changes: 45 additions & 0 deletions openeo_driver/util/pgparsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import dataclasses
from typing import Optional


class NotASingleRunUDFProcessGraph(ValueError):
pass


@dataclasses.dataclass(frozen=True)
class SingleRunUDFProcessGraph:
"""
Container (and parser) for a callback process graph containing only a single `run_udf` node.
"""

data: dict
udf: str
runtime: str
version: Optional[str] = None
context: Optional[dict] = None

@classmethod
def parse(cls, process_graph: dict) -> "SingleRunUDFProcessGraph":
try:
(node,) = process_graph.values()
assert node["process_id"] == "run_udf"
assert node["result"] is True
arguments = node["arguments"]
assert {"data", "udf", "runtime"}.issubset(arguments.keys())

return cls(
data=arguments["data"],
udf=arguments["udf"],
runtime=arguments["runtime"],
version=arguments.get("version"),
context=arguments.get("context") or {},
)
except Exception as e:
raise NotASingleRunUDFProcessGraph(str(e)) from e

@classmethod
def parse_or_none(cls, process_graph: dict) -> Optional["SingleNodeRunUDFProcessGraph"]:
try:
return cls.parse(process_graph=process_graph)
except NotASingleRunUDFProcessGraph:
return None
78 changes: 78 additions & 0 deletions tests/util/test_pgparsing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pytest

from openeo_driver.util.pgparsing import SingleRunUDFProcessGraph, NotASingleRunUDFProcessGraph


class TestSingleRunUDFProcessGraph:
def test_parse_basic(self):
pg = {
"runudf1": {
"process_id": "run_udf",
"arguments": {
"data": {"from_parameter": "data"},
"udf": "print('Hello world')",
"runtime": "Python",
},
"result": True,
}
}
run_udf = SingleRunUDFProcessGraph.parse(pg)
assert run_udf.data == {"from_parameter": "data"}
assert run_udf.udf == "print('Hello world')"
assert run_udf.runtime == "Python"
assert run_udf.version is None
assert run_udf.context == {}

@pytest.mark.parametrize(
"pg",
[
{
"runudf1": {
"process_id": "run_udffffffffffffffff",
"arguments": {"data": {"from_parameter": "data"}, "udf": "x = 4", "runtime": "Python"},
"result": True,
}
},
{
"runudf1": {
"process_id": "run_udf",
"arguments": {"udf": "x = 4", "runtime": "Python"},
"result": True,
}
},
{
"runudf1": {
"process_id": "run_udf",
"arguments": {"data": {"from_parameter": "data"}, "runtime": "Python"},
"result": True,
}
},
{
"runudf1": {
"process_id": "run_udf",
"arguments": {"data": {"from_parameter": "data"}, "udf": "x = 4"},
"result": True,
}
},
{
"runudf1": {
"process_id": "run_udf",
"arguments": {"data": {"from_parameter": "data"}, "udf": "x = 4", "runtime": "Python"},
}
},
{
"runudf1": {
"process_id": "run_udf",
"arguments": {"data": {"from_parameter": "data"}, "udf": "x = 4", "runtime": "Python"},
"result": True,
},
"runudf2": {
"process_id": "run_udf",
"arguments": {"data": {"from_parameter": "data"}, "udf": "x = 4", "runtime": "Python"},
},
},
],
)
def test_parse_invalid(self, pg):
with pytest.raises(NotASingleRunUDFProcessGraph):
_ = SingleRunUDFProcessGraph.parse(pg)

0 comments on commit 3c2f0e9

Please sign in to comment.