Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(delay): compatibility with RingMap and GridBeam containers #140

Merged
merged 2 commits into from
Jun 8, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 56 additions & 20 deletions draco/analysis/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ class DelayFilterBase(task.SingleTask):
The main axis to iterate over. The delay cut can be varied for each element
of this axis. If not set, a suitable default is picked for the container
type.
dataset : str
Apply the delay filter to this dataset. If not set, a suitable default
is picked for the container type.

Notes
-----
Expand All @@ -170,6 +173,7 @@ class DelayFilterBase(task.SingleTask):
delay_cut = config.Property(proptype=float, default=0.1)
window = config.Property(proptype=bool, default=False)
axis = config.Property(proptype=str, default=None)
dataset = config.Property(proptype=str, default=None)

def setup(self, telescope: io.TelescopeConvertible):
"""Set the telescope needed to obtain baselines.
Expand Down Expand Up @@ -221,6 +225,15 @@ def process(self, ss: FreqContainerType) -> FreqContainerType:
_default_axis = {
containers.SiderealStream: "stack",
containers.HybridVisMModes: "m",
containers.RingMap: "el",
containers.GridBeam: "theta",
}

_default_dataset = {
containers.SiderealStream: "vis",
containers.HybridVisMModes: "vis",
containers.RingMap: "map",
containers.GridBeam: "beam",
}

axis = self.axis
Expand All @@ -233,26 +246,36 @@ def process(self, ss: FreqContainerType) -> FreqContainerType:
else:
raise ValueError(f"No default axis know for {type(ss)} container.")

dset = self.dataset

if self.dataset is None:
for cls, dataset in _default_dataset.items():
if isinstance(ss, cls):
dset = dataset
break
else:
raise ValueError(f"No default dataset know for {type(ss)} container.")

ss.redistribute(axis)

freq = ss.freq[:]
bandwidth = np.ptp(freq)

# Get views of the relevant datasets, but make sure that the weights have the
# same number of axes as the visibilities (inserting length-1 axes as needed)
ssv = ss.vis[:].view(np.ndarray)
ssw = match_axes(ss.vis, ss.weight).view(np.ndarray)
ssv = ss.datasets[dset][:].view(np.ndarray)
ssw = match_axes(ss.datasets[dset], ss.weight).view(np.ndarray)

dist_axis_pos = list(ss.vis.attrs["axis"]).index(axis)
freq_axis_pos = list(ss.vis.attrs["axis"]).index("freq")
dist_axis_pos = list(ss.datasets[dset].attrs["axis"]).index(axis)
freq_axis_pos = list(ss.datasets[dset].attrs["axis"]).index("freq")

# Once we have selected elements of dist_axis the location of freq_axis_pos may
# be one lower
sel_freq_axis_pos = (
freq_axis_pos if freq_axis_pos < dist_axis_pos else freq_axis_pos - 1
)

for lbi, bi in ss.vis[:].enumerate(axis=dist_axis_pos):
for lbi, bi in ss.datasets[dset][:].enumerate(axis=dist_axis_pos):

# Extract the part of the array that we are processing, and
# transpose/reshape to make a 2D array with frequency as axis=0
Expand Down Expand Up @@ -476,6 +499,8 @@ class DelaySpectrumEstimatorBase(task.SingleTask, random.RandomTask):
skip_nyquist : bool, optional
Whether the Nyquist frequency is included in the data. This is `True` by
default to align with the output of CASPER PFBs.
dataset : str
Calculate the delay spectrum of this dataset (e.g., "vis", "map", "beam").
average_axis : str
Name of the axis to take the average over.
"""
Expand All @@ -486,6 +511,7 @@ class DelaySpectrumEstimatorBase(task.SingleTask, random.RandomTask):
nfreq = config.Property(proptype=int, default=None)
skip_nyquist = config.Property(proptype=bool, default=True)

dataset = config.Property(proptype=str, default="vis")
average_axis = config.Property(proptype=str)

def setup(self, telescope: io.TelescopeConvertible):
Expand All @@ -512,9 +538,15 @@ def process(self, ss: FreqContainerType) -> containers.DelaySpectrum:
"""
ss.redistribute("freq")

if self.dataset not in ss.datasets:
raise ValueError(
f"Specified dataset to delay transform ({self.dataset}) not in "
f"container of type {type(ss)}."
)

if (
self.average_axis not in ss.axes
or self.average_axis not in ss.vis.attrs["axis"]
or self.average_axis not in ss.datasets[self.dataset].attrs["axis"]
):
raise ValueError(
f"Specified axis to average over ({self.average_axis}) not in "
Expand Down Expand Up @@ -543,24 +575,28 @@ def process(self, ss: FreqContainerType) -> containers.DelaySpectrum:
delays = np.fft.fftshift(np.fft.fftfreq(ndelay, d=self.freq_spacing)) # in us

# Find the relevant axis positions
vis_axes = ss.vis.attrs["axis"]
freq_axis_pos = list(vis_axes).index("freq")
average_axis_pos = list(vis_axes).index(self.average_axis)
data_axes = ss.datasets[self.dataset].attrs["axis"]
freq_axis_pos = list(data_axes).index("freq")
average_axis_pos = list(data_axes).index(self.average_axis)

# Create a view of the visibility dataset with the relevant axes at the back,
# Create a view of the dataset with the relevant axes at the back,
# and all other axes compressed
vis_view = np.moveaxis(
ss.vis[:].view(np.ndarray), [average_axis_pos, freq_axis_pos], [-2, -1]
data_view = np.moveaxis(
ss.datasets[self.dataset][:].view(np.ndarray),
[average_axis_pos, freq_axis_pos],
[-2, -1],
)
vis_view = vis_view.reshape(-1, vis_view.shape[-2], vis_view.shape[-1])
vis_view = mpiarray.MPIArray.wrap(vis_view, axis=2, comm=ss.comm)
nbase = int(np.prod(vis_view.shape[:-2]))
vis_view = vis_view.redistribute(axis=0)
data_view = data_view.reshape(-1, data_view.shape[-2], data_view.shape[-1])
data_view = mpiarray.MPIArray.wrap(data_view, axis=2, comm=ss.comm)
nbase = int(np.prod(data_view.shape[:-2]))
data_view = data_view.redistribute(axis=0)

# ... do the same for the weights, but we also need to make the weights full
# size
weight_full = np.zeros(ss.vis[:].shape, dtype=ss.weight.dtype)
weight_full[:] = match_axes(ss.vis, ss.weight)
weight_full = np.zeros(
ss.datasets[self.dataset][:].shape, dtype=ss.weight.dtype
)
weight_full[:] = match_axes(ss.datasets[self.dataset], ss.weight)
weight_view = np.moveaxis(
weight_full, [average_axis_pos, freq_axis_pos], [-2, -1]
)
Expand All @@ -576,7 +612,7 @@ def process(self, ss: FreqContainerType) -> containers.DelaySpectrum:
delay_spec = containers.DelaySpectrum(baseline=nbase, delay=delays)
delay_spec.redistribute("baseline")
delay_spec.spectrum[:] = 0.0
bl_axes = [va for va in vis_axes if va not in [self.average_axis, "freq"]]
bl_axes = [da for da in data_axes if da not in [self.average_axis, "freq"]]

# Copy the index maps for all the flattened axes into the output container, and
# write out their order into an attribute so we can reconstruct this easily
Expand All @@ -593,7 +629,7 @@ def process(self, ss: FreqContainerType) -> containers.DelaySpectrum:
self.log.debug(f"Delay transforming baseline {bi}/{nbase}")

# Get the local selections
data = vis_view[lbi].view(np.ndarray)
data = data_view[lbi].view(np.ndarray)
weight = weight_view[lbi].view(np.ndarray)

# Mask out data with completely zero'd weights and generate time
Expand Down