Skip to content

Commit

Permalink
feat(delay): Implement a NRML delay power spectrum estimator option
Browse files Browse the repository at this point in the history
  • Loading branch information
ljgray authored and jmaceachern committed Jun 13, 2024
1 parent fd0236f commit 7d557df
Showing 1 changed file with 83 additions and 23 deletions.
106 changes: 83 additions & 23 deletions draco/analysis/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from ..core import containers, io, task
from ..util import random, tools
from .delayopt import delay_power_spectrum_maxpost


class DelayFilter(task.SingleTask):
Expand Down Expand Up @@ -642,22 +643,34 @@ class DelayGibbsSamplerBase(DelayTransformBase, random.RandomTask):
Attributes
----------
nsamp : int, optional
The number of Gibbs samples to draw.
If maxpost=False, the number of Gibbs samples to draw. If maxpost=True,
the number of iterations allowed in the call to scipy.optimize.minimize
in the maximum-likelihood estimator.
initial_amplitude : float, optional
The Gibbs sampler will be initialized with a flat power spectrum with
this amplitude. Default: 10.
this amplitude. Unused if maxpost=True (flat spectrum is a bad initial
guess for the max-likelihood estimator). Default: 10.
save_samples : bool, optional.
The entire chain of samples will be saved rather than just the final
result. Default: False
initial_sample_path : str, optional
File path to load an initial power spectrum sample. If no file is given,
start with a flat power spectrum. Default: None
start with a flat power spectrum (Gibbs) or inverse FFT (max-likelihood).
Default: None
maxpost : bool, optional
The NRML maximum-likelihood delay spectrum estimator will be used instead
of the Gibbs sampler.
maxpost_tol : float, optional
Only used if maxpost=True. The convergence tolerance used by
scipy.optimize.minimize in the maximum likelihood estimator.
"""

nsamp = config.Property(proptype=int, default=20)
initial_amplitude = config.Property(proptype=float, default=10.0)
save_samples = config.Property(proptype=bool, default=False)
initial_sample_path = config.Property(proptype=str, default=None)
maxpost = config.Property(proptype=bool, default=False)
maxpost_tol = config.Property(proptype=float, default=1e-3)

def _create_output(
self,
Expand Down Expand Up @@ -702,6 +715,12 @@ def _create_output(
if self.save_samples:
delay_spec.add_dataset("spectrum_samples")

# If estimating delay spectrum w/ max-likelihood, initialize a mask dataset
# to record the baselines for which the estimator did/didn't converge.
if self.maxpost:
delay_spec.add_dataset("spectrum_mask")
delay_spec.datasets["spectrum_mask"][:] = 0

# Save the frequency axis of the input data as an attribute in the output
# container
delay_spec.attrs["freq"] = ss.freq
Expand Down Expand Up @@ -729,8 +748,16 @@ def _get_initial_S(self, nbase, ndelay, dtype):
initial_S = cont.spectrum[:].local_array
bl_ax = cont.spectrum.attrs["axis"].tolist().index("baseline")
initial_S = np.moveaxis(initial_S, bl_ax, 0)
else:
# Gibbs case.
elif not self.maxpost:
initial_S = np.ones((nbase, ndelay), dtype=dtype) * self.initial_amplitude
# Max-likelihood case.
else:
# Flat spectrum is a bad initial guess for max-likelihood.
# Passing None as the initial guess to the max-likelihood
# estimator will cause it to use an inverse FFT as the
# initial guess, which works well in practice.
initial_S = np.full(nbase, None)

return initial_S

Expand Down Expand Up @@ -773,30 +800,63 @@ def _evaluate(self, data_view, weight_view, out_cont, delays, channel_ind):
weight = weight_view.local_array[lbi]

# Apply the cuts to the data
t = self._cut_data(data, weight, channel_ind)
t = self._cut_data(data, weight)
if t is None:
continue
data, weight, non_zero_channel = t
data, weight, nzf, _ = t

spec = delay_power_spectrum_gibbs(
data,
ndelay,
weight,
initial_S[lbi],
window=self.window if self.apply_window else None,
fsel=non_zero_channel,
niter=self.nsamp,
rng=rng,
complex_timedomain=self.complex_timedomain,
)
if self.maxpost:
spec, success = delay_power_spectrum_maxpost(
data,
ndelay,
weight,
initial_S[lbi],
window=self.window if self.apply_window else None,
fsel=channel_ind[nzf],
maxiter=self.nsamp,
tol=self.maxpost_tol,
)

# Take an average over the last half of the delay spectrum samples
# (presuming that removes the burn-in)
spec_av = np.median(spec[-(self.nsamp // 2) :], axis=0)
out_cont.spectrum[bi] = np.fft.fftshift(spec_av)
# If max-likelihood didn't converge in allowed number of iters, reflect this in the mask.
if not success:
out_cont.datasets["spectrum_mask"][bi] = 1 # Indexing into a MemDatasetDistributed object with the
# global index bi actually ends up (under the hood)
# indexing the underlying MPIArray with the local index.

if self.save_samples:
out_cont.datasets["spectrum_samples"][:, bi] = spec
out_cont.spectrum[bi] = np.fft.fftshift(spec[-1])

if self.save_samples:
nsamp = len(spec)
out_cont.datasets["spectrum_samples"][:, bi] = 0.0
out_cont.datasets["spectrum_samples"][-nsamp:, bi] = np.array(spec)

else:
spec = delay_power_spectrum_gibbs(
data,
ndelay,
weight,
initial_S[lbi],
window=self.window if self.apply_window else None,
fsel=channel_ind[nzf],
niter=self.nsamp,
rng=rng,
complex_timedomain=self.complex_timedomain,
)

# Take an average over the last half of the delay spectrum samples
# (presuming that removes the burn-in)
spec_av = np.median(spec[-(self.nsamp // 2) :], axis=0)
out_cont.spectrum[bi] = np.fft.fftshift(spec_av)

if self.save_samples:
out_cont.datasets["spectrum_samples"][:, bi] = spec

if self.maxpost:
# Record number of converged baselines for debugging info.
n_conv = nbase - out_cont.datasets["spectrum_mask"][:].allgather().sum()
self.log.debug(
f"{n_conv}/{nbase} baselines converged in maximum-likelihood estimate of delay power spectrum."
)

return out_cont

Expand Down

0 comments on commit 7d557df

Please sign in to comment.