Skip to content

Commit

Permalink
PoC load_stac implementation (very ad-hoc) #402
Browse files Browse the repository at this point in the history
Download this cube to test it:

data_cube = (connection
             .load_stac(url="https://tamn.snapplanet.io/collections/S2",
                        spatial_extent={"west": -87.83465281740789, "south": 42.57836607418331, "east": -87.80890361086492, "north": 42.59100512331456},
                        temporal_extent=["2022-05-10", "2022-05-10"],
                        bands=["B04", "B03", "B02"])
             .save_result("GTiff"))
  • Loading branch information
bossie committed May 23, 2023
1 parent b5ea93e commit fc8df8b
Showing 1 changed file with 101 additions and 0 deletions.
101 changes: 101 additions & 0 deletions openeogeotrellis/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,6 +812,107 @@ def intersects_spatial_extent(item) -> bool:

return cube

def load_stac(self, url: str, load_params: LoadParameters, env: EvalEnv) -> GeopysparkDataCube:
logger.info("load_stac from url {u!r} with load params {p!r}".format(u=url, p=load_params))

# current assumption: url points to a STAC API Collection
collection = pystac.Collection.from_file(href=url)
collection_id = collection.id

root_catalog = collection.get_root()
# FIXME: check conformsTo item-search (TBC)
search_url = root_catalog.get_single_link(rel="search").href
self_url = root_catalog.get_self_href()

# TODO: instantiate PyramidFactory(FileLayerProvider(STACClient(self_url)))
jvm = get_jvm()

is_utm = False
date_regex = None
bands = None
stac_api_client = jvm.org.openeo.opensearch.OpenSearchClient.apply(self_url, is_utm, date_regex, bands, "stac") # FIXME: pass search_url iso/ self_url

root_path = None
cell_size = jvm.geotrellis.raster.CellSize(10.0, 10.0) # FIXME: get it from the band metadata?
experimental = False
pyramid_factory = jvm.org.openeo.geotrellis.file.PyramidFactory(stac_api_client,
collection_id,
load_params.bands,
root_path,
cell_size,
experimental)

single_level = env.get('pyramid_levels', 'all') != 'all'

if single_level:
requested_bbox = BoundingBox.from_dict_or_none(
load_params.spatial_extent, default_crs="EPSG:4326"
)
collection_bbox = BoundingBox.from_wsen_tuple(
collection.extent.spatial.bboxes[0], crs="EPSG:4326"
)

target_bbox = requested_bbox or collection_bbox
target_epsg = target_bbox.best_utm()

extent = jvm.geotrellis.vector.Extent(*target_bbox.as_wsen_tuple())
extent_crs = target_bbox.crs

projected_polygons = jvm.org.openeo.geotrellis.ProjectedPolygons.fromExtent(
extent, target_bbox.crs
)
projected_polygons = getattr(
getattr(jvm.org.openeo.geotrellis, "ProjectedPolygons$"), "MODULE$"
).reproject(projected_polygons, target_epsg)

temporal_extent = load_params.temporal_extent
from_date, to_date = normalize_temporal_extent(temporal_extent)

metadata_properties = {}
correlation_id = env.get('correlation_id', '')

data_cube_parameters = jvm.org.openeo.geotrelliscommon.DataCubeParameters()
getattr(data_cube_parameters, "layoutScheme_$eq")("FloatingLayoutScheme")

pyramid = pyramid_factory.datacube_seq(
projected_polygons, from_date, to_date,
metadata_properties, correlation_id, data_cube_parameters
)
else:
raise NotImplementedError("pyramid")

band_names = [b["name"] for b in collection.extra_fields["summaries"]["eo:bands"]]

metadata = GeopysparkCubeMetadata(metadata={}, dimensions=[
# TODO: detect actual dimensions instead of this simple default?
SpatialDimension(name="x", extent=[]), SpatialDimension(name="y", extent=[]),
TemporalDimension(name='t', extent=[]),
BandDimension(name="bands", bands=[Band(band_name) for band_name in band_names])
])

metadata = metadata.filter_temporal(from_date, to_date)

metadata = metadata.filter_bbox(
west=extent.xmin(),
south=extent.ymin(),
east=extent.xmax(),
north=extent.ymax(),
crs=extent_crs,
)

if load_params.bands:
metadata = metadata.filter_bands(load_params.bands)

temporal_tiled_raster_layer = jvm.geopyspark.geotrellis.TemporalTiledRasterLayer
option = jvm.scala.Option

# noinspection PyProtectedMember
levels = {pyramid.apply(index)._1(): TiledRasterLayer(LayerType.SPACETIME, temporal_tiled_raster_layer(
option.apply(pyramid.apply(index)._1()), pyramid.apply(index)._2())) for index in
range(0, pyramid.size())}

return GeopysparkDataCube(pyramid=gps.Pyramid(levels), metadata=metadata)

def load_ml_model(self, model_id: str) -> 'JavaObject':

# Trick to make sure IDE infers right type of `self.batch_jobs` and can resolve `get_job_output_dir`
Expand Down

0 comments on commit fc8df8b

Please sign in to comment.