Skip to content

Commit

Permalink
chore: linters
Browse files Browse the repository at this point in the history
  • Loading branch information
senysenyseny16 committed Jul 30, 2024
1 parent d4671a0 commit 35beb2a
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 58 deletions.
52 changes: 26 additions & 26 deletions onnx2torch/node_converters/arg_extrema.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
# pylint: disable=missing-docstring
__all__ = [
"OnnxArgExtremumOld",
"OnnxArgExtremum",
'OnnxArgExtremumOld',
'OnnxArgExtremum',
]

from typing import Optional

import torch
import torch.nn.functional as F
from torch import nn

from onnx2torch.node_converters.registry import add_converter
Expand All @@ -21,31 +19,31 @@
DEFAULT_SELECT_LAST_INDEX = 0

_TORCH_FUNCTION_FROM_ONNX_TYPE = {
"ArgMax": torch.argmax,
"ArgMin": torch.argmin,
'ArgMax': torch.argmax,
'ArgMin': torch.argmin,
}


class OnnxArgExtremumOld(nn.Module, OnnxToTorchModule): # pylint: disable=missing-docstring
class OnnxArgExtremumOld(nn.Module, OnnxToTorchModule):
def __init__(self, operation_type: str, axis: int, keepdims: int):
super().__init__()
self.axis = axis
self.keepdims = bool(keepdims)
self.extremum_function = _TORCH_FUNCTION_FROM_ONNX_TYPE[operation_type]

def forward(self, data: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring
def forward(self, data: torch.Tensor) -> torch.Tensor:
return self.extremum_function(data, dim=self.axis, keepdim=self.keepdims)


class OnnxArgExtremum(nn.Module, OnnxToTorchModule): # pylint: disable=missing-class-docstring
class OnnxArgExtremum(nn.Module, OnnxToTorchModule):
def __init__(self, operation_type: str, axis: int, keepdims: int, select_last_index: int):
super().__init__()
self.axis = axis
self.keepdims = bool(keepdims)
self.select_last_index = bool(select_last_index)
self.extremum_function = _TORCH_FUNCTION_FROM_ONNX_TYPE[operation_type]

def forward(self, data: torch.Tensor) -> torch.Tensor: # pylint: disable=missing-function-docstring
def forward(self, data: torch.Tensor) -> torch.Tensor:
if self.select_last_index:
# torch's argmax does not handle the select_last_index attribute from Onnx.
# We flip the data, call the normal argmax, then map it back to the original
Expand All @@ -54,34 +52,36 @@ def forward(self, data: torch.Tensor) -> torch.Tensor: # pylint: disable=missin
extremum_index_flipped = self.extremum_function(flipped, dim=self.axis, keepdim=self.keepdims)
extremum_index_original = data.size(dim=self.axis) - 1 - extremum_index_flipped
return extremum_index_original
else:
return self.extremum_function(data, dim=self.axis, keepdim=self.keepdims)

return self.extremum_function(data, dim=self.axis, keepdim=self.keepdims)


@add_converter(operation_type="ArgMax", version=12)
@add_converter(operation_type="ArgMax", version=13)
@add_converter(operation_type="ArgMin", version=12)
@add_converter(operation_type="ArgMin", version=13)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
@add_converter(operation_type='ArgMax', version=12)
@add_converter(operation_type='ArgMax', version=13)
@add_converter(operation_type='ArgMin', version=12)
@add_converter(operation_type='ArgMin', version=13)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
del graph
return OperationConverterResult(
torch_module=OnnxArgExtremum(
operation_type=node.operation_type,
axis=node.attributes.get("axis", DEFAULT_AXIS),
keepdims=node.attributes.get("keepdims", DEFAULT_KEEPDIMS),
select_last_index=node.attributes.get("select_last_index", DEFAULT_SELECT_LAST_INDEX),
axis=node.attributes.get('axis', DEFAULT_AXIS),
keepdims=node.attributes.get('keepdims', DEFAULT_KEEPDIMS),
select_last_index=node.attributes.get('select_last_index', DEFAULT_SELECT_LAST_INDEX),
),
onnx_mapping=onnx_mapping_from_node(node=node),
)


@add_converter(operation_type="ArgMax", version=11)
@add_converter(operation_type="ArgMin", version=11)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult: # pylint: disable=unused-argument
@add_converter(operation_type='ArgMax', version=11)
@add_converter(operation_type='ArgMin', version=11)
def _(node: OnnxNode, graph: OnnxGraph) -> OperationConverterResult:
del graph
return OperationConverterResult(
torch_module=OnnxArgExtremumOld(
operation_type=node.operation_type,
axis=node.attributes.get("axis", DEFAULT_AXIS),
keepdims=node.attributes.get("keepdims", DEFAULT_KEEPDIMS),
axis=node.attributes.get('axis', DEFAULT_AXIS),
keepdims=node.attributes.get('keepdims', DEFAULT_KEEPDIMS),
),
onnx_mapping=onnx_mapping_from_node(node=node),
)
2 changes: 1 addition & 1 deletion onnx2torch/utils/custom_export_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def export(cls, forward_function: Callable, *args) -> Any:
return cls.apply(*args)

@staticmethod
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument
def forward(ctx: Any, *args: Any, **kwargs: Any) -> Any: # pylint: disable=unused-argument, arguments-differ
"""Applies custom forward function."""
if CustomExportToOnnx._NEXT_FORWARD_FUNCTION is None:
raise RuntimeError('Forward function is not set')
Expand Down
40 changes: 11 additions & 29 deletions tests/node_converters/arg_extrema_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# pylint: disable=missing-docstring
from pathlib import Path

import numpy as np
import onnx
from onnx.helper import make_tensor_value_info
import pytest
import torch
from onnx.helper import make_tensor_value_info

from tests.utils.common import check_onnx_model
from tests.utils.common import make_model_from_nodes
Expand Down Expand Up @@ -51,7 +52,7 @@
"select_last_index",
(0, 1),
)
def test_arg_max_arg_min( # pylint: disable=missing-function-docstring
def test_arg_max_arg_min(
op_type: str,
opset_version: int,
dims: int,
Expand Down Expand Up @@ -95,7 +96,7 @@ class ArgMaxModel(torch.nn.Module):
def __init__(self, axis: int, keepdims: bool):
super().__init__()
self.axis = axis
self.keepdims = bool(keepdims)
self.keepdims = keepdims

def forward(self, data: torch.Tensor) -> torch.Tensor:
return torch.argmax(data, dim=self.axis, keepdim=self.keepdims)
Expand All @@ -105,29 +106,16 @@ class ArgMinModel(torch.nn.Module):
def __init__(self, axis: int, keepdims: bool):
super().__init__()
self.axis = axis
self.keepdims = bool(keepdims)
self.keepdims = keepdims

def forward(self, data: torch.Tensor) -> torch.Tensor:
return torch.argmin(data, dim=self.axis, keepdim=self.keepdims)


@pytest.mark.parametrize("op_type", ["ArgMax", "ArgMin"])
@pytest.mark.parametrize("opset_version", [11, 12, 13])
@pytest.mark.parametrize(
"op_type",
(
"ArgMax",
"ArgMin",
),
)
@pytest.mark.parametrize(
"opset_version",
(
11,
12,
13,
),
)
@pytest.mark.parametrize(
"dims,axis",
"dims, axis",
(
(1, 0),
(2, 0),
Expand All @@ -141,19 +129,13 @@ def forward(self, data: torch.Tensor) -> torch.Tensor:
(4, 3),
),
)
@pytest.mark.parametrize(
"keepdims",
(
0,
1,
),
)
@pytest.mark.parametrize("keepdims", [True, False])
def test_start_from_torch_module(
op_type: str,
opset_version: int,
dims: int,
axis: int,
keepdims: int,
keepdims: bool,
tmp_path: Path,
) -> None:
"""
Expand All @@ -179,7 +161,7 @@ def test_start_from_torch_module(
input_names=input_names,
output_names=output_names,
do_constant_folding=False,
training=torch._C._onnx.TrainingMode.TRAINING,
opset_version=opset_version,
)

# load the exported onnx file
Expand Down
5 changes: 3 additions & 2 deletions tests/node_converters/conv_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from itertools import chain
from itertools import product
from typing import Literal
from typing import Tuple

import numpy as np
Expand All @@ -10,7 +11,7 @@


def _test_conv(
op_type: str,
op_type: Literal['Conv', 'ConvTranspose'],
in_channels: int,
out_channels: int,
kernel_shape: Tuple[int, int],
Expand All @@ -23,7 +24,7 @@ def _test_conv(
x = np.random.uniform(low=-1.0, high=1.0, size=x_shape).astype(np.float32)
if op_type == 'Conv':
weights_shape = (out_channels, in_channels // group) + kernel_shape
elif op_type == 'ConvTranspose':
else: # ConvTranspose
weights_shape = (in_channels, out_channels // group) + kernel_shape
weights = np.random.uniform(low=-1.0, high=1.0, size=weights_shape).astype(np.float32)

Expand Down

0 comments on commit 35beb2a

Please sign in to comment.