Skip to content

Commit

Permalink
fix(flagging): add a proper noise estimation to sumthreshold
Browse files Browse the repository at this point in the history
  • Loading branch information
ljgray committed Jun 10, 2024
1 parent 6c51753 commit ad55e0d
Showing 1 changed file with 53 additions and 30 deletions.
83 changes: 53 additions & 30 deletions draco/analysis/flagging.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
from typing import Union, overload

import numpy as np
import scipy.signal
from caput import config, mpiarray, weighted_median
from cora.util import units
from scipy.signal import convolve, firwin, oaconvolve
from skimage.filters import apply_hysteresis_threshold

from ..analysis import delay
from ..analysis import delay, transform
from ..core import containers, io, task
from ..util import rfi, tools

Expand Down Expand Up @@ -393,7 +393,7 @@ def process(self, data):
mask_extended = np.zeros_like(mask)
for ind in np.ndindex(*shp):
mask_extended[ind] = (
scipy.signal.convolve(
convolve(
mask[ind].astype(np.float32),
kernel,
mode="same",
Expand Down Expand Up @@ -1000,7 +1000,7 @@ def process(
return mask_cont


class RFIStokesIMask(task.SingleTask):
class RFIStokesIMask(transform.ReduceVar):
"""Two-stage RFI filter based on Stokes I visibilities.
Tries to independently target transient and persistant RFI.
Expand All @@ -1016,7 +1016,9 @@ class RFIStokesIMask(task.SingleTask):
is taken over 2+ cylinder separation baselines to obtain a single
1D array per frequency. These powers are gathered across all
frequencies and a basic background subtraction is applied. Sumthreshold
algorithm is then used for flagging.
algorithm is then used for flagging, with a variance estimate used to
boost the expected noise during the daytime and bright point source
transits.
Attributes
----------
Expand All @@ -1042,7 +1044,15 @@ class RFIStokesIMask(task.SingleTask):
Maximum size of the SumThreshold window. Default is 16.
nsigma : float, optional
Initial threshold for SumThreshold. Default is 5.0.
lowpass_ang : float, optional
bg_win_size : list, optional
The size of the window used to estimate the background sky, provided
as (number of frequency channels, number of time samples).
Default is [11, 3].
var_win_size : list, optional
The size of the window used when estimating the variance, provided
as (number of frequency channels, number of time samples).
Default is [3, 101].
lowpass_cutoff : float, optional
Angular cutoff of the ra lowpass filter. Default is 7.5, which
corresponds to about 30 minutes of observation time.
"""
Expand All @@ -1056,7 +1066,9 @@ class RFIStokesIMask(task.SingleTask):
include_sumthreshold = config.Property(proptype=bool, default=True)
max_m = config.Property(proptype=int, default=16)
nsigma = config.Property(proptype=float, default=5.0)
lowpass_ang = config.Property(proptype=float, default=7.5)
bg_win_size = config.list_type(int, length=2, default=[11, 3])
var_win_size = config.list_type(int, length=2, default=[3, 101])
lowpass_cutoff = config.Property(proptype=float, default=7.5)

def setup(self, telescope):
"""Set up the baseline selections and ordering.
Expand All @@ -1067,6 +1079,8 @@ def setup(self, telescope):
The telescope object to use
"""
self.telescope = io.get_telescope(telescope)
# Set the parent class attribute to use the correct weighting
self.weighting = "weighted"

def process(self, stream):
"""Make a mask from the data.
Expand Down Expand Up @@ -1206,31 +1220,40 @@ def mask_single_channel(self, vis, weight, mask, freq, baselines, ra):

def mask_multi_channel(self, power, mask, times):
"""Mask slow-moving narrow-band RFI."""
# Make a copy of the power dataset since it will be
# modified in place
power = power.copy()
# Find times where there are bright sources transiting
source_flag = self._source_flag_hook(times)

# Get a median for each frequency
med = weighted_median.weighted_median(power, (~mask).astype(power.dtype))
# Set power to median when bright sources are in the sky
power[:, source_flag] = med[:, np.newaxis]

# Subtract out a background, assuming that the type of
# rfi we're looking for is very localised in frequency
power -= weighted_median.moving_weighted_median(
power, (~mask).astype(power.dtype), size=(11, 3)
)

# Mask bright data with bright sources and daytime removed
# Avoid bright parts of the sky in the variance estimation
weight = ~mask & ~source_flag[np.newaxis]

# Calculate the weighted variance over time, excluding daytime
# and times where bright sources are transiting
wvar, ws = self.reduction(power, weight, axis=1)
# Get a smoothed estimate of the per-frequency variance
wvar = tools.arPLS_1d(wvar, ws == 0, lam=1e1)[:, np.newaxis]

# Get a background estimate of the sky, assuming that the
# type of rfi we're looking for is very localised in frequency
p_med = medfilt(power, mask, size=self.bg_win_size)

# Create an estimate of the variance for each sample. Find the
# ratio of a rolling median of the background sky to the overall
# median in time and multiply this ratio by the per-frequency
# variance estimate
med = weighted_median.weighted_median(p_med, (~mask).astype(p_med.dtype))
rmed = medfilt(p_med, mask, size=self.var_win_size)
var = wvar * rmed * tools.invert_no_zero(med)[:, np.newaxis]

# Generate an RFI mask from the background-subtracted data
summask = rfi.sumthreshold(
power, start_flag=mask, max_m=self.max_m, threshold1=self.nsigma
power - p_med,
start_flag=mask,
max_m=self.max_m,
threshold1=self.nsigma,
variance=var,
)

# Expand the mask in time only. Expanding in frequency generally ends
# up being too aggressive, and the single-channel flagging does a fine
# job at catching broad-spectrum transient rfi
# up being too aggressive
summask |= rfi.sir((summask & ~mask)[:, np.newaxis], only_time=True)[:, 0]

return summask
Expand All @@ -1243,12 +1266,12 @@ def apply_filter(vis, weight, samples, fcut, type_="high"):
# Order is sample frequency over cutoff frequency. Ensure order is odd
order = int(np.ceil(fs / fcut) // 2 * 2 + 1)
# Make the window. Flattop seems to work well here
kernel = scipy.signal.firwin(order, fcut, window="flattop", fs=fs)[np.newaxis]
kernel = firwin(order, fcut, window="flattop", fs=fs)[np.newaxis]

# Low-pass filter the visibilities. `oaconvolve` is significantly
# faster than the standard convolve method
vw_lp = scipy.signal.oaconvolve(vis * weight, kernel, mode="same")
ww_lp = scipy.signal.oaconvolve(weight, kernel, mode="same")
vw_lp = oaconvolve(vis * weight, kernel, mode="same")
ww_lp = oaconvolve(weight, kernel, mode="same")
vis_lp = vw_lp * tools.invert_no_zero(ww_lp)

if type_ == "high":
Expand All @@ -1266,7 +1289,7 @@ def _hpf_cut_hook(self, freq, baselines):

def _lpf_cut_hook(self, freq, baselines):
"""Get a low-pass fringe rate cut for each frequency."""
cut = 1 / np.deg2rad(self.lowpass_ang)
cut = 1 / np.deg2rad(self.lowpass_cutoff)

return np.ones(len(freq), dtype=np.float64) * cut

Expand Down

0 comments on commit ad55e0d

Please sign in to comment.