Skip to content

Commit

Permalink
Merge pull request #112 from zmoon/remap
Browse files Browse the repository at this point in the history
Fix `.monet.remap_xesmf()` for Dataset-to-Dataset
  • Loading branch information
zmoon authored Sep 2, 2022
2 parents 4506291 + 7abe677 commit 9d8fb80
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 4 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*~
.DS_Store
docs/_build/
monet_xesmf_regrid_file.nc


# Default GitHub .gitignore for Python below:
Expand Down
25 changes: 25 additions & 0 deletions environment-dev.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
name: monet-dev
channels:
- conda-forge
- nodefaults
dependencies:
- python=3.9
#
# core
- cartopy
- dask
- matplotlib
- netcdf4
- pydecorate
- pandas
- seaborn
- xarray
#
# optional
- pyresample
- xesmf # non-Windows only!
#
# test/dev
- ipykernel
- ipython
- pytest
7 changes: 4 additions & 3 deletions monet/monet_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1477,13 +1477,14 @@ def remap_xesmf(self, data, **kwargs):
raise TypeError
except TypeError:
print("data must be an xarray.DataArray or xarray.Dataset")
# TODO: raise

def _remap_xesmf_dataset(self, dset, filename="monet_xesmf_regrid_file.nc", **kwargs):
skip_keys = ["latitude", "longitude", "time", "TFLAG"]
vars = pd.Series(dset.variables)
skip_keys = ["lat", "lon", "time", "TFLAG"]
vars = pd.Series(list(dset.variables))
loop_vars = vars.loc[~vars.isin(skip_keys)]
dataarray = dset[loop_vars[0]]
da = self._remap_xesmf_dataarray(dataarray, self._obj, filename=filename, **kwargs)
da = self._remap_xesmf_dataarray(dataarray, filename=filename, **kwargs)
self._obj[da.name] = da
das = {}
das[da.name] = da
Expand Down
5 changes: 4 additions & 1 deletion monet/util/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,7 @@ def resample_xesmf(source_da, target_da, cleanup=False, **kwargs):
ds.attrs = source_da.attrs
return ds
else:
return regridder(source_da)
da = regridder(source_da)
if da.name is None:
da.name = source_da.name
return da
8 changes: 8 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,11 @@ line-length = 100
known_first_party = "monet"
profile = "black"
line_length = 100

[tool.pytest.ini_options]
filterwarnings = [
"ignore:distutils Version classes:DeprecationWarning::",
"""ignore:elementwise comparison failed; returning scalar instead:\
FutureWarning:xarray.core.dataarray:""",
# ^ looks to be coming from xESMF
]
48 changes: 48 additions & 0 deletions tests/test_remap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import numpy as np
import pytest
import xarray as xr

import monet # noqa: F401


def test_remap_ds_ds():
# Barry noted a problem with this

lonmin, latmin, lonmax, latmax = [0, 0, 10, 10]

def make_ds(*, nx=10, ny=10):
data = np.arange(nx * ny).reshape((ny, nx))
assert data.flags["C_CONTIGUOUS"], "xESMF wants this"

return xr.Dataset(
data_vars={"data": (("y", "x"), data)},
coords={
"latitude": ("y", np.linspace(latmin, latmax, ny)),
"longitude": ("x", np.linspace(lonmin, lonmax, nx)),
},
)

target = make_ds()
source = make_ds(nx=5)
# When we call `target.monet.remap_xesmf()`,
# data on the source grid is regridded to the target grid
# and added as a new variable.

# Check for cf accessor
assert hasattr(target, "cf")
with pytest.raises(KeyError, match="No results found for 'latitude'."):
target.cf.get_bounds("latitude")
assert hasattr(target.monet._obj, "cf")

# On the data DataArray directly
target.monet.remap_xesmf(source["data"])
ds1 = target.copy(deep=True)

# On the Dataset
# Note conservative methods don't work here because need cell bounds
target = target.drop_vars("data_y")
target.monet.remap_xesmf(source, method="nearest_d2s")
ds2 = target.copy(deep=True)

assert np.all(ds1.data == ds2.data), "original data should be same"
assert not np.all(ds1.data_y == ds2.data_y), "remapped data should be different"

0 comments on commit 9d8fb80

Please sign in to comment.