diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 1a29c2fc5c80..d19118259d56 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -132,12 +132,16 @@ def _impl(inputs, input_types): return get_relay_op(name)(data0, data1) return _impl -def _abs(): + +def _unary(name): def _impl(inputs, input_types): - data = inputs[0] - return _op.abs(data) + input_type = input_types[0] + data = _convert_elemwise_input(inputs[0], input_type) + + return get_relay_op(name)(data) return _impl + def _arange(): def _impl(inputs, input_types): if len(inputs) == 5: @@ -1260,26 +1264,6 @@ def _impl(inputs, input_types): return _op.nn.pad(data, pad_width, pad_value) return _impl -def _sqrt(): - def _impl(inputs, input_types): - data = inputs[0] - return _op.tensor.sqrt(data) - return _impl - - -def _rsqrt(): - def _impl(inputs, input_types): - data = inputs[0] - return _op.tensor.rsqrt(data) - return _impl - - -def _ceil(): - def _impl(inputs, input_types): - data = inputs[0] - return _op.ceil(data) - return _impl - def _clamp(): def _impl(inputs, input_types): @@ -1290,20 +1274,6 @@ def _impl(inputs, input_types): return _impl -def _floor(): - def _impl(inputs, input_types): - data = inputs[0] - return _op.floor(data) - return _impl - - -def _round(): - def _impl(inputs, input_types): - data = inputs[0] - return _op.round(data) - return _impl - - def _to(): def _impl(inputs, input_types): data = inputs[0] @@ -1381,17 +1351,6 @@ def _impl(inputs, input_types): return inputs[0] return _impl -def _neg(): - def _impl(inputs, input_types): - data = inputs[0] - return _op.tensor.negative(data) - return _impl - -def _tanh(): - def _impl(inputs, input_types): - data = inputs[0] - return _op.tensor.tanh(data) - return _impl def _Bool(): def _impl(inputs, input_types): @@ -1473,18 +1432,6 @@ def _impl(inputs, input_types): return _impl -def _isfinite(): - def _impl(inputs, input_types): - return _op.isfinite(inputs[0]) - return _impl - - -def _isnan(): - def _impl(inputs, input_types): - return _op.isnan(inputs[0]) - return _impl - - def _list_getitem(prelude): def _impl(inputs, input_types): return prelude.nth(inputs[0], _wrap_const(inputs[1])) @@ -1607,7 +1554,6 @@ def _get_convert_map(prelude): "aten::mul" : _elemwise("multiply"), "aten::mul_" : _elemwise("multiply"), "aten::pow" : _elemwise("power"), - "aten::abs" : _abs(), "aten::arange" : _arange(), "aten::div" : _elemwise("divide"), "aten::div_" : _elemwise("divide"), @@ -1689,12 +1635,26 @@ def _get_convert_map(prelude): "aten::argmax" : _reduce("argmax"), "aten::std" : _std(), "aten::var" : _variance(), - "aten::sqrt" : _sqrt(), - "aten::rsqrt" : _rsqrt(), - "aten::ceil" : _ceil(), + "aten::abs" : _unary("abs"), + "aten::neg" : _unary("negative"), + "aten::cos" : _unary("cos"), + "aten::sin" : _unary("sin"), + "aten::tan" : _unary("tan"), + "aten::tanh" : _unary("tanh"), + "aten::atan" : _unary("atan"), + "aten::log" : _unary("log"), + "aten::exp" : _unary("exp"), + "aten::erf" : _unary("erf"), + "aten::trunc" : _unary("trunc"), + "aten::sign" : _unary("sign"), + "aten::sqrt" : _unary("sqrt"), + "aten::rsqrt" : _unary("rsqrt"), + "aten::ceil" : _unary("ceil"), + "aten::floor" : _unary("floor"), + "aten::round" : _unary("round"), + "aten::isfinite" : _unary("isfinite"), + "aten::isnan" : _unary("isnan"), "aten::clamp" : _clamp(), - "aten::floor" : _floor(), - "aten::round" : _round(), "aten::detach" : _identity(), "aten::upsample_bilinear2d" : _upsample("bilinear"), "aten::upsample_nearest2d" : _upsample("nearest_neighbor"), @@ -1709,12 +1669,8 @@ def _get_convert_map(prelude): "aten::logical_xor" : _logical_xor(), "aten::bitwise_not" : _bitwise_not(), "aten::bitwise_xor" : _bitwise_xor(), - "aten::isfinite" : _isfinite(), - "aten::isnan" : _isnan(), "aten::Bool" : _Bool(), "aten::Float" : _Float(), - "aten::neg" : _neg(), - "aten::tanh" : _tanh(), "aten::adaptive_avg_pool3d" : _adaptive_avg_pool_3d(), "aten::adaptive_max_pool3d" : _adaptive_max_pool_3d(), "aten::mm" : _matmul(), diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 1ca6fd24eebc..0ba473add06e 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1508,30 +1508,6 @@ def forward(self, *args): verify_model(IsInf1().float().eval(), input_data=input_data) -def test_forward_rsqrt(): - torch.set_grad_enabled(False) - input_shape = [1, 3, 10, 10] - - class Rsqrt1(Module): - def forward(self, *args): - return torch.rsqrt(args[0]) - - input_data = torch.rand(input_shape).float() - verify_model(Rsqrt1().float().eval(), input_data=input_data) - - -def test_forward_ceil(): - torch.set_grad_enabled(False) - input_shape = [1, 3, 10, 10] - - class Ceil1(Module): - def forward(self, *args): - return torch.ceil(args[0]) - - input_data = torch.rand(input_shape).float() - verify_model(Ceil1().float().eval(), input_data=input_data) - - def test_forward_clamp(): torch.set_grad_enabled(False) input_shape = [1, 3, 10, 10] @@ -1554,30 +1530,6 @@ def forward(self, *args): verify_model(Clamp3().float().eval(), input_data=input_data) -def test_forward_floor(): - torch.set_grad_enabled(False) - input_shape = [1, 3, 10, 10] - - class Floor1(Module): - def forward(self, *args): - return torch.floor(args[0]) - - input_data = torch.rand(input_shape).float() - verify_model(Floor1().float().eval(), input_data=input_data) - - -def test_forward_round(): - torch.set_grad_enabled(False) - input_shape = [1, 3, 10, 10] - - class Round1(Module): - def forward(self, *args): - return torch.round(args[0]) - - input_data = torch.rand(input_shape).float() - verify_model(Round1().float().eval(), input_data=input_data) - - def test_forward_ones(): torch.set_grad_enabled(False) @@ -1860,6 +1812,93 @@ def forward(self, *args): verify_model(LogicalXor2().float().eval(), input_data=[lhs]) +def test_forward_unary(): + torch.set_grad_enabled(False) + + class Sqrt1(Module): + def forward(self, *args): + return torch.sqrt(args[0]) + + class RSqrt1(Module): + def forward(self, *args): + return torch.rsqrt(args[0]) + + class Ceil1(Module): + def forward(self, *args): + return torch.ceil(args[0]) + + class Floor1(Module): + def forward(self, *args): + return torch.floor(args[0]) + + class Round1(Module): + def forward(self, *args): + return torch.round(args[0]) + + class Cos1(Module): + def forward(self, *args): + return torch.cos(args[0]) + + class Sin1(Module): + def forward(self, *args): + return torch.sin(args[0]) + + class Tan1(Module): + def forward(self, *args): + return torch.tan(args[0]) + + class Tanh1(Module): + def forward(self, *args): + return torch.tanh(args[0]) + + class ATanh1(Module): + def forward(self, *args): + return torch.atan(args[0]) + + class Log1(Module): + def forward(self, *args): + return torch.log(args[0]) + + class Exp1(Module): + def forward(self, *args): + return torch.exp(args[0]) + + class Erf1(Module): + def forward(self, *args): + return torch.erf(args[0]) + + class Trunc1(Module): + def forward(self, *args): + return torch.trunc(args[0]) + + class Sign1(Module): + def forward(self, *args): + return torch.sign(args[0]) + + class Neg1(Module): + def forward(self, *args): + return torch.neg(args[0]) + + input_shape = [1, 3, 10, 10] + input_data = torch.rand(input_shape).float() + verify_model(Sqrt1().float().eval(), input_data=input_data) + verify_model(RSqrt1().float().eval(), input_data=input_data) + verify_model(Ceil1().float().eval(), input_data=input_data) + verify_model(Floor1().float().eval(), input_data=input_data) + verify_model(Round1().float().eval(), input_data=input_data) + verify_model(Cos1().float().eval(), input_data=input_data) + verify_model(Sin1().float().eval(), input_data=input_data) + verify_model(Tan1().float().eval(), input_data=input_data) + verify_model(Tanh1().float().eval(), input_data=input_data) + verify_model(ATanh1().float().eval(), input_data=input_data) + verify_model(Log1().float().eval(), input_data=input_data) + verify_model(Exp1().float().eval(), input_data=input_data) + verify_model(Erf1().float().eval(), input_data=input_data) + verify_model(Trunc1().float().eval(), input_data=input_data) + verify_model(Sign1().float().eval(), input_data=input_data) + verify_model(Neg1().float().eval(), input_data=input_data) + + if __name__ == "__main__": # Single operator tests test_forward_add() @@ -1918,12 +1957,8 @@ def forward(self, *args): test_forward_mean() test_forward_expand() test_forward_pow() - test_forward_abs() - test_forward_rsqrt() - test_forward_ceil() + test_forward_unary() test_forward_clamp() - test_forward_floor() - test_forward_round() test_forward_logical_not() test_forward_bitwise_not() test_forward_bitwise_xor()