Skip to content

Commit

Permalink
fix(containers): could not make selections along inherited axes
Browse files Browse the repository at this point in the history
This fixes an issue making selections along inherited axes and
straightens out some issues with `_make_selections` being an instance
method instead of a class method.
  • Loading branch information
jrs65 committed Aug 14, 2020
1 parent ddca417 commit 9d678eb
Showing 1 changed file with 21 additions and 9 deletions.
30 changes: 21 additions & 9 deletions draco/core/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,30 +312,42 @@ def dataset_spec(self):
# Ensure that the dataset_spec is the same order on all ranks
return {k: ddict[k] for k in sorted(ddict)}

@property
def axes(self):
"""Return the set of axes for this container..
@classmethod
def _class_axes(cls):
"""Return the set of axes for this container defined by this class and the base classes.
"""
axes = set()

# Iterate over the reversed MRO and look for _table_spec attributes
# which get added to a temporary dict. We go over the reversed MRO so
# that the `tdict.update` overrides tables in base classes.
for cls in inspect.getmro(self.__class__)[::-1]:
for c in inspect.getmro(cls)[::-1]:

try:
axes |= set(cls._axes)
axes |= set(c._axes)
except AttributeError:
pass

# Add in any axes found on the instance
# This must be the same order on all ranks, so we need to explicitly sort to get around the
# hash randomization
return tuple(sorted(axes))

@property
def axes(self):
"""The set of axes for this container including any defined on the instance.
"""
axes = set(self._class_axes())

# Add in any axes found on the instance (this is needed to support the table classes where
# the axes get added at run time)
axes |= set(self.__dict__.get("_axes", []))

# This must be the same order on all ranks, so we need to explicitly sort to get around the
# hash randomization
return tuple(sorted(axes))

def _make_selections(self, sel_args):
@classmethod
def _make_selections(cls, sel_args):
"""
Match down-selection arguments to axes of datasets.
Expand All @@ -355,12 +367,12 @@ def _make_selections(self, sel_args):
"""
# Check if all those axes exist
for axis in sel_args.keys():
if axis not in self._axes:
if axis not in cls._class_axes():
raise RuntimeError("No '{}' axis found to select from.".format(axis))

# Build selections dict
selections = {}
for name, dataset in self._dataset_spec.items():
for name, dataset in cls._dataset_spec.items():
ds_axes = dataset["axes"]
sel = []
ds_relevant = False
Expand Down

0 comments on commit 9d678eb

Please sign in to comment.