From 9d678ebce486f3e4e3b5c00a3dff8513493c9045 Mon Sep 17 00:00:00 2001 From: Richard Shaw Date: Wed, 12 Aug 2020 14:44:15 -0700 Subject: [PATCH] fix(containers): could not make selections along inherited axes This fixes an issue making selections along inherited axes and straightens out some issues with `_make_selections` being an instance method instead of a class method. --- draco/core/containers.py | 30 +++++++++++++++++++++--------- 1 file changed, 21 insertions(+), 9 deletions(-) diff --git a/draco/core/containers.py b/draco/core/containers.py index a5c1ea141..e9a525b46 100644 --- a/draco/core/containers.py +++ b/draco/core/containers.py @@ -312,30 +312,42 @@ def dataset_spec(self): # Ensure that the dataset_spec is the same order on all ranks return {k: ddict[k] for k in sorted(ddict)} - @property - def axes(self): - """Return the set of axes for this container.. + @classmethod + def _class_axes(cls): + """Return the set of axes for this container defined by this class and the base classes. """ axes = set() # Iterate over the reversed MRO and look for _table_spec attributes # which get added to a temporary dict. We go over the reversed MRO so # that the `tdict.update` overrides tables in base classes. - for cls in inspect.getmro(self.__class__)[::-1]: + for c in inspect.getmro(cls)[::-1]: try: - axes |= set(cls._axes) + axes |= set(c._axes) except AttributeError: pass - # Add in any axes found on the instance + # This must be the same order on all ranks, so we need to explicitly sort to get around the + # hash randomization + return tuple(sorted(axes)) + + @property + def axes(self): + """The set of axes for this container including any defined on the instance. + """ + axes = set(self._class_axes()) + + # Add in any axes found on the instance (this is needed to support the table classes where + # the axes get added at run time) axes |= set(self.__dict__.get("_axes", [])) # This must be the same order on all ranks, so we need to explicitly sort to get around the # hash randomization return tuple(sorted(axes)) - def _make_selections(self, sel_args): + @classmethod + def _make_selections(cls, sel_args): """ Match down-selection arguments to axes of datasets. @@ -355,12 +367,12 @@ def _make_selections(self, sel_args): """ # Check if all those axes exist for axis in sel_args.keys(): - if axis not in self._axes: + if axis not in cls._class_axes(): raise RuntimeError("No '{}' axis found to select from.".format(axis)) # Build selections dict selections = {} - for name, dataset in self._dataset_spec.items(): + for name, dataset in cls._dataset_spec.items(): ds_axes = dataset["axes"] sel = [] ds_relevant = False