Skip to content

Commit

Permalink
ENH: Allow index inputs for read_timeseries (#139)
Browse files Browse the repository at this point in the history
  • Loading branch information
szwiep authored Aug 7, 2024
1 parent 04a84e3 commit 38a175d
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 42 deletions.
3 changes: 3 additions & 0 deletions src/spectral_recovery/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

REQ_DIMS = ["band", "time", "y", "x"]

# Index configurations
STANDARD_BANDS = list(spx.bands)
SUPPORTED_DOMAINS = ["vegetation", "burn"]
SUPPORTED_INDICES = [ix for ix in list(spx.indices) if spx.indices[ix].application_domain in SUPPORTED_DOMAINS] + ["GCI", "TCW", "TCG"]

VALID_YEAR = re.compile(r"^\d{4}$")
2 changes: 1 addition & 1 deletion src/spectral_recovery/indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
from typing import List, Dict

from spectral_recovery._utils import maintain_rio_attrs
from spectral_recovery._config import SUPPORTED_DOMAINS

# Set up global index configurations:
# 1. Only support vegetation and burn indices
# 2. Init index-specific constant defaults
SUPPORTED_DOMAINS = ["vegetation", "burn"]
with pkg_resources.open_text(
"spectral_recovery.resources", "constant_defaults.json"
) as f:
Expand Down
48 changes: 7 additions & 41 deletions src/spectral_recovery/io/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import xarray as xr

from spectral_recovery._utils import bands_pretty_table, common_and_long_to_short
from spectral_recovery._config import SUPPORTED_INDICES
from rasterio._err import CPLE_AppDefinedError

from spectral_recovery._config import (
Expand Down Expand Up @@ -227,16 +228,17 @@ def _to_standard_band_names(in_names: List[str]) -> Tuple[List[str], List[str]]:
converted = True
standard_names.append(COMMON_LONG_SHORT_DICT[given_name])
attr_names.append(given_name)

elif given_name in STANDARD_BANDS:
converted = True
standard_names.append(given_name)
elif given_name in SUPPORTED_INDICES:
converted = True
standard_names.append(given_name)

if not converted:
raise ValueError(
"Band must be named standard, common, or long name. Could not find"
f" '{given_name}' in catalogue. See table below for accepted names:"
f" \n\n {BANDS_TABLE} \n\n"
"Band must be named standard, common, or long name for a spectral band, or a spectral index. Could not find"
f" '{given_name}' in supported bands or indices."
).with_traceback(None) from None

return (standard_names, attr_names)
Expand All @@ -249,40 +251,4 @@ def _mask_stack(stack: xr.DataArray, mask: xr.DataArray, fill=np.nan) -> xr.Data
f"Only 2D masks are supported. {len(mask.dims)}D mask provided."
)
masked_stack = stack.where(mask, fill)
return masked_stack


def _metrics_to_tifs(
metric: xr.DataArray,
out_dir: str,
) -> None:
"""
Write a DataArray of metrics to TIFs.
Parameters
----------
metric : xr.DataArray
The metric to write to TIFs. Must have dimensions: 'metric', 'band', 'y', and 'x'.
out_dir : str
Path to directory to write TIFs.
"""
# NOTE: out_raster MUST be all null otherwise merging of rasters will fail
out_raster = xr.full_like(metric[0, 0, :, :], np.nan)
for m in metric["metric"].values:
xa_dataset = xr.Dataset()
for band in metric["band"].values:
out_metric = metric.sel(metric=m, band=band)

merged = out_metric.combine_first(out_raster)
xa_dataset[str(band)] = merged
try:
filename = f"{out_dir}/{str(m)}.tif"
xa_dataset.rio.to_raster(raster_path=filename)
# TODO: Probably shouldn't except on an error hidden from API users...
except CPLE_AppDefinedError:
raise PermissionError(
f"Permission denied to overwrite {filename}. Is the existing TIF"
" open in an application (e.g QGIS)? If so, try closing it before"
" your next run to avoid this error."
) from None
return masked_stack
24 changes: 24 additions & 0 deletions src/tests/unit/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,31 @@ def test_no_band_desc_or_band_names_throws_value_err(
path_to_tifs="a/dir",
array_type="numpy",
)

@patch(
"rioxarray.open_rasterio",
)
@patch("spectral_recovery.io.raster._get_tifs_from_dir")
def test_index_band_names_accepted(
self, mocked_get_tifs, mocked_rasterio_open, filenames
):
expected_bands = ["GCI", "NBR", "NDVI"]
rasterio_return = xr.DataArray(
[[[[0]]], [[[0]]], [[[0]]]],
dims=["band", "time", "y", "x"],
coords={"band": [1, 2, 3]},
)
mocked_get_tifs.return_value = filenames
mocked_rasterio_open.return_value = rasterio_return

stacked_tifs = read_timeseries(
path_to_tifs="a/dir",
band_names={1: "GCI", 2: "NBR", 3: "NDVI"},
array_type="numpy",
)
# assert
print(stacked_tifs["band"].data, expected_bands)
assert_array_equal(stacked_tifs["band"].data, expected_bands)

class TestReadTimeseriesDictInput:

Expand Down

0 comments on commit 38a175d

Please sign in to comment.