diff --git a/python/tvm/relay/frontend/tensorflow.py b/python/tvm/relay/frontend/tensorflow.py index 4a7a7da307fc..ca24e1b35374 100644 --- a/python/tvm/relay/frontend/tensorflow.py +++ b/python/tvm/relay/frontend/tensorflow.py @@ -1454,9 +1454,9 @@ def _impl(inputs, attr, params, mod): break if is_symbolic_shape: - ret = _op.shape_of(inputs[0], dtype="int32") + ret = _op.shape_of(inputs[0], dtype=attr["out_type"].name) else: - ret = np.array(input_shape, dtype="int32") + ret = np.array(input_shape, dtype=attr["out_type"].name) return ret return _impl @@ -1862,11 +1862,11 @@ def _impl(inputs, attr, params, mod): dtype = attr["Tidx"].name if "Tidx" in attr else str(start.dtype) if isinstance(start, (np.int32, np.int64, int, np.float32, np.float64, float)): - start = _expr.const(start) + start = _expr.const(start, dtype=dtype) if isinstance(limit, (np.int32, np.int64, int, np.float32, np.float64, float)): - limit = _expr.const(limit) + limit = _expr.const(limit, dtype=dtype) if isinstance(delta, (np.int32, np.int64, int, np.float32, np.float64, float)): - delta = _expr.const(delta) + delta = _expr.const(delta, dtype=dtype) return AttrCvt( op_name="arange", diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 93bfd0cbaf83..23a4b7abe5ab 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -2783,10 +2783,11 @@ def test_forward_unpack(): def test_forward_range(): """test operator Range""" - tf.reset_default_graph() - with tf.Graph().as_default(): - tf.range(1, 18, 3, name="range") - compare_tf_with_tvm([], [], "range:0") + for dtype in [tf.int32, tf.int64]: + tf.reset_default_graph() + with tf.Graph().as_default(): + tf.range(1, 18, 3, name="range", dtype=dtype) + compare_tf_with_tvm([], [], "range:0") """test type assignment for operator Range""" tf.reset_default_graph()