Skip to content

Commit

Permalink
Merge pull request #40 from Open-EO/EP-3399-aggregate-polygon
Browse files Browse the repository at this point in the history
zonal_statistics: add vector_file support for "histogram" and "sd"
  • Loading branch information
soxofaan committed May 19, 2020
2 parents 3ec4b69 + 021fdb8 commit 075d8ea
Show file tree
Hide file tree
Showing 12 changed files with 148 additions and 90 deletions.
139 changes: 57 additions & 82 deletions openeogeotrellis/GeotrellisImageCollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import geopyspark as gps
import numpy as np
import pandas as pd
from py4j.java_gateway import JVMView
import pyproj
import pytz
from geopyspark import TiledRasterLayer, TMS, Pyramid, Tile, SpaceTimeKey, SpatialKey, Metadata
Expand Down Expand Up @@ -55,6 +56,10 @@ def __init__(self, pyramid: Pyramid, service_registry: AbstractServiceRegistry,
# TODO get rid of this _band_index stuff. See https://github.com/Open-EO/openeo-geopyspark-driver/issues/29
self._band_index = 0

def _get_jvm(self) -> JVMView:
# TODO: cache this?
return gps.get_spark_context()._gateway.jvm

def apply_to_levels(self, func):
"""
Applies a function to each level of the pyramid. The argument provided to the function is of type TiledRasterLayer
Expand All @@ -74,7 +79,7 @@ def _apply_to_levels_geotrellis_rdd(self, func):
"""

def create_tilelayer(contextrdd, layer_type, zoom_level):
jvm = gps.get_spark_context()._gateway.jvm
jvm = self._get_jvm()
spatial_tiled_raster_layer = jvm.geopyspark.geotrellis.SpatialTiledRasterLayer
temporal_tiled_raster_layer = jvm.geopyspark.geotrellis.TemporalTiledRasterLayer

Expand Down Expand Up @@ -531,131 +536,101 @@ def timeseries(self, x, y, srs="EPSG:4326") -> Dict:

return result

def zonal_statistics(self, regions, func) -> AggregatePolygonResult:
def zonal_statistics(self, regions: Union[str, GeometryCollection, Polygon, MultiPolygon], func) -> AggregatePolygonResult:
# TODO: rename to aggregate_polygon?
# TODO eliminate code duplication
_log.info("zonal_statistics with {f!r}, {r}".format(f=func, r=type(regions)))

def insert_timezone(instant):
return instant.replace(tzinfo=pytz.UTC) if instant.tzinfo is None else instant

from_vector_file = isinstance(regions, str)
multiple_geometries = from_vector_file or isinstance(regions, GeometryCollection)

if func == 'histogram' or func == 'sd':
if func in ['histogram', 'sd', 'median']:
highest_level = self.pyramid.levels[self.pyramid.max_zoom]
layer_metadata = highest_level.layer_metadata

scala_data_cube = highest_level.srdd.rdd()

polygon_wkts = [str(ob) for ob in regions] if multiple_geometries else [str(regions)]
polygons_srs = 'EPSG:4326'
polygons = self._compute_stats_geotrellis_projected_polygons(regions)
from_date = insert_timezone(layer_metadata.bounds.minKey.instant)
to_date = insert_timezone(layer_metadata.bounds.maxKey.instant)

# TODO also add dumping results first to temp json file like with "mean"
if func == 'histogram':
if multiple_geometries:
implementation = self._compute_stats_geotrellis().compute_histograms_time_series_from_datacube
polygon_wkts = (str(ob) for ob in regions)
else:
implementation = self._compute_stats_geotrellis().compute_histogram_time_series_from_datacube
polygon_wkts = str(regions)
stats = self._compute_stats_geotrellis().compute_histograms_time_series_from_datacube(
scala_data_cube, polygons, from_date.isoformat(), to_date.isoformat(), self._band_index
)
elif func == 'sd':
implementation = self._compute_stats_geotrellis().compute_sd_time_series_from_datacube
else:
raise ValueError(func)

stats = implementation(
scala_data_cube,
polygon_wkts,
polygons_srs,
from_date.isoformat(),
to_date.isoformat(),
self._band_index
)

return AggregatePolygonResult(
timeseries=self._as_python(stats),
regions=regions if multiple_geometries else GeometryCollection([regions]),
)
elif func == 'median':
highest_level = self.pyramid.levels[self.pyramid.max_zoom]
layer_metadata = highest_level.layer_metadata

scala_data_cube = highest_level.srdd.rdd()

from_date = insert_timezone(layer_metadata.bounds.minKey.instant)
to_date = insert_timezone(layer_metadata.bounds.maxKey.instant)

if from_vector_file:
stats = self._compute_stats_geotrellis().compute_median_time_series_from_datacube(
scala_data_cube,
regions,
from_date.isoformat(),
to_date.isoformat(),
self._band_index
stats = self._compute_stats_geotrellis().compute_sd_time_series_from_datacube(
scala_data_cube, polygons, from_date.isoformat(), to_date.isoformat(), self._band_index
)
else:
polygon_wkts = [str(ob) for ob in regions] if multiple_geometries else [str(regions)]
polygons_srs = 'EPSG:4326'

elif func == 'median':
stats = self._compute_stats_geotrellis().compute_median_time_series_from_datacube(
scala_data_cube,
polygon_wkts,
polygons_srs,
from_date.isoformat(),
to_date.isoformat(),
self._band_index
scala_data_cube, polygons, from_date.isoformat(), to_date.isoformat(), self._band_index
)
else:
raise ValueError(func)

return AggregatePolygonResult(
timeseries=self._as_python(stats),
# TODO: regions can also be a string (path to vector file) instead of geometry object
regions=regions if multiple_geometries else GeometryCollection([regions]),
)
else: # defaults to mean, historically
elif func == "mean":
if multiple_geometries:
highest_level = self.pyramid.levels[self.pyramid.max_zoom]
layer_metadata = highest_level.layer_metadata
scala_data_cube = highest_level.srdd.rdd()
polygons_srs = 'EPSG:4326'
polygons = self._compute_stats_geotrellis_projected_polygons(regions)
from_date = insert_timezone(layer_metadata.bounds.minKey.instant)
to_date = insert_timezone(layer_metadata.bounds.maxKey.instant)

with tempfile.NamedTemporaryFile(suffix=".json.tmp") as temp_file:
if from_vector_file:
self._compute_stats_geotrellis().compute_average_timeseries_from_datacube(
scala_data_cube,
regions,
from_date.isoformat(),
to_date.isoformat(),
self._band_index,
temp_file.name
)

else:
self._compute_stats_geotrellis().compute_average_timeseries_from_datacube(
scala_data_cube,
[str(ob) for ob in regions],
polygons_srs,
from_date.isoformat(),
to_date.isoformat(),
self._band_index,
temp_file.name
)

self._compute_stats_geotrellis().compute_average_timeseries_from_datacube(
scala_data_cube,
polygons,
from_date.isoformat(),
to_date.isoformat(),
self._band_index,
temp_file.name
)
with open(temp_file.name, encoding='utf-8') as f:
timeseries = json.load(f)
return AggregatePolygonResult(
timeseries=timeseries,
# TODO: regions can also be a string (path to vector file) instead of geometry object
regions=regions,
)
else:
return AggregatePolygonResult(
timeseries=self.polygonal_mean_timeseries(regions),
regions=GeometryCollection([regions]),
)
else:
raise ValueError(func)

def _compute_stats_geotrellis(self):
jvm = gps.get_spark_context()._gateway.jvm
accumulo_instance_name = 'hdp-accumulo-instance'
return jvm.org.openeo.geotrellis.ComputeStatsGeotrellisAdapter(self._zookeepers(), accumulo_instance_name)
return self._get_jvm().org.openeo.geotrellis.ComputeStatsGeotrellisAdapter(self._zookeepers(), accumulo_instance_name)

def _compute_stats_geotrellis_projected_polygons(self, *args):
"""Construct ProjectedPolygon instance"""
jvm = self._get_jvm()
if len(args) == 1 and isinstance(args[0], (str, pathlib.Path)):
# Vector file
return jvm.org.openeo.geotrellis.ProjectedPolygons.fromVectorFile(str(args[0]))
elif 1 <= len(args) <= 2 and isinstance(args[0], GeometryCollection):
# Multiple polygons
polygon_wkts = [str(x) for x in args[0]]
polygons_srs = args[1] if len(args) >= 2 else 'EPSG:4326'
return jvm.org.openeo.geotrellis.ProjectedPolygons.fromWkt(polygon_wkts, polygons_srs)
elif 1 <= len(args) <= 2 and isinstance(args[0], (Polygon, MultiPolygon)):
# Single polygon
polygon_wkts = [str(args[0])]
polygons_srs = args[1] if len(args) >= 2 else 'EPSG:4326'
return jvm.org.openeo.geotrellis.ProjectedPolygons.fromWkt(polygon_wkts, polygons_srs)
else:
raise ValueError(args)

def _zookeepers(self):
return ','.join(ConfigParams().zookeepernodes)
Expand Down Expand Up @@ -805,7 +780,7 @@ def write_tiff(item):


def _save_stitched(self, spatial_rdd, path, crop_bounds=None):
jvm = gps.get_spark_context()._gateway.jvm
jvm = self._get_jvm()

max_compression = jvm.geotrellis.raster.io.geotiff.compression.DeflateCompression(9)

Expand Down
Empty file added tests/__init__.py
Empty file.
8 changes: 8 additions & 0 deletions tests/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from pathlib import Path

TEST_DATA_ROOT = Path(__file__).parent


def get_test_data_file(path: str) -> Path:
"""Get path of test data by relative path."""
return TEST_DATA_ROOT / path
1 change: 1 addition & 0 deletions tests/data/geometries/polygons01.cpg
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
UTF-8
Binary file added tests/data/geometries/polygons01.dbf
Binary file not shown.
9 changes: 9 additions & 0 deletions tests/data/geometries/polygons01.geojson
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
{
"type": "FeatureCollection",
"name": "polygons01",
"crs": { "type": "name", "properties": { "name": "urn:ogc:def:crs:OGC:1.3:CRS84" } },
"features": [
{ "type": "Feature", "properties": { "id": null, "name": "Area1" }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 0.5, 0.5 ], [ 0.4, 1.0 ], [ 1.0, 1.1 ], [ 2.1, 1.5 ], [ 2.2, 0.4 ], [ 0.5, 0.5 ] ] ] } },
{ "type": "Feature", "properties": { "id": null, "name": "Area2" }, "geometry": { "type": "Polygon", "coordinates": [ [ [ 2.2, 1.0 ], [ 2.3, 2.2 ], [ 1.5, 3.7 ], [ 3.8, 3.8 ], [ 3.7, 1.1 ], [ 2.2, 1.0 ] ] ] } }
]
}
1 change: 1 addition & 0 deletions tests/data/geometries/polygons01.prj
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
GEOGCS["GCS_WGS_1984",DATUM["D_WGS_1984",SPHEROID["WGS_1984",6378137,298.257223563]],PRIMEM["Greenwich",0],UNIT["Degree",0.017453292519943295]]
1 change: 1 addition & 0 deletions tests/data/geometries/polygons01.qpj
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
GEOGCS["WGS 84",DATUM["WGS_1984",SPHEROID["WGS 84",6378137,298.257223563,AUTHORITY["EPSG","7030"]],AUTHORITY["EPSG","6326"]],PRIMEM["Greenwich",0,AUTHORITY["EPSG","8901"]],UNIT["degree",0.0174532925199433,AUTHORITY["EPSG","9122"]],AUTHORITY["EPSG","4326"]]
Binary file added tests/data/geometries/polygons01.shp
Binary file not shown.
Binary file added tests/data/geometries/polygons01.shx
Binary file not shown.
76 changes: 69 additions & 7 deletions tests/test_timeseries.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import datetime
import json
from tempfile import NamedTemporaryFile
from unittest import TestCase

from tempfile import NamedTemporaryFile
import json
from shapely.geometry import mapping
import geopyspark as gps
import numpy as np
import pytz
from geopyspark.geotrellis import (SpaceTimeKey, Tile, _convert_to_unix_time)
from geopyspark.geotrellis.constants import LayerType
from geopyspark.geotrellis.layer import TiledRasterLayer
import numpy as np
from pyspark import SparkContext
from shapely.geometry import Point
from shapely.geometry import Polygon, GeometryCollection, MultiPolygon
import pytest
import pytz
from shapely.geometry import mapping, Point, Polygon, GeometryCollection, MultiPolygon, box

from openeogeotrellis.GeotrellisImageCollection import GeotrellisTimeSeriesImageCollection
from openeogeotrellis.service_registry import InMemoryServiceRegistry
from .data import get_test_data_file


class TestTimeSeries(TestCase):
Expand Down Expand Up @@ -261,3 +261,65 @@ def test_zonal_statistics_for_unsigned_byte_layer(self):
"values": [220.0]
}
}


def _build_cube():
# TODO: avoid instantiating TestTimeSeries? e.g. use pytest fixtures or simple builder functions.
layer = TestTimeSeries().create_spacetime_layer()
cube = GeotrellisTimeSeriesImageCollection(gps.Pyramid({0: layer}), InMemoryServiceRegistry())
return cube


@pytest.mark.parametrize(["func", "expected"], [
("mean", {'2017-09-25T11:37:00': [[1.0, 2.0]]}),
("median", {'2017-09-25T11:37:00Z': [[1.0, 2.0]]}),
("histogram", {'2017-09-25T11:37:00Z': [[{1.0: 4}, {2.0: 4}]]}),
("sd", {'2017-09-25T11:37:00Z': [[0.0, 0.0]]})
])
def test_zonal_statistics_single_polygon(func, expected):
cube = _build_cube()
polygon = box(0.0, 0.0, 1.0, 1.0)
result = cube.zonal_statistics(polygon, func=func)
assert result.data == expected


@pytest.mark.parametrize(["func", "expected"], [
("mean", {'2017-09-25T11:37:00Z': [[1.0, 2.0], [1.0, 2.0]]}),
("median", {'2017-09-25T11:37:00Z': [[1.0, 2.0], [1.0, 2.0]]}),
("histogram", {'2017-09-25T11:37:00Z': [[{1.0: 4}, {2.0: 4}], [{1.0: 23}, {2.0: 23}]]}),
("sd", {'2017-09-25T11:37:00Z': [[0.0, 0.0], [0.0, 0.0]]})
])
def test_zonal_statistics_geometry_collection(func, expected):
cube = _build_cube()
geometry = GeometryCollection([
box(0.5, 0.5, 1.5, 1.5),
MultiPolygon([box(2.0, 0.5, 4.0, 1.5), box(1.5, 2, 4.0, 3.5)])
])
result = cube.zonal_statistics(geometry, func=func)
assert result.data == expected


@pytest.mark.parametrize(["func", "expected"], [
("mean", {'2017-09-25T11:37:00Z': [[1.0, 2.0], [1.0, 2.0]]}),
("median", {'2017-09-25T11:37:00Z': [[1.0, 2.0], [1.0, 2.0]]}),
("histogram", {'2017-09-25T11:37:00Z': [[{1.0: 4}, {2.0: 4}], [{1.0: 19}, {2.0: 19}]]}),
("sd", {'2017-09-25T11:37:00Z': [[0.0, 0.0], [0.0, 0.0]]})
])
def test_zonal_statistics_shapefile(func, expected):
cube = _build_cube()
shapefile = str(get_test_data_file("geometries/polygons01.shp"))
result = cube.zonal_statistics(regions=shapefile, func=func)
assert result.data == expected


@pytest.mark.parametrize(["func", "expected"], [
("mean", {'2017-09-25T11:37:00Z': [[1.0, 2.0], [1.0, 2.0]]}),
("median", {'2017-09-25T11:37:00Z': [[1.0, 2.0], [1.0, 2.0]]}),
("histogram", {'2017-09-25T11:37:00Z': [[{1.0: 4}, {2.0: 4}], [{1.0: 19}, {2.0: 19}]]}),
("sd", {'2017-09-25T11:37:00Z': [[0.0, 0.0], [0.0, 0.0]]})
])
def test_zonal_statistics_geojson(func, expected):
cube = _build_cube()
shapefile = str(get_test_data_file("geometries/polygons01.geojson"))
result = cube.zonal_statistics(regions=shapefile, func=func)
assert result.data == expected
3 changes: 2 additions & 1 deletion tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from openeogeotrellis.backend import GpsBatchJobs
import openeogeotrellis.job_registry
from openeogeotrellis.testing import KazooClientMock
from .data import TEST_DATA_ROOT


@pytest.fixture(params=["0.4.0", "1.0.0"])
Expand All @@ -30,7 +31,7 @@ def client():


class ApiTester(openeo_driver.testing.ApiTester):
data_root = Path(__file__).parent / "data"
data_root = TEST_DATA_ROOT


@pytest.fixture
Expand Down

0 comments on commit 075d8ea

Please sign in to comment.