Skip to content

Commit

Permalink
[PYTORCH]Take, Topk op support (apache#5332)
Browse files Browse the repository at this point in the history
* [PYTORCH]take, topk op support

* Ci Failure fix
  • Loading branch information
siju-samuel authored and Trevor Morris committed Apr 16, 2020
1 parent d1a0062 commit 0f9644b
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 0 deletions.
35 changes: 35 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,39 @@ def _impl(inputs, input_types):
return _op.transform.take(data, index, axis=dim)
return _impl

def _take():
def _impl(inputs, input_types):
data = inputs[0]
import torch

if isinstance(inputs[1], _expr.Var):
indices = _op.cast(inputs[1], "int32")
elif isinstance(inputs[1], torch.Tensor):
indices = _wrap_const(inputs[1].numpy())
else:
msg = "Data type %s could not be parsed in take operator." % (type(inputs[1]))
raise AssertionError(msg)

return _op.transform.take(data, indices=indices)
return _impl

def _topk():
def _impl(inputs, input_types):
data = inputs[0]
k = int(inputs[1])
axis = int(inputs[2])
is_ascend = not bool(inputs[3])
sort = bool(inputs[4])

if not sort:
msg = "Currently supports only sorted output for topk operator."
raise AssertionError(msg)

outs = _op.topk(data, k=k, axis=axis, is_ascend=is_ascend, ret_type="both")

return outs[0], outs[1]
return _impl

def _reciprocal():
def _impl(inputs, input_types):
data = inputs[0]
Expand Down Expand Up @@ -1416,6 +1449,8 @@ def _get_convert_map(prelude):
"aten::split" : _split(),
"aten::split_with_sizes" : _split_with_sizes(),
"aten::select" : _select(),
"aten::take" : _take(),
"aten::topk" : _topk(),
"aten::relu" : _relu(),
"aten::relu_" : _relu(),
"aten::prelu" : _prelu(),
Expand Down
57 changes: 57 additions & 0 deletions tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1545,6 +1545,61 @@ def forward(self, *args):
verify_model(Round1().float().eval(), input_data=input_data)


def test_forward_take():
torch.set_grad_enabled(False)
class Take1(Module):
def forward(self, *args):
indices = torch.tensor([[0,0],[1,0]])
if torch.cuda.is_available():
indices = indices.cuda()
return torch.take(args[0], indices)

class Take2(Module):
def forward(self, *args):
return torch.take(args[0], args[1])

input_data = torch.tensor([[1,2],[3,4]])
verify_model(Take1().float().eval(), input_data=input_data)
indices = torch.tensor([[0,0],[1,0]])
verify_model(Take2().float().eval(), input_data=[input_data, indices])


def test_forward_topk():
torch.set_grad_enabled(False)
class Topk1(Module):
def forward(self, *args):
return torch.topk(args[0], k=3)

class Topk2(Module):
def forward(self, *args):
return torch.topk(args[0], k=3, dim=-2)

class Topk3(Module):
def forward(self, *args):
return torch.topk(args[0], k=3, dim=3)

class Topk4(Module):
def forward(self, *args):
return torch.topk(args[0], k=3, largest=True)

class Topk5(Module):
def forward(self, *args):
return torch.topk(args[0], k=3, largest=False)

class Topk6(Module):
def forward(self, *args):
return torch.topk(args[0], k=3, sorted=True)

input_shape = [1, 3, 10, 10]
input_data = torch.rand(input_shape).float()
verify_model(Topk1().float().eval(), input_data=input_data)
verify_model(Topk2().float().eval(), input_data=input_data)
verify_model(Topk3().float().eval(), input_data=input_data)
verify_model(Topk4().float().eval(), input_data=input_data)
verify_model(Topk5().float().eval(), input_data=input_data)
verify_model(Topk6().float().eval(), input_data=input_data)


if __name__ == "__main__":
# Single operator tests
test_forward_add()
Expand Down Expand Up @@ -1587,6 +1642,8 @@ def forward(self, *args):
test_forward_size()
test_forward_view()
test_forward_select()
test_forward_take()
test_forward_topk()
test_forward_clone()
test_forward_softplus()
test_forward_softsign()
Expand Down

0 comments on commit 0f9644b

Please sign in to comment.