From 0f9644b3f8d6050682ceee5461bf60850e85b3d7 Mon Sep 17 00:00:00 2001 From: Samuel Date: Wed, 15 Apr 2020 15:48:03 +0530 Subject: [PATCH] [PYTORCH]Take, Topk op support (#5332) * [PYTORCH]take, topk op support * Ci Failure fix --- python/tvm/relay/frontend/pytorch.py | 35 ++++++++++++ tests/python/frontend/pytorch/test_forward.py | 57 +++++++++++++++++++ 2 files changed, 92 insertions(+) diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 38a811d1d558..0acebe488f88 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -272,6 +272,39 @@ def _impl(inputs, input_types): return _op.transform.take(data, index, axis=dim) return _impl +def _take(): + def _impl(inputs, input_types): + data = inputs[0] + import torch + + if isinstance(inputs[1], _expr.Var): + indices = _op.cast(inputs[1], "int32") + elif isinstance(inputs[1], torch.Tensor): + indices = _wrap_const(inputs[1].numpy()) + else: + msg = "Data type %s could not be parsed in take operator." % (type(inputs[1])) + raise AssertionError(msg) + + return _op.transform.take(data, indices=indices) + return _impl + +def _topk(): + def _impl(inputs, input_types): + data = inputs[0] + k = int(inputs[1]) + axis = int(inputs[2]) + is_ascend = not bool(inputs[3]) + sort = bool(inputs[4]) + + if not sort: + msg = "Currently supports only sorted output for topk operator." + raise AssertionError(msg) + + outs = _op.topk(data, k=k, axis=axis, is_ascend=is_ascend, ret_type="both") + + return outs[0], outs[1] + return _impl + def _reciprocal(): def _impl(inputs, input_types): data = inputs[0] @@ -1416,6 +1449,8 @@ def _get_convert_map(prelude): "aten::split" : _split(), "aten::split_with_sizes" : _split_with_sizes(), "aten::select" : _select(), + "aten::take" : _take(), + "aten::topk" : _topk(), "aten::relu" : _relu(), "aten::relu_" : _relu(), "aten::prelu" : _prelu(), diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index d9d280f25a70..c562fce547f8 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1545,6 +1545,61 @@ def forward(self, *args): verify_model(Round1().float().eval(), input_data=input_data) +def test_forward_take(): + torch.set_grad_enabled(False) + class Take1(Module): + def forward(self, *args): + indices = torch.tensor([[0,0],[1,0]]) + if torch.cuda.is_available(): + indices = indices.cuda() + return torch.take(args[0], indices) + + class Take2(Module): + def forward(self, *args): + return torch.take(args[0], args[1]) + + input_data = torch.tensor([[1,2],[3,4]]) + verify_model(Take1().float().eval(), input_data=input_data) + indices = torch.tensor([[0,0],[1,0]]) + verify_model(Take2().float().eval(), input_data=[input_data, indices]) + + +def test_forward_topk(): + torch.set_grad_enabled(False) + class Topk1(Module): + def forward(self, *args): + return torch.topk(args[0], k=3) + + class Topk2(Module): + def forward(self, *args): + return torch.topk(args[0], k=3, dim=-2) + + class Topk3(Module): + def forward(self, *args): + return torch.topk(args[0], k=3, dim=3) + + class Topk4(Module): + def forward(self, *args): + return torch.topk(args[0], k=3, largest=True) + + class Topk5(Module): + def forward(self, *args): + return torch.topk(args[0], k=3, largest=False) + + class Topk6(Module): + def forward(self, *args): + return torch.topk(args[0], k=3, sorted=True) + + input_shape = [1, 3, 10, 10] + input_data = torch.rand(input_shape).float() + verify_model(Topk1().float().eval(), input_data=input_data) + verify_model(Topk2().float().eval(), input_data=input_data) + verify_model(Topk3().float().eval(), input_data=input_data) + verify_model(Topk4().float().eval(), input_data=input_data) + verify_model(Topk5().float().eval(), input_data=input_data) + verify_model(Topk6().float().eval(), input_data=input_data) + + if __name__ == "__main__": # Single operator tests test_forward_add() @@ -1587,6 +1642,8 @@ def forward(self, *args): test_forward_size() test_forward_view() test_forward_select() + test_forward_take() + test_forward_topk() test_forward_clone() test_forward_softplus() test_forward_softsign()