From 4dd6097dcdb7808398a765703108bee7d68cbb58 Mon Sep 17 00:00:00 2001 From: Samuel Date: Thu, 16 Apr 2020 14:04:36 +0530 Subject: [PATCH] [TOPI][PYTORCH]Logical & Bitwise operator support (#5341) --- docs/api/python/topi.rst | 2 + docs/langref/relay_op.rst | 1 + python/tvm/relay/frontend/pytorch.py | 66 ++++++++++++- python/tvm/relay/op/_tensor.py | 2 + python/tvm/relay/op/tensor.py | 17 ++++ src/relay/op/tensor/binary.cc | 6 ++ tests/python/frontend/pytorch/test_forward.py | 95 ++++++++++++++++++- topi/include/topi/broadcast.h | 13 +++ topi/python/topi/broadcast.py | 19 ++++ topi/src/broadcast.cc | 1 + topi/tests/python/test_topi_broadcast.py | 2 + 11 files changed, 222 insertions(+), 2 deletions(-) diff --git a/docs/api/python/topi.rst b/docs/api/python/topi.rst index e6a2c38b515a2..cef2999bef525 100644 --- a/docs/api/python/topi.rst +++ b/docs/api/python/topi.rst @@ -99,6 +99,7 @@ List of operators topi.logical_and topi.logical_or topi.logical_not + topi.logical_xor topi.arange topi.stack topi.repeat @@ -193,6 +194,7 @@ topi .. autofunction:: topi.logical_and .. autofunction:: topi.logical_or .. autofunction:: topi.logical_not +.. autofunction:: topi.logical_xor topi.nn ~~~~~~~ diff --git a/docs/langref/relay_op.rst b/docs/langref/relay_op.rst index f1d7d442a14ce..798d440f74258 100644 --- a/docs/langref/relay_op.rst +++ b/docs/langref/relay_op.rst @@ -150,6 +150,7 @@ This level enables additional math and transform operators. tvm.relay.logical_and tvm.relay.logical_or tvm.relay.logical_not + tvm.relay.logical_xor tvm.relay.maximum tvm.relay.minimum tvm.relay.power diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 0acebe488f882..2a95de243c1f1 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -1168,7 +1168,6 @@ def _impl(inputs, input_types): def _clamp(): def _impl(inputs, input_types): - print(inputs, input_types) data = inputs[0] amin = inputs[1] if inputs[1] else np.finfo(np.float32).min amax = inputs[2] if inputs[2] else np.finfo(np.float32).max @@ -1298,6 +1297,67 @@ def _impl(inputs, input_types): return _impl +def _bitwise_not(): + def _impl(inputs, input_types): + data = inputs[0] + # The input tensor must be of integral or Boolean types. + # For bool tensors, it computes the logical NOT + if input_types[0] == "bool": + out = _op.logical_not(_op.cast(data, "bool")) + else: + out = _op.bitwise_not(_op.cast(data, "int")) + + return out + return _impl + + +def _bitwise_xor(): + def _impl(inputs, input_types): + lhs = inputs[0] + + import torch + if isinstance(inputs[1], _expr.Var): + rhs = inputs[1] + elif isinstance(inputs[1], torch.Tensor): + rhs = _wrap_const(inputs[1].numpy()) + else: + msg = "Data type %s could not be parsed in bitwise_xor operator." % (type(inputs[1])) + raise AssertionError(msg) + + lhs = _op.cast(lhs, "bool") if input_types[0] == "bool" else _op.cast(lhs, "int") + rhs = _op.cast(rhs, "bool") if input_types[1] == "bool" else _op.cast(rhs, "int") + + return _op.bitwise_xor(lhs, rhs) + return _impl + + +def _logical_not(): + def _impl(inputs, input_types): + data = inputs[0] + + return _op.logical_not(_op.cast(data, "bool")) + return _impl + + +def _logical_xor(): + def _impl(inputs, input_types): + lhs = _op.cast(inputs[0], "bool") + + import torch + if isinstance(inputs[1], _expr.Var): + rhs = inputs[1] + elif isinstance(inputs[1], torch.Tensor): + rhs = _wrap_const(inputs[1].numpy()) + else: + msg = "Data type %s could not be parsed in logical_xor operator." % (type(inputs[1])) + raise AssertionError(msg) + + rhs = _op.cast(rhs, "bool") + + return _op.logical_xor(lhs, rhs) + return _impl + + def _isfinite(): def _impl(inputs, input_types): return _op.isfinite(inputs[0]) @@ -1524,6 +1584,10 @@ def _get_convert_map(prelude): "aten::ge" : _elemwise("greater_equal"), "aten::ne" : _elemwise("not_equal"), "aten::eq" : _elemwise("equal"), + "aten::logical_not" : _logical_not(), + "aten::logical_xor" : _logical_xor(), + "aten::bitwise_not" : _bitwise_not(), + "aten::bitwise_xor" : _bitwise_xor(), "aten::isfinite" : _isfinite(), "aten::isnan" : _isnan(), "aten::Bool" : _Bool(), diff --git a/python/tvm/relay/op/_tensor.py b/python/tvm/relay/op/_tensor.py index 79a623d34c4a8..6bddaa1337f63 100644 --- a/python/tvm/relay/op/_tensor.py +++ b/python/tvm/relay/op/_tensor.py @@ -53,6 +53,7 @@ register_broadcast_schedule("logical_not") register_broadcast_schedule("logical_and") register_broadcast_schedule("logical_or") +register_broadcast_schedule("logical_xor") register_broadcast_schedule("bitwise_not") register_broadcast_schedule("bitwise_and") register_broadcast_schedule("bitwise_or") @@ -205,6 +206,7 @@ def elemwise_shape_func(attrs, inputs, _): register_shape_func("floor_mod", False, broadcast_shape_func) register_shape_func("logical_and", False, broadcast_shape_func) register_shape_func("logical_or", False, broadcast_shape_func) +register_shape_func("logical_xor", False, broadcast_shape_func) register_shape_func("bitwise_not", False, broadcast_shape_func) register_shape_func("bitwise_and", False, broadcast_shape_func) register_shape_func("bitwise_or", False, broadcast_shape_func) diff --git a/python/tvm/relay/op/tensor.py b/python/tvm/relay/op/tensor.py index f6024075d9259..162f83b1f52ad 100644 --- a/python/tvm/relay/op/tensor.py +++ b/python/tvm/relay/op/tensor.py @@ -537,6 +537,23 @@ def logical_or(lhs, rhs): return _make.logical_or(lhs, rhs) +def logical_xor(lhs, rhs): + """logical XOR with numpy-style broadcasting. + + Parameters + ---------- + lhs : relay.Expr + The left hand side input data + rhs : relay.Expr + The right hand side input data + + Returns + ------- + result : relay.Expr + The computed result. + """ + return _make.logical_xor(lhs, rhs) + def bitwise_and(lhs, rhs): """bitwise AND with numpy-style broadcasting. diff --git a/src/relay/op/tensor/binary.cc b/src/relay/op/tensor/binary.cc index 58221ae66f6e3..0f47c9aa25534 100644 --- a/src/relay/op/tensor/binary.cc +++ b/src/relay/op/tensor/binary.cc @@ -123,6 +123,12 @@ RELAY_REGISTER_BINARY_OP("logical_or") .set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_or)); +RELAY_REGISTER_BINARY_OP("logical_xor") +.describe("Elementwise logical XOR with broadcasting") +.set_support_level(4) +.set_attr("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_xor)); + + RELAY_REGISTER_BINARY_OP("bitwise_and") .describe("Elementwise bitwise AND with broadcasting") .set_support_level(4) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index c562fce547f87..796b5a8535859 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -159,7 +159,7 @@ def verify_model(model_name, input_data=[], if isinstance(baseline_outputs, tuple): baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs) else: - baseline_outputs = (baseline_outputs.float().cpu().numpy(),) + baseline_outputs = (baseline_outputs.cpu().numpy(),) trace = torch.jit.trace(baseline_model, baseline_input).float().eval() @@ -1600,6 +1600,95 @@ def forward(self, *args): verify_model(Topk6().float().eval(), input_data=input_data) +def test_forward_logical_not(): + torch.set_grad_enabled(False) + + class LogicalNot1(Module): + def forward(self, *args): + return torch.logical_not(args[0]) + + input_data = torch.tensor([True, False]) + verify_model(LogicalNot1().float().eval(), input_data=input_data) + + input_data = torch.tensor([0, 1, -10], dtype=torch.int8) + verify_model(LogicalNot1().float().eval(), input_data=input_data) + + input_data = torch.tensor([0., 1.5, -10.], dtype=torch.double) + verify_model(LogicalNot1().float().eval(), input_data=input_data) + + input_data = torch.tensor([0., 1., -10.], dtype=torch.int32) + verify_model(LogicalNot1().float().eval(), input_data=input_data) + + +def test_forward_bitwise_not(): + torch.set_grad_enabled(False) + + class BitwiseNot1(Module): + def forward(self, *args): + return torch.bitwise_not(args[0]) + + input_data = torch.tensor([0, 1, -10], dtype=torch.int8) + verify_model(BitwiseNot1().float().eval(), input_data=input_data) + + input_data = torch.tensor([0., 1., -10.], dtype=torch.int32) + verify_model(BitwiseNot1().float().eval(), input_data=input_data) + + input_data = torch.tensor([True, False]) + verify_model(BitwiseNot1().float().eval(), input_data=input_data) + + +def test_forward_bitwise_xor(): + torch.set_grad_enabled(False) + + class BitwiseXor1(Module): + def forward(self, *args): + return torch.bitwise_xor(args[0], args[1]) + + class BitwiseXor2(Module): + def forward(self, *args): + rhs = torch.tensor([1, 0, 3], dtype=torch.int8) + if torch.cuda.is_available(): + rhs = rhs.cuda() + return torch.bitwise_xor(args[0], rhs) + + lhs = torch.tensor([-1, -2, 3], dtype=torch.int8) + rhs = torch.tensor([1, 0, 3], dtype=torch.int8) + verify_model(BitwiseXor1().float().eval(), input_data=[lhs, rhs]) + + lhs = torch.tensor([True, True, False]) + rhs = torch.tensor([False, True, False]) + verify_model(BitwiseXor1().float().eval(), input_data=[lhs, rhs]) + + lhs = torch.tensor([-1, -2, 3], dtype=torch.int8) + verify_model(BitwiseXor2().float().eval(), input_data=[lhs]) + + +def test_forward_logical_xor(): + torch.set_grad_enabled(False) + + class LogicalXor1(Module): + def forward(self, *args): + return torch.logical_xor(args[0], args[1]) + + class LogicalXor2(Module): + def forward(self, *args): + rhs = torch.tensor([1, 0, 3], dtype=torch.int8) + if torch.cuda.is_available(): + rhs = rhs.cuda() + return torch.logical_xor(args[0], rhs) + + lhs = torch.tensor([-1, -2, 3], dtype=torch.int8) + rhs = torch.tensor([1, 0, 3], dtype=torch.int8) + verify_model(LogicalXor1().float().eval(), input_data=[lhs, rhs]) + + lhs = torch.tensor([True, True, False]) + rhs = torch.tensor([False, True, False]) + verify_model(LogicalXor1().float().eval(), input_data=[lhs, rhs]) + + lhs = torch.tensor([-1, -2, 3], dtype=torch.int8) + verify_model(LogicalXor2().float().eval(), input_data=[lhs]) + + if __name__ == "__main__": # Single operator tests test_forward_add() @@ -1663,6 +1752,10 @@ def forward(self, *args): test_forward_clamp() test_forward_floor() test_forward_round() + test_forward_logical_not() + test_forward_bitwise_not() + test_forward_bitwise_xor() + test_forward_logical_xor() test_forward_isfinite() test_forward_isnan() test_forward_isinf() diff --git a/topi/include/topi/broadcast.h b/topi/include/topi/broadcast.h index c9b12d3862393..98614c3d49031 100644 --- a/topi/include/topi/broadcast.h +++ b/topi/include/topi/broadcast.h @@ -140,6 +140,19 @@ TOPI_DEFINE_OP_OVERLOAD(operator&&, logical_and); TOPI_DEFINE_BCAST_OP(logical_or, { return a || b; }); TOPI_DEFINE_OP_OVERLOAD(operator||, logical_or); +/*! + * \fn logical_xor + * \brief Compute A ^ B with auto-broadcasting. + * + * \param A The first tensor, or Expr + * \param B The second tensor, or Expr + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return The result. + */ +TOPI_DEFINE_BCAST_OP(logical_xor, { return a ^ b; }); + /*! * \fn bitwise_and * \brief Compute A & B with auto-broadcasting. diff --git a/topi/python/topi/broadcast.py b/topi/python/topi/broadcast.py index 39b2841da8540..cc36637993c36 100644 --- a/topi/python/topi/broadcast.py +++ b/topi/python/topi/broadcast.py @@ -420,6 +420,25 @@ def logical_or(lhs, rhs): return _cpp.logical_or(lhs, rhs) +def logical_xor(lhs, rhs): + """Compute element-wise logical xor of data. + + Parameters + ---------- + lhs : tvm.te.Tensor or Expr + The left operand + rhs : tvm.te.Tensor or Expr + The right operand + + Returns + ------- + ret : tvm.te.Tensor or Expr + Returns Expr if both operands are Expr. + Otherwise returns Tensor. + """ + return _cpp.logical_xor(lhs, rhs) + + def bitwise_and(lhs, rhs): """Compute element-wise bitwise and of data. diff --git a/topi/src/broadcast.cc b/topi/src/broadcast.cc index 0f0241f0f975d..b14754573c648 100644 --- a/topi/src/broadcast.cc +++ b/topi/src/broadcast.cc @@ -65,6 +65,7 @@ TOPI_REGISTER_BCAST_OP("topi.power", topi::power); TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift); TOPI_REGISTER_BCAST_OP("topi.logical_and", topi::logical_and); TOPI_REGISTER_BCAST_OP("topi.logical_or", topi::logical_or); +TOPI_REGISTER_BCAST_OP("topi.logical_xor", topi::logical_xor); TOPI_REGISTER_BCAST_OP("topi.bitwise_and", topi::bitwise_and); TOPI_REGISTER_BCAST_OP("topi.bitwise_or", topi::bitwise_or); TOPI_REGISTER_BCAST_OP("topi.bitwise_xor", topi::bitwise_xor); diff --git a/topi/tests/python/test_topi_broadcast.py b/topi/tests/python/test_topi_broadcast.py index 2fe00c7d4ec9e..27b66e04e3947 100644 --- a/topi/tests/python/test_topi_broadcast.py +++ b/topi/tests/python/test_topi_broadcast.py @@ -355,6 +355,8 @@ def check_device(device): test_apply(topi.logical_and, "logical_and", np.logical_and, [True, False], [False, False]) test_apply(topi.logical_or, "logical_or", np.logical_or, True, False) test_apply(topi.logical_or, "logical_or", np.logical_or, [True, False], [False, False]) + test_apply(topi.logical_xor, "logical_xor", np.logical_xor, True, False) + test_apply(topi.logical_xor, "logical_xor", np.logical_xor, [True, False], [False, False]) def test_bitwise_and():