diff --git a/draco/core/io.py b/draco/core/io.py index c88489b6c..2c603ccab 100644 --- a/draco/core/io.py +++ b/draco/core/io.py @@ -45,6 +45,8 @@ from caput import pipeline from caput import config +from cora.util import units + from . import task from ..util.truncate import bit_truncate_weights, bit_truncate_fixed from .containers import SiderealStream, TimeStream, TrackBeam @@ -204,6 +206,101 @@ def next(self): return map_stack +class LoadFITSCatalog(task.SingleTask): + """Load an SDSS-style FITS source catalog. + + Catalogs are given as one, or a list of `File Groups` (see + :mod:`draco.core.io`). Catalogs within the same group are combined together + before being passed on. + + Attributes + ---------- + catalogs : list or dict + A dictionary specifying a file group, or a list of them. + z_range : list, optional + Select only sources with a redshift within the given range. + freq_range : list, optional + Select only sources with a 21cm line freq within the given range. Overrides + `z_range`. + """ + + catalogs = config.Property(proptype=_list_of_filegroups) + z_range = config.list_type(type_=float, length=2, default=None) + freq_range = config.list_type(type_=float, length=2, default=None) + + def process(self): + """Load the groups of catalogs from disk, concatenate them and pass them on. + + Returns + ------- + catalog : :class:`containers.SpectroscopicCatalog` + """ + + from astropy.io import fits + from . import containers + + # Exit this task if we have eaten all the file groups + if len(self.catalogs) == 0: + raise pipeline.PipelineStopIteration + + group = self.catalogs.pop(0) + + # Set the redshift selection + if self.freq_range: + zl = units.nu21 / self.freq_range[1] - 1 + zh = units.nu21 / self.freq_range[0] - 1 + self.z_range = (zl, zh) + + if self.z_range: + zl, zh = self.z_range + self.log.info(f"Applying redshift selection {zl:.2f} <= z <= {zh:.2f}") + + # Load the data only on rank=0 and then broadcast + if self.comm.rank == 0: + # Iterate over all the files in the group, load them into a Map + # container and add them all together + catalog_stack = [] + for cfile in group["files"]: + + self.log.debug("Loading file %s", cfile) + + # TODO: read out the weights from the catalogs + with fits.open(cfile, mode="readonly") as cat: + pos = np.array([cat[1].data[col] for col in ["RA", "DEC", "Z"]]) + + # Apply any redshift selection to the objects + if self.z_range: + zsel = (pos[2] >= self.z_range[0]) & (pos[2] <= self.z_range[1]) + pos = pos[:, zsel] + + catalog_stack.append(pos) + + catalog_array = np.concatenate(catalog_stack, axis=-1).astype(np.float64) + num_objects = catalog_array.shape[-1] + else: + num_objects = None + catalog_array = None + + # Broadcast the size of the catalog to all ranks, create the target array and + # broadcast into it + num_objects = self.comm.bcast(num_objects, root=0) + self.log.debug(f"Constructing catalog with {num_objects} objects.") + if self.comm.rank != 0: + catalog_array = np.zeros((3, num_objects), dtype=np.float64) + self.comm.Bcast(catalog_array, root=0) + + catalog = containers.SpectroscopicCatalog(object_id=num_objects) + catalog["position"]["ra"] = catalog_array[0] + catalog["position"]["dec"] = catalog_array[1] + catalog["redshift"]["z"] = catalog_array[2] + catalog["redshift"]["z_error"] = 0 + + # Assign a tag to the stack of maps + catalog.attrs["tag"] = group["tag"] + + return catalog + + class LoadFilesFromParams(task.SingleTask): """Load data from files given in the tasks parameters.