Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a new RFI Flagging task #265

Merged
merged 4 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions draco/analysis/delay.py
Original file line number Diff line number Diff line change
Expand Up @@ -1861,15 +1861,24 @@ def delay_spectrum_wiener_filter(
return y_spec


def null_delay_filter(freq, max_delay, mask, num_delay=200, tol=1e-8, window=True):
def null_delay_filter(
freq,
delay_cut,
mask,
num_delay=200,
tol=1e-8,
window=True,
type_="high",
lapack_driver="gesvd",
):
"""Take frequency data and null out any delays below some value.

Parameters
----------
freq : np.ndarray[freq]
Frequencies we have data at.
max_delay : float
Maximum delay to keep.
delay_cut : float
Delay cut to apply.
mask : np.ndarray[freq]
Frequencies to mask out.
num_delay : int, optional
Expand All @@ -1878,17 +1887,27 @@ def null_delay_filter(freq, max_delay, mask, num_delay=200, tol=1e-8, window=Tru
Cut off value for singular values.
window : bool, optional
Apply a window function to the data while filtering.
type_ : str, optional
Whether to apply a high-pass or low-pass filter. Options are
`high` or `low`. Default is `high`.
lapack_driver : str, optional
Which lapack driver to use in the SVD. Options are 'gesvd' or 'gesdd'.
'gesdd' is generally faster, but seems to experience convergence issues.
Default is 'gesvd'.

Returns
-------
filter : np.ndarray[freq, freq]
The filter as a 2D matrix.
"""
if type_ not in {"high", "low"}:
raise ValueError(f"Filter type must be one of [high, low]. Got {type_}")

# Construct the window function
x = (freq - freq.min()) / freq.ptp()
w = tools.window_generalised(x, window="nuttall")

delay = np.linspace(-max_delay, max_delay, num_delay)
delay = np.linspace(-delay_cut, delay_cut, num_delay)

# Construct the Fourier matrix
F = mask[:, np.newaxis] * np.exp(
Expand All @@ -1904,9 +1923,14 @@ def null_delay_filter(freq, max_delay, mask, num_delay=200, tol=1e-8, window=Tru
# to be the fault of MKL (see https://github.com/scipy/scipy/issues/10032 and links
# therein). This seems to be limited to the `gesdd` LAPACK routine, so we can get
# around it by switching to `gesvd`.
u, sig, vh = la.svd(F, lapack_driver="gesvd")
u, sig, vh = la.svd(F, full_matrices=False, lapack_driver=lapack_driver)
nmodes = np.sum(sig > tol * sig.max())
p = u[:, :nmodes]

# Select the modes to null out based on the filter type
if type_ == "high":
p = u[:, :nmodes]
elif type_ == "low":
p = u[:, nmodes:]

# Construct a projection matrix for the filter
proj = np.identity(len(freq)) - np.dot(p, p.T.conj())
Expand Down
303 changes: 303 additions & 0 deletions draco/analysis/flagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@
import numpy as np
import scipy.signal
from caput import config, mpiarray, weighted_median
from cora.util import units
from skimage.filters import apply_hysteresis_threshold

from ..analysis import delay
from ..core import containers, io, task
from ..util import rfi, tools

Expand Down Expand Up @@ -997,6 +1000,306 @@ def process(
return mask_cont


class RFIStokesIMask(task.SingleTask):
"""Two-stage RFI filter based on Stokes I visibilities.

Tries to independently target transient and persistant RFI.

Stage 1 is applied to each frequency independently. A high-pass
filter is applied in RA to isolate transient RFI. The high-pass
filtered visibilities are beamformed, and a MAD filter is applied
to the resulting map. A time/RA sample is then flagged if some
fraction of beams exceed the MAD threshold for that sample.

Stage 2 is applied across frequencies. A low-pass filter is applied
in RA to reduce transient sky sources. The average visibility power
is taken over 2+ cylinder separation baselines to obtain a single
1D array per frequency. These powers are gathered across all
frequencies and a basic background subtraction is applied. A high-sigma
MAD flag is used during the daytime and bright transits, and the
sumthreshold algorithm is used everywhere else.

ljgray marked this conversation as resolved.
Show resolved Hide resolved
Attributes
----------
mad_base_size : list of int, optional
Median absolute deviations base window. Default is [1, 101].
mad_dev_size : list of int, optional
Median absolute deviation median deviation window.
Default is [1, 51].
sigma_high : float, optional
Median absolute deviations sigma threshold. Default is 8.0.
sigma_low : float, optional
Median absolute deviations low sigma threshold. A value above
this threshold is masked only if it is either larger than `sigma_high`
or it is larger than `sigma_low` AND connected to a region larger
than `sigma_high`. Default is 2.0.
frac_samples : float, optional
Fraction of flagged samples in map space above which the entire
time sample will be flagged. Default is 0.01.
st_max_m : int, optional
Maximum size of the SumThreshold window. Default is 32.0.
sigma_day : float, optional
Sigma threshold for the MAD mask applied to bright source transits
and daytime data, which is required to avoid masking out transits.
Generally this should be quite high, as the SumThreshold mask is
applied to the per-frequency median of the data in these regions,
and will catch most bright frequency bands. Default is 10.0.
lowpass_ang : float, optional
Angular cutoff of the ra lowpass filter. Default is 7.5, which
corresponds to about 30 minutes of observation time.
include_multi_channel : bool, optional
If True, include second-stage multi-channel flagging. This should
generally always be included. Default is True.
"""

mad_base_size = config.list_type(int, length=2, default=[1, 101])
mad_dev_size = config.list_type(int, length=2, default=[1, 51])
sigma_high = config.Property(proptype=float, default=8.0)
sigma_low = config.Property(proptype=float, default=2.0)
frac_samples = config.Property(proptype=float, default=0.01)

st_max_m = config.Property(proptype=int, default=32)
sigma_day = config.Property(proptype=float, default=10.0)
lowpass_ang = config.Property(proptype=float, default=7.5)

include_multi_channel = config.Property(proptype=bool, default=True)

def setup(self, telescope):
"""Set up the baseline selections and ordering.

Parameters
----------
telescope : TransitTelescope
The telescope object to use
"""
self.telescope = io.get_telescope(telescope)

def process(self, stream):
"""Make a mask from the data.

Parameters
----------
stream : dcontainers.TimeStream | dcontainers.SiderealStream
Data to use when masking. Axes should be frequency, stack,
and time-like.

Returns
-------
mask : dcontainers.RFIMask | dcontainers.SiderealRFIMask
Time-frequency mask, where values marked `True` are flagged.
"""
stream.redistribute("freq")

csd = stream.attrs.get("lsd", stream.attrs.get("csd"))

if csd is None:
raise ValueError("Dataset does not have a `csd` or `lsd` attribute.")

if "time" in stream.index_map:
times = stream.time
elif "ra" in stream.index_map:
times = self.telescope.lsd_to_unix(csd + stream.ra / 360.0)
else:
raise TypeError(
f"Expected data with `time` or `ra` axis. Got {type(stream)}."
)

ra = 2 * np.pi * (self.telescope.unix_to_lsd(times) - csd)
freq = stream.freq[stream.vis[:].local_bounds]

# Get stokes I and redistribute over frequency. Axes are rearranged
# in order (baseline, freq, time)
vis, weight, baselines = delay.stokes_I(stream, self.telescope)
vis = vis.redistribute(1).local_array
weight = weight.redistribute(1).local_array

# Set up the initial mask
mask = np.all(weight == 0, axis=0)
mask |= self._static_rfi_mask_hook(freq, times[0])[:, np.newaxis]
self.log.debug(f"{100.0 * mask.mean():.2f}% of data initially flagged.")

# Mask scattered transient rfi for each frequency independently
# Also get the average power per frequency after applying a low-pass filter
mask, power = self.mask_single_channel(vis, weight, mask, freq, baselines, ra)

# Gather the entire mask and power arrays
mask = mpiarray.MPIArray.wrap(mask, axis=0).allgather()
power = mpiarray.MPIArray.wrap(power, axis=0).allgather()

if self.include_multi_channel:
# Mask high power across frequencies
mask |= self.mask_multi_channel(power, mask, times)

self.log.debug(f"{100.0 * mask.mean():.2f}% of data flagged.")

if "ra" in stream.index_map:
output = containers.SiderealRFIMask(axes_from=stream, attrs_from=stream)
else:
output = containers.RFIMask(axes_from=stream, attrs_from=stream)

output.mask[:] = mask

return output

def mask_single_channel(self, vis, weight, mask, freq, baselines, ra):
"""Mask scattered rfi."""
# Get the per-frequency high-pass and low-pass cuts
hpf_cut = self._hpf_cut_hook(freq, baselines)
lpf_cut = self._lpf_cut_hook(freq, baselines)
# Select cylinders to include in static power estimation.
# Choose baselines which should not contain much sky structure
bl_sel = baselines[:, 0] > 2.0 * self.telescope.u_width
# Set up an array to store mean power from non-sky sources
power = np.zeros_like(weight[0], dtype=np.float64)

# Iterate over frequencies
for fsel in range(vis.shape[1]):
if np.all(mask[fsel]):
# Frequency is already masked
continue

# Apply a high-pass mmode filter. Scattered emission appears
# similar to an impulse function in time, so its fourier transform
# should extend to high m
v_hpf = self.apply_filter(
vis[:, fsel], weight[:, fsel], ra, hpf_cut[fsel], type_="high"
)

# MAD filter flags scattered emission after beamforming
map_hpf = abs(np.fft.fft(v_hpf, axis=0))
mad_mask = np.zeros_like(v_hpf, dtype=bool) | mask[fsel][np.newaxis]
mad_ = mad(map_hpf, mad_mask, self.mad_base_size, self.mad_dev_size)
# Hysteresis threshold mask flags anything above `sigma_high` or
# anything above `sigma_low` ONLY if it is connected to a region
# above `sigma_high`
mad_mask |= apply_hysteresis_threshold(
mad_, self.sigma_low, self.sigma_high
)
# Collapse over baselines and flag
mean_flagged = np.mean(mad_mask, axis=0)

# Apply a low pass filter
lp_win = (mean_flagged < 0.5)[np.newaxis]
v_lpf = self.apply_filter(
vis[:, fsel], weight[:, fsel] * lp_win, ra, lpf_cut[fsel], type_="low"
)

# Take the average over selected baselines
power[fsel] = np.mean(abs(v_lpf)[bl_sel], axis=0)
# Apply the hp mask
mask[fsel] |= mean_flagged > self.frac_samples

return mask, power

def mask_multi_channel(self, power, mask, times):
"""Mask slow-moving narrow-band RFI."""
# Find times where there are bright sources transiting
source_flag = self._source_flag_hook(times)

# Get a median for each frequency
med = weighted_median.weighted_median(power, (~mask).astype(power.dtype))
# Set power to median when bright sources are in the sky
p = power.copy()
p[:, source_flag] = med[:, np.newaxis]

# Subtract out a background, assuming that the type of
# rfi we're looking for is very localised in frequency
p_med = weighted_median.moving_weighted_median(
p, (~mask).astype(p.dtype), size=(11, 3)
)

# Mask bright data with bright sources removed
summask = rfi.sumthreshold(abs(p - p_med), start_flag=mask, max_m=self.st_max_m)
# Expand the mask in time only. Expanding in frequency generally ends
# up being too aggressive, and the single-channel flagging does a fine
# job at catching broad-spectrum transient rfi
summask |= rfi.sir((summask & ~mask)[:, np.newaxis], only_time=True)[:, 0]

# Extra masking over bright sources
mad_ = mad(power[:, source_flag], summask[:, source_flag])
# Combine with the sumthreshold mask
summask[:, source_flag] |= mad_ > self.sigma_day

# Expand the mask to try to fill small holes in heavily masked areas.
# The values used here are determined experimentally
kf = scipy.signal.windows.gaussian(11, std=5)[:, np.newaxis]
kt = scipy.signal.windows.gaussian(51, std=7)[np.newaxis]
kernel = (kf * kt) ** 0.5

mm = scipy.signal.oaconvolve(summask, kernel, mode="same")
summask |= mm > 0.75 * kernel.sum()

return summask

@staticmethod
def apply_filter(vis, weight, samples, fcut, type_="high"):
"""Apply a high-pass or low-pass mmode filter."""
# Median sampling rate
fs = 1 / np.median(abs(np.diff(samples)))
# Order is sample frequency over cutoff frequency. Ensure order is odd
order = int(np.ceil(fs / fcut) // 2 * 2 + 1)
# Make the window. Flattop seems to work well here
kernel = scipy.signal.firwin(order, fcut, window="flattop", fs=fs)[np.newaxis]

# Low-pass filter the visibilities. `oaconvolve` is significantly
# faster than the standard convolve method
vw_lp = scipy.signal.oaconvolve(vis * weight, kernel, mode="same")
ww_lp = scipy.signal.oaconvolve(weight, kernel, mode="same")
vis_lp = vw_lp * tools.invert_no_zero(ww_lp)

if type_ == "high":
return vis - vis_lp

return vis_lp

def _hpf_cut_hook(self, freq, baselines):
"""Get a high-pass fringe rate cut for each frequency."""
dec = np.deg2rad(self.telescope.latitude)
lambda_inv = freq[:, np.newaxis] * 1e6 / units.c

# Maximum cut per frequency
return lambda_inv * baselines[:, 0].max() / np.cos(dec)

def _lpf_cut_hook(self, freq, baselines):
"""Get a low-pass fringe rate cut for each frequency."""
cut = 1 / np.deg2rad(self.lowpass_ang)

return np.ones(len(freq), dtype=np.float64) * cut

def _static_rfi_mask_hook(self, freq, timestamp=None):
"""Override to mask entire frequency channels.

Parameters
----------
freq : np.ndarray[nfreq]
1D array of frequencies in the data (in MHz).

timestamp : np.array[float]
Start observing time (in unix time)

Returns
-------
mask : np.ndarray[nfreq]
Mask array. True will mask a frequency channel.
"""
return np.zeros_like(freq, dtype=bool)

def _source_flag_hook(self, times):
"""Override to mask out bright point sources.

Parameters
----------
times : np.ndarray[float]
Array of timestamps.

Returns
-------
mask : np.ndarray[float]
Mask array. True will mask out a time sample.
"""
return np.zeros_like(times, dtype=bool)


class RFISensitivityMask(task.SingleTask):
"""Slightly less crappy RFI masking.

Expand Down
Loading
Loading