diff --git a/xesmf/frontend.py b/xesmf/frontend.py index 130fe710..c1acfbf6 100644 --- a/xesmf/frontend.py +++ b/xesmf/frontend.py @@ -954,111 +954,114 @@ def __init__( self.out_coords = {lat_out.name: lat_out, lon_out.name: lon_out} if parallel: - # Check if we have bounds as variable and not coords, and add them to coords in both datasets - if 'lon_b' in ds_out.data_vars: - ds_out = ds_out.set_coords(['lon_b', 'lat_b']) - if 'lon_b' in ds_in.data_vars: - ds_in = ds_in.set_coords(['lon_b', 'lat_b']) - # Drop everything in ds_out except mask or create mask if None. This is to prevent map_blocks loading unnecessary large data - if locstream_out: - ds_out_dims_drop = set(ds_out.variables).difference(ds_out.data_vars) - ds_out = ds_out.drop_dims(ds_out_dims_drop) + self._init_para_regrid(ds_in, ds_out, kwargs) + + def _init_para_regrid(self, ds_in, ds_out, kwargs): + # Check if we have bounds as variable and not coords, and add them to coords in both datasets + if 'lon_b' in ds_out.data_vars: + ds_out = ds_out.set_coords(['lon_b', 'lat_b']) + if 'lon_b' in ds_in.data_vars: + ds_in = ds_in.set_coords(['lon_b', 'lat_b']) + # Drop everything in ds_out except mask or create mask if None. This is to prevent map_blocks loading unnecessary large data + if self.sequence_out: + ds_out_dims_drop = set(ds_out.variables).difference(ds_out.data_vars) + ds_out = ds_out.drop_dims(ds_out_dims_drop) + else: + if 'mask' in ds_out: + mask = ds_out.mask + ds_out = ds_out.coords.to_dataset() + ds_out['mask'] = mask else: - if 'mask' in ds_out: - mask = ds_out.mask - ds_out = ds_out.coords.to_dataset() - ds_out['mask'] = mask - else: - ds_out_chunks = tuple([ds_out.chunksizes[i] for i in output_dims]) - ds_out = ds_out.coords.to_dataset() - mask = da.ones(shape_out, dtype=bool, chunks=ds_out_chunks) - ds_out['mask'] = (output_dims, mask) - - ds_out_dims_drop = set(ds_out.cf.coordinates.keys()).difference( - ['longitude', 'latitude'] - ) - ds_out = ds_out.cf.drop_dims(ds_out_dims_drop) + ds_out_chunks = tuple([ds_out.chunksizes[i] for i in self.out_horiz_dims]) + ds_out = ds_out.coords.to_dataset() + mask = da.ones(self.shape_out, dtype=bool, chunks=ds_out_chunks) + ds_out['mask'] = (self.out_horiz_dims, mask) - # Drop unnecessary variables in ds_in to save memory - if not locstream_in: - # Drop unnecessary dims - ds_in_dims_drop = set(ds_in.cf.coordinates.keys()).difference( - ['longitude', 'latitude'] - ) - ds_in = ds_in.cf.drop_dims(ds_in_dims_drop) - - # Drop unnecessary vars - ds_in = ds_in.coords.to_dataset() - - # Ensure ds_in is not dask-backed - if xr.core.pycompat.is_dask_collection(ds_in): - ds_in = ds_in.compute() - - # if bounds in ds_out, we switch to cf bounds for map_blocks - if 'lon_b' in ds_out and (ds_out.lon_b.ndim == ds_out.cf['longitude'].ndim): - ds_out = ds_out.assign_coords( - lon_bounds=cfxr.vertices_to_bounds( - ds_out.lon_b, ('bounds', *ds_out.cf['longitude'].dims) - ), - lat_bounds=cfxr.vertices_to_bounds( - ds_out.lat_b, ('bounds', *ds_out.cf['latitude'].dims) - ), - ) - # Make cf-xarray aware of the new bounds - ds_out[ds_out.cf['longitude'].name].attrs['bounds'] = 'lon_bounds' - ds_out[ds_out.cf['latitude'].name].attrs['bounds'] = 'lat_bounds' - ds_out = ds_out.drop_dims(ds_out.lon_b.dims + ds_out.lat_b.dims) - # rename dims to avoid map_blocks confusing ds_in and ds_out dims. - if locstream_in: - ds_in = ds_in.rename({self.in_horiz_dims[0]: 'x_in'}) - else: - ds_in = ds_in.rename({self.in_horiz_dims[0]: 'y_in', self.in_horiz_dims[1]: 'x_in'}) + ds_out_dims_drop = set(ds_out.cf.coordinates.keys()).difference( + ['longitude', 'latitude'] + ) + ds_out = ds_out.cf.drop_dims(ds_out_dims_drop) - if locstream_out: - ds_out = ds_out.rename({self.out_horiz_dims[1]: 'x_out'}) - out_chunks = ds_out.chunks.get('x_out') - else: - ds_out = ds_out.rename( - {self.out_horiz_dims[0]: 'y_out', self.out_horiz_dims[1]: 'x_out'} - ) - out_chunks = [ds_out.chunks.get(k) for k in ['y_out', 'x_out']] - - weights_dims = ('y_out', 'x_out', 'y_in', 'x_in') - templ = sps.zeros((shape_out + shape_in)) - w_templ = xr.DataArray(templ, dims=weights_dims).chunk( - out_chunks - ) # template has same chunks as ds_out - - w = xr.map_blocks( - subset_regridder, - ds_out, - args=[ - ds_in, - method, - self.in_horiz_dims, - self.out_horiz_dims, - locstream_in, - locstream_out, - periodic, - ], - kwargs=kwargs, - template=w_templ, + # Drop unnecessary variables in ds_in to save memory + if not self.sequence_in: + # Drop unnecessary dims + ds_in_dims_drop = set(ds_in.cf.coordinates.keys()).difference( + ['longitude', 'latitude'] ) - w = w.compute(scheduler='processes') - weights = w.stack(out_dim=weights_dims[:2], in_dim=weights_dims[2:]) - weights.name = 'weights' - self.weights = weights + ds_in = ds_in.cf.drop_dims(ds_in_dims_drop) - # follows legacy logic of writing weights if filename is provided - if 'filename' in kwargs: - filename = kwargs['filename'] - else: - filename = None - if filename is not None and not self.reuse_weights: - self.to_netcdf(filename=filename) + # Drop unnecessary vars + ds_in = ds_in.coords.to_dataset() - # set default weights filename if none given - self.filename = self._get_default_filename() if filename is None else filename + # Ensure ds_in is not dask-backed + if xr.core.pycompat.is_dask_collection(ds_in): + ds_in = ds_in.compute() + + # if bounds in ds_out, we switch to cf bounds for map_blocks + if 'lon_b' in ds_out and (ds_out.lon_b.ndim == ds_out.cf['longitude'].ndim): + ds_out = ds_out.assign_coords( + lon_bounds=cfxr.vertices_to_bounds( + ds_out.lon_b, ('bounds', *ds_out.cf['longitude'].dims) + ), + lat_bounds=cfxr.vertices_to_bounds( + ds_out.lat_b, ('bounds', *ds_out.cf['latitude'].dims) + ), + ) + # Make cf-xarray aware of the new bounds + ds_out[ds_out.cf['longitude'].name].attrs['bounds'] = 'lon_bounds' + ds_out[ds_out.cf['latitude'].name].attrs['bounds'] = 'lat_bounds' + ds_out = ds_out.drop_dims(ds_out.lon_b.dims + ds_out.lat_b.dims) + # rename dims to avoid map_blocks confusing ds_in and ds_out dims. + if self.sequence_in: + ds_in = ds_in.rename({self.in_horiz_dims[0]: 'x_in'}) + else: + ds_in = ds_in.rename({self.in_horiz_dims[0]: 'y_in', self.in_horiz_dims[1]: 'x_in'}) + + if self.sequence_out: + ds_out = ds_out.rename({self.out_horiz_dims[1]: 'x_out'}) + out_chunks = ds_out.chunks.get('x_out') + else: + ds_out = ds_out.rename( + {self.out_horiz_dims[0]: 'y_out', self.out_horiz_dims[1]: 'x_out'} + ) + out_chunks = [ds_out.chunks.get(k) for k in ['y_out', 'x_out']] + + weights_dims = ('y_out', 'x_out', 'y_in', 'x_in') + templ = sps.zeros((self.shape_out + self.shape_in)) + w_templ = xr.DataArray(templ, dims=weights_dims).chunk( + out_chunks + ) # template has same chunks as ds_out + + w = xr.map_blocks( + subset_regridder, + ds_out, + args=[ + ds_in, + self.method, + self.in_horiz_dims, + self.out_horiz_dims, + self.sequence_in, + self.sequence_out, + self.periodic, + ], + kwargs=kwargs, + template=w_templ, + ) + w = w.compute(scheduler='processes') + weights = w.stack(out_dim=weights_dims[:2], in_dim=weights_dims[2:]) + weights.name = 'weights' + self.weights = weights + + # follows legacy logic of writing weights if filename is provided + if 'filename' in kwargs: + filename = kwargs['filename'] + else: + filename = None + if filename is not None and not self.reuse_weights: + self.to_netcdf(filename=filename) + + # set default weights filename if none given + self.filename = self._get_default_filename() if filename is None else filename def _format_xroutput(self, out, new_dims=None): if new_dims is not None: