Skip to content

Commit

Permalink
feat: ArgMax and ArgMin
Browse files Browse the repository at this point in the history
  • Loading branch information
ajstarna authored and senysenyseny16 committed Jul 30, 2024
1 parent 9a58fec commit 536a299
Show file tree
Hide file tree
Showing 4 changed files with 279 additions and 1 deletion.
1 change: 1 addition & 0 deletions onnx2torch/node_converters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from onnx2torch.node_converters.activations import *
from onnx2torch.node_converters.arg_extrema import *
from onnx2torch.node_converters.average_pool import *
from onnx2torch.node_converters.batch_norm import *
from onnx2torch.node_converters.binary_math_operations import *
Expand Down
87 changes: 87 additions & 0 deletions onnx2torch/node_converters/arg_extrema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
__all__ = [
"OnnxArgExtremumOld",
"OnnxArgExtremum",
]

from typing import Optional

import torch
import torch.nn.functional as F
from torch import nn

from onnx2torch.node_converters.registry import add_converter
from onnx2torch.onnx_graph import OnnxGraph
from onnx2torch.onnx_node import OnnxNode
from onnx2torch.utils.common import OnnxToTorchModule
from onnx2torch.utils.common import OperationConverterResult
from onnx2torch.utils.common import onnx_mapping_from_node

DEFAULT_AXIS = 0
DEFAULT_KEEPDIMS = 1
DEFAULT_SELECT_LAST_INDEX = 0

_TORCH_FUNCTION_FROM_ONNX_TYPE = {
"ArgMax": torch.argmax,
"ArgMin": torch.argmin,
}


class OnnxArgExtremumOld(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring
def __init__(self, operation_type: str, axis: int, keepdims: int):
super().__init__()
self.axis = axis
self.keepdims = bool(keepdims)
self.extremum_function = _TORCH_FUNCTION_FROM_ONNX_TYPE[operation_type]

def forward(self, data: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring
return self.extremum_function(data, dim=self.axis, keepdim=self.keepdims)


class OnnxArgExtremum(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring
def __init__(self, operation_type: str, axis: int, keepdims: int, select_last_index: int):
super().__init__()
self.axis = axis
self.keepdims = bool(keepdims)
self.select_last_index = bool(select_last_index)
self.extremum_function = _TORCH_FUNCTION_FROM_ONNX_TYPE[operation_type]

def forward(self, data: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring
if self.select_last_index:
# torch's argmax does not handle the select_last_index attribute from Onnx.
# We flip the data, call the normal argmax, then map it back to the original
flipped = torch.flip(data, dims=[self.axis])

extremum_index_flipped = self.extremum_function(flipped, dim=self.axis, keepdim=self.keepdims)
extremum_index_original = data.size(dim=self.axis) - 1 - extremum_index_flipped
return extremum_index_original
else:
return self.extremum_function(data, dim=self.axis, keepdim=self.keepdims)


@add_converter(operation_type="ArgMax", version=12)
@add_converter(operation_type="ArgMax", version=13)
@add_converter(operation_type="ArgMin", version=12)
@add_converter(operation_type="ArgMin", version=13)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
return OperationConverterResult(
torch_module=OnnxArgExtremum(
operation_type=node.operation_type,
axis=node.attributes.get("axis", DEFAULT_AXIS),
keepdims=node.attributes.get("keepdims", DEFAULT_KEEPDIMS),
select_last_index=node.attributes.get("select_last_index", DEFAULT_SELECT_LAST_INDEX),
),
onnx_mapping=onnx_mapping_from_node(node=node),
)


@add_converter(operation_type="ArgMax", version=11)
@add_converter(operation_type="ArgMin", version=11)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
return OperationConverterResult(
torch_module=OnnxArgExtremumOld(
operation_type=node.operation_type,
axis=node.attributes.get("axis", DEFAULT_AXIS),
keepdims=node.attributes.get("keepdims", DEFAULT_KEEPDIMS),
),
onnx_mapping=onnx_mapping_from_node(node=node),
)
190 changes: 190 additions & 0 deletions tests/node_converters/arg_extrema_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
from pathlib import Path

Check failure on line 1 in tests/node_converters/arg_extrema_test.py

View workflow job for this annotation

GitHub Actions / Python format

Imports are incorrectly sorted and/or formatted.

import numpy as np
import onnx
from onnx.helper import make_tensor_value_info
import pytest
import torch

from tests.utils.common import check_onnx_model
from tests.utils.common import make_model_from_nodes


@pytest.mark.parametrize(
"op_type",
(
"ArgMax",
"ArgMin",
),
)
@pytest.mark.parametrize(
"opset_version",
(
11,
12,
13,
),
)
@pytest.mark.parametrize(
"dims,axis",
(
(1, 0),
(2, 0),
(2, 1),
(3, 0),
(3, 1),
(3, 2),
(4, 0),
(4, 1),
(4, 2),
(4, 3),
),
)
@pytest.mark.parametrize(
"keepdims",
(
0,
1,
),
)
@pytest.mark.parametrize(
"select_last_index",
(0, 1),
)
def test_arg_max_arg_min( # pylint: disable=missing-function-docstring
op_type: str,
opset_version: int,
dims: int,
axis: int,
keepdims: int,
select_last_index: int,
) -> None:
input_shape = [3] * dims # arbitrary magnitude in each dimension
test_inputs = {"data": np.random.randn(*input_shape).astype(np.float32)}

kwargs = {"keepdims": keepdims, "axis": axis}
if opset_version >= 12:
# since opset_version 12, we can specify whether to return the LAST index
# of the max/min (respectively) occurance
kwargs["select_last_index"] = select_last_index

node = onnx.helper.make_node(op_type=op_type, inputs=["data"], outputs=["reduced"], **kwargs)

# we need to specify outputs_info, since the required output type for arg max (int64)
# is different than the input type
outputs_info = [make_tensor_value_info(name="reduced", elem_type=onnx.TensorProto.INT64, shape=None)]

model = make_model_from_nodes(
nodes=node,
initializers={},
inputs_example=test_inputs,
outputs_info=outputs_info,
opset_version=opset_version,
)

check_onnx_model(model, test_inputs)

# Test once again with input we know to all be the same.
# This is a way to force the testing of the select_last_index attribute.
# We need the min/max index to occur more than once.
test_inputs2 = {"data": np.ones_like(test_inputs["data"])}
check_onnx_model(model, test_inputs2)


class ArgMaxModel(torch.nn.Module):
def __init__(self, axis: int, keepdims: bool):
super().__init__()
self.axis = axis
self.keepdims = bool(keepdims)

def forward(self, data: torch.Tensor) -> torch.Tensor:
return torch.argmax(data, dim=self.axis, keepdim=self.keepdims)


class ArgMinModel(torch.nn.Module):
def __init__(self, axis: int, keepdims: bool):
super().__init__()
self.axis = axis
self.keepdims = bool(keepdims)

def forward(self, data: torch.Tensor) -> torch.Tensor:
return torch.argmin(data, dim=self.axis, keepdim=self.keepdims)


@pytest.mark.parametrize(
"op_type",
(
"ArgMax",
"ArgMin",
),
)
@pytest.mark.parametrize(
"opset_version",
(
11,
12,
13,
),
)
@pytest.mark.parametrize(
"dims,axis",
(
(1, 0),
(2, 0),
(2, 1),
(3, 0),
(3, 1),
(3, 2),
(4, 0),
(4, 1),
(4, 2),
(4, 3),
),
)
@pytest.mark.parametrize(
"keepdims",
(
0,
1,
),
)
def test_start_from_torch_module(
op_type: str,
opset_version: int,
dims: int,
axis: int,
keepdims: int,
tmp_path: Path,
) -> None:
"""
Test starting from a torch module, export to Onnx, then converting back to torch.
"""
if op_type == "ArgMax":
model = ArgMaxModel(axis=axis, keepdims=keepdims)
else:
model = ArgMinModel(axis=axis, keepdims=keepdims)

input_shape = [3] * dims # arbitrary magnitude in each dimension

# export the pytorch model to onnx
dummy_data = {"data": torch.randn(*input_shape)}
input_names = ["data"]
output_names = ["indices"]
model_path = tmp_path / "model.onnx"
torch.onnx.export(
model,
(dummy_data,),
str(model_path),
export_params=True,
input_names=input_names,
output_names=output_names,
do_constant_folding=False,
training=torch._C._onnx.TrainingMode.TRAINING,
)

# load the exported onnx file
model = onnx.load(model_path)
onnx.checker.check_model(model, False)

test_inputs = {"data": np.random.randn(*input_shape).astype(np.float32)}
check_onnx_model(model, test_inputs)
2 changes: 1 addition & 1 deletion tests/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def _check_onnx_model(

onnx_torch_check_function(ort_outputs, torch_outputs)

if torch_cpu_cuda_check_function is not None:
if torch_cpu_cuda_check_function is not None and torch.cuda.is_available():
torch_cuda_outputs = calc_torch_outputs(onnx_model, onnx_inputs, device='cuda')
torch_cpu_cuda_check_function(torch_outputs, torch_cuda_outputs)

Expand Down

0 comments on commit 536a299

Please sign in to comment.