From ad55e0d2942f28ef132e4ee6c50280667534b2e8 Mon Sep 17 00:00:00 2001 From: Liam Gray Date: Wed, 5 Jun 2024 18:58:34 -0700 Subject: [PATCH] fix(flagging): add a proper noise estimation to sumthreshold --- draco/analysis/flagging.py | 83 ++++++++++++++++++++++++-------------- 1 file changed, 53 insertions(+), 30 deletions(-) diff --git a/draco/analysis/flagging.py b/draco/analysis/flagging.py index b889be236..673b71eb3 100644 --- a/draco/analysis/flagging.py +++ b/draco/analysis/flagging.py @@ -11,12 +11,12 @@ from typing import Union, overload import numpy as np -import scipy.signal from caput import config, mpiarray, weighted_median from cora.util import units +from scipy.signal import convolve, firwin, oaconvolve from skimage.filters import apply_hysteresis_threshold -from ..analysis import delay +from ..analysis import delay, transform from ..core import containers, io, task from ..util import rfi, tools @@ -393,7 +393,7 @@ def process(self, data): mask_extended = np.zeros_like(mask) for ind in np.ndindex(*shp): mask_extended[ind] = ( - scipy.signal.convolve( + convolve( mask[ind].astype(np.float32), kernel, mode="same", @@ -1000,7 +1000,7 @@ def process( return mask_cont -class RFIStokesIMask(task.SingleTask): +class RFIStokesIMask(transform.ReduceVar): """Two-stage RFI filter based on Stokes I visibilities. Tries to independently target transient and persistant RFI. @@ -1016,7 +1016,9 @@ class RFIStokesIMask(task.SingleTask): is taken over 2+ cylinder separation baselines to obtain a single 1D array per frequency. These powers are gathered across all frequencies and a basic background subtraction is applied. Sumthreshold - algorithm is then used for flagging. + algorithm is then used for flagging, with a variance estimate used to + boost the expected noise during the daytime and bright point source + transits. Attributes ---------- @@ -1042,7 +1044,15 @@ class RFIStokesIMask(task.SingleTask): Maximum size of the SumThreshold window. Default is 16. nsigma : float, optional Initial threshold for SumThreshold. Default is 5.0. - lowpass_ang : float, optional + bg_win_size : list, optional + The size of the window used to estimate the background sky, provided + as (number of frequency channels, number of time samples). + Default is [11, 3]. + var_win_size : list, optional + The size of the window used when estimating the variance, provided + as (number of frequency channels, number of time samples). + Default is [3, 101]. + lowpass_cutoff : float, optional Angular cutoff of the ra lowpass filter. Default is 7.5, which corresponds to about 30 minutes of observation time. """ @@ -1056,7 +1066,9 @@ class RFIStokesIMask(task.SingleTask): include_sumthreshold = config.Property(proptype=bool, default=True) max_m = config.Property(proptype=int, default=16) nsigma = config.Property(proptype=float, default=5.0) - lowpass_ang = config.Property(proptype=float, default=7.5) + bg_win_size = config.list_type(int, length=2, default=[11, 3]) + var_win_size = config.list_type(int, length=2, default=[3, 101]) + lowpass_cutoff = config.Property(proptype=float, default=7.5) def setup(self, telescope): """Set up the baseline selections and ordering. @@ -1067,6 +1079,8 @@ def setup(self, telescope): The telescope object to use """ self.telescope = io.get_telescope(telescope) + # Set the parent class attribute to use the correct weighting + self.weighting = "weighted" def process(self, stream): """Make a mask from the data. @@ -1206,31 +1220,40 @@ def mask_single_channel(self, vis, weight, mask, freq, baselines, ra): def mask_multi_channel(self, power, mask, times): """Mask slow-moving narrow-band RFI.""" - # Make a copy of the power dataset since it will be - # modified in place - power = power.copy() # Find times where there are bright sources transiting source_flag = self._source_flag_hook(times) - - # Get a median for each frequency - med = weighted_median.weighted_median(power, (~mask).astype(power.dtype)) - # Set power to median when bright sources are in the sky - power[:, source_flag] = med[:, np.newaxis] - - # Subtract out a background, assuming that the type of - # rfi we're looking for is very localised in frequency - power -= weighted_median.moving_weighted_median( - power, (~mask).astype(power.dtype), size=(11, 3) - ) - - # Mask bright data with bright sources and daytime removed + # Avoid bright parts of the sky in the variance estimation + weight = ~mask & ~source_flag[np.newaxis] + + # Calculate the weighted variance over time, excluding daytime + # and times where bright sources are transiting + wvar, ws = self.reduction(power, weight, axis=1) + # Get a smoothed estimate of the per-frequency variance + wvar = tools.arPLS_1d(wvar, ws == 0, lam=1e1)[:, np.newaxis] + + # Get a background estimate of the sky, assuming that the + # type of rfi we're looking for is very localised in frequency + p_med = medfilt(power, mask, size=self.bg_win_size) + + # Create an estimate of the variance for each sample. Find the + # ratio of a rolling median of the background sky to the overall + # median in time and multiply this ratio by the per-frequency + # variance estimate + med = weighted_median.weighted_median(p_med, (~mask).astype(p_med.dtype)) + rmed = medfilt(p_med, mask, size=self.var_win_size) + var = wvar * rmed * tools.invert_no_zero(med)[:, np.newaxis] + + # Generate an RFI mask from the background-subtracted data summask = rfi.sumthreshold( - power, start_flag=mask, max_m=self.max_m, threshold1=self.nsigma + power - p_med, + start_flag=mask, + max_m=self.max_m, + threshold1=self.nsigma, + variance=var, ) # Expand the mask in time only. Expanding in frequency generally ends - # up being too aggressive, and the single-channel flagging does a fine - # job at catching broad-spectrum transient rfi + # up being too aggressive summask |= rfi.sir((summask & ~mask)[:, np.newaxis], only_time=True)[:, 0] return summask @@ -1243,12 +1266,12 @@ def apply_filter(vis, weight, samples, fcut, type_="high"): # Order is sample frequency over cutoff frequency. Ensure order is odd order = int(np.ceil(fs / fcut) // 2 * 2 + 1) # Make the window. Flattop seems to work well here - kernel = scipy.signal.firwin(order, fcut, window="flattop", fs=fs)[np.newaxis] + kernel = firwin(order, fcut, window="flattop", fs=fs)[np.newaxis] # Low-pass filter the visibilities. `oaconvolve` is significantly # faster than the standard convolve method - vw_lp = scipy.signal.oaconvolve(vis * weight, kernel, mode="same") - ww_lp = scipy.signal.oaconvolve(weight, kernel, mode="same") + vw_lp = oaconvolve(vis * weight, kernel, mode="same") + ww_lp = oaconvolve(weight, kernel, mode="same") vis_lp = vw_lp * tools.invert_no_zero(ww_lp) if type_ == "high": @@ -1266,7 +1289,7 @@ def _hpf_cut_hook(self, freq, baselines): def _lpf_cut_hook(self, freq, baselines): """Get a low-pass fringe rate cut for each frequency.""" - cut = 1 / np.deg2rad(self.lowpass_ang) + cut = 1 / np.deg2rad(self.lowpass_cutoff) return np.ones(len(freq), dtype=np.float64) * cut