Skip to content

Commit

Permalink
[TOPI][PYTORCH]Logical & Bitwise operator support (apache#5341)
Browse files Browse the repository at this point in the history
  • Loading branch information
siju-samuel authored and dpankratz committed Apr 24, 2020
1 parent a77b429 commit 4dd6097
Show file tree
Hide file tree
Showing 11 changed files with 222 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/api/python/topi.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ List of operators
topi.logical_and
topi.logical_or
topi.logical_not
topi.logical_xor
topi.arange
topi.stack
topi.repeat
Expand Down Expand Up @@ -193,6 +194,7 @@ topi
.. autofunction:: topi.logical_and
.. autofunction:: topi.logical_or
.. autofunction:: topi.logical_not
.. autofunction:: topi.logical_xor

topi.nn
~~~~~~~
Expand Down
1 change: 1 addition & 0 deletions docs/langref/relay_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ This level enables additional math and transform operators.
tvm.relay.logical_and
tvm.relay.logical_or
tvm.relay.logical_not
tvm.relay.logical_xor
tvm.relay.maximum
tvm.relay.minimum
tvm.relay.power
Expand Down
66 changes: 65 additions & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,7 +1168,6 @@ def _impl(inputs, input_types):

def _clamp():
def _impl(inputs, input_types):
print(inputs, input_types)
data = inputs[0]
amin = inputs[1] if inputs[1] else np.finfo(np.float32).min
amax = inputs[2] if inputs[2] else np.finfo(np.float32).max
Expand Down Expand Up @@ -1298,6 +1297,67 @@ def _impl(inputs, input_types):
return _impl


def _bitwise_not():
def _impl(inputs, input_types):
data = inputs[0]
# The input tensor must be of integral or Boolean types.
# For bool tensors, it computes the logical NOT
if input_types[0] == "bool":
out = _op.logical_not(_op.cast(data, "bool"))
else:
out = _op.bitwise_not(_op.cast(data, "int"))

return out
return _impl


def _bitwise_xor():
def _impl(inputs, input_types):
lhs = inputs[0]

import torch
if isinstance(inputs[1], _expr.Var):
rhs = inputs[1]
elif isinstance(inputs[1], torch.Tensor):
rhs = _wrap_const(inputs[1].numpy())
else:
msg = "Data type %s could not be parsed in bitwise_xor operator." % (type(inputs[1]))
raise AssertionError(msg)

lhs = _op.cast(lhs, "bool") if input_types[0] == "bool" else _op.cast(lhs, "int")
rhs = _op.cast(rhs, "bool") if input_types[1] == "bool" else _op.cast(rhs, "int")

return _op.bitwise_xor(lhs, rhs)
return _impl


def _logical_not():
def _impl(inputs, input_types):
data = inputs[0]

return _op.logical_not(_op.cast(data, "bool"))
return _impl


def _logical_xor():
def _impl(inputs, input_types):
lhs = _op.cast(inputs[0], "bool")

import torch
if isinstance(inputs[1], _expr.Var):
rhs = inputs[1]
elif isinstance(inputs[1], torch.Tensor):
rhs = _wrap_const(inputs[1].numpy())
else:
msg = "Data type %s could not be parsed in logical_xor operator." % (type(inputs[1]))
raise AssertionError(msg)

rhs = _op.cast(rhs, "bool")

return _op.logical_xor(lhs, rhs)
return _impl


def _isfinite():
def _impl(inputs, input_types):
return _op.isfinite(inputs[0])
Expand Down Expand Up @@ -1524,6 +1584,10 @@ def _get_convert_map(prelude):
"aten::ge" : _elemwise("greater_equal"),
"aten::ne" : _elemwise("not_equal"),
"aten::eq" : _elemwise("equal"),
"aten::logical_not" : _logical_not(),
"aten::logical_xor" : _logical_xor(),
"aten::bitwise_not" : _bitwise_not(),
"aten::bitwise_xor" : _bitwise_xor(),
"aten::isfinite" : _isfinite(),
"aten::isnan" : _isnan(),
"aten::Bool" : _Bool(),
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/op/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
register_broadcast_schedule("logical_not")
register_broadcast_schedule("logical_and")
register_broadcast_schedule("logical_or")
register_broadcast_schedule("logical_xor")
register_broadcast_schedule("bitwise_not")
register_broadcast_schedule("bitwise_and")
register_broadcast_schedule("bitwise_or")
Expand Down Expand Up @@ -205,6 +206,7 @@ def elemwise_shape_func(attrs, inputs, _):
register_shape_func("floor_mod", False, broadcast_shape_func)
register_shape_func("logical_and", False, broadcast_shape_func)
register_shape_func("logical_or", False, broadcast_shape_func)
register_shape_func("logical_xor", False, broadcast_shape_func)
register_shape_func("bitwise_not", False, broadcast_shape_func)
register_shape_func("bitwise_and", False, broadcast_shape_func)
register_shape_func("bitwise_or", False, broadcast_shape_func)
Expand Down
17 changes: 17 additions & 0 deletions python/tvm/relay/op/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,23 @@ def logical_or(lhs, rhs):
return _make.logical_or(lhs, rhs)


def logical_xor(lhs, rhs):
"""logical XOR with numpy-style broadcasting.
Parameters
----------
lhs : relay.Expr
The left hand side input data
rhs : relay.Expr
The right hand side input data
Returns
-------
result : relay.Expr
The computed result.
"""
return _make.logical_xor(lhs, rhs)

def bitwise_and(lhs, rhs):
"""bitwise AND with numpy-style broadcasting.
Expand Down
6 changes: 6 additions & 0 deletions src/relay/op/tensor/binary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,12 @@ RELAY_REGISTER_BINARY_OP("logical_or")
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_or));


RELAY_REGISTER_BINARY_OP("logical_xor")
.describe("Elementwise logical XOR with broadcasting")
.set_support_level(4)
.set_attr<FTVMCompute>("FTVMCompute", RELAY_BINARY_COMPUTE(topi::logical_xor));


RELAY_REGISTER_BINARY_OP("bitwise_and")
.describe("Elementwise bitwise AND with broadcasting")
.set_support_level(4)
Expand Down
95 changes: 94 additions & 1 deletion tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def verify_model(model_name, input_data=[],
if isinstance(baseline_outputs, tuple):
baseline_outputs = tuple(out.cpu().numpy() for out in baseline_outputs)
else:
baseline_outputs = (baseline_outputs.float().cpu().numpy(),)
baseline_outputs = (baseline_outputs.cpu().numpy(),)

trace = torch.jit.trace(baseline_model, baseline_input).float().eval()

Expand Down Expand Up @@ -1600,6 +1600,95 @@ def forward(self, *args):
verify_model(Topk6().float().eval(), input_data=input_data)


def test_forward_logical_not():
torch.set_grad_enabled(False)

class LogicalNot1(Module):
def forward(self, *args):
return torch.logical_not(args[0])

input_data = torch.tensor([True, False])
verify_model(LogicalNot1().float().eval(), input_data=input_data)

input_data = torch.tensor([0, 1, -10], dtype=torch.int8)
verify_model(LogicalNot1().float().eval(), input_data=input_data)

input_data = torch.tensor([0., 1.5, -10.], dtype=torch.double)
verify_model(LogicalNot1().float().eval(), input_data=input_data)

input_data = torch.tensor([0., 1., -10.], dtype=torch.int32)
verify_model(LogicalNot1().float().eval(), input_data=input_data)


def test_forward_bitwise_not():
torch.set_grad_enabled(False)

class BitwiseNot1(Module):
def forward(self, *args):
return torch.bitwise_not(args[0])

input_data = torch.tensor([0, 1, -10], dtype=torch.int8)
verify_model(BitwiseNot1().float().eval(), input_data=input_data)

input_data = torch.tensor([0., 1., -10.], dtype=torch.int32)
verify_model(BitwiseNot1().float().eval(), input_data=input_data)

input_data = torch.tensor([True, False])
verify_model(BitwiseNot1().float().eval(), input_data=input_data)


def test_forward_bitwise_xor():
torch.set_grad_enabled(False)

class BitwiseXor1(Module):
def forward(self, *args):
return torch.bitwise_xor(args[0], args[1])

class BitwiseXor2(Module):
def forward(self, *args):
rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
if torch.cuda.is_available():
rhs = rhs.cuda()
return torch.bitwise_xor(args[0], rhs)

lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
verify_model(BitwiseXor1().float().eval(), input_data=[lhs, rhs])

lhs = torch.tensor([True, True, False])
rhs = torch.tensor([False, True, False])
verify_model(BitwiseXor1().float().eval(), input_data=[lhs, rhs])

lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
verify_model(BitwiseXor2().float().eval(), input_data=[lhs])


def test_forward_logical_xor():
torch.set_grad_enabled(False)

class LogicalXor1(Module):
def forward(self, *args):
return torch.logical_xor(args[0], args[1])

class LogicalXor2(Module):
def forward(self, *args):
rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
if torch.cuda.is_available():
rhs = rhs.cuda()
return torch.logical_xor(args[0], rhs)

lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
rhs = torch.tensor([1, 0, 3], dtype=torch.int8)
verify_model(LogicalXor1().float().eval(), input_data=[lhs, rhs])

lhs = torch.tensor([True, True, False])
rhs = torch.tensor([False, True, False])
verify_model(LogicalXor1().float().eval(), input_data=[lhs, rhs])

lhs = torch.tensor([-1, -2, 3], dtype=torch.int8)
verify_model(LogicalXor2().float().eval(), input_data=[lhs])


if __name__ == "__main__":
# Single operator tests
test_forward_add()
Expand Down Expand Up @@ -1663,6 +1752,10 @@ def forward(self, *args):
test_forward_clamp()
test_forward_floor()
test_forward_round()
test_forward_logical_not()
test_forward_bitwise_not()
test_forward_bitwise_xor()
test_forward_logical_xor()
test_forward_isfinite()
test_forward_isnan()
test_forward_isinf()
Expand Down
13 changes: 13 additions & 0 deletions topi/include/topi/broadcast.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,19 @@ TOPI_DEFINE_OP_OVERLOAD(operator&&, logical_and);
TOPI_DEFINE_BCAST_OP(logical_or, { return a || b; });
TOPI_DEFINE_OP_OVERLOAD(operator||, logical_or);

/*!
* \fn logical_xor
* \brief Compute A ^ B with auto-broadcasting.
*
* \param A The first tensor, or Expr
* \param B The second tensor, or Expr
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return The result.
*/
TOPI_DEFINE_BCAST_OP(logical_xor, { return a ^ b; });

/*!
* \fn bitwise_and
* \brief Compute A & B with auto-broadcasting.
Expand Down
19 changes: 19 additions & 0 deletions topi/python/topi/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,25 @@ def logical_or(lhs, rhs):
return _cpp.logical_or(lhs, rhs)


def logical_xor(lhs, rhs):
"""Compute element-wise logical xor of data.
Parameters
----------
lhs : tvm.te.Tensor or Expr
The left operand
rhs : tvm.te.Tensor or Expr
The right operand
Returns
-------
ret : tvm.te.Tensor or Expr
Returns Expr if both operands are Expr.
Otherwise returns Tensor.
"""
return _cpp.logical_xor(lhs, rhs)


def bitwise_and(lhs, rhs):
"""Compute element-wise bitwise and of data.
Expand Down
1 change: 1 addition & 0 deletions topi/src/broadcast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ TOPI_REGISTER_BCAST_OP("topi.power", topi::power);
TOPI_REGISTER_BCAST_OP("topi.left_shift", topi::left_shift);
TOPI_REGISTER_BCAST_OP("topi.logical_and", topi::logical_and);
TOPI_REGISTER_BCAST_OP("topi.logical_or", topi::logical_or);
TOPI_REGISTER_BCAST_OP("topi.logical_xor", topi::logical_xor);
TOPI_REGISTER_BCAST_OP("topi.bitwise_and", topi::bitwise_and);
TOPI_REGISTER_BCAST_OP("topi.bitwise_or", topi::bitwise_or);
TOPI_REGISTER_BCAST_OP("topi.bitwise_xor", topi::bitwise_xor);
Expand Down
2 changes: 2 additions & 0 deletions topi/tests/python/test_topi_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,8 @@ def check_device(device):
test_apply(topi.logical_and, "logical_and", np.logical_and, [True, False], [False, False])
test_apply(topi.logical_or, "logical_or", np.logical_or, True, False)
test_apply(topi.logical_or, "logical_or", np.logical_or, [True, False], [False, False])
test_apply(topi.logical_xor, "logical_xor", np.logical_xor, True, False)
test_apply(topi.logical_xor, "logical_xor", np.logical_xor, [True, False], [False, False])


def test_bitwise_and():
Expand Down

0 comments on commit 4dd6097

Please sign in to comment.