diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index af60bf20c847..7f088bacacc8 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -348,12 +348,25 @@ def _impl(inputs, input_types): msg = "Data type %s could not be parsed in ones op" % (type(data)) raise AssertionError(msg) - dtype_map = {6: "float32", 3: "int32"} - dtype_id = inputs[1] - assert dtype_id in dtype_map, "Unsupported dtype %d" % dtype_id - return _op.full(_expr.const(1), shape, dtype=dtype_map[dtype_id]) + dtype = _convert_data_type(_convert_dtype_value(inputs[1])) + + return _op.full(_expr.const(1), shape, dtype=dtype) + return _impl + +def _ones_like(): + def _impl(inputs, input_types): + data = inputs[0] + out = _op.ones_like(data) + + # If the input and the output datatype is different, do a cast + dtype = _convert_data_type(_convert_dtype_value(inputs[1])) + if input_types[0] not in dtype: + out = _op.cast(out, dtype) + + return out return _impl + def _zeros(): def _impl(inputs, input_types): data = inputs[0] @@ -369,12 +382,88 @@ def _impl(inputs, input_types): msg = "Data type %s could not be parsed in zeros op" % (type(data)) raise AssertionError(msg) - dtype_map = {6: "float32", 3: "int32"} - dtype_id = inputs[1] - assert dtype_id in dtype_map, "Unsupported dtype %d" % dtype_id - return _op.full(_expr.const(0), shape, dtype=dtype_map[dtype_id]) + dtype = _convert_data_type(_convert_dtype_value(inputs[1])) + + return _op.full(_expr.const(0), shape, dtype=dtype) return _impl + +def _zeros_like(): + def _impl(inputs, input_types): + data = inputs[0] + out = _op.zeros_like(data) + + # If the input and the output datatype is different, do a cast + dtype = _convert_data_type(_convert_dtype_value(inputs[1])) + if input_types[0] not in dtype: + out = _op.cast(out, dtype) + + return out + return _impl + + +def _full(): + def _impl(inputs, input_types): + data = inputs[0] + + fill_value = inputs[1] + import torch + if isinstance(data, _expr.Expr): + shape = _infer_shape(data) + elif isinstance(data, list): + shape = data + elif isinstance(data, (torch.Tensor, np.ndarray)): + shape = data.shape + else: + msg = "Data type %s could not be parsed in zeros op" % (type(data)) + raise AssertionError(msg) + + dtype = _convert_data_type(_convert_dtype_value(inputs[2])) + + return _op.full(_expr.const(fill_value), shape, dtype=dtype) + return _impl + +def _full_like(): + def _impl(inputs, input_types): + data = inputs[0] + fill_value = inputs[1] + + out = _op.full_like(data, _expr.const(fill_value)) + + # If the input and the output datatype is different, do a cast + dtype = _convert_data_type(_convert_dtype_value(inputs[2])) + if input_types[0] not in dtype: + out = _op.cast(out, dtype) + + return out + return _impl + + +def _linspace(): + def _impl(inputs, input_types): + start = inputs[0] + stop = inputs[1] + step = inputs[2] + + # Find the spacing between values as step + if step != 1: + step = (stop - start) / (step - 1) + stop = stop + step + else: + stop = start + step + + dtype = "float" if "float" in input_types[0:3] else _convert_dtype_value(inputs[3]) + start = _create_typed_const(start, dtype) + stop = _create_typed_const(stop, dtype) + step = _create_typed_const(step, dtype) + + return _op.transform.arange(start=start, + stop=stop, + step=step, + dtype=_convert_data_type(dtype)) + return _impl + + def _relu(): def _impl(inputs, input_types): data = inputs[0] @@ -1503,7 +1592,12 @@ def _get_convert_map(prelude): "aten::div" : _elemwise("divide"), "aten::div_" : _elemwise("divide"), "aten::ones" : _ones(), + "aten::ones_like" : _ones_like(), "aten::zeros" : _zeros(), + "aten::zeros_like" : _zeros_like(), + "aten::full" : _full(), + "aten::full_like" : _full_like(), + "aten::linspace" : _linspace(), "aten::reciprocal" : _reciprocal(), "aten::repeat" : _repeat(), "aten::repeat_interleave" : _repeat_interleave(), diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 4eba4d002fe9..2056d3255b25 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1556,6 +1556,144 @@ def forward(self, *args): verify_model(Round1().float().eval(), input_data=input_data) +def test_forward_ones(): + torch.set_grad_enabled(False) + + class Ones1(Module): + def forward(self, *args): + return torch.ones(2,3) + + verify_model(Ones1().float().eval(), input_data=[]) + + +def test_forward_ones_like(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + + class OnesLike1(Module): + def forward(self, *args): + return torch.ones_like(args[0]) + + class OnesLike2(Module): + def forward(self, *args): + return torch.ones_like(args[0], dtype=torch.int8) + + class OnesLike3(Module): + def forward(self, *args): + return torch.ones_like(args[0], dtype=torch.float) + + input_data = torch.rand(input_shape).float() + verify_model(OnesLike1().float().eval(), input_data=input_data) + verify_model(OnesLike2().float().eval(), input_data=input_data) + verify_model(OnesLike3().float().eval(), input_data=input_data) + + +def test_forward_zeros(): + torch.set_grad_enabled(False) + + class Zeros1(Module): + def forward(self, *args): + return torch.zeros(2,3) + + verify_model(Zeros1().float().eval(), input_data=[]) + + +def test_forward_zeros_like(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + + class ZerosLike1(Module): + def forward(self, *args): + return torch.zeros_like(args[0]) + + class ZerosLike2(Module): + def forward(self, *args): + return torch.zeros_like(args[0], dtype=torch.int32) + + class ZerosLike3(Module): + def forward(self, *args): + return torch.zeros_like(args[0], dtype=torch.float) + + input_data = torch.rand(input_shape).float() + verify_model(ZerosLike1().float().eval(), input_data=input_data) + verify_model(ZerosLike2().float().eval(), input_data=input_data) + verify_model(ZerosLike3().float().eval(), input_data=input_data) + + +def test_forward_full(): + torch.set_grad_enabled(False) + + class Full1(Module): + def forward(self, *args): + return torch.full((2,3), 3.14) + + class Full2(Module): + def forward(self, *args): + return torch.full((1, 2,3), 1.0, dtype=torch.int32) + + verify_model(Full1().float().eval(), input_data=[]) + verify_model(Full2().float().eval(), input_data=[]) + + +def test_forward_full_like(): + torch.set_grad_enabled(False) + input_shape = [1, 3, 10, 10] + + class FullLike1(Module): + def forward(self, *args): + return torch.full_like(args[0], 3.14) + + class FullLike2(Module): + def forward(self, *args): + return torch.full_like(args[0], 22.22, dtype=torch.int32) + + class FullLike3(Module): + def forward(self, *args): + return torch.full_like(args[0], 1.4, dtype=torch.float) + + input_data = torch.rand(input_shape).float() + verify_model(FullLike1().float().eval(), input_data=input_data) + verify_model(FullLike2().float().eval(), input_data=input_data) + verify_model(FullLike3().float().eval(), input_data=input_data) + +def test_forward_linspace(): + torch.set_grad_enabled(False) + + class Linspace1(Module): + def forward(self, *args): + return torch.linspace(5, 10) + class Linspace2(Module): + def forward(self, *args): + return torch.linspace(-10, 10, steps=5) + class Linspace3(Module): + def forward(self, *args): + return torch.linspace(start=-10, end=10, steps=5) + class Linspace4(Module): + def forward(self, *args): + return torch.linspace(start=-10, end=10, steps=1) + class Linspace5(Module): + def forward(self, *args): + return torch.linspace(1, 2, 1, dtype=torch.int32) + class Linspace6(Module): + def forward(self, *args): + return torch.linspace(start=1, end=6, steps=2) + class Linspace7(Module): + def forward(self, *args): + return torch.linspace(1, 4, dtype=torch.float32) + class Linspace8(Module): + def forward(self, *args): + return torch.linspace(1, 2, 1, dtype=torch.int16) + + verify_model(Linspace1().float().eval()) + verify_model(Linspace2().float().eval()) + verify_model(Linspace3().float().eval()) + verify_model(Linspace4().float().eval()) + verify_model(Linspace5().float().eval()) + verify_model(Linspace6().float().eval()) + verify_model(Linspace7().float().eval()) + verify_model(Linspace8().float().eval()) + + def test_forward_take(): torch.set_grad_enabled(False) class Take1(Module): @@ -1770,6 +1908,13 @@ def forward(self, *args): test_forward_isfinite() test_forward_isnan() test_forward_isinf() + test_forward_ones() + test_forward_ones_like() + test_forward_zeros() + test_forward_zeros_like() + test_forward_full() + test_forward_full_like() + test_forward_linspace() test_forward_arange() test_forward_chunk() test_forward_split()