Skip to content

Commit

Permalink
Moved para regridding code outside of init and into its own method to…
Browse files Browse the repository at this point in the history
… keep init from getting too complex
  • Loading branch information
charlesgauthier-udm committed Aug 31, 2023
1 parent 0b45750 commit ffae767
Showing 1 changed file with 102 additions and 99 deletions.
201 changes: 102 additions & 99 deletions xesmf/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,111 +954,114 @@ def __init__(
self.out_coords = {lat_out.name: lat_out, lon_out.name: lon_out}

if parallel:
# Check if we have bounds as variable and not coords, and add them to coords in both datasets
if 'lon_b' in ds_out.data_vars:
ds_out = ds_out.set_coords(['lon_b', 'lat_b'])
if 'lon_b' in ds_in.data_vars:
ds_in = ds_in.set_coords(['lon_b', 'lat_b'])
# Drop everything in ds_out except mask or create mask if None. This is to prevent map_blocks loading unnecessary large data
if locstream_out:
ds_out_dims_drop = set(ds_out.variables).difference(ds_out.data_vars)
ds_out = ds_out.drop_dims(ds_out_dims_drop)
self._init_para_regrid(ds_in, ds_out, kwargs)

def _init_para_regrid(self, ds_in, ds_out, kwargs):
# Check if we have bounds as variable and not coords, and add them to coords in both datasets
if 'lon_b' in ds_out.data_vars:
ds_out = ds_out.set_coords(['lon_b', 'lat_b'])
if 'lon_b' in ds_in.data_vars:
ds_in = ds_in.set_coords(['lon_b', 'lat_b'])
# Drop everything in ds_out except mask or create mask if None. This is to prevent map_blocks loading unnecessary large data
if self.sequence_out:
ds_out_dims_drop = set(ds_out.variables).difference(ds_out.data_vars)
ds_out = ds_out.drop_dims(ds_out_dims_drop)
else:
if 'mask' in ds_out:
mask = ds_out.mask
ds_out = ds_out.coords.to_dataset()
ds_out['mask'] = mask
else:
if 'mask' in ds_out:
mask = ds_out.mask
ds_out = ds_out.coords.to_dataset()
ds_out['mask'] = mask
else:
ds_out_chunks = tuple([ds_out.chunksizes[i] for i in output_dims])
ds_out = ds_out.coords.to_dataset()
mask = da.ones(shape_out, dtype=bool, chunks=ds_out_chunks)
ds_out['mask'] = (output_dims, mask)

ds_out_dims_drop = set(ds_out.cf.coordinates.keys()).difference(
['longitude', 'latitude']
)
ds_out = ds_out.cf.drop_dims(ds_out_dims_drop)
ds_out_chunks = tuple([ds_out.chunksizes[i] for i in self.out_horiz_dims])
ds_out = ds_out.coords.to_dataset()
mask = da.ones(self.shape_out, dtype=bool, chunks=ds_out_chunks)
ds_out['mask'] = (self.out_horiz_dims, mask)

# Drop unnecessary variables in ds_in to save memory
if not locstream_in:
# Drop unnecessary dims
ds_in_dims_drop = set(ds_in.cf.coordinates.keys()).difference(
['longitude', 'latitude']
)
ds_in = ds_in.cf.drop_dims(ds_in_dims_drop)

# Drop unnecessary vars
ds_in = ds_in.coords.to_dataset()

# Ensure ds_in is not dask-backed
if xr.core.pycompat.is_dask_collection(ds_in):
ds_in = ds_in.compute()

# if bounds in ds_out, we switch to cf bounds for map_blocks
if 'lon_b' in ds_out and (ds_out.lon_b.ndim == ds_out.cf['longitude'].ndim):
ds_out = ds_out.assign_coords(
lon_bounds=cfxr.vertices_to_bounds(
ds_out.lon_b, ('bounds', *ds_out.cf['longitude'].dims)
),
lat_bounds=cfxr.vertices_to_bounds(
ds_out.lat_b, ('bounds', *ds_out.cf['latitude'].dims)
),
)
# Make cf-xarray aware of the new bounds
ds_out[ds_out.cf['longitude'].name].attrs['bounds'] = 'lon_bounds'
ds_out[ds_out.cf['latitude'].name].attrs['bounds'] = 'lat_bounds'
ds_out = ds_out.drop_dims(ds_out.lon_b.dims + ds_out.lat_b.dims)
# rename dims to avoid map_blocks confusing ds_in and ds_out dims.
if locstream_in:
ds_in = ds_in.rename({self.in_horiz_dims[0]: 'x_in'})
else:
ds_in = ds_in.rename({self.in_horiz_dims[0]: 'y_in', self.in_horiz_dims[1]: 'x_in'})
ds_out_dims_drop = set(ds_out.cf.coordinates.keys()).difference(
['longitude', 'latitude']
)
ds_out = ds_out.cf.drop_dims(ds_out_dims_drop)

if locstream_out:
ds_out = ds_out.rename({self.out_horiz_dims[1]: 'x_out'})
out_chunks = ds_out.chunks.get('x_out')
else:
ds_out = ds_out.rename(
{self.out_horiz_dims[0]: 'y_out', self.out_horiz_dims[1]: 'x_out'}
)
out_chunks = [ds_out.chunks.get(k) for k in ['y_out', 'x_out']]

weights_dims = ('y_out', 'x_out', 'y_in', 'x_in')
templ = sps.zeros((shape_out + shape_in))
w_templ = xr.DataArray(templ, dims=weights_dims).chunk(
out_chunks
) # template has same chunks as ds_out

w = xr.map_blocks(
subset_regridder,
ds_out,
args=[
ds_in,
method,
self.in_horiz_dims,
self.out_horiz_dims,
locstream_in,
locstream_out,
periodic,
],
kwargs=kwargs,
template=w_templ,
# Drop unnecessary variables in ds_in to save memory
if not self.sequence_in:
# Drop unnecessary dims
ds_in_dims_drop = set(ds_in.cf.coordinates.keys()).difference(
['longitude', 'latitude']
)
w = w.compute(scheduler='processes')
weights = w.stack(out_dim=weights_dims[:2], in_dim=weights_dims[2:])
weights.name = 'weights'
self.weights = weights
ds_in = ds_in.cf.drop_dims(ds_in_dims_drop)

# follows legacy logic of writing weights if filename is provided
if 'filename' in kwargs:
filename = kwargs['filename']
else:
filename = None
if filename is not None and not self.reuse_weights:
self.to_netcdf(filename=filename)
# Drop unnecessary vars
ds_in = ds_in.coords.to_dataset()

# set default weights filename if none given
self.filename = self._get_default_filename() if filename is None else filename
# Ensure ds_in is not dask-backed
if xr.core.pycompat.is_dask_collection(ds_in):
ds_in = ds_in.compute()

# if bounds in ds_out, we switch to cf bounds for map_blocks
if 'lon_b' in ds_out and (ds_out.lon_b.ndim == ds_out.cf['longitude'].ndim):
ds_out = ds_out.assign_coords(
lon_bounds=cfxr.vertices_to_bounds(
ds_out.lon_b, ('bounds', *ds_out.cf['longitude'].dims)
),
lat_bounds=cfxr.vertices_to_bounds(
ds_out.lat_b, ('bounds', *ds_out.cf['latitude'].dims)
),
)
# Make cf-xarray aware of the new bounds
ds_out[ds_out.cf['longitude'].name].attrs['bounds'] = 'lon_bounds'
ds_out[ds_out.cf['latitude'].name].attrs['bounds'] = 'lat_bounds'
ds_out = ds_out.drop_dims(ds_out.lon_b.dims + ds_out.lat_b.dims)
# rename dims to avoid map_blocks confusing ds_in and ds_out dims.
if self.sequence_in:
ds_in = ds_in.rename({self.in_horiz_dims[0]: 'x_in'})
else:
ds_in = ds_in.rename({self.in_horiz_dims[0]: 'y_in', self.in_horiz_dims[1]: 'x_in'})

if self.sequence_out:
ds_out = ds_out.rename({self.out_horiz_dims[1]: 'x_out'})
out_chunks = ds_out.chunks.get('x_out')
else:
ds_out = ds_out.rename(
{self.out_horiz_dims[0]: 'y_out', self.out_horiz_dims[1]: 'x_out'}
)
out_chunks = [ds_out.chunks.get(k) for k in ['y_out', 'x_out']]

weights_dims = ('y_out', 'x_out', 'y_in', 'x_in')
templ = sps.zeros((self.shape_out + self.shape_in))
w_templ = xr.DataArray(templ, dims=weights_dims).chunk(
out_chunks
) # template has same chunks as ds_out

w = xr.map_blocks(
subset_regridder,
ds_out,
args=[
ds_in,
self.method,
self.in_horiz_dims,
self.out_horiz_dims,
self.sequence_in,
self.sequence_out,
self.periodic,
],
kwargs=kwargs,
template=w_templ,
)
w = w.compute(scheduler='processes')
weights = w.stack(out_dim=weights_dims[:2], in_dim=weights_dims[2:])
weights.name = 'weights'
self.weights = weights

# follows legacy logic of writing weights if filename is provided
if 'filename' in kwargs:
filename = kwargs['filename']
else:
filename = None
if filename is not None and not self.reuse_weights:
self.to_netcdf(filename=filename)

# set default weights filename if none given
self.filename = self._get_default_filename() if filename is None else filename

def _format_xroutput(self, out, new_dims=None):
if new_dims is not None:
Expand Down

0 comments on commit ffae767

Please sign in to comment.