diff --git a/draco/core/task.py b/draco/core/task.py index 0126cac5a..b849945cf 100644 --- a/draco/core/task.py +++ b/draco/core/task.py @@ -259,10 +259,11 @@ class SingleTask(MPILoggedTask, pipeline.BasicContMixin): compression : dict or bool, optional Set compression options for each dataset. Provided as a dict with the dataset names as keys and values for `chunks`, `compression`, and `compression_opts`. - If set to `False`, chunks and compression will be disabled for all datasets. - Otherwise, the default parameters set in the dataset spec are used. - Note that this will modify these parameters on the container itself, such that - if it is written out again downstream in the pipeline these will be used. + If set to `False` (or anything that evaluates to `False`, such as an empty dict), + chunks and compression will be disabled for all datasets. Otherwise, the default + parameters set in the dataset spec are used. Note that this will modify these + parameters on the container itself, such that if it is written out again downstream + in the pipeline these will be used. output_root : string Pipeline settable parameter giving the first part of the output path. Deprecated in favour of `output_name`. @@ -291,7 +292,7 @@ class SingleTask(MPILoggedTask, pipeline.BasicContMixin): output_name = config.Property(default="{output_root}{tag}.h5", proptype=str) output_format = config.file_format() compression = config.Property( - default={}, proptype=lambda x: x if isinstance(x, dict) else bool(x) + default=True, proptype=lambda x: x if isinstance(x, dict) else bool(x) ) nan_check = config.Property(default=True, proptype=bool) @@ -444,12 +445,8 @@ def walk_dset_tree(grp, root=""): datasets.append(root + key) return datasets - if not self.compression: - for ds in walk_dset_tree(output): - output._data._storage_root[ds].chunks = None - output._data._storage_root[ds].compression = None - output._data._storage_root[ds].compression_opts = None - else: + if isinstance(self.compression, dict): + # We want to overwrite some compression settings datasets = walk_dset_tree(output) for ds in self.compression: if ds in datasets: @@ -459,7 +456,10 @@ def walk_dset_tree(grp, root=""): ) setattr(output._data._storage_root[ds], key, val) # shorthand for bitshuffle - if output[ds].compression in ("bitshuffle", fileformats.H5FILTER): + if output[ds].compression in ( + "bitshuffle", + fileformats.H5FILTER, + ): output[ds].compression = fileformats.H5FILTER if output[ds].compression_opts is None: output._data._storage_root[ds].compression_opts = ( @@ -470,6 +470,12 @@ def walk_dset_tree(grp, root=""): self.log.warning( f"Ignoring config entry in `compression` for non-existing dataset `{ds}`" ) + elif not self.compression: + # Disable compression + for ds in walk_dset_tree(output): + output._data._storage_root[ds].chunks = None + output._data._storage_root[ds].compression = None + output._data._storage_root[ds].compression_opts = None # Routine to write output if needed. if self.save: