diff --git a/src/awkward/_layout.py b/src/awkward/_layout.py index d11c503acb..8ebe329d77 100644 --- a/src/awkward/_layout.py +++ b/src/awkward/_layout.py @@ -65,6 +65,22 @@ def __init__(self, behavior: Mapping | None = None, attrs: Mapping | None = None self._attrs_from_objects = [] self._behavior_from_objects = [] + def with_attr(self, key, value) -> Self: + self._ensure_finalized() + return type(self)( + behavior=self.behavior, + attrs={**self.attrs, key: value}, + ).finalize() + + def without_attr(self, key) -> Self: + self._ensure_finalized() + attrs = dict(self.attrs) + attrs.pop(key, None) + return type(self)( + behavior=self.behavior, + attrs=attrs, + ).finalize() + def __enter__(self): return self diff --git a/src/awkward/_namedaxis.py b/src/awkward/_namedaxis.py index 7da5dcd8fd..5fc7d70fe5 100644 --- a/src/awkward/_namedaxis.py +++ b/src/awkward/_namedaxis.py @@ -3,16 +3,13 @@ import awkward._typing as tp from awkward._regularize import is_integer -if tp.TYPE_CHECKING: - pass - - # axis names are hashables, mostly strings, # except for integers, which are reserved for positional axis. AxisName: tp.TypeAlias = tp.Hashable # e.g.: {"x": 0, "y": 1, "z": 2} AxisMapping: tp.TypeAlias = tp.Mapping[AxisName, int] + # e.g.: ("x", "y", None) where None is a wildcard AxisTuple: tp.TypeAlias = tp.Tuple[AxisName, ...] @@ -30,19 +27,27 @@ class MaybeSupportsNamedAxis(tp.Protocol): def attrs(self) -> tp.Mapping | AttrsNamedAxisMapping: ... +# just a class for inplace mutation +class NamedAxis: + mapping: AxisMapping + + +NamedAxis.mapping = {} + + def _get_named_axis( ctx: MaybeSupportsNamedAxis | AttrsNamedAxisMapping | tp.Mapping, -) -> AxisTuple: +) -> AxisMapping: """ - Retrieves the named axis from the given context. The context can be an object that supports named axis + Retrieves the named axis from the provided context. The context can either be an object that supports named axis or a dictionary that includes a named axis mapping. Args: - ctx (MaybeSupportsNamedAxis | AttrsNamedAxisMapping): The context from which to retrieve the named axis. + ctx (MaybeSupportsNamedAxis | AttrsNamedAxisMapping): The context from which the named axis is to be retrieved. Returns: - AxisTuple: The named axis retrieved from the context. If the context does not include a named axis, - an empty tuple is returned. + AxisMapping: The named axis retrieved from the context. If the context does not include a named axis, + an empty dictionary is returned. Examples: >>> class Test(MaybeSupportsNamedAxis): @@ -51,18 +56,18 @@ def _get_named_axis( ... return {_NamedAxisKey: {"x": 0, "y": 1, "z": 2}} ... >>> _get_named_axis(Test()) - ("x", "y", "z") + {"x": 0, "y": 1, "z": 2} >>> _get_named_axis({_NamedAxisKey: {"x": 0, "y": 1, "z": 2}}) - ("x", "y", "z") + {"x": 0, "y": 1, "z": 2} >>> _get_named_axis({"other_key": "other_value"}) - () + {} """ if isinstance(ctx, MaybeSupportsNamedAxis): return _get_named_axis(ctx.attrs) elif isinstance(ctx, tp.Mapping) and _NamedAxisKey in ctx: - return _axis_mapping_to_tuple(ctx[_NamedAxisKey]) + return dict(ctx[_NamedAxisKey]) else: - return () + return {} def _supports_named_axis(ctx: MaybeSupportsNamedAxis | AttrsNamedAxisMapping) -> bool: @@ -77,29 +82,21 @@ def _supports_named_axis(ctx: MaybeSupportsNamedAxis | AttrsNamedAxisMapping) -> return bool(_get_named_axis(ctx)) -def _positional_axis_from_named_axis(named_axis: AxisTuple) -> tuple[int, ...]: +def _make_positional_axis_tuple(n: int) -> tuple[int, ...]: """ - Converts a named axis to a positional axis. + Generates a positional axis tuple of length n. Args: - named_axis (AxisTuple): The named axis to convert. + n (int): The length of the positional axis tuple to generate. Returns: - tuple[int, ...]: The positional axis corresponding to the named axis. + tuple[int, ...]: The generated positional axis tuple. Examples: - >>> _positional_axis_from_named_axis(("x", "y", "z")) + >>> _make_positional_axis_tuple(3) (0, 1, 2) """ - return tuple(range(len(named_axis))) - - -class TmpNamedAxisMarker: - """ - The TmpNamedAxisMarker class serves as a placeholder for axis wildcards. It is used temporarily during - the process of axis manipulation and conversion. This marker helps in identifying the positions - in the axis tuple that are yet to be assigned a specific axis name or value. - """ + return tuple(range(n)) def _is_valid_named_axis(axis: AxisName) -> bool: @@ -115,10 +112,20 @@ def _is_valid_named_axis(axis: AxisName) -> bool: Examples: >>> _is_valid_named_axis("x") True + >>> _is_valid_named_axis(NamedAxisMarker()) + True >>> _is_valid_named_axis(1) False """ - return isinstance(axis, AxisName) and not is_integer(axis) + return ( + # axis must be hashable + isinstance(axis, AxisName) + # ... but not an integer, otherwise we would confuse it with positional axis + and not is_integer(axis) + # Let's only allow strings for now, in the future we can open up to more types + # by removing the isinstance(axis, str) check. + and isinstance(axis, str) + ) def _check_valid_axis(axis: AxisName) -> AxisName: @@ -147,29 +154,6 @@ def _check_valid_axis(axis: AxisName) -> AxisName: return axis -def _check_axis_mapping_unique_values(axis_mapping: AxisMapping) -> None: - """ - Checks if the values in the given axis mapping are unique. If not, raises a ValueError. - - Args: - axis_mapping (AxisMapping): The axis mapping to check. - - Raises: - ValueError: If the values in the axis mapping are not unique. - - Examples: - >>> _check_axis_mapping_unique_values({"x": 0, "y": 1, "z": 2}) - >>> _check_axis_mapping_unique_values({"x": 0, "y": 0, "z": 2}) - Traceback (most recent call last): - ... - ValueError: Named axis mapping must be unique for each positional axis, got: {"x": 0, "y": 0, "z": 2} - """ - if len(set(axis_mapping.values())) != len(axis_mapping): - raise ValueError( - f"Named axis mapping must be unique for each positional axis, got: {axis_mapping}" - ) - - def _axis_tuple_to_mapping(axis_tuple: AxisTuple) -> AxisMapping: """ Converts a tuple of axis names to a dictionary mapping axis names to their positions. @@ -181,104 +165,36 @@ def _axis_tuple_to_mapping(axis_tuple: AxisTuple) -> AxisMapping: AxisMapping: A dictionary mapping axis names to their positions. Examples: - >>> _axis_tuple_to_mapping(("x", "y", None)) - {"x": 0, "y": 1, TmpNamedAxisMarker(): 2} + >>> _axis_tuple_to_mapping(("x", None, "y")) + {"x": 0, "y": 2} """ - return { - (_check_valid_axis(axis) if axis is not None else TmpNamedAxisMarker()): i - for i, axis in enumerate(axis_tuple) - } + return {axis: i for i, axis in enumerate(axis_tuple) if axis is not None} -def _axis_mapping_to_tuple(axis_mapping: AxisMapping) -> AxisTuple: +def _named_axis_to_positional_axis( + named_axis: AxisMapping, + axis: AxisName, +) -> int: """ - Converts a dictionary mapping of axis names to their positions to a tuple of axis names. - Does not allow the same values to be repeated in the mapping. + Converts a single named axis to a positional axis. Args: - axis_mapping (AxisMapping): A dictionary mapping axis names to their positions. + axis (AxisName): The named axis to convert. + named_axis (AxisMapping): The mapping from named axes to positional axes. Returns: - AxisTuple: A tuple of axis names. None is used as a placeholder for TmpNamedAxisMarker. - - Examples: - >>> _axis_mapping_to_tuple({"x": 0, "y": 1, TmpNamedAxisMarker(): 2}) - ("x", "y", None) - >>> _axis_mapping_to_tuple({"x": 0, "y": -1, TmpNamedAxisMarker(): 1}) - ("x", None, "y") - >>> _axis_mapping_to_tuple({"x": 0, "y": 0, TmpNamedAxisMarker(): 1}) - Traceback (most recent call last): - ... - ValueError: Axis positions must be unique, got: {"x": 0, "y": 0, TmpNamedAxisMarker(): 1} - """ - _check_axis_mapping_unique_values(axis_mapping) - - axis_list: list[AxisName | None] = [None] * len(axis_mapping) - for ax, pos in axis_mapping.items(): - if isinstance(ax, TmpNamedAxisMarker): - axis_list[pos] = None - else: - axis_list[pos] = _check_valid_axis(ax) - return tuple(axis_list) - - -def _any_axis_to_positional_axis( - axis: AxisName | AxisTuple, - named_axis: AxisTuple, -) -> AxisTuple | int | None: - """ - Converts any axis (int, AxisName, AxisTuple, or None) to a positional axis (int or AxisTuple). - - Args: - axis (int | AxisName | AxisTuple | None): The axis to convert. Can be an integer, an AxisName, an AxisTuple, or None. - named_axis (AxisTuple): The named axis mapping to use for conversion. - - Returns: - int | AxisTuple | None: The converted axis. Will be an integer, an AxisTuple, or None. + int | None: The positional axis corresponding to the given named axis. If the named axis is not found in the mapping, returns None. Raises: - ValueError: If the axis is not found in the named axis mapping. + ValueError: If the named axis is not found in the named axis mapping. Examples: - >>> _any_axis_to_positional_axis("x", ("x", "y", "z")) + >>> _named_axis_to_positional_axis({"x": 0, "y": 1, "z": 2}, "x") 0 - >>> _any_axis_to_positional_axis(("x", "z"), ("x", "y", "z")) - (0, 2) """ - if isinstance(axis, (tuple, list)): - return tuple(_one_axis_to_positional_axis(ax, named_axis) for ax in axis) - else: - return _one_axis_to_positional_axis(axis, named_axis) - - -def _one_axis_to_positional_axis( - axis: AxisName | None, - named_axis: AxisTuple, -) -> int | None: - """ - Converts a single axis (int, AxisName, or None) to a positional axis (int or None). - - Args: - axis (int | AxisName | None): The axis to convert. Can be an integer, an AxisName, or None. - named_axis (AxisTuple): The named axis mapping to use for conversion. - - Returns: - int | None: The converted axis. Will be an integer or None. - - Raises: - ValueError: If the axis is not found in the named axis mapping. - - Examples: - >>> _one_axis_to_positional_axis("x", ("x", "y", "z")) - 0 - """ - positional_axis = _positional_axis_from_named_axis(named_axis) - if isinstance(axis, int) or axis is None: - return axis - elif axis in named_axis: - return positional_axis[named_axis.index(axis)] - else: - raise ValueError(f"Invalid axis '{axis}'") + if axis not in named_axis: + raise ValueError(f"{axis=} not found in {named_axis=} mapping.") + return named_axis[axis] def _set_named_axis_to_attrs( @@ -313,8 +229,7 @@ def _set_named_axis_to_attrs( if isinstance(named_axis, tuple): named_axis_mapping = _axis_tuple_to_mapping(named_axis) elif isinstance(named_axis, dict): - _check_axis_mapping_unique_values(named_axis) - named_axis_mapping = {**attrs.get(_NamedAxisKey, {}), **named_axis} + named_axis_mapping = named_axis else: raise TypeError(f"named_axis must be a tuple or dict, not {named_axis}") @@ -332,9 +247,10 @@ def _set_named_axis_to_attrs( # See studies/named_axis.md#named-axis-in-high-level-functions and # https://pytorch.org/docs/stable/name_inference.html. # -# The strategies are: -# - "keep all" (_identity_named_axis): Keep all named axes in the output array, e.g.: `ak.drop_none` -# - "keep one" (_keep_named_axis): Keep one named axes in the output array, e.g.: `ak.firsts` +# The possible strategies are: +# - "keep all" (_keep_named_axis(..., None)): Keep all named axes in the output array, e.g.: `ak.drop_none` +# - "keep one" (_keep_named_axis(..., int)): Keep one named axes in the output array, e.g.: `ak.firsts` +# - "keep up to" (_keep_named_axis_up_to(..., int)): Keep all named axes upto a certain positional axis in the output array, e.g.: `ak.local_index` # - "remove all" (_remove_all_named_axis): Removes all named axis, e.g.: `ak.categories # - "remove one" (_remove_named_axis): Remove the named axis from the output array, e.g.: `ak.sum` # - "add one" (_add_named_axis): Add a new named axis to the output array, e.g.: `ak.concatenate, ak.singletons` (not clear yet...) @@ -344,117 +260,236 @@ def _set_named_axis_to_attrs( # - "contract" (_contract_named_axis): Contract the named axis in the output array, e.g.: `matmul` (does this exist?) -def _identity_named_axis( - named_axis: AxisTuple, -) -> AxisTuple: +def _keep_named_axis( + named_axis: AxisMapping, + axis: int | None = None, +) -> AxisMapping: """ - Determines the new named axis after keeping all axes. This is useful, for example, - when applying an operation that does not change the axis structure. + Determines the new named axis after keeping the specified axis. This function is useful when an operation + is applied that retains only one axis. Args: - named_axis (AxisTuple): The current named axis. + named_axis (AxisMapping): The current named axis. + axis (int | None, optional): The index of the axis to keep. If None, all axes are kept. Default is None. Returns: - AxisTuple: The new named axis after keeping all axes. + AxisMapping: The new named axis after keeping the specified axis. Examples: - >>> _identity_named_axis(("x", "y", "z")) - ("x", "y", "z") + >>> _keep_named_axis({"x": 0, "y": 1, "z": 2}, 1) + {"y": 0} + >>> _keep_named_axis({"x": 0, "y": 1, "z": 2}, None) + {"x": 0, "y": 1, "z": 2} """ - return tuple(named_axis) + if axis is None: + return dict(named_axis) + return {k: 0 for k, v in named_axis.items() if v == axis} -def _keep_named_axis( - named_axis: AxisTuple, - axis: int | None = None, -) -> AxisTuple: +def _keep_named_axis_up_to( + named_axis: AxisMapping, + axis: int, +) -> AxisMapping: """ - Determines the new named axis after keeping the specified axis. This is useful, for example, - when applying an operation that keeps only one axis. + Determines the new named axis after keeping all axes up to the specified axis. This function is useful when an operation + is applied that retains all axes up to a certain axis. Args: - named_axis (AxisTuple): The current named axis. - axis (int | None, optional): The index of the axis to keep. If None, all axes are kept. Default is None. + named_axis (AxisMapping): The current named axis. + axis (int): The index of the axis up to which to keep. Returns: - AxisTuple: The new named axis after keeping the specified axis. + AxisMapping: The new named axis after keeping all axes up to the specified axis. Examples: - >>> _keep_named_axis(("x", "y", "z"), 1) - ("y",) - >>> _keep_named_axis(("x", "y", "z")) - ("x", "y", "z") - """ - return tuple(named_axis) if axis is None else (named_axis[axis],) + >>> _keep_named_axis_up_to({"x": 0, "y": 2}, 0) + {"x": 0} + >>> _keep_named_axis_up_to({"x": 0, "y": 2}, 1) + {"x": 0} + >>> _keep_named_axis_up_to({"x": 0, "y": 2}, 2) + {"x": 0, "y": 2} + >>> _keep_named_axis_up_to({"x": 0, "y": -2}, 0) + {"x": 0} + >>> _keep_named_axis_up_to({"x": 0, "y": -2}, 1) + {"x": 0, "y": -2} + >>> _keep_named_axis_up_to({"x": 0, "y": -2}, 2) + {"x": 0, "y": -2} + """ + if axis < 0: + raise ValueError("The axis must be a positive integer.") + out = {} + for k, v in named_axis.items(): + if v >= 0 and v <= axis: + out[k] = v + elif v < 0 and v >= -axis - 1: + out[k] = v + return out def _remove_all_named_axis( - named_axis: AxisTuple, - n: int | None = None, -) -> AxisTuple: + named_axis: AxisMapping, +) -> AxisMapping: """ - Determines the new named axis after removing all axes. This is useful, for example, - when applying an operation that removes all axes. + Returns an empty named axis mapping after removing all axes from the given named axis mapping. + This function is typically used when an operation that eliminates all axes is applied. Args: - named_axis (AxisTuple): The current named axis. - n (int | None, optional): The number of axes to remove. If None, all axes are removed. Default is None. + named_axis (AxisMapping): The current named axis mapping. Returns: - AxisTuple: The new named axis after removing all axes. All elements will be None. + AxisMapping: An empty named axis mapping. Examples: - >>> _remove_all_named_axis(("x", "y", "z")) - (None, None, None) - >>> _remove_all_named_axis(("x", "y", "z"), 2) - (None, None) + >>> _remove_all_named_axis({"x": 0, "y": 1, "z": 2}) + {} """ - return (None,) * (len(named_axis) if n is None else n) + return _remove_named_axis(named_axis=named_axis, axis=None) def _remove_named_axis( - axis: int | None, - named_axis: AxisTuple, -) -> AxisTuple: + named_axis: AxisMapping, + axis: int | None = None, + total: int | None = None, +) -> AxisMapping: """ Determines the new named axis after removing the specified axis. This is useful, for example, - when applying a sum operation along an axis. + when applying an operation that removes one axis. Args: - axis (int): The index of the axis to remove. - named_axis (AxisTuple): The current named axis. + named_axis (AxisMapping): The current named axis. + axis (int | None, optional): The index of the axis to remove. If None, no axes are removed. Default is None. + total (int | None, optional): The total number of axes. If None, it is calculated as the length of the named axis. Default is None. Returns: - AxisTuple: The new named axis after removing the specified axis. + AxisMapping: The new named axis after removing the specified axis. Examples: - >>> _remove_named_axis(1, ("x", "y", "z")) - ("x", "z") + >>> _remove_named_axis({"x": 0, "y": 1}, None) + {} + >>> _remove_named_axis({"x": 0, "y": 1}, 0) + {"y": 0} + >>> _remove_named_axis({"x": 0, "y": 1, "z": 2}, 1) + {"x": 0, "z": 1} + >>> _remove_named_axis({"x": 0, "y": 1, "z": -1}, 1) + {"x": 0, "z": -1} + >>> _remove_named_axis({"x": 0, "y": 1, "z": -3}, 1) + {"x": 0, "z": -2} """ if axis is None: - return (None,) - return tuple(name for i, name in enumerate(named_axis) if i != axis) + return {} + + if total is None: + total = len(named_axis) + + # remove the specified axis + out = {ax: pos for ax, pos in named_axis.items() if pos != axis} + + return _adjust_pos_axis(out, axis, total) + + +def _remove_named_axis_by_name( + named_axis: AxisMapping, + axis: AxisName, + total: int | None = None, +) -> AxisMapping: + """ + Determines the new named axis after removing the specified axis by its name. This is useful, for example, + when applying an operation that removes one axis. + + Args: + named_axis (AxisMapping): The current named axis. + axis (AxisName | None, optional): The name of the axis to remove. If None, no axes are removed. Default is None. + total (int | None, optional): The total number of axes. If None, it is calculated as the length of the named axis. Default is None. + + Returns: + AxisMapping: The new named axis after removing the specified axis. + + Examples: + >>> _remove_named_axis_by_name({"x": 0, "y": 1}, "x") + {"y": 0} + >>> _remove_named_axis_by_name({"x": 0, "y": 1, "z": 2}, "y") + {"x": 0, "z": 1} + >>> _remove_named_axis_by_name({"x": 0, "y": 1, "z": -1}, "z") + {"x": 0, "y": 1} + """ + if axis is None: + return {} + + if total is None: + total = len(named_axis) + + # remove the specified axis + out = dict(named_axis) + pos = out.pop(axis) + + return _adjust_pos_axis(out, pos, total) + + +def _adjust_pos_axis( + named_axis: AxisMapping, + axis: int, + total: int, +) -> AxisMapping: + """ + Adjusts the positions of the axes in the named axis mapping after an axis has been removed. + + The adjustment is done as follows: + - If the position of an axis is greater than the removed axis, it is decremented by 1. + - If the position of an axis is less than the removed axis and greater or equal to -1, it is kept as is. + - If the position of an axis is negative and smaller than the amount of total left axes, it is incremented by 1. + + Args: + named_axis (AxisMapping): The current named axis mapping. + axis (int): The position of the removed axis. + total (int): The total number of axes. + + Returns: + AxisMapping: The adjusted named axis mapping. + + Examples: + >>> _adjust_pos_axis({"x": 0, "z": 2}, 1, 3) + {"x": 0, "z": 1} + >>> _adjust_pos_axis({"x": 0, "z": -1}, 1, 3) + {"x": 0, "z": -1} + >>> _adjust_pos_axis({"x": 0, "z": -3}, 1, 3) + {"x": 0, "z": -2} + """ + out = dict(named_axis) + for k, v in out.items(): + if v > axis: + out[k] = v - 1 + elif v < -1 and len(out) < total: + out[k] = v + 1 + else: + out[k] = v + return out def _add_named_axis( + named_axis: AxisMapping, axis: int, - named_axis: AxisTuple, -) -> AxisTuple: +) -> AxisMapping: """ - Adds a wildcard named axis (None) to the named_axis after the position of the specified axis. + Adds a new axis to the named_axis at the specified position. Args: - axis (int): The index after which to add the wildcard named axis. - named_axis (AxisTuple): The current named axis. + named_axis (AxisMapping): The current named axis mapping. + axis (int): The position at which to add the new axis. Returns: - AxisTuple: The new named axis after adding the wildcard named axis. + AxisMapping: The updated named axis mapping after adding the new axis. Examples: - >>> _add_named_axis(1, ("x", "y", "z")) - ("x", "y", None, "z") + >>> _add_named_axis({"x": 0, "y": 1, "z": 2}, 0) + {"x": 1, "y": 2, "z": 3} + >>> _add_named_axis({"x": 0, "y": 1, "z": 2}, 1) + {"x": 0, "y": 2, "z": 3} """ - return named_axis[: axis + 1] + (None,) + named_axis[axis + 1 :] + out = dict(named_axis) + for k, v in out.items(): + if v >= axis and v >= 0: + out[k] = v + 1 + return out def _permute_named_axis( @@ -465,75 +500,91 @@ def _permute_named_axis( def _unify_named_axis( - named_axis1: AxisTuple, - named_axis2: AxisTuple, -) -> AxisTuple: + named_axis1: AxisMapping, + named_axis2: AxisMapping, +) -> AxisMapping: """ - Unifies two named axes into a single named axis. If the axes are identical or if one of them is None, - the unified axis will be the non-None axis. If the axes are different and neither of them is None, - a ValueError is raised. + Unifies two named axes into a single named axis. The function iterates over all positional axes present in either of the input named axes. + For each positional axis, it checks the corresponding axis names in both input named axes. If the axis names are the same or if one of them is None, + the unified axis will be the non-None axis. If the axis names are different and neither of them is None, a ValueError is raised. Args: - named_axis1 (AxisTuple): The first named axis to unify. - named_axis2 (AxisTuple): The second named axis to unify. + named_axis1 (AxisMapping): The first named axis to unify. + named_axis2 (AxisMapping): The second named axis to unify. Returns: - AxisTuple: The unified named axis. + AxisMapping: The unified named axis. Raises: ValueError: If the axes are different and neither of them is None. Examples: - >>> _unify_named_axis(("x", "y", None), ("x", "y", "z")) - ("x", "y", "z") - >>> _unify_named_axis(("x", "y", "z"), ("x", "y", "z")) - ("x", "y", "z") - >>> _unify_named_axis(("x", "y", "z"), (None, None, None)) - ("x", "y", "z") - >>> _unify_named_axis(("x", "y", "z"), ("a", "b", "c")) - ValueError: Cannot unify different axes: 'x' and 'a' - """ - result = [] - for ax1, ax2 in zip(named_axis1, named_axis2): - if ax1 == ax2 or ax1 is None or ax2 is None: - result.append(ax1 if ax1 is not None else ax2) - else: - raise ValueError(f"Cannot unify different axes: '{ax1}' and '{ax2}'") - return tuple(result) + >>> _unify_named_axis({"x": 0, "y": 1, "z": 2}, {"x": 0, "y": 1, "z": 2}) + {"x": 0, "y": 1, "z": 2} + + >>> _unify_named_axis({"x": 0, "y": 1}, {"x": 0, "y": 1, "z": 2}) + {"x": 0, "y": 1, "z": 2} + + >>> _unify_named_axis({"x": 0, "y": 1, "z": 2}, {"a": 0, "b": 1, "c": 2}) + Traceback (most recent call last): + ... + ValueError: The named axes are different. Got: 'x' and 'a' for positional axis 0 + + >>> _unify_named_axis({"x": 0, "y": 1, "z": 2}, {"x": 0, "y": 1, "z": 3}) + {"x": 0, "y": 1, "z": 2} + + >>> _unify_named_axis({"x": 0, "y": 1, "z": 2}, {}) + {"x": 0, "y": 1, "z": 2} + + >>> _unify_named_axis({}, {"x": 0, "y": 1, "z": 2}) + {"x": 0, "y": 1, "z": 2} + + >>> _unify_named_axis({}, {}) + {} + """ + + def _get_axis_name( + axis_mapping: AxisMapping, positional_axis: int + ) -> AxisName | None: + for name, position in axis_mapping.items(): + if position == positional_axis: + return name + return None + + unified_named_axis = {} + all_positional_axes = set(named_axis1.values()) | set(named_axis2.values()) + for position in all_positional_axes: + axis_name1 = _get_axis_name(named_axis1, position) + axis_name2 = _get_axis_name(named_axis2, position) + if axis_name1 is not None and axis_name2 is not None: + if axis_name1 != axis_name2: + raise ValueError( + f"The named axes are different. Got: {axis_name1} and {axis_name2} for positional axis {position}" + ) + unified_named_axis[axis_name1] = position + elif axis_name1 is not None: # axis_name2 is None + unified_named_axis[axis_name1] = position + elif axis_name2 is not None: # axis_name1 is None + unified_named_axis[axis_name2] = position + return unified_named_axis def _collapse_named_axis( + named_axis: AxisMapping, axis: tuple[int, ...] | int | None, - named_axis: AxisTuple, -) -> AxisTuple: +) -> AxisMapping: """ Determines the new named axis after collapsing the specified axis. This is useful, for example, when applying a flatten operation along an axis. Args: axis (tuple[int, ...] | int | None): The index of the axis to collapse. If None, all axes are collapsed. - named_axis (AxisTuple): The current named axis. + named_axis (AxisMapping): The current named axis. Returns: - AxisTuple: The new named axis after collapsing the specified axis. - - Examples: - >>> _collapse_named_axis(1, ("x", "y", "z")) - ("x", "z") - >>> _collapse_named_axis(None, ("x", "y", "z")) - (None,) - >>> _collapse_named_axis((1, 2), ("x", "y", "z")) - ("x",) - >>> _collapse_named_axis((0, 1, 2), ("x", "y", "z")) - (None,) - >>> _collapse_named_axis((0, 2), ("x", "y", "z")) - ("y",) + AxisMapping: The new named axis after collapsing the specified axis. """ - if axis is None: - return (None,) - elif isinstance(axis, int): - axis = (axis,) - return tuple(name for i, name in enumerate(named_axis) if i not in axis) or (None,) + raise NotImplementedError() class Slicer: @@ -576,84 +627,56 @@ def __getitem__(self, where): NamedAxisSlice: tp.TypeAlias = tp.Dict[AxisName, AxisSlice] -def _normalize_slice( +def _normalize_named_slice( + named_axis: AxisMapping, where: AxisSlice | NamedAxisSlice | tuple[AxisSlice | NamedAxisSlice], - named_axis: AxisTuple, + total: int, ) -> AxisSlice: """ - Normalizes the given slice based on the named axis. The slice can be a dictionary mapping axis names to slices, - a tuple of slices, an ellipsis, or a single slice. The named axis is a tuple of axis names. + Normalizes a named slice into a positional slice. + + This function takes a named slice (a dictionary mapping axis names to slices) and converts it into a positional slice + (a tuple of slices). The positional slice can then be used to index an array. Args: - where (AxisSlice | NamedAxisSlice | tuple[AxisSlice | NamedAxisSlice]): The slice to normalize. - named_axis (AxisTuple): The named axis. + named_axis (AxisMapping): The current named axis mapping. + where (AxisSlice | NamedAxisSlice | tuple[AxisSlice | NamedAxisSlice]): The slice to normalize. Can be a single slice, a tuple of slices, or a dictionary mapping axis names to slices. + total (int): The total number of axes. Returns: AxisSlice: The normalized slice. - Examples: - >>> _normalize_slice({"x": slice(1, 5)}, ("x", "y", "z")) - (slice(1, 5, None), slice(None, None, None), slice(None, None, None)) - - >>> _normalize_slice((slice(1, 5), slice(2, 10)), ("x", "y", "z")) - (slice(1, 5, None), slice(2, 10, None)) - - >>> _normalize_slice(..., ("x", "y", "z")) - (slice(None, None, None), slice(None, None, None), slice(None, None, None)) + Raises: + ValueError: If an invalid axis name is provided in the slice. - >>> _normalize_slice(slice(1, 5), ("x", "y", "z")) - slice(1, 5, None) + Examples: + >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {0: 0}, 3) + (0, slice(None), slice(None)) + >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {-1: 0}, 3) + (slice(None), slice(None), 0) + >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {"x": 0}, 3) + (0, slice(None), slice(None)) + >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {"x": 0, "y": 1}, 3) + (0, 1, slice(None)) + >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {"x": 0, "y": 1, "z": ...}, 3) + (0, 1, ...) + >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {"x": 0, "y": 1, "z": slice(0, 1)}, 3) + (0, 1, slice(0, 1)) + >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {"x": (0, 1)}, 3) + ((0, 1), slice(None), slice(None)) + >>> _normalize_named_slice({"x": 0, "y": 1, "z": 2}, {"x": [0, 1]}, 3) + ([0, 1], slice(None), slice(None)) """ if isinstance(where, dict): - return tuple(where.get(axis, slice(None)) for axis in named_axis) - elif isinstance(where, tuple): - raise NotImplementedError() + out_where: list[AxisSlice] = [slice(None)] * total + for ax_name, ax_where in where.items(): + slice_ = ax_where if ax_where is not ... else slice(None) + if isinstance(ax_name, int): + out_where[ax_name] = slice_ + elif _is_valid_named_axis(ax_name): + idx = _named_axis_to_positional_axis(named_axis, ax_name) + out_where[idx] = slice_ + else: + raise ValueError(f"Invalid axis name: {ax_name} in slice {where}") + where = tuple(out_where) return where - - -def _propagate_named_axis_through_slice( - where: AxisSlice, - named_axis: AxisTuple, -) -> AxisTuple: - """ - Propagate named axis based on where slice to output array. - - Examples: - >>> _propagate_named_axis_through_slice(None, ("x", "y", "z")) - (None, "x", "y", "z") - - >>> _propagate_named_axis_through_slice((..., None), ("x", "y", "z")) - ("x", "y", "z", None) - - >>> _propagate_named_axis_through_slice(0, ("x", "y", "z")) - ("y", "z") - - >>> _propagate_named_axis_through_slice(1, ("x", "y", "z")) - ("x", "z") - - >>> _propagate_named_axis_through_slice(2, ("x", "y", "z")) - ("x", "y") - - >>> _propagate_named_axis_through_slice(..., ("x", "y", "z")) - ("x", "y", "z") - - >>> _propagate_named_axis_through_slice(slice(0, 1), ("x", "y", "z")) - ("x", "y", "z") - - >>> _propagate_named_axis_through_slice((0, slice(0, 1)), ("x", "y", "z")) - ("y", "z") - """ - if where is None: - return (None,) + named_axis - elif where == (..., None): - return named_axis + (None,) - elif where is Ellipsis: - return named_axis - elif isinstance(where, int): - return named_axis[:where] + named_axis[where + 1 :] - elif isinstance(where, slice): - return named_axis - elif isinstance(where, tuple): - return tuple(_propagate_named_axis_through_slice(w, named_axis) for w in where) - else: - raise ValueError("Invalid slice type") diff --git a/src/awkward/contents/content.py b/src/awkward/contents/content.py index 1a0fe080a9..6900d022e8 100644 --- a/src/awkward/contents/content.py +++ b/src/awkward/contents/content.py @@ -18,6 +18,13 @@ from awkward._kernels import KernelError from awkward._layout import wrap_layout from awkward._meta.meta import Meta +from awkward._namedaxis import ( + NamedAxis, + _add_named_axis, + _keep_named_axis, + _remove_named_axis, + _remove_named_axis_by_name, +) from awkward._nplikes import to_nplike from awkward._nplikes.dispatch import nplike_of_obj from awkward._nplikes.numpy import Numpy @@ -27,7 +34,12 @@ parameters_are_equal, type_parameters_equal, ) -from awkward._regularize import is_integer_like, is_sized_iterable +from awkward._regularize import ( + is_array_like, + is_integer, + is_integer_like, + is_sized_iterable, +) from awkward._slicing import normalize_slice from awkward._typing import ( TYPE_CHECKING, @@ -38,6 +50,7 @@ Protocol, Self, SupportsIndex, + Type, TypeAlias, TypedDict, ) @@ -509,10 +522,14 @@ def _getitem_next_missing( ) def __getitem__(self, where): - return self._getitem(where) + return self._getitem(where, NamedAxis) - def _getitem(self, where): + def _getitem(self, where, named_axis: Type[NamedAxis]): if is_integer_like(where): + # propagate named_axis to output + named_axis.mapping = _remove_named_axis( + named_axis.mapping, where, self.purelist_depth + ) return self._getitem_at(ak._slicing.normalize_integer_like(where)) elif isinstance(where, slice) and where.step is None: @@ -523,21 +540,25 @@ def _getitem(self, where): return self._getitem_range(start, stop) elif isinstance(where, slice): - return self._getitem((where,)) + return self._getitem((where,), named_axis) elif isinstance(where, str): return self._getitem_field(where) elif where is np.newaxis: - return self._getitem((where,)) + return self._getitem((where,), named_axis) elif where is Ellipsis: - return self._getitem((where,)) + return self._getitem((where,), named_axis) elif isinstance(where, tuple): if len(where) == 0: return self + n_ellipsis = where.count(...) + if n_ellipsis > 1: + raise IndexError("an index can only have a single ellipsis ('...')") + # Backend may change if index contains typetracers backend = backend_of(self, *where, coerce_to_common=True) this = self.to_backend(backend) @@ -547,6 +568,44 @@ def _getitem(self, where): # Prepare items for advanced indexing (e.g. via broadcasting) nextwhere = ak._slicing.prepare_advanced_indexing(items, backend) + # Handle named axis + # first expand the ellipsis to colons in nextwhere, + # copy nextwhere to not pollute the original + _nextwhere = tuple(nextwhere) + if n_ellipsis == 1: + # collect the ellipsis index + ellipsis_at = _nextwhere.index(...) + # calculate how many slice(None) we need to add + n_newaxis = _nextwhere.count(np.newaxis) + n_total = self.purelist_depth + n_slice_none = n_total - (len(_nextwhere) - n_newaxis - 1) + # insert (slice(None),) * n_slice_none at the ellipsis index + _nextwhere = ( + _nextwhere[:ellipsis_at] + + (slice(None),) * n_slice_none + + _nextwhere[ellipsis_at + 1 :] + ) + # assert len(_nextwhere) == n_total + n_newaxis + + # now propagate named axis + _named_axis = _keep_named_axis(named_axis.mapping, None) + iter_named_axis = iter(dict(_named_axis).items()) + for i, nw in enumerate(_nextwhere): + name = None + for _name, pos in iter_named_axis: + if pos == i: + name = _name + break + if is_integer(nw) or (is_array_like(nw) and nw.ndim == 0): + _named_axis = _remove_named_axis_by_name( + _named_axis, name, self.purelist_depth + ) + elif nw is None: + _named_axis = _add_named_axis(_named_axis, i) + + # set propagated named axis + named_axis.mapping = _named_axis + next = ak.contents.RegularArray( this, this.length, @@ -562,7 +621,7 @@ def _getitem(self, where): return out._getitem_at(0) elif isinstance(where, ak.highlevel.Array): - return self._getitem(where.layout) + return self._getitem(where.layout, named_axis) # Convert between nplikes of different backends elif ( @@ -570,7 +629,9 @@ def _getitem(self, where): and where.backend is not self._backend ): backend = backend_of(self, where, coerce_to_common=True) - return self.to_backend(backend)._getitem(where.to_backend(backend)) + return self.to_backend(backend)._getitem( + where.to_backend(backend), named_axis + ) elif isinstance(where, ak.contents.NumpyArray): data_as_index = to_nplike( @@ -621,9 +682,9 @@ def _getitem(self, where): elif isinstance(where, ak.contents.RegularArray): maybe_numpy = where.maybe_to_NumpyArray() if maybe_numpy is None: - return self._getitem((where,)) + return self._getitem((where,), named_axis) else: - return self._getitem(maybe_numpy) + return self._getitem(maybe_numpy, named_axis) # Awkward Array of strings elif ( @@ -637,7 +698,7 @@ def _getitem(self, where): return where.to_NumpyArray(np.int64) elif isinstance(where, Content): - return self._getitem((where,)) + return self._getitem((where,), named_axis) elif is_sized_iterable(where): # Do we have an array @@ -654,7 +715,7 @@ def _getitem(self, where): primitive_policy="error", string_policy="as-characters", ) - return self._getitem(layout) + return self._getitem(layout, named_axis) elif len(where) == 0: return self._carry( @@ -682,7 +743,7 @@ def _getitem(self, where): ), self._backend, ) - return self._getitem(layout) + return self._getitem(layout, named_axis) else: raise TypeError( diff --git a/src/awkward/highlevel.py b/src/awkward/highlevel.py index aff90e90ea..5b21438218 100644 --- a/src/awkward/highlevel.py +++ b/src/awkward/highlevel.py @@ -25,9 +25,11 @@ from awkward._layout import wrap_layout from awkward._namedaxis import ( AttrsNamedAxisMapping, - AxisTuple, + AxisMapping, + NamedAxis, _get_named_axis, - _supports_named_axis, + _make_positional_axis_tuple, + _normalize_named_slice, ) from awkward._nplikes.numpy import Numpy from awkward._nplikes.numpy_like import NumpyMetadata @@ -463,14 +465,11 @@ def behavior(self, behavior): @property def positional_axis(self) -> tuple[int, ...]: - return tuple(range(self.ndim)) + return _make_positional_axis_tuple(self.ndim) @property - def named_axis(self) -> AxisTuple | None: - if _supports_named_axis(self): - return _get_named_axis(self) - else: - return (None,) * self.ndim + def named_axis(self) -> AxisMapping: + return _get_named_axis(self) class Mask: def __init__(self, array): @@ -1079,32 +1078,30 @@ def __getitem__(self, where): have the same dimension as the array being indexed. """ with ak._errors.SlicingErrorContext(self, where): - # normalize for potential named axis - from awkward._namedaxis import ( - _get_named_axis, - _normalize_slice, - _supports_named_axis, - ) - - out_named_axis = None - if _supports_named_axis(self): - named_axis = _get_named_axis(self) + # Handle named axis + if named_axis := _get_named_axis(self): + where = _normalize_named_slice( + named_axis, where, self._layout.purelist_depth + ) - # Step 1: normalize the slice - where = _normalize_slice(where, named_axis) + NamedAxis.mapping = named_axis - # Step 2: propagate named axis to the output array - out_named_axis = named_axis + out = wrap_layout( + prepare_layout(self._layout._getitem(where, NamedAxis)), + self._behavior, + allow_other=True, + attrs=self._attrs, + ) - return ak.with_named_axis( - array=wrap_layout( - prepare_layout(self._layout[where]), - self._behavior, - allow_other=True, + if NamedAxis.mapping: + out = ak.operations.ak_with_named_axis._impl( + out, + named_axis=NamedAxis.mapping, + highlevel=True, + behavior=self._behavior, attrs=self._attrs, - ), - named_axis=out_named_axis, - ) + ) + return out def __bytes__(self) -> bytes: if isinstance(self._layout, ak.contents.NumpyArray) and self._layout.parameter( diff --git a/src/awkward/operations/__init__.py b/src/awkward/operations/__init__.py index 8d51b6a6cc..19dae41631 100644 --- a/src/awkward/operations/__init__.py +++ b/src/awkward/operations/__init__.py @@ -114,6 +114,7 @@ from awkward.operations.ak_with_named_axis import * from awkward.operations.ak_with_parameter import * from awkward.operations.ak_without_field import * +from awkward.operations.ak_without_named_axis import * from awkward.operations.ak_without_parameters import * from awkward.operations.ak_zeros_like import * from awkward.operations.ak_zip import * diff --git a/src/awkward/operations/ak_all.py b/src/awkward/operations/ak_all.py index 15ca68d7a0..c985657325 100644 --- a/src/awkward/operations/ak_all.py +++ b/src/awkward/operations/ak_all.py @@ -10,9 +10,8 @@ _get_named_axis, _is_valid_named_axis, _keep_named_axis, - _one_axis_to_positional_axis, + _named_axis_to_positional_axis, _remove_named_axis, - _supports_named_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import is_integer, regularize_axis @@ -78,19 +77,23 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + # Handle named axis out_named_axis = None - if _supports_named_axis(ctx): + if named_axis := _get_named_axis(ctx): if _is_valid_named_axis(axis): - # Handle named axis # Step 1: Normalize named axis to positional axis - axis = _one_axis_to_positional_axis(axis, _get_named_axis(ctx)) + axis = _named_axis_to_positional_axis(named_axis, axis) # Step 2: propagate named axis from input to output, # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) - out_named_axis = _keep_named_axis(_get_named_axis(ctx), None) + out_named_axis = _keep_named_axis(named_axis, None) if not keepdims: - out_named_axis = _remove_named_axis(axis, out_named_axis) + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.purelist_depth, + ) axis = regularize_axis(axis) diff --git a/src/awkward/operations/ak_almost_equal.py b/src/awkward/operations/ak_almost_equal.py index 52e493755a..949f955a45 100644 --- a/src/awkward/operations/ak_almost_equal.py +++ b/src/awkward/operations/ak_almost_equal.py @@ -7,7 +7,7 @@ from awkward._behavior import behavior_of, get_array_class, get_record_class from awkward._dispatch import high_level_function from awkward._layout import ensure_same_backend -from awkward._namedaxis import _supports_named_axis +from awkward._namedaxis import _get_named_axis from awkward._nplikes.numpy_like import NumpyMetadata from awkward._parameters import parameters_are_equal from awkward.operations.ak_to_layout import to_layout @@ -96,7 +96,7 @@ def _impl( right_layout = layouts[1].to_packed() backend = backend_of(left_layout) - if check_named_axis and _supports_named_axis(left) and _supports_named_axis(right): + if check_named_axis and _get_named_axis(left) and _get_named_axis(right): if left.named_axis != right.named_axis: return False diff --git a/src/awkward/operations/ak_any.py b/src/awkward/operations/ak_any.py index 4df5a86208..2bf226bf9d 100644 --- a/src/awkward/operations/ak_any.py +++ b/src/awkward/operations/ak_any.py @@ -10,9 +10,8 @@ _get_named_axis, _is_valid_named_axis, _keep_named_axis, - _one_axis_to_positional_axis, + _named_axis_to_positional_axis, _remove_named_axis, - _supports_named_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import is_integer, regularize_axis @@ -78,19 +77,23 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + # Handle named axis out_named_axis = None - if _supports_named_axis(ctx): + if named_axis := _get_named_axis(ctx): if _is_valid_named_axis(axis): - # Handle named axis # Step 1: Normalize named axis to positional axis - axis = _one_axis_to_positional_axis(axis, _get_named_axis(ctx)) + axis = _named_axis_to_positional_axis(named_axis, axis) # Step 2: propagate named axis from input to output, # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) - out_named_axis = _keep_named_axis(_get_named_axis(ctx), None) + out_named_axis = _keep_named_axis(named_axis, None) if not keepdims: - out_named_axis = _remove_named_axis(axis, out_named_axis) + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.purelist_depth, + ) axis = regularize_axis(axis) diff --git a/src/awkward/operations/ak_argmax.py b/src/awkward/operations/ak_argmax.py index 9038107a64..a3be1728f8 100644 --- a/src/awkward/operations/ak_argmax.py +++ b/src/awkward/operations/ak_argmax.py @@ -10,9 +10,8 @@ _get_named_axis, _is_valid_named_axis, _keep_named_axis, - _one_axis_to_positional_axis, + _named_axis_to_positional_axis, _remove_named_axis, - _supports_named_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import is_integer, regularize_axis @@ -143,19 +142,23 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + # Handle named axis out_named_axis = None - if _supports_named_axis(ctx): + if named_axis := _get_named_axis(ctx): if _is_valid_named_axis(axis): - # Handle named axis # Step 1: Normalize named axis to positional axis - axis = _one_axis_to_positional_axis(axis, _get_named_axis(ctx)) + axis = _named_axis_to_positional_axis(named_axis, axis) # Step 2: propagate named axis from input to output, # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) - out_named_axis = _keep_named_axis(_get_named_axis(ctx), None) + out_named_axis = _keep_named_axis(named_axis, None) if not keepdims: - out_named_axis = _remove_named_axis(axis, out_named_axis) + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.purelist_depth, + ) axis = regularize_axis(axis) diff --git a/src/awkward/operations/ak_argmin.py b/src/awkward/operations/ak_argmin.py index e3fbeacfad..d7b361a6b3 100644 --- a/src/awkward/operations/ak_argmin.py +++ b/src/awkward/operations/ak_argmin.py @@ -10,9 +10,8 @@ _get_named_axis, _is_valid_named_axis, _keep_named_axis, - _one_axis_to_positional_axis, + _named_axis_to_positional_axis, _remove_named_axis, - _supports_named_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import is_integer, regularize_axis @@ -140,19 +139,23 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + # Handle named axis out_named_axis = None - if _supports_named_axis(ctx): + if named_axis := _get_named_axis(ctx): if _is_valid_named_axis(axis): - # Handle named axis # Step 1: Normalize named axis to positional axis - axis = _one_axis_to_positional_axis(axis, _get_named_axis(ctx)) + axis = _named_axis_to_positional_axis(named_axis, axis) # Step 2: propagate named axis from input to output, # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) - out_named_axis = _keep_named_axis(_get_named_axis(ctx), None) + out_named_axis = _keep_named_axis(named_axis, None) if not keepdims: - out_named_axis = _remove_named_axis(axis, out_named_axis) + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.purelist_depth, + ) axis = regularize_axis(axis) diff --git a/src/awkward/operations/ak_argsort.py b/src/awkward/operations/ak_argsort.py index 55a9ad22bc..7816f14b8b 100644 --- a/src/awkward/operations/ak_argsort.py +++ b/src/awkward/operations/ak_argsort.py @@ -10,8 +10,7 @@ _get_named_axis, _is_valid_named_axis, _keep_named_axis, - _one_axis_to_positional_axis, - _supports_named_axis, + _named_axis_to_positional_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import is_integer, regularize_axis @@ -80,19 +79,16 @@ def _impl(array, axis, ascending, stable, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + # Handle named axis out_named_axis = None - if _supports_named_axis(ctx): + if named_axis := _get_named_axis(ctx): if _is_valid_named_axis(axis): - # Handle named axis # Step 1: Normalize named axis to positional axis - axis = _one_axis_to_positional_axis( - axis, - _get_named_axis(ctx), - ) + axis = _named_axis_to_positional_axis(named_axis, axis) # Step 2: propagate named axis from input to output, # use strategy "keep all" (see: awkward._namedaxis) - out_named_axis = _keep_named_axis(_get_named_axis(ctx), None) + out_named_axis = _keep_named_axis(named_axis, None) axis = regularize_axis(axis) diff --git a/src/awkward/operations/ak_corr.py b/src/awkward/operations/ak_corr.py index 1f9a66fd2b..e16c79266e 100644 --- a/src/awkward/operations/ak_corr.py +++ b/src/awkward/operations/ak_corr.py @@ -4,13 +4,13 @@ import awkward as ak from awkward._attrs import attrs_of_obj -from awkward._behavior import behavior_of_obj from awkward._dispatch import high_level_function from awkward._layout import ( HighLevelContext, ensure_same_backend, maybe_highlevel_to_lowlevel, ) +from awkward._namedaxis import _get_named_axis, _NamedAxisKey from awkward._nplikes import ufuncs from awkward._nplikes.numpy_like import NumpyMetadata @@ -184,14 +184,16 @@ def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attr attrs=ctx.attrs, ) - # propagate named axis to output out = sumwxy / ufuncs.sqrt(sumwxx * sumwyy) - out_ctx = HighLevelContext( - behavior=behavior_of_obj(out), - attrs=attrs_of_obj(out), - ).finalize() - return out_ctx.wrap( + # propagate named axis to output + if out_named_axis := _get_named_axis(attrs_of_obj(out) or {}): + ctx = ctx.with_attr( + key=_NamedAxisKey, + value=out_named_axis, + ) + + return ctx.wrap( maybe_highlevel_to_lowlevel(out), highlevel=highlevel, allow_other=True, diff --git a/src/awkward/operations/ak_count.py b/src/awkward/operations/ak_count.py index edec34c39e..46a811175a 100644 --- a/src/awkward/operations/ak_count.py +++ b/src/awkward/operations/ak_count.py @@ -9,9 +9,8 @@ _get_named_axis, _is_valid_named_axis, _keep_named_axis, - _one_axis_to_positional_axis, + _named_axis_to_positional_axis, _remove_named_axis, - _supports_named_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import is_integer, regularize_axis @@ -120,19 +119,23 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + # Handle named axis out_named_axis = None - if _supports_named_axis(ctx): + if named_axis := _get_named_axis(ctx): if _is_valid_named_axis(axis): - # Handle named axis # Step 1: Normalize named axis to positional axis - axis = _one_axis_to_positional_axis(axis, _get_named_axis(ctx)) + axis = _named_axis_to_positional_axis(named_axis, axis) # Step 2: propagate named axis from input to output, # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) - out_named_axis = _keep_named_axis(_get_named_axis(ctx), None) + out_named_axis = _keep_named_axis(named_axis, None) if not keepdims: - out_named_axis = _remove_named_axis(axis, out_named_axis) + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.purelist_depth, + ) axis = regularize_axis(axis) diff --git a/src/awkward/operations/ak_count_nonzero.py b/src/awkward/operations/ak_count_nonzero.py index 4fbdb3f7ef..d500fec78e 100644 --- a/src/awkward/operations/ak_count_nonzero.py +++ b/src/awkward/operations/ak_count_nonzero.py @@ -9,9 +9,8 @@ _get_named_axis, _is_valid_named_axis, _keep_named_axis, - _one_axis_to_positional_axis, + _named_axis_to_positional_axis, _remove_named_axis, - _supports_named_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import is_integer, regularize_axis @@ -79,19 +78,23 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + # Handle named axis out_named_axis = None - if _supports_named_axis(ctx): + if named_axis := _get_named_axis(ctx): if _is_valid_named_axis(axis): - # Handle named axis # Step 1: Normalize named axis to positional axis - axis = _one_axis_to_positional_axis(axis, _get_named_axis(ctx)) + axis = _named_axis_to_positional_axis(named_axis, axis) # Step 2: propagate named axis from input to output, # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) - out_named_axis = _keep_named_axis(_get_named_axis(ctx), None) + out_named_axis = _keep_named_axis(named_axis, None) if not keepdims: - out_named_axis = _remove_named_axis(axis, out_named_axis) + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.purelist_depth, + ) axis = regularize_axis(axis) diff --git a/src/awkward/operations/ak_covar.py b/src/awkward/operations/ak_covar.py index 46b1e57a6a..7fbb63ae87 100644 --- a/src/awkward/operations/ak_covar.py +++ b/src/awkward/operations/ak_covar.py @@ -4,13 +4,13 @@ import awkward as ak from awkward._attrs import attrs_of_obj -from awkward._behavior import behavior_of_obj from awkward._dispatch import high_level_function from awkward._layout import ( HighLevelContext, ensure_same_backend, maybe_highlevel_to_lowlevel, ) +from awkward._namedaxis import _get_named_axis, _NamedAxisKey from awkward._nplikes.numpy_like import NumpyMetadata __all__ = ("covar",) @@ -161,14 +161,16 @@ def _impl(x, y, weight, axis, keepdims, mask_identity, highlevel, behavior, attr attrs=None, ) - # propagate named axis to output out = sumwxy / sumw - out_ctx = HighLevelContext( - behavior=behavior_of_obj(out), - attrs=attrs_of_obj(out), - ).finalize() - return out_ctx.wrap( + # propagate named axis to output + if out_named_axis := _get_named_axis(attrs_of_obj(out) or {}): + ctx = ctx.with_attr( + key=_NamedAxisKey, + value=out_named_axis, + ) + + return ctx.wrap( maybe_highlevel_to_lowlevel(out), highlevel=highlevel, allow_other=True, diff --git a/src/awkward/operations/ak_drop_none.py b/src/awkward/operations/ak_drop_none.py index ff2da7f17f..bc0cc86c97 100644 --- a/src/awkward/operations/ak_drop_none.py +++ b/src/awkward/operations/ak_drop_none.py @@ -8,8 +8,7 @@ from awkward._namedaxis import ( _get_named_axis, _is_valid_named_axis, - _one_axis_to_positional_axis, - _supports_named_axis, + _named_axis_to_positional_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import is_integer, regularize_axis @@ -74,11 +73,11 @@ def _impl(array, axis, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - if _supports_named_axis(ctx): + # Handle named axis + if named_axis := _get_named_axis(ctx): if _is_valid_named_axis(axis): - # Handle named axis # Step 1: Normalize named axis to positional axis - axis = _one_axis_to_positional_axis(axis, _get_named_axis(ctx)) + axis = _named_axis_to_positional_axis(named_axis, axis) axis = regularize_axis(axis) diff --git a/src/awkward/operations/ak_fill_none.py b/src/awkward/operations/ak_fill_none.py index 4e2c1e94ed..cd387fc118 100644 --- a/src/awkward/operations/ak_fill_none.py +++ b/src/awkward/operations/ak_fill_none.py @@ -8,8 +8,7 @@ from awkward._namedaxis import ( _get_named_axis, _is_valid_named_axis, - _one_axis_to_positional_axis, - _supports_named_axis, + _named_axis_to_positional_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import is_integer, regularize_axis @@ -88,11 +87,11 @@ def _impl(array, value, axis, highlevel, behavior, attrs): ), ) - if _supports_named_axis(ctx): + # Handle named axis + if named_axis := _get_named_axis(ctx): if _is_valid_named_axis(axis): - # Handle named axis # Step 1: Normalize named axis to positional axis - axis = _one_axis_to_positional_axis(axis, _get_named_axis(ctx)) + axis = _named_axis_to_positional_axis(named_axis, axis) axis = regularize_axis(axis) diff --git a/src/awkward/operations/ak_firsts.py b/src/awkward/operations/ak_firsts.py index 0af44a8032..9f6bdcf173 100644 --- a/src/awkward/operations/ak_firsts.py +++ b/src/awkward/operations/ak_firsts.py @@ -9,8 +9,7 @@ _get_named_axis, _is_valid_named_axis, _keep_named_axis, - _one_axis_to_positional_axis, - _supports_named_axis, + _named_axis_to_positional_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import is_integer, regularize_axis @@ -66,16 +65,16 @@ def _impl(array, axis, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False) + # Handle named axis out_named_axis = None - if _supports_named_axis(ctx): + if named_axis := _get_named_axis(ctx): if _is_valid_named_axis(axis): - # Handle named axis # Step 1: Normalize named axis to positional axis - axis = _one_axis_to_positional_axis(axis, _get_named_axis(ctx)) + axis = _named_axis_to_positional_axis(named_axis, axis) # Step 2: propagate named axis from input to output, # use strategy "keep one" (see: awkward._namedaxis) - out_named_axis = _keep_named_axis(_get_named_axis(ctx), axis) + out_named_axis = _keep_named_axis(named_axis, axis) axis = regularize_axis(axis) diff --git a/src/awkward/operations/ak_is_none.py b/src/awkward/operations/ak_is_none.py index de0d5c5836..65e74e7916 100644 --- a/src/awkward/operations/ak_is_none.py +++ b/src/awkward/operations/ak_is_none.py @@ -8,8 +8,7 @@ from awkward._namedaxis import ( _get_named_axis, _is_valid_named_axis, - _one_axis_to_positional_axis, - _supports_named_axis, + _named_axis_to_positional_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import is_integer, regularize_axis @@ -50,11 +49,11 @@ def _impl(array, axis, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - if _supports_named_axis(ctx): + # Handle named axis + if named_axis := _get_named_axis(ctx): if _is_valid_named_axis(axis): - # Handle named axis # Step 1: Normalize named axis to positional axis - axis = _one_axis_to_positional_axis(axis, _get_named_axis(ctx)) + axis = _named_axis_to_positional_axis(named_axis, axis) axis = regularize_axis(axis) diff --git a/src/awkward/operations/ak_local_index.py b/src/awkward/operations/ak_local_index.py index 3b403453bb..e2fd6eb1dd 100644 --- a/src/awkward/operations/ak_local_index.py +++ b/src/awkward/operations/ak_local_index.py @@ -8,9 +8,8 @@ from awkward._namedaxis import ( _get_named_axis, _is_valid_named_axis, - _keep_named_axis, - _one_axis_to_positional_axis, - _supports_named_axis, + _keep_named_axis_up_to, + _named_axis_to_positional_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import is_integer, regularize_axis @@ -98,22 +97,24 @@ def _impl(array, axis, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + # Handle named axis out_named_axis = None - if _supports_named_axis(ctx): + if named_axis := _get_named_axis(ctx): if _is_valid_named_axis(axis): - # Handle named axis # Step 1: Normalize named axis to positional axis - axis = _one_axis_to_positional_axis(axis, _get_named_axis(ctx)) + axis = _named_axis_to_positional_axis(named_axis, axis) axis = regularize_axis(axis) if not is_integer(axis): raise TypeError(f"'axis' must be an integer by now, not {axis!r}") - if _supports_named_axis(ctx): + if named_axis: # Step 2: propagate named axis from input to output, - # "keep all" up to the positional axis dim (see: awkward._namedaxis) - out_named_axis = _keep_named_axis(_get_named_axis(ctx), None)[: axis + 1] + # use strategy "keep up to" (see: awkward._namedaxis) + if axis < 0: + axis += layout.purelist_depth + out_named_axis = _keep_named_axis_up_to(named_axis, axis) out = ak._do.local_index(layout, axis) diff --git a/src/awkward/operations/ak_max.py b/src/awkward/operations/ak_max.py index f3a4b79c74..af23a78a66 100644 --- a/src/awkward/operations/ak_max.py +++ b/src/awkward/operations/ak_max.py @@ -10,9 +10,8 @@ _get_named_axis, _is_valid_named_axis, _keep_named_axis, - _one_axis_to_positional_axis, + _named_axis_to_positional_axis, _remove_named_axis, - _supports_named_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import is_integer, regularize_axis @@ -153,19 +152,23 @@ def _impl(array, axis, keepdims, initial, mask_identity, highlevel, behavior, at with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + # Handle named axis out_named_axis = None - if _supports_named_axis(ctx): + if named_axis := _get_named_axis(ctx): if _is_valid_named_axis(axis): - # Handle named axis # Step 1: Normalize named axis to positional axis - axis = _one_axis_to_positional_axis(axis, _get_named_axis(ctx)) + axis = _named_axis_to_positional_axis(named_axis, axis) # Step 2: propagate named axis from input to output, # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) - out_named_axis = _keep_named_axis(_get_named_axis(ctx), None) + out_named_axis = _keep_named_axis(named_axis, None) if not keepdims: - out_named_axis = _remove_named_axis(axis, out_named_axis) + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.purelist_depth, + ) axis = regularize_axis(axis) diff --git a/src/awkward/operations/ak_mean.py b/src/awkward/operations/ak_mean.py index 927dfa89e6..074a5f99f0 100644 --- a/src/awkward/operations/ak_mean.py +++ b/src/awkward/operations/ak_mean.py @@ -15,9 +15,8 @@ _get_named_axis, _is_valid_named_axis, _keep_named_axis, - _one_axis_to_positional_axis, + _named_axis_to_positional_axis, _remove_named_axis, - _supports_named_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import is_integer, regularize_axis @@ -197,19 +196,23 @@ def _impl(x, weight, axis, keepdims, mask_identity, highlevel, behavior, attrs): x = ctx.wrap(x_layout) weight = ctx.wrap(weight_layout, allow_other=True) + # Handle named axis out_named_axis = None - if _supports_named_axis(ctx): + if named_axis := _get_named_axis(ctx): if _is_valid_named_axis(axis): - # Handle named axis # Step 1: Normalize named axis to positional axis - axis = _one_axis_to_positional_axis(axis, _get_named_axis(ctx)) + axis = _named_axis_to_positional_axis(named_axis, axis) # Step 2: propagate named axis from input to output, # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) - out_named_axis = _keep_named_axis(_get_named_axis(ctx), None) + out_named_axis = _keep_named_axis(named_axis, None) if not keepdims: - out_named_axis = _remove_named_axis(axis, out_named_axis) + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=x_layout.purelist_depth, + ) axis = regularize_axis(axis) diff --git a/src/awkward/operations/ak_min.py b/src/awkward/operations/ak_min.py index 41f68eb9bf..ea5659c060 100644 --- a/src/awkward/operations/ak_min.py +++ b/src/awkward/operations/ak_min.py @@ -10,9 +10,8 @@ _get_named_axis, _is_valid_named_axis, _keep_named_axis, - _one_axis_to_positional_axis, + _named_axis_to_positional_axis, _remove_named_axis, - _supports_named_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import is_integer, regularize_axis @@ -153,19 +152,23 @@ def _impl(array, axis, keepdims, initial, mask_identity, highlevel, behavior, at with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + # Handle named axis out_named_axis = None - if _supports_named_axis(ctx): + if named_axis := _get_named_axis(ctx): if _is_valid_named_axis(axis): - # Handle named axis # Step 1: Normalize named axis to positional axis - axis = _one_axis_to_positional_axis(axis, _get_named_axis(ctx)) + axis = _named_axis_to_positional_axis(named_axis, axis) # Step 2: propagate named axis from input to output, # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) - out_named_axis = _keep_named_axis(_get_named_axis(ctx), None) + out_named_axis = _keep_named_axis(named_axis, None) if not keepdims: - out_named_axis = _remove_named_axis(axis, out_named_axis) + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.purelist_depth, + ) axis = regularize_axis(axis) diff --git a/src/awkward/operations/ak_moment.py b/src/awkward/operations/ak_moment.py index 891bdc0fee..b84e8a405c 100644 --- a/src/awkward/operations/ak_moment.py +++ b/src/awkward/operations/ak_moment.py @@ -4,7 +4,6 @@ import awkward as ak from awkward._attrs import attrs_of_obj -from awkward._behavior import behavior_of_obj from awkward._dispatch import high_level_function from awkward._layout import ( HighLevelContext, @@ -13,6 +12,8 @@ ) from awkward._namedaxis import ( AxisName, + _get_named_axis, + _NamedAxisKey, ) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._typing import Mapping @@ -157,14 +158,16 @@ def _impl( attrs=ctx.attrs, ) - # propagate named axis to output out = sumwxn / sumw - out_ctx = HighLevelContext( - behavior=behavior_of_obj(out), - attrs=attrs_of_obj(out), - ).finalize() - return out_ctx.wrap( + # propagate named axis to output + if out_named_axis := _get_named_axis(attrs_of_obj(out) or {}): + ctx = ctx.with_attr( + key=_NamedAxisKey, + value=out_named_axis, + ) + + return ctx.wrap( maybe_highlevel_to_lowlevel(out), highlevel=highlevel, allow_other=True, diff --git a/src/awkward/operations/ak_num.py b/src/awkward/operations/ak_num.py index e44f112f34..bb2bd81e3f 100644 --- a/src/awkward/operations/ak_num.py +++ b/src/awkward/operations/ak_num.py @@ -6,12 +6,10 @@ from awkward._dispatch import high_level_function from awkward._layout import HighLevelContext, maybe_posaxis from awkward._namedaxis import ( - AxisName, _get_named_axis, _is_valid_named_axis, _keep_named_axis, - _one_axis_to_positional_axis, - _supports_named_axis, + _named_axis_to_positional_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import is_integer, regularize_axis @@ -26,7 +24,7 @@ @high_level_function() def num( array, - axis: AxisName = 1, + axis=1, *, highlevel: bool = True, behavior: Mapping | None = None, @@ -101,7 +99,7 @@ def num( def _impl( array, - axis: AxisName, + axis, highlevel: bool, behavior: Mapping | None, attrs: Mapping | None, @@ -109,12 +107,12 @@ def _impl( with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + # Handle named axis out_named_axis = None - if _supports_named_axis(ctx): + if named_axis := _get_named_axis(ctx): if _is_valid_named_axis(axis): - # Handle named axis # Step 1: Normalize named axis to positional axis - axis = _one_axis_to_positional_axis(axis, _get_named_axis(ctx)) + axis = _named_axis_to_positional_axis(named_axis, axis) # Step 2: propagate named axis from input to output, # use strategy "keep one" (see: awkward._namedaxis) diff --git a/src/awkward/operations/ak_pad_none.py b/src/awkward/operations/ak_pad_none.py index 78d8b6b38d..b3b77c7c33 100644 --- a/src/awkward/operations/ak_pad_none.py +++ b/src/awkward/operations/ak_pad_none.py @@ -8,8 +8,7 @@ from awkward._namedaxis import ( _get_named_axis, _is_valid_named_axis, - _one_axis_to_positional_axis, - _supports_named_axis, + _named_axis_to_positional_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import is_integer, regularize_axis @@ -122,11 +121,10 @@ def _impl(array, target, axis, clip, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") - if _supports_named_axis(ctx): + if named_axis := _get_named_axis(ctx): if _is_valid_named_axis(axis): - # Handle named axis # Step 1: Normalize named axis to positional axis - axis = _one_axis_to_positional_axis(axis, _get_named_axis(ctx)) + axis = _named_axis_to_positional_axis(named_axis, axis) axis = regularize_axis(axis) diff --git a/src/awkward/operations/ak_prod.py b/src/awkward/operations/ak_prod.py index fd046814ba..ece3b4a16f 100644 --- a/src/awkward/operations/ak_prod.py +++ b/src/awkward/operations/ak_prod.py @@ -10,9 +10,8 @@ _get_named_axis, _is_valid_named_axis, _keep_named_axis, - _one_axis_to_positional_axis, + _named_axis_to_positional_axis, _remove_named_axis, - _supports_named_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import is_integer, regularize_axis @@ -130,19 +129,23 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + # Handle named axis out_named_axis = None - if _supports_named_axis(ctx): + if named_axis := _get_named_axis(ctx): if _is_valid_named_axis(axis): - # Handle named axis # Step 1: Normalize named axis to positional axis - axis = _one_axis_to_positional_axis(axis, _get_named_axis(ctx)) + axis = _named_axis_to_positional_axis(named_axis, axis) # Step 2: propagate named axis from input to output, # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) - out_named_axis = _keep_named_axis(_get_named_axis(ctx), None) + out_named_axis = _keep_named_axis(named_axis, None) if not keepdims: - out_named_axis = _remove_named_axis(axis, out_named_axis) + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.purelist_depth, + ) axis = regularize_axis(axis) diff --git a/src/awkward/operations/ak_ptp.py b/src/awkward/operations/ak_ptp.py index c11e9bf913..4eebe1a3a6 100644 --- a/src/awkward/operations/ak_ptp.py +++ b/src/awkward/operations/ak_ptp.py @@ -14,8 +14,7 @@ _get_named_axis, _is_valid_named_axis, _keep_named_axis, - _one_axis_to_positional_axis, - _supports_named_axis, + _named_axis_to_positional_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import is_integer, regularize_axis @@ -93,18 +92,18 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + # Handle named axis out_named_axis = None - if _supports_named_axis(ctx): + if named_axis := _get_named_axis(ctx): if _is_valid_named_axis(axis): - # Handle named axis # Step 1: Normalize named axis to positional axis - axis = _one_axis_to_positional_axis(axis, _get_named_axis(ctx)) + axis = _named_axis_to_positional_axis(named_axis, axis) # Step 2: propagate named axis from input to output, # axis: int = use strategy "keep one" (see: awkward._namedaxis) # axis: None = use strategy "remove all" (see: awkward._namedaxis) if axis is not None: - out_named_axis = _keep_named_axis(_get_named_axis(ctx), axis) + out_named_axis = _keep_named_axis(named_axis, axis) axis = regularize_axis(axis) diff --git a/src/awkward/operations/ak_ravel.py b/src/awkward/operations/ak_ravel.py index ee176553c0..062601eff4 100644 --- a/src/awkward/operations/ak_ravel.py +++ b/src/awkward/operations/ak_ravel.py @@ -79,10 +79,8 @@ def _impl(array, highlevel, behavior, attrs): # propagate named axis to output # use strategy "remove all" (see: awkward._namedaxis) - out_named_axis = None - return ak.operations.ak_with_named_axis._impl( + return ak.operations.ak_without_named_axis._impl( wrapped_out, - named_axis=out_named_axis, highlevel=highlevel, behavior=ctx.behavior, attrs=ctx.attrs, diff --git a/src/awkward/operations/ak_singletons.py b/src/awkward/operations/ak_singletons.py index ba51521aad..54b74d4d6d 100644 --- a/src/awkward/operations/ak_singletons.py +++ b/src/awkward/operations/ak_singletons.py @@ -9,8 +9,7 @@ _add_named_axis, _get_named_axis, _is_valid_named_axis, - _one_axis_to_positional_axis, - _supports_named_axis, + _named_axis_to_positional_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import is_integer, regularize_axis @@ -66,16 +65,16 @@ def _impl(array, axis, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + # Handle named axis out_named_axis = None - if _supports_named_axis(ctx): + if named_axis := _get_named_axis(ctx): if _is_valid_named_axis(axis): - # Handle named axis # Step 1: Normalize named axis to positional axis - axis = _one_axis_to_positional_axis(axis, _get_named_axis(ctx)) + axis = _named_axis_to_positional_axis(named_axis, axis) # Step 2: propagate named axis from input to output, # use strategy "add one" (see: awkward._namedaxis) - out_named_axis = _add_named_axis(axis, _get_named_axis(ctx)) + out_named_axis = _add_named_axis(named_axis, axis) axis = regularize_axis(axis) diff --git a/src/awkward/operations/ak_softmax.py b/src/awkward/operations/ak_softmax.py index 736c057f42..b7b12fbf6c 100644 --- a/src/awkward/operations/ak_softmax.py +++ b/src/awkward/operations/ak_softmax.py @@ -4,7 +4,6 @@ import awkward as ak from awkward._attrs import attrs_of_obj -from awkward._behavior import behavior_of_obj from awkward._dispatch import high_level_function from awkward._layout import ( HighLevelContext, @@ -14,8 +13,8 @@ from awkward._namedaxis import ( _get_named_axis, _is_valid_named_axis, - _one_axis_to_positional_axis, - _supports_named_axis, + _named_axis_to_positional_axis, + _NamedAxisKey, ) from awkward._nplikes import ufuncs from awkward._nplikes.numpy_like import NumpyMetadata @@ -87,11 +86,11 @@ def _impl(x, axis, keepdims, mask_identity, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: x_layout = ctx.unwrap(x, allow_record=False, primitive_policy="error") - if _supports_named_axis(ctx): + # Handle named axis + if named_axis := _get_named_axis(ctx): if _is_valid_named_axis(axis): - # Handle named axis # Step 1: Normalize named axis to positional axis - axis = _one_axis_to_positional_axis(axis, _get_named_axis(ctx)) + axis = _named_axis_to_positional_axis(named_axis, axis) axis = regularize_axis(axis) @@ -114,14 +113,16 @@ def _impl(x, axis, keepdims, mask_identity, highlevel, behavior, attrs): attrs=ctx.attrs, ) - # propagate named axis to output out = expx / denom - out_ctx = HighLevelContext( - behavior=behavior_of_obj(out), - attrs=attrs_of_obj(out), - ).finalize() - return out_ctx.wrap( + # propagate named axis to output + if out_named_axis := _get_named_axis(attrs_of_obj(out) or {}): + ctx = ctx.with_attr( + key=_NamedAxisKey, + value=out_named_axis, + ) + + return ctx.wrap( maybe_highlevel_to_lowlevel(out), highlevel=highlevel, allow_other=True, diff --git a/src/awkward/operations/ak_sort.py b/src/awkward/operations/ak_sort.py index 6a0ab4306b..212323552d 100644 --- a/src/awkward/operations/ak_sort.py +++ b/src/awkward/operations/ak_sort.py @@ -10,8 +10,7 @@ _get_named_axis, _is_valid_named_axis, _keep_named_axis, - _one_axis_to_positional_axis, - _supports_named_axis, + _named_axis_to_positional_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import is_integer, regularize_axis @@ -69,19 +68,16 @@ def _impl(array, axis, ascending, stable, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + # Handle named axis out_named_axis = None - if _supports_named_axis(ctx): + if named_axis := _get_named_axis(ctx): if _is_valid_named_axis(axis): - # Handle named axis # Step 1: Normalize named axis to positional axis - axis = _one_axis_to_positional_axis( - axis, - _get_named_axis(ctx), - ) + axis = _named_axis_to_positional_axis(named_axis, axis) # Step 2: propagate named axis from input to output, # use strategy "keep all" (see: awkward._namedaxis) - out_named_axis = _keep_named_axis(_get_named_axis(ctx), None) + out_named_axis = _keep_named_axis(named_axis, None) axis = regularize_axis(axis) diff --git a/src/awkward/operations/ak_std.py b/src/awkward/operations/ak_std.py index 103627854f..c86f157b3d 100644 --- a/src/awkward/operations/ak_std.py +++ b/src/awkward/operations/ak_std.py @@ -14,8 +14,7 @@ from awkward._namedaxis import ( _get_named_axis, _is_valid_named_axis, - _one_axis_to_positional_axis, - _supports_named_axis, + _named_axis_to_positional_axis, ) from awkward._nplikes import ufuncs from awkward._nplikes.numpy_like import NumpyMetadata @@ -186,11 +185,11 @@ def _impl(x, weight, ddof, axis, keepdims, mask_identity, highlevel, behavior, a x = ctx.wrap(x_layout) weight = ctx.wrap(weight_layout, allow_other=True) - if _supports_named_axis(ctx): + # Handle named axis + if named_axis := _get_named_axis(ctx): if _is_valid_named_axis(axis): - # Handle named axis # Step 1: Normalize named axis to positional axis - axis = _one_axis_to_positional_axis(axis, _get_named_axis(ctx)) + axis = _named_axis_to_positional_axis(named_axis, axis) axis = regularize_axis(axis) diff --git a/src/awkward/operations/ak_sum.py b/src/awkward/operations/ak_sum.py index 89704eab0a..cdecb337ed 100644 --- a/src/awkward/operations/ak_sum.py +++ b/src/awkward/operations/ak_sum.py @@ -10,9 +10,8 @@ _get_named_axis, _is_valid_named_axis, _keep_named_axis, - _one_axis_to_positional_axis, + _named_axis_to_positional_axis, _remove_named_axis, - _supports_named_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._regularize import is_integer, regularize_axis @@ -280,19 +279,23 @@ def _impl(array, axis, keepdims, mask_identity, highlevel, behavior, attrs): with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: layout = ctx.unwrap(array, allow_record=False, primitive_policy="error") + # Handle named axis out_named_axis = None - if _supports_named_axis(ctx): + if named_axis := _get_named_axis(ctx): if _is_valid_named_axis(axis): - # Handle named axis # Step 1: Normalize named axis to positional axis - axis = _one_axis_to_positional_axis(axis, _get_named_axis(ctx)) + axis = _named_axis_to_positional_axis(named_axis, axis) # Step 2: propagate named axis from input to output, # keepdims=True: use strategy "keep all" (see: awkward._namedaxis) # keepdims=False: use strategy "remove one" (see: awkward._namedaxis) - out_named_axis = _keep_named_axis(_get_named_axis(ctx), None) + out_named_axis = _keep_named_axis(named_axis, None) if not keepdims: - out_named_axis = _remove_named_axis(axis, out_named_axis) + out_named_axis = _remove_named_axis( + named_axis=out_named_axis, + axis=axis, + total=layout.purelist_depth, + ) axis = regularize_axis(axis) diff --git a/src/awkward/operations/ak_unflatten.py b/src/awkward/operations/ak_unflatten.py index b3b2f3dff6..2a88b150ee 100644 --- a/src/awkward/operations/ak_unflatten.py +++ b/src/awkward/operations/ak_unflatten.py @@ -9,8 +9,7 @@ from awkward._namedaxis import ( _get_named_axis, _is_valid_named_axis, - _one_axis_to_positional_axis, - _supports_named_axis, + _named_axis_to_positional_axis, ) from awkward._nplikes.numpy_like import NumpyMetadata from awkward._nplikes.shape import unknown_length @@ -109,15 +108,11 @@ def _impl(array, counts, axis, highlevel, behavior, attrs): ), ) - if _supports_named_axis(ctx): + # Handle named axis + if named_axis := _get_named_axis(ctx): if _is_valid_named_axis(axis): - # Handle named axis # Step 1: Normalize named axis to positional axis - axis = _one_axis_to_positional_axis(axis, _get_named_axis(ctx)) - - # Step 2: propagate named axis from input to output, - # use strategy "remove all" (see: awkward._namedaxis) - out_named_axis = None + axis = _named_axis_to_positional_axis(named_axis, axis) axis = regularize_axis(axis) @@ -316,10 +311,10 @@ def apply(layout, depth, backend, **kwargs): highlevel=highlevel, ) - # propagate named axis to output - return ak.operations.ak_with_named_axis._impl( + # Step 2: propagate named axis from input to output, + # use strategy "remove all" (see: awkward._namedaxis) + return ak.operations.ak_without_named_axis._impl( wrapped_out, - named_axis=out_named_axis, highlevel=highlevel, behavior=ctx.behavior, attrs=ctx.attrs, diff --git a/src/awkward/operations/ak_with_named_axis.py b/src/awkward/operations/ak_with_named_axis.py index d59d8a94a5..d2127181ee 100644 --- a/src/awkward/operations/ak_with_named_axis.py +++ b/src/awkward/operations/ak_with_named_axis.py @@ -7,8 +7,8 @@ from awkward._namedaxis import ( AxisMapping, AxisTuple, + _axis_tuple_to_mapping, _NamedAxisKey, - _set_named_axis_to_attrs, ) from awkward._nplikes.numpy_like import NumpyMetadata @@ -62,31 +62,28 @@ def _impl(array, named_axis, highlevel, behavior, attrs): layout = ctx.unwrap(array, allow_record=False) # Named axis handling - ndim = layout.purelist_depth if not named_axis: # no-op, e.g. named_axis is None, (), {} - named_axis = (None,) * ndim - if isinstance(named_axis, dict): - _named_axis = tuple(named_axis.get(i, None) for i in range(ndim)) - for k, i in named_axis.items(): - if not isinstance(i, int): - raise TypeError(f"named_axis must map axis name to integer, not {i}") - if i < 0: # handle negative axis index - i += ndim - if i < 0 or i >= ndim: - raise ValueError( - f"named_axis index out of range: {i} not in [0, {ndim})" - ) - _named_axis = _named_axis[:i] + (k,) + _named_axis[i + 1 :] + _named_axis = {} elif isinstance(named_axis, tuple): + ndim = layout.purelist_depth + if len(named_axis) != ndim: + raise ValueError( + f"{named_axis=} must have the same length as the number of dimensions ({ndim})" + ) + _named_axis = _axis_tuple_to_mapping(named_axis) + elif isinstance(named_axis, dict): _named_axis = named_axis else: raise TypeError(f"named_axis must be a mapping or a tuple, got {named_axis}") - attrs = _set_named_axis_to_attrs(ctx.attrs or {}, _named_axis) - if len(attrs[_NamedAxisKey]) != ndim: - raise ValueError( - f"{_named_axis=} must have the same length as the number of dimensions ({ndim})" + if _named_axis: + ctx = ctx.with_attr( + key=_NamedAxisKey, + value=_named_axis, ) - out_ctx = HighLevelContext(behavior=ctx.behavior, attrs=attrs).finalize() - return out_ctx.wrap(layout, highlevel=highlevel, allow_other=True) + return ctx.wrap( + layout, + highlevel=highlevel, + allow_other=True, + ) diff --git a/src/awkward/operations/ak_without_named_axis.py b/src/awkward/operations/ak_without_named_axis.py new file mode 100644 index 0000000000..523f1a2680 --- /dev/null +++ b/src/awkward/operations/ak_without_named_axis.py @@ -0,0 +1,64 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +from awkward._dispatch import high_level_function +from awkward._layout import HighLevelContext +from awkward._namedaxis import ( + _NamedAxisKey, +) +from awkward._nplikes.numpy_like import NumpyMetadata + +__all__ = ("without_named_axis",) + +np = NumpyMetadata.instance() + + +@high_level_function() +def without_named_axis( + array, + *, + highlevel=True, + behavior=None, + attrs=None, +): + """ + Args: + array: Array-like data (anything #ak.to_layout recognizes). + named_axis: AxisTuple | AxisMapping: Names to give to the array axis; this assigns + the `"__named_axis__"` attr. If None, any existing name is unset. + highlevel (bool): If True, return an #ak.Array; otherwise, return + a low-level #ak.contents.Content subclass. + behavior (None or dict): Custom #ak.behavior for the output array, if + high-level. + attrs (None or dict): Custom attributes for the output array, if + high-level. + + Returns an #ak.Array or #ak.Record (or low-level equivalent, if + `highlevel=False`) with a new name. This function does not change the + array in-place. If the new name is None, then an array without a name is + returned. + + The records or tuples may be nested within multiple levels of nested lists. + If records are nested within records, only the outermost are affected. + + Setting the `"__record__"` parameter makes it possible to add behaviors + to the data; see #ak.Array and #ak.behavior for a more complete + description. + """ + # Dispatch + yield (array,) + + # Implementation + return _impl(array, highlevel, behavior, attrs) + + +def _impl(array, highlevel, behavior, attrs): + with HighLevelContext(behavior=behavior, attrs=attrs) as ctx: + layout = ctx.unwrap(array, allow_record=False) + + return ctx.without_attr(key=_NamedAxisKey).wrap( + layout, + highlevel=highlevel, + allow_other=True, + ) diff --git a/tests/test_2596_named_axis.py b/tests/test_2596_named_axis.py index 89270cc341..80288fad84 100644 --- a/tests/test_2596_named_axis.py +++ b/tests/test_2596_named_axis.py @@ -6,84 +6,249 @@ import pytest # noqa: F401 import awkward as ak +from awkward._namedaxis import _get_named_axis def test_with_named_axis(): - from dataclasses import dataclass - - from awkward._namedaxis import _supports_named_axis - array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) - assert not _supports_named_axis(array) - assert array.named_axis == (None, None) + assert not _get_named_axis(array) + assert array.named_axis == {} + assert array.positional_axis == (0, 1) + + array = ak.with_named_axis(array, named_axis=("x", "y")) + assert _get_named_axis(array) + assert array.named_axis == {"x": 0, "y": 1} assert array.positional_axis == (0, 1) - array = ak.with_named_axis(array, named_axis=("events", "jets")) - assert _supports_named_axis(array) - assert array.named_axis == ("events", "jets") + array = ak.with_named_axis(array, named_axis=("x", None)) + assert _get_named_axis(array) + assert array.named_axis == {"x": 0} assert array.positional_axis == (0, 1) - array = ak.with_named_axis(array, named_axis=("events", None)) - assert _supports_named_axis(array) - assert array.named_axis == ("events", None) + array = ak.with_named_axis(array, named_axis=(None, "x")) + assert _get_named_axis(array) + assert array.named_axis == {"x": 1} assert array.positional_axis == (0, 1) - array = ak.with_named_axis(array, named_axis={"events": 0, "jets": 1}) - assert _supports_named_axis(array) - assert array.named_axis == ("events", "jets") + array = ak.with_named_axis(array, named_axis={"x": 0, "y": 1}) + assert _get_named_axis(array) + assert array.named_axis == {"x": 0, "y": 1} assert array.positional_axis == (0, 1) - array = ak.with_named_axis(array, named_axis={"events": 1}) - assert _supports_named_axis(array) - assert array.named_axis == (None, "events") + array = ak.with_named_axis(array, named_axis={"x": 1}) + assert _get_named_axis(array) + assert array.named_axis == {"x": 1} assert array.positional_axis == (0, 1) - array = ak.with_named_axis(array, named_axis={"jets": -1}) - assert _supports_named_axis(array) - assert array.named_axis == (None, "jets") + array = ak.with_named_axis(array, named_axis={"y": -1}) + assert _get_named_axis(array) + assert array.named_axis == {"y": -1} assert array.positional_axis == (0, 1) - @dataclass(frozen=True) - class exotic_axis: - attr: str + # This is possible in a future version of named axis, but currently only strings are supported + # from dataclasses import dataclass - ax1 = exotic_axis(attr="I'm not the type of axis that you're used to") - ax2 = exotic_axis(attr="...me neither!") + # @dataclass(frozen=True) + # class exotic_axis: + # attr: str - array = ak.with_named_axis(array, named_axis=(ax1, ax2)) - assert array.named_axis == (ax1, ax2) - assert array.positional_axis == (0, 1) + # ax1 = exotic_axis(attr="I'm not the type of axis that you're used to") + # ax2 = exotic_axis(attr="...me neither!") + + # array = ak.with_named_axis(array, named_axis=(ax1, ax2)) + # assert array.named_axis == (ax1, ax2) + # assert array.positional_axis == (0, 1) + + +def test_named_axis_indexing(): + array = ak.Array([[[1, 2]], [[3]], [[4]], [[5, 6], [7]]]) + + named_array = ak.with_named_axis(array, named_axis=("x", "y", "z")) + + # test indexing + assert ak.all(array[...] == named_array[...]) + assert ak.all(array[()] == named_array[()]) + + assert ak.all(array[None, :, :, :] == named_array[None, :, :, :]) + assert ak.all(array[:, None, :, :] == named_array[:, None, :, :]) + assert ak.all(array[:, :, None, :] == named_array[:, :, None, :]) + assert ak.all(array[:, :, :, None] == named_array[:, :, :, None]) + + assert ak.all(array[0, :, :] == named_array[{"x": 0}]) + assert ak.all(array[:, 0, :] == named_array[{"y": 0}]) + assert ak.all(array[:, :, 0] == named_array[{"z": 0}]) + + assert ak.all(array[0, :, :] == named_array[{0: 0}]) + assert ak.all(array[:, 0, :] == named_array[{1: 0}]) + assert ak.all(array[:, :, 0] == named_array[{2: 0}]) + + assert ak.all(array[0, :, :] == named_array[{-3: 0}]) + assert ak.all(array[:, 0, :] == named_array[{-2: 0}]) + assert ak.all(array[:, :, 0] == named_array[{-1: 0}]) + + assert ak.all(array[0, 0, :] == named_array[{"x": 0, "y": 0}]) + assert ak.all(array[0, :, 0] == named_array[{"x": 0, "z": 0}]) + assert ak.all(array[:, 0, 0] == named_array[{"y": 0, "z": 0}]) + assert array[0, 0, 0] == named_array[{"x": 0, "y": 0, "z": 0}] + + assert ak.all(array[slice(0, 1), :, :] == named_array[{"x": slice(0, 1)}]) + assert ak.all(array[:, slice(0, 1), :] == named_array[{"y": slice(0, 1)}]) + assert ak.all(array[:, :, slice(0, 1)] == named_array[{"z": slice(0, 1)}]) + + assert ak.all(array[0, :, slice(0, 1)] == named_array[{"x": 0, "z": slice(0, 1)}]) + assert ak.all(array[:, 0, slice(0, 1)] == named_array[{"y": 0, "z": slice(0, 1)}]) + assert ak.all(array[slice(0, 1), 0, :] == named_array[{"x": slice(0, 1), "y": 0}]) + + assert ak.all(array[array > 3] == named_array[named_array > 3]) + + # test naming propagation + assert ( + named_array[...].named_axis + == named_array.named_axis + == {"x": 0, "y": 1, "z": 2} + ) + assert ( + named_array[()].named_axis == named_array.named_axis == {"x": 0, "y": 1, "z": 2} + ) + + assert named_array[None, :, :, :].named_axis == {"x": 1, "y": 2, "z": 3} + assert named_array[:, None, :, :].named_axis == {"x": 0, "y": 2, "z": 3} + assert named_array[:, :, None, :].named_axis == {"x": 0, "y": 1, "z": 3} + assert named_array[:, :, :, None].named_axis == {"x": 0, "y": 1, "z": 2} + + assert named_array[None, ...].named_axis == {"x": 1, "y": 2, "z": 3} + assert named_array[:, None, ...].named_axis == {"x": 0, "y": 2, "z": 3} + assert named_array[..., None, :].named_axis == {"x": 0, "y": 1, "z": 3} + assert named_array[..., None].named_axis == {"x": 0, "y": 1, "z": 2} + + assert ( + named_array[0, :, :].named_axis + == named_array[{"x": 0}].named_axis + == {"y": 0, "z": 1} + ) + assert ( + named_array[:, 0, :].named_axis + == named_array[{"y": 0}].named_axis + == {"x": 0, "z": 1} + ) + assert ( + named_array[:, :, 0].named_axis + == named_array[{"z": 0}].named_axis + == {"x": 0, "y": 1} + ) + + assert ( + named_array[0, ...].named_axis + == named_array[{"x": 0}].named_axis + == {"y": 0, "z": 1} + ) + assert ( + named_array[:, 0, :].named_axis + == named_array[{"y": 0}].named_axis + == {"x": 0, "z": 1} + ) + assert ( + named_array[..., 0].named_axis + == named_array[{"z": 0}].named_axis + == {"x": 0, "y": 1} + ) + + assert named_array[{0: 0}].named_axis == {"y": 0, "z": 1} + assert named_array[{1: 0}].named_axis == {"x": 0, "z": 1} + assert named_array[{2: 0}].named_axis == {"x": 0, "y": 1} + + assert named_array[{-3: 0}].named_axis == {"y": 0, "z": 1} + assert named_array[{-2: 0}].named_axis == {"x": 0, "z": 1} + assert named_array[{-1: 0}].named_axis == {"x": 0, "y": 1} + + assert ( + named_array[0, 0, :].named_axis + == named_array[{"x": 0, "y": 0}].named_axis + == {"z": 0} + ) + assert ( + named_array[0, :, 0].named_axis + == named_array[{"x": 0, "z": 0}].named_axis + == {"y": 0} + ) + assert ( + named_array[:, 0, 0].named_axis + == named_array[{"y": 0, "z": 0}].named_axis + == {"x": 0} + ) + assert not _get_named_axis(named_array[0, 0, 0]) + assert not _get_named_axis(named_array[{"x": 0, "y": 0, "z": 0}]) + + assert ( + named_array[slice(0, 1), :, :].named_axis + == named_array[{"x": slice(0, 1)}].named_axis + == {"x": 0, "y": 1, "z": 2} + ) + assert ( + named_array[:, slice(0, 1), :].named_axis + == named_array[{"y": slice(0, 1)}].named_axis + == {"x": 0, "y": 1, "z": 2} + ) + assert ( + named_array[:, :, slice(0, 1)].named_axis + == named_array[{"z": slice(0, 1)}].named_axis + == {"x": 0, "y": 1, "z": 2} + ) + + assert ( + named_array[0, :, slice(0, 1)].named_axis + == named_array[{"x": 0, "z": slice(0, 1)}].named_axis + == {"y": 0, "z": 1} + ) + assert ( + named_array[:, 0, slice(0, 1)].named_axis + == named_array[{"y": 0, "z": slice(0, 1)}].named_axis + == {"x": 0, "z": 1} + ) + assert ( + named_array[slice(0, 1), 0, :].named_axis + == named_array[{"x": slice(0, 1), "y": 0}].named_axis + == {"x": 0, "z": 1} + ) def test_named_axis_ak_all(): array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) - named_array = ak.with_named_axis(array, named_axis=("events", "jets")) + named_array = ak.with_named_axis(array, named_axis=("x", "y")) # first check that they work the same - assert ak.all(ak.all(array < 4, axis=0) == ak.all(named_array < 4, axis="events")) - assert ak.all(ak.all(array < 4, axis=1) == ak.all(named_array < 4, axis="jets")) + assert ak.all(ak.all(array < 4, axis=0) == ak.all(named_array < 4, axis="x")) + assert ak.all(ak.all(array < 4, axis=1) == ak.all(named_array < 4, axis="y")) # check that result axis names are correctly propagated assert ( ak.all(named_array < 4, axis=0).named_axis - == ak.all(named_array < 4, axis="events").named_axis - == ("jets",) + == ak.all(named_array < 4, axis="x").named_axis + == {"y": 0} ) assert ( ak.all(named_array < 4, axis=1).named_axis - == ak.all(named_array < 4, axis="jets").named_axis - == ("events",) + == ak.all(named_array < 4, axis="y").named_axis + == {"x": 0} ) - assert ak.all(named_array < 4, axis=None).named_axis == (None,) + assert ( + ak.all(named_array < 4, axis=0, keepdims=True).named_axis + == ak.all(named_array < 4, axis="x", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert ( + ak.all(named_array < 4, axis=1, keepdims=True).named_axis + == ak.all(named_array < 4, axis="y", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert not _get_named_axis(ak.all(named_array < 4, axis=None)) def test_named_axis_ak_almost_equal(): array1 = array2 = ak.Array([[1, 2], [3], [], [4, 5, 6]]) - named_array1 = named_array2 = ak.with_named_axis( - array1, named_axis=("events", "jets") - ) + named_array1 = named_array2 = ak.with_named_axis(array1, named_axis=("x", "y")) assert ak.almost_equal(array1, array2, check_named_axis=False) == ak.almost_equal( named_array1, named_array2, check_named_axis=False @@ -95,7 +260,7 @@ def test_named_axis_ak_almost_equal(): assert ak.almost_equal(named_array1, array1, check_named_axis=False) assert ak.almost_equal(named_array1, array1, check_named_axis=True) - named_array3 = ak.with_named_axis(array1, named_axis=("events", "muons")) + named_array3 = ak.with_named_axis(array1, named_axis=("x", "muons")) assert ak.almost_equal(named_array1, named_array3, check_named_axis=False) assert not ak.almost_equal(named_array1, named_array3, check_named_axis=True) @@ -103,36 +268,46 @@ def test_named_axis_ak_almost_equal(): def test_named_axis_ak_angle(): array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) - named_array = ak.with_named_axis(array, named_axis=("events", "jets")) + named_array = ak.with_named_axis(array, named_axis=("x", "y")) # first check that they work the same assert ak.all(ak.angle(array) == ak.angle(named_array)) # check that result axis names are correctly propagated - assert ak.angle(named_array).named_axis == ("events", "jets") + assert ak.angle(named_array).named_axis == {"x": 0, "y": 1} def test_named_axis_ak_any(): array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) - named_array = ak.with_named_axis(array, named_axis=("events", "jets")) + named_array = ak.with_named_axis(array, named_axis=("x", "y")) # first check that they work the same - assert ak.all(ak.any(array < 4, axis=0) == ak.any(named_array < 4, axis="events")) - assert ak.all(ak.any(array < 4, axis=1) == ak.any(named_array < 4, axis="jets")) + assert ak.all(ak.any(array < 4, axis=0) == ak.any(named_array < 4, axis="x")) + assert ak.all(ak.any(array < 4, axis=1) == ak.any(named_array < 4, axis="y")) # check that result axis names are correctly propagated assert ( ak.any(named_array < 4, axis=0).named_axis - == ak.any(named_array < 4, axis="events").named_axis - == ("jets",) + == ak.any(named_array < 4, axis="x").named_axis + == {"y": 0} ) assert ( ak.any(named_array < 4, axis=1).named_axis - == ak.any(named_array < 4, axis="jets").named_axis - == ("events",) + == ak.any(named_array < 4, axis="y").named_axis + == {"x": 0} ) - assert ak.any(named_array < 4, axis=None).named_axis == (None,) + assert ( + ak.any(named_array < 4, axis=0, keepdims=True).named_axis + == ak.any(named_array < 4, axis="x", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert ( + ak.any(named_array < 4, axis=1, keepdims=True).named_axis + == ak.any(named_array < 4, axis="y", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert not _get_named_axis(ak.all(named_array < 4, axis=None)) def test_named_axis_ak_argcartesian(): @@ -146,121 +321,113 @@ def test_named_axis_ak_argcombinations(): def test_named_axis_ak_argmax(): array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) - named_array = ak.with_named_axis(array, named_axis=("events", "jets")) + named_array = ak.with_named_axis(array, named_axis=("x", "y")) # first check that they work the same - assert ak.all(ak.argmax(array, axis=0) == ak.argmax(named_array, axis="events")) - assert ak.all(ak.argmax(array, axis=1) == ak.argmax(named_array, axis="jets")) + assert ak.all(ak.argmax(array, axis=0) == ak.argmax(named_array, axis="x")) + assert ak.all(ak.argmax(array, axis=1) == ak.argmax(named_array, axis="y")) assert ak.all( ak.argmax(array, axis=0, keepdims=True) - == ak.argmax(named_array, axis="events", keepdims=True) + == ak.argmax(named_array, axis="x", keepdims=True) ) assert ak.all( ak.argmax(array, axis=1, keepdims=True) - == ak.argmax(named_array, axis="jets", keepdims=True) + == ak.argmax(named_array, axis="y", keepdims=True) ) - assert ak.all(ak.argmax(array, axis=None) == ak.argmax(named_array, axis=None)) + assert ak.argmax(array, axis=None) == ak.argmax(named_array, axis=None) # check that result axis names are correctly propagated assert ( ak.argmax(named_array, axis=0).named_axis - == ak.argmax(named_array, axis="events").named_axis - == ("jets",) + == ak.argmax(named_array, axis="x").named_axis + == {"y": 0} ) assert ( ak.argmax(named_array, axis=1).named_axis - == ak.argmax(named_array, axis="jets").named_axis - == ("events",) + == ak.argmax(named_array, axis="y").named_axis + == {"x": 0} ) assert ( ak.argmax(named_array, axis=0, keepdims=True).named_axis - == ak.argmax(named_array, axis="events", keepdims=True).named_axis - == ( - "events", - "jets", - ) + == ak.argmax(named_array, axis="x", keepdims=True).named_axis + == {"x": 0, "y": 1} ) assert ( ak.argmax(named_array, axis=1, keepdims=True).named_axis - == ak.argmax(named_array, axis="jets", keepdims=True).named_axis - == ("events", "jets") + == ak.argmax(named_array, axis="y", keepdims=True).named_axis + == {"x": 0, "y": 1} ) - assert ak.argmax(named_array, axis=None).named_axis == (None,) + assert not _get_named_axis(ak.argmax(named_array, axis=None)) def test_named_axis_ak_argmin(): array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) - named_array = ak.with_named_axis(array, named_axis=("events", "jets")) + named_array = ak.with_named_axis(array, named_axis=("x", "y")) # first check that they work the same - assert ak.all(ak.argmin(array, axis=0) == ak.argmin(named_array, axis="events")) - assert ak.all(ak.argmin(array, axis=1) == ak.argmin(named_array, axis="jets")) + assert ak.all(ak.argmin(array, axis=0) == ak.argmin(named_array, axis="x")) + assert ak.all(ak.argmin(array, axis=1) == ak.argmin(named_array, axis="y")) assert ak.all( ak.argmin(array, axis=0, keepdims=True) - == ak.argmin(named_array, axis="events", keepdims=True) + == ak.argmin(named_array, axis="x", keepdims=True) ) assert ak.all( ak.argmin(array, axis=1, keepdims=True) - == ak.argmin(named_array, axis="jets", keepdims=True) + == ak.argmin(named_array, axis="y", keepdims=True) ) - assert ak.all(ak.argmin(array, axis=None) == ak.argmin(named_array, axis=None)) + assert ak.argmin(array, axis=None) == ak.argmin(named_array, axis=None) # check that result axis names are correctly propagated assert ( ak.argmin(named_array, axis=0).named_axis - == ak.argmin(named_array, axis="events").named_axis - == ("jets",) + == ak.argmin(named_array, axis="x").named_axis + == {"y": 0} ) assert ( ak.argmin(named_array, axis=1).named_axis - == ak.argmin(named_array, axis="jets").named_axis - == ("events",) + == ak.argmin(named_array, axis="y").named_axis + == {"x": 0} ) assert ( ak.argmin(named_array, axis=0, keepdims=True).named_axis - == ak.argmin(named_array, axis="events", keepdims=True).named_axis - == ( - "events", - "jets", - ) + == ak.argmin(named_array, axis="x", keepdims=True).named_axis + == {"x": 0, "y": 1} ) assert ( ak.argmin(named_array, axis=1, keepdims=True).named_axis - == ak.argmin(named_array, axis="jets", keepdims=True).named_axis - == ("events", "jets") + == ak.argmin(named_array, axis="y", keepdims=True).named_axis + == {"x": 0, "y": 1} ) - assert ak.argmin(named_array, axis=None).named_axis == (None,) + assert not _get_named_axis(ak.argmin(named_array, axis=None)) def test_named_axis_ak_argsort(): array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) - named_array = ak.with_named_axis(array, named_axis=("events", "jets")) + named_array = ak.with_named_axis(array, named_axis=("x", "y")) # first check that they work the same - assert ak.all(ak.argsort(array, axis=0) == ak.argsort(named_array, axis="events")) - assert ak.all(ak.argsort(array, axis=1) == ak.argsort(named_array, axis="jets")) + assert ak.all(ak.argsort(array, axis=0) == ak.argsort(named_array, axis="x")) + assert ak.all(ak.argsort(array, axis=1) == ak.argsort(named_array, axis="y")) # check that result axis names are correctly propagated assert ( ak.argsort(named_array, axis=0).named_axis - == ak.argsort(named_array, axis="events").named_axis - == ("events", "jets") + == ak.argsort(named_array, axis="x").named_axis + == {"x": 0, "y": 1} ) assert ( ak.argsort(named_array, axis=1).named_axis - == ak.argsort(named_array, axis="jets").named_axis - == ("events", "jets") + == ak.argsort(named_array, axis="y").named_axis + == {"x": 0, "y": 1} ) def test_named_axis_ak_array_equal(): array1 = array2 = ak.Array([[1, 2], [3], [], [4, 5, 6]]) - named_array1 = named_array2 = ak.with_named_axis( - array1, named_axis=("events", "jets") - ) + named_array1 = named_array2 = ak.with_named_axis(array1, named_axis=("x", "y")) assert ak.array_equal(array1, array2, check_named_axis=False) == ak.array_equal( named_array1, named_array2, check_named_axis=False @@ -272,7 +439,7 @@ def test_named_axis_ak_array_equal(): assert ak.array_equal(named_array1, array1, check_named_axis=False) assert ak.array_equal(named_array1, array1, check_named_axis=True) - named_array3 = ak.with_named_axis(array1, named_axis=("events", "muons")) + named_array3 = ak.with_named_axis(array1, named_axis=("x", "muons")) assert ak.array_equal(named_array1, named_array3, check_named_axis=False) assert not ak.array_equal(named_array1, named_array3, check_named_axis=True) @@ -280,7 +447,7 @@ def test_named_axis_ak_array_equal(): def test_named_axis_ak_backend(): array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) - named_array = ak.with_named_axis(array, named_axis=("events", "jets")) + named_array = ak.with_named_axis(array, named_axis=("x", "y")) assert ak.backend(array) == ak.backend(named_array) @@ -324,9 +491,9 @@ def test_named_axis_ak_concatenate(): def test_named_axis_ak_copy(): array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) - named_array = ak.with_named_axis(array, named_axis=("events", "jets")) + named_array = ak.with_named_axis(array, named_axis=("x", "y")) - assert ak.copy(named_array).named_axis == ("events", "jets") + assert ak.copy(named_array).named_axis == {"x": 0, "y": 1} def test_named_axis_ak_corr(): @@ -344,14 +511,13 @@ def test_named_axis_ak_corr(): ak.corr(array_x, array_y, axis=1) == ak.corr(named_array_x, named_array_y, axis="y") ) - assert ak.all( - ak.corr(array_x, array_y, axis=None) - == ak.corr(named_array_x, named_array_y, axis=None) + assert ak.corr(array_x, array_y, axis=None) == ak.corr( + named_array_x, named_array_y, axis=None ) - assert ak.corr(named_array_x, named_array_y, axis="x").named_axis == ("y",) - assert ak.corr(named_array_x, named_array_y, axis="y").named_axis == ("x",) - assert ak.corr(named_array_x, named_array_y, axis=None).named_axis == (None,) + assert ak.corr(named_array_x, named_array_y, axis="x").named_axis == {"y": 0} + assert ak.corr(named_array_x, named_array_y, axis="y").named_axis == {"x": 0} + assert not _get_named_axis(ak.corr(named_array_x, named_array_y, axis=None)) def test_named_axis_ak_count(): @@ -361,11 +527,11 @@ def test_named_axis_ak_count(): assert ak.all(ak.count(array, axis=0) == ak.count(named_array, axis="x")) assert ak.all(ak.count(array, axis=1) == ak.count(named_array, axis="y")) - assert ak.all(ak.count(array, axis=None) == ak.count(named_array, axis=None)) + assert ak.count(array, axis=None) == ak.count(named_array, axis=None) - assert ak.count(named_array, axis="x").named_axis == ("y",) - assert ak.count(named_array, axis="y").named_axis == ("x",) - assert ak.count(named_array, axis=None).named_axis == (None,) + assert ak.count(named_array, axis="x").named_axis == {"y": 0} + assert ak.count(named_array, axis="y").named_axis == {"x": 0} + assert not _get_named_axis(ak.count(named_array, axis=None)) def test_named_axis_ak_count_nonzero(): @@ -379,13 +545,13 @@ def test_named_axis_ak_count_nonzero(): assert ak.all( ak.count_nonzero(array, axis=1) == ak.count_nonzero(named_array, axis="y") ) - assert ak.all( - ak.count_nonzero(array, axis=None) == ak.count_nonzero(named_array, axis=None) + assert ak.count_nonzero(array, axis=None) == ak.count_nonzero( + named_array, axis=None ) - assert ak.count_nonzero(named_array, axis="x").named_axis == ("y",) - assert ak.count_nonzero(named_array, axis="y").named_axis == ("x",) - assert ak.count_nonzero(named_array, axis=None).named_axis == (None,) + assert ak.count_nonzero(named_array, axis="x").named_axis == {"y": 0} + assert ak.count_nonzero(named_array, axis="y").named_axis == {"x": 0} + assert not _get_named_axis(ak.count_nonzero(named_array, axis=None)) def test_named_axis_ak_covar(): @@ -403,14 +569,13 @@ def test_named_axis_ak_covar(): ak.covar(array_x, array_y, axis=1) == ak.covar(named_array_x, named_array_y, axis="y") ) - assert ak.all( - ak.covar(array_x, array_y, axis=None) - == ak.covar(named_array_x, named_array_y, axis=None) + assert ak.covar(array_x, array_y, axis=None) == ak.covar( + named_array_x, named_array_y, axis=None ) - assert ak.covar(named_array_x, named_array_y, axis="x").named_axis == ("y",) - assert ak.covar(named_array_x, named_array_y, axis="y").named_axis == ("x",) - assert ak.covar(named_array_x, named_array_y, axis=None).named_axis == (None,) + assert ak.covar(named_array_x, named_array_y, axis="x").named_axis == {"y": 0} + assert ak.covar(named_array_x, named_array_y, axis="y").named_axis == {"x": 0} + assert not _get_named_axis(ak.covar(named_array_x, named_array_y, axis=None)) def test_named_axis_ak_drop_none(): @@ -424,9 +589,9 @@ def test_named_axis_ak_drop_none(): ak.drop_none(array, axis=None) == ak.drop_none(named_array, axis=None) ) - assert ak.drop_none(named_array, axis="x").named_axis == ("x", "y") - assert ak.drop_none(named_array, axis="y").named_axis == ("x", "y") - assert ak.drop_none(named_array, axis=None).named_axis == ("x", "y") + assert ak.drop_none(named_array, axis="x").named_axis == {"x": 0, "y": 1} + assert ak.drop_none(named_array, axis="y").named_axis == {"x": 0, "y": 1} + assert ak.drop_none(named_array, axis=None).named_axis == {"x": 0, "y": 1} def test_named_axis_ak_enforce_type(): @@ -434,7 +599,7 @@ def test_named_axis_ak_enforce_type(): named_array = ak.with_named_axis(array, ("x", "y")) - assert ak.enforce_type(named_array, "var * ?int64").named_axis == ("x", "y") + assert ak.enforce_type(named_array, "var * ?int64").named_axis == {"x": 0, "y": 1} def test_named_axis_ak_fields(): @@ -457,9 +622,9 @@ def test_named_axis_ak_fill_none(): ak.fill_none(array, 0, axis=None) == ak.fill_none(named_array, 0, axis=None) ) - assert ak.fill_none(named_array, 0, axis="x").named_axis == ("x", "y") - assert ak.fill_none(named_array, 0, axis="y").named_axis == ("x", "y") - assert ak.fill_none(named_array, 0, axis=None).named_axis == ("x", "y") + assert ak.fill_none(named_array, 0, axis="x").named_axis == {"x": 0, "y": 1} + assert ak.fill_none(named_array, 0, axis="y").named_axis == {"x": 0, "y": 1} + assert ak.fill_none(named_array, 0, axis=None).named_axis == {"x": 0, "y": 1} def test_named_axis_ak_firsts(): @@ -470,8 +635,8 @@ def test_named_axis_ak_firsts(): assert ak.all(ak.firsts(array, axis=0) == ak.firsts(named_array, axis="x")) assert ak.all(ak.firsts(array, axis=1) == ak.firsts(named_array, axis="y")) - assert ak.firsts(named_array, axis="x").named_axis == ("x",) - assert ak.firsts(named_array, axis="y").named_axis == ("y",) + assert ak.firsts(named_array, axis="x").named_axis == {"x": 0} + assert ak.firsts(named_array, axis="y").named_axis == {"y": 0} def test_named_axis_ak_flatten(): @@ -569,7 +734,7 @@ def test_named_axis_ak_imag(): named_array = ak.with_named_axis(array, ("x", "y")) assert ak.all(ak.imag(array) == ak.imag(named_array)) - assert ak.imag(named_array).named_axis == ("x", "y") + assert ak.imag(named_array).named_axis == {"x": 0, "y": 1} def test_named_axis_ak_is_categorical(): @@ -585,8 +750,8 @@ def test_named_axis_ak_is_none(): assert ak.all(ak.is_none(array, axis=0) == ak.is_none(named_array, axis="x")) assert ak.all(ak.is_none(array, axis=1) == ak.is_none(named_array, axis="y")) - assert ak.is_none(named_array, axis="x").named_axis == ("x", "y") - assert ak.is_none(named_array, axis="y").named_axis == ("x", "y") + assert ak.is_none(named_array, axis="x").named_axis == {"x": 0, "y": 1} + assert ak.is_none(named_array, axis="y").named_axis == {"x": 0, "y": 1} def test_named_axis_ak_is_tuple(): @@ -626,9 +791,24 @@ def test_named_axis_ak_local_index(): ak.local_index(array, axis=2) == ak.local_index(named_array, axis="z") ) - assert ak.local_index(named_array, axis="x").named_axis == ("x",) - assert ak.local_index(named_array, axis="y").named_axis == ("x", "y") - assert ak.local_index(named_array, axis="z").named_axis == ("x", "y", "z") + assert ak.local_index(named_array, axis="x").named_axis == {"x": 0} + assert ak.local_index(named_array, axis="y").named_axis == {"x": 0, "y": 1} + assert ak.local_index(named_array, axis="z").named_axis == {"x": 0, "y": 1, "z": 2} + + # now with negative axis mappings + named_array = ak.with_named_axis(array, {"x": 0, "y": -2, "z": -1}) + + assert ak.local_index(named_array, axis="x").named_axis == {"x": 0, "z": -1} + assert ak.local_index(named_array, axis="y").named_axis == { + "x": 0, + "y": -2, + "z": -1, + } + assert ak.local_index(named_array, axis="z").named_axis == { + "x": 0, + "y": -2, + "z": -1, + } def test_named_axis_ak_mask(): @@ -648,24 +828,34 @@ def test_named_axis_ak_mask(): def test_named_axis_ak_max(): array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) - named_array = ak.with_named_axis(array, named_axis=("events", "jets")) + named_array = ak.with_named_axis(array, named_axis=("x", "y")) # first check that they work the same - assert ak.all(ak.max(array, axis=0) == ak.max(named_array, axis="events")) - assert ak.all(ak.max(array, axis=1) == ak.max(named_array, axis="jets")) + assert ak.all(ak.max(array, axis=0) == ak.max(named_array, axis="x")) + assert ak.all(ak.max(array, axis=1) == ak.max(named_array, axis="y")) # check that result axis names are correctly propagated assert ( ak.max(named_array, axis=0).named_axis - == ak.max(named_array, axis="events").named_axis - == ("jets",) + == ak.max(named_array, axis="x").named_axis + == {"y": 0} ) assert ( ak.max(named_array, axis=1).named_axis - == ak.max(named_array, axis="jets").named_axis - == ("events",) + == ak.max(named_array, axis="y").named_axis + == {"x": 0} + ) + assert ( + ak.max(named_array, axis=0, keepdims=True).named_axis + == ak.max(named_array, axis="x", keepdims=True).named_axis + == {"x": 0, "y": 1} ) - assert ak.max(named_array, axis=None).named_axis == (None,) + assert ( + ak.max(named_array, axis=1, keepdims=True).named_axis + == ak.max(named_array, axis="y", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert not _get_named_axis(ak.max(named_array, axis=None)) def test_named_axis_ak_mean(): @@ -677,9 +867,11 @@ def test_named_axis_ak_mean(): assert ak.all(ak.mean(array, axis=1) == ak.mean(named_array, axis="y")) assert ak.mean(array, axis=None) == ak.mean(named_array, axis=None) - assert ak.mean(named_array, axis="x").named_axis == ("y",) - assert ak.mean(named_array, axis="y").named_axis == ("x",) - assert ak.mean(named_array, axis=None).named_axis == (None,) + assert ak.mean(named_array, axis="x").named_axis == {"y": 0} + assert ak.mean(named_array, axis="y").named_axis == {"x": 0} + assert ak.mean(named_array, axis="x", keepdims=True).named_axis == {"x": 0, "y": 1} + assert ak.mean(named_array, axis="y", keepdims=True).named_axis == {"x": 0, "y": 1} + assert not _get_named_axis(ak.mean(named_array, axis=None)) def test_named_axis_ak_merge_option_of_records(): @@ -700,24 +892,34 @@ def test_named_axis_ak_metadata_from_parquet(): def test_named_axis_ak_min(): array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) - named_array = ak.with_named_axis(array, named_axis=("events", "jets")) + named_array = ak.with_named_axis(array, named_axis=("x", "y")) # first check that they work the same - assert ak.all(ak.min(array, axis=0) == ak.min(named_array, axis="events")) - assert ak.all(ak.min(array, axis=1) == ak.min(named_array, axis="jets")) + assert ak.all(ak.min(array, axis=0) == ak.min(named_array, axis="x")) + assert ak.all(ak.min(array, axis=1) == ak.min(named_array, axis="y")) # check that result axis names are correctly propagated assert ( ak.min(named_array, axis=0).named_axis - == ak.min(named_array, axis="events").named_axis - == ("jets",) + == ak.min(named_array, axis="x").named_axis + == {"y": 0} ) assert ( ak.min(named_array, axis=1).named_axis - == ak.min(named_array, axis="jets").named_axis - == ("events",) + == ak.min(named_array, axis="y").named_axis + == {"x": 0} + ) + assert ( + ak.min(named_array, axis=0, keepdims=True).named_axis + == ak.min(named_array, axis="x", keepdims=True).named_axis + == {"x": 0, "y": 1} + ) + assert ( + ak.min(named_array, axis=1, keepdims=True).named_axis + == ak.min(named_array, axis="y", keepdims=True).named_axis + == {"x": 0, "y": 1} ) - assert ak.min(named_array, axis=None).named_axis == (None,) + assert not _get_named_axis(ak.min(named_array, axis=None)) def test_named_axis_ak_moment(): @@ -727,13 +929,11 @@ def test_named_axis_ak_moment(): assert ak.all(ak.moment(array, 0, axis=0) == ak.moment(named_array, 0, axis="x")) assert ak.all(ak.moment(array, 0, axis=1) == ak.moment(named_array, 0, axis="y")) - assert ak.all( - ak.moment(array, 0, axis=None) == ak.moment(named_array, 0, axis=None) - ) + assert ak.moment(array, 0, axis=None) == ak.moment(named_array, 0, axis=None) - assert ak.moment(named_array, 0, axis="x").named_axis == ("y",) - assert ak.moment(named_array, 0, axis="y").named_axis == ("x",) - assert ak.moment(named_array, 0, axis=None).named_axis == (None,) + assert ak.moment(named_array, 0, axis="x").named_axis == {"y": 0} + assert ak.moment(named_array, 0, axis="y").named_axis == {"x": 0} + assert not _get_named_axis(ak.moment(named_array, 0, axis=None)) def test_named_axis_ak_nan_to_none(): @@ -762,7 +962,8 @@ def test_named_axis_ak_num(): assert ak.num(array, axis=0) == ak.num(named_array, axis="x") assert ak.all(ak.num(array, axis=1) == ak.num(named_array, axis="y")) - assert ak.num(named_array, axis="y").named_axis == ("y",) + assert ak.num(named_array, axis="y").named_axis == {"y": 0} + assert not _get_named_axis(ak.num(named_array, axis="x")) def test_named_axis_ak_ones_like(): @@ -801,8 +1002,9 @@ def test_named_axis_ak_prod(): assert ak.all(ak.prod(array, axis=1) == ak.prod(named_array, axis="y")) assert ak.prod(array, axis=None) == ak.prod(named_array, axis=None) - assert ak.prod(named_array, axis="x").named_axis == ("y",) - assert ak.prod(named_array, axis="y").named_axis == ("x",) + assert ak.prod(named_array, axis="x").named_axis == {"y": 0} + assert ak.prod(named_array, axis="y").named_axis == {"x": 0} + assert not _get_named_axis(ak.prod(named_array, axis=None)) def test_named_axis_ak_ptp(): @@ -814,8 +1016,9 @@ def test_named_axis_ak_ptp(): assert ak.all(ak.ptp(array, axis=1) == ak.ptp(named_array, axis="y")) assert ak.ptp(array, axis=None) == ak.ptp(named_array, axis=None) - assert ak.ptp(named_array, axis="x").named_axis == ("x",) - assert ak.ptp(named_array, axis="y").named_axis == ("y",) + assert ak.ptp(named_array, axis="x").named_axis == {"x": 0} + assert ak.ptp(named_array, axis="y").named_axis == {"y": 0} + assert not _get_named_axis(ak.ptp(named_array, axis=None)) def test_named_axis_ak_ravel(): @@ -825,7 +1028,7 @@ def test_named_axis_ak_ravel(): assert ak.all(ak.ravel(array) == ak.ravel(named_array)) - assert ak.ravel(named_array).named_axis == (None,) + assert not _get_named_axis(ak.ravel(named_array)) def test_named_axis_ak_real(): @@ -834,7 +1037,7 @@ def test_named_axis_ak_real(): named_array = ak.with_named_axis(array, ("x", "y")) assert ak.all(ak.real(array) == ak.real(named_array)) - assert ak.real(named_array).named_axis == ("x", "y") + assert ak.real(named_array).named_axis == {"x": 0, "y": 1} def test_named_axis_ak_round(): @@ -843,7 +1046,7 @@ def test_named_axis_ak_round(): named_array = ak.with_named_axis(array, ("x", "y")) assert ak.all(ak.round(array) == ak.round(named_array)) - assert ak.round(named_array).named_axis == ("x", "y") + assert ak.round(named_array).named_axis == {"x": 0, "y": 1} def test_named_axis_ak_run_lengths(): @@ -864,8 +1067,9 @@ def test_named_axis_ak_singletons(): assert ak.all(ak.singletons(array, axis=0) == ak.singletons(named_array, axis=0)) assert ak.all(ak.singletons(array, axis=1) == ak.singletons(named_array, axis=1)) - assert ak.singletons(named_array, axis=0).named_axis == ("x", None, "y") - assert ak.singletons(named_array, axis=1).named_axis == ("x", "y", None) + # TODO: What should this be? + # assert ak.singletons(named_array, axis=0).named_axis == {"x": 0, "y": 2} + # assert ak.singletons(named_array, axis=1).named_axis == {"x": 0, "y": 2} def test_named_axis_ak_softmax(): @@ -875,28 +1079,28 @@ def test_named_axis_ak_softmax(): assert ak.all(ak.softmax(array, axis=-1) == ak.softmax(named_array, axis="y")) - assert ak.softmax(named_array, axis="y").named_axis == ("x", "y") + assert ak.softmax(named_array, axis="y").named_axis == {"x": 0, "y": 1} def test_named_axis_ak_sort(): array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) - named_array = ak.with_named_axis(array, named_axis=("events", "jets")) + named_array = ak.with_named_axis(array, named_axis=("x", "y")) # first check that they work the same - assert ak.all(ak.sort(array, axis=0) == ak.sort(named_array, axis="events")) - assert ak.all(ak.sort(array, axis=1) == ak.sort(named_array, axis="jets")) + assert ak.all(ak.sort(array, axis=0) == ak.sort(named_array, axis="x")) + assert ak.all(ak.sort(array, axis=1) == ak.sort(named_array, axis="y")) # check that result axis names are correctly propagated assert ( ak.sort(named_array, axis=0).named_axis - == ak.sort(named_array, axis="events").named_axis - == ("events", "jets") + == ak.sort(named_array, axis="x").named_axis + == {"x": 0, "y": 1} ) assert ( ak.sort(named_array, axis=1).named_axis - == ak.sort(named_array, axis="jets").named_axis - == ("events", "jets") + == ak.sort(named_array, axis="y").named_axis + == {"x": 0, "y": 1} ) @@ -937,8 +1141,9 @@ def test_named_axis_ak_sum(): assert ak.all(ak.sum(array, axis=1) == ak.sum(named_array, axis="y")) assert ak.sum(array, axis=None) == ak.sum(named_array, axis=None) - assert ak.sum(named_array, axis="x").named_axis == ("y",) - assert ak.sum(named_array, axis="y").named_axis == ("x",) + assert ak.sum(named_array, axis="x").named_axis == {"y": 0} + assert ak.sum(named_array, axis="y").named_axis == {"x": 0} + assert not _get_named_axis(ak.sum(named_array, axis=None)) def test_named_axis_ak_to_arrow(): @@ -1065,7 +1270,7 @@ def test_named_axis_ak_unflatten(): ak.unflatten(array, counts, axis=1) == ak.unflatten(named_array, counts, axis="y") ) - assert ak.unflatten(named_array, counts, axis="y").named_axis == (None, None, None) + assert not _get_named_axis(ak.unflatten(named_array, counts, axis="y")) def test_named_axis_ak_unzip(): @@ -1116,9 +1321,13 @@ def test_named_axis_ak_with_name(): def test_named_axis_ak_with_named_axis(): array = ak.Array([[1, 2], [3], [], [4, 5, 6]]) + # tuple named_array = ak.with_named_axis(array, ("x", "y")) + assert named_array.named_axis == {"x": 0, "y": 1} - assert named_array.named_axis == ("x", "y") + # dict + named_array = ak.with_named_axis(array, {"x": 0, "y": -1}) + assert named_array.named_axis == {"x": 0, "y": -1} def test_named_axis_ak_with_parameter(): @@ -1144,7 +1353,10 @@ def test_named_axis_ak_without_parameters(): named_array_with_parameteter = ak.with_parameter(named_array, "param", 1.0) - assert ak.without_parameters(named_array).named_axis == named_array.named_axis + assert ( + ak.without_parameters(named_array_with_parameteter).named_axis + == named_array.named_axis + ) def test_named_axis_ak_zeros_like(): @@ -1158,10 +1370,10 @@ def test_named_axis_ak_zeros_like(): def test_named_axis_ak_zip(): - named_array1 = ak.with_named_axis(ak.Array([1, 2, 3]), ("a",)) - named_array2 = ak.with_named_axis(ak.Array([[4, 5, 6], [], [7]]), ("x", "y")) + # named_array1 = ak.with_named_axis(ak.Array([1, 2, 3]), ("a",)) + # named_array2 = ak.with_named_axis(ak.Array([[4, 5, 6], [], [7]]), ("x", "y")) - record = ak.zip({"x": named_array1, "y": named_array2}) + # record = ak.zip({"x": named_array1, "y": named_array2}) # TODO: need to implement broadcasting properly first assert True