Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jul 4, 2023
1 parent 3d896db commit e26d144
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 15 deletions.
39 changes: 29 additions & 10 deletions xesmf/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,14 +480,27 @@ def __call__(self, indata, keep_attrs=False, skipna=False, na_thres=1.0, output_
"""
if isinstance(indata, dask_array_type + (np.ndarray,)):
return self.regrid_array(
indata, self.weights.data, skipna=skipna, na_thres=na_thres, output_chunks=output_chunks)
indata,
self.weights.data,
skipna=skipna,
na_thres=na_thres,
output_chunks=output_chunks,
)
elif isinstance(indata, xr.DataArray):
return self.regrid_dataarray(
indata, keep_attrs=keep_attrs, skipna=skipna, na_thres=na_thres, output_chunks=output_chunks
indata,
keep_attrs=keep_attrs,
skipna=skipna,
na_thres=na_thres,
output_chunks=output_chunks,
)
elif isinstance(indata, xr.Dataset):
return self.regrid_dataset(
indata, keep_attrs=keep_attrs, skipna=skipna, na_thres=na_thres, output_chunks=output_chunks
indata,
keep_attrs=keep_attrs,
skipna=skipna,
na_thres=na_thres,
output_chunks=output_chunks,
)
else:
raise TypeError('input must be numpy array, dask array, xarray DataArray or Dataset!')
Expand Down Expand Up @@ -528,19 +541,21 @@ def regrid_array(self, indata, weights, skipna=False, na_thres=1.0, output_chunk

if isinstance(indata, dask_array_type): # dask
if output_chunks is None:
weights = da.from_array(self.w.data, chunks=(indata.chunksize[-2:] + indata.chunksize[-2:]))
weights = da.from_array(
self.w.data, chunks=(indata.chunksize[-2:] + indata.chunksize[-2:])
)
elif output_chunks is not None:
if len(output_chunks) != len(self.shape_out):
raise ValueError(
f'output_chunks must have same dimension as ds_out,'
f' output_chunks dimension ({len(output_chunks)}) does not '
f'match ds_out dimension ({len(self.shape_out)})'
)
)
weights = da.from_array(self.w.data, chunks=(output_chunks + indata.chunksize[-2:]))

outdata = self._regrid(indata, weights, **kwargs)
else: # numpy
weights = self.w.data # 4D weights
weights = self.w.data # 4D weights
outdata = self._regrid(indata, weights, **kwargs)
return outdata

Expand All @@ -558,11 +573,13 @@ def regrid_dask(self, indata, **kwargs):
)
return self.regrid_array(indata, self.weights.data, **kwargs)

def regrid_dataarray(self, dr_in, keep_attrs=False, skipna=False, na_thres=1.0, output_chunks=None):
def regrid_dataarray(
self, dr_in, keep_attrs=False, skipna=False, na_thres=1.0, output_chunks=None
):
"""See __call__()."""

input_horiz_dims, temp_horiz_dims = self._parse_xrinput(dr_in)
kwargs = dict(skipna=skipna, na_thres=na_thres,output_chunks=output_chunks)
kwargs = dict(skipna=skipna, na_thres=na_thres, output_chunks=output_chunks)
dr_out = xr.apply_ufunc(
self.regrid_array,
dr_in,
Expand All @@ -576,13 +593,15 @@ def regrid_dataarray(self, dr_in, keep_attrs=False, skipna=False, na_thres=1.0,

return self._format_xroutput(dr_out, temp_horiz_dims)

def regrid_dataset(self, ds_in, keep_attrs=False, skipna=False, na_thres=1.0, output_chunks=None):
def regrid_dataset(
self, ds_in, keep_attrs=False, skipna=False, na_thres=1.0, output_chunks=None
):
"""See __call__()."""

# get the first data variable to infer input_core_dims
input_horiz_dims, temp_horiz_dims = self._parse_xrinput(ds_in)

kwargs = dict(skipna=skipna, na_thres=na_thres,output_chunks=output_chunks)
kwargs = dict(skipna=skipna, na_thres=na_thres, output_chunks=output_chunks)

non_regriddable = [
name
Expand Down
8 changes: 6 additions & 2 deletions xesmf/smm.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,14 @@ def apply_weights(weights, indata, shape_in, shape_out):
nb.from_dtype(indata.dtype)
nb.from_dtype(weights.dtype)
except (NotImplementedError, nb.core.errors.NumbaError):
indata = indata.astype('<f8') # On the fly conversion
indata = indata.astype('<f8') # On the fly conversion

# Dot product
outdata = np.tensordot(indata, weights, axes=((indata.ndim-2,indata.ndim-1),(weights.ndim-2,weights.ndim-1)))
outdata = np.tensordot(
indata,
weights,
axes=((indata.ndim - 2, indata.ndim - 1), (weights.ndim - 2, weights.ndim - 1)),
)

# Ensure same dtype as the input.
outdata = outdata.astype(indata_dtype)
Expand Down
2 changes: 1 addition & 1 deletion xesmf/tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def test_regrid():
weights = read_weights(filename, lon_in.size, lon_out.size).data
shape_in = lon_in.shape
shape_out = lon_out.shape
w = weights.reshape(shape_out + shape_in) # 4D weights
w = weights.reshape(shape_out + shape_in) # 4D weights
data_out_scipy = apply_weights(w, data_in, shape_in, shape_out)

# must be almost exactly the same as esmpy's result!
Expand Down
7 changes: 5 additions & 2 deletions xesmf/tests/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

# use non-divisible chunk size to catch edge cases
ds_in_chunked = ds_in.chunk({'time': 3, 'lev': 2})
ds_spatial_chunked = ds_in.chunk({'time':3,'lev':2, 'y':5,'x':9})
ds_spatial_chunked = ds_in.chunk({'time': 3, 'lev': 2, 'y': 5, 'x': 9})

ds_locs = xr.Dataset()
ds_locs['lat'] = xr.DataArray(data=[-20, -10, 0, 10], dims=('locations',))
Expand Down Expand Up @@ -592,12 +592,13 @@ def test_regrid_dataarray_dask_from_locstream(request, scheduler):
outdata = regridder(ds_locs.chunk()['lat'])
assert dask.is_dask_collection(outdata)


def test_dask_output_chunks():
regridder = xe.Regridder(ds_in, ds_out, 'conservative')

test_output_chunks = (10, 12)

indata = ds_spatial_chunked['data4D'].data # Data chunked along spatial dims
indata = ds_spatial_chunked['data4D'].data # Data chunked along spatial dims
# Use ridiculous small chunk size value to be sure it _isn't_ impacting computation.
with dask.config.set({'array.chunk-size': '1MiB'}):
outdata = regridder(indata)
Expand All @@ -613,6 +614,8 @@ def test_dask_output_chunks():
# Verify that we get specified outputchunks when the argument is provided
assert outdata_spec.shape == indata.shape[:-2] + horiz_shape_out
assert outdata_spec.chunksize == indata.chunksize[:-2] + test_output_chunks


def test_regrid_dataset():
# xarray.Dataset containing in-memory numpy array
regridder = xe.Regridder(ds_in, ds_out, 'conservative')
Expand Down

0 comments on commit e26d144

Please sign in to comment.