diff --git a/draco/analysis/transform.py b/draco/analysis/transform.py index 042834269..4401c880e 100644 --- a/draco/analysis/transform.py +++ b/draco/analysis/transform.py @@ -879,14 +879,20 @@ class ShiftRA(task.SingleTask): ---------- delta : float The shift to *add* to the RA axis. + periodic : bool, optional + If True, wrap any time sample that is shifted to RA > 360 deg around to its + 360-degree-periodic counterpart, and likewise for any sample that is shifted + to RA < 0 deg. This wrapping is applied to the RA index_map along with any + dataset with an `ra` axis. Default: False. """ delta = config.Property(proptype=float) + periodic = config.Property(proptype=bool, default=False) def process( self, sscont: containers.SiderealContainer ) -> containers.SiderealContainer: - """Add a shift to the input sidereal cont. + """Add a shift to the input sidereal container. Parameters ---------- @@ -904,8 +910,39 @@ def process( f"Expected a SiderealContainer, got {type(sscont)} instead." ) + # Shift RA coordinates by delta sscont.ra[:] += self.delta + if self.periodic: + # If shift is positive, subtract 360 deg from any sample shifted to + # > 360 deg. Same idea if shift is negative, for samples shifted to < 0 deg + if self.delta > 0: + sscont.ra[sscont.ra[:] >= 360] -= 360 + else: + sscont.ra[sscont.ra[:] < 0] += 360 + + # Get indices that sort shifted RA axis in ascending order, and apply sort + ascending_ra_idx = np.argsort(sscont.ra[:]) + sscont.ra[:] = sscont.ra[ascending_ra_idx] + + # Loop over datasets in container + for name, dset in sscont.datasets.items(): + if "ra" in dset.attrs["axis"]: + # If dataset has RA axis, identify which axis it is + ra_axis_idx = np.where(dset.attrs["axis"] == "ra")[0][0] + + # Make sure dataset isn't distributed in RA. If it is, redistribute + # along another (somewhat arbitrarily chosen) axis. (This should + # usually not be necessary.) + if dset.distributed and dset.distributed_axis == ra_axis_idx: + redist_axis = max(ra_axis_idx - 1, 0) + dset.redistribute(redist_axis) + + # Apply RA-sorting from earlier to the appropriate axis + slc = [slice(None)] * len(dset.attrs["axis"]) + slc[ra_axis_idx] = ascending_ra_idx + dset[:] = dset[:][tuple(slc)] + return sscont