diff --git a/python/tvm/relay/frontend/mxnet.py b/python/tvm/relay/frontend/mxnet.py index 4edf0b80de4c..be2f1105d960 100644 --- a/python/tvm/relay/frontend/mxnet.py +++ b/python/tvm/relay/frontend/mxnet.py @@ -1073,6 +1073,20 @@ def _mx_one_hot(inputs, attrs): return _op.one_hot(indices, on_value, off_value, depth, -1, dtype) +def _mx_depth_to_space(inputs, attrs): + assert len(inputs) == 1 + new_attrs = {} + new_attrs["block_size"] = attrs.get_int("block_size") + return _op.nn.depth_to_space(*inputs, **new_attrs) + + +def _mx_space_to_depth(inputs, attrs): + assert len(inputs) == 1 + new_attrs = {} + new_attrs["block_size"] = attrs.get_int("block_size") + return _op.nn.space_to_depth(*inputs, **new_attrs) + + def _mx_contrib_fifo_buffer(inputs, attrs): new_attrs = {} new_attrs['axis'] = attrs.get_int('axis') @@ -1854,6 +1868,8 @@ def _get_bias_requantize_scale(_inputs, _data_scale, _kernel_scale): "make_loss" : _mx_make_loss, "_contrib_div_sqrt_dim": _mx_contrib_div_sqrt_dim, "one_hot" : _mx_one_hot, + "depth_to_space" : _mx_depth_to_space, + "space_to_depth" : _mx_space_to_depth, # vision "_contrib_BilinearResize2D" : _mx_resize, "_contrib_MultiBoxPrior" : _mx_multibox_prior, diff --git a/tests/python/frontend/mxnet/test_forward.py b/tests/python/frontend/mxnet/test_forward.py index 4a9848e03b5e..10edff9031fe 100644 --- a/tests/python/frontend/mxnet/test_forward.py +++ b/tests/python/frontend/mxnet/test_forward.py @@ -995,6 +995,38 @@ def _verify_swap_axis(in_shape, out_shape, dim1, dim2): # _verify_swap_axis((4, 5), (5, 4), 0, 0) +def test_forward_depth_to_space(): + def verify(shape, blocksize=2): + x = np.random.uniform(size=shape).astype("float32") + ref_res = mx.nd.depth_to_space(mx.nd.array(x), blocksize) + mx_sym = mx.sym.depth_to_space(mx.sym.var("x"), blocksize) + shape_dict = {"x": x.shape, } + mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) + + verify((1, 18, 3, 3), 3) + + +def test_forward_space_to_depth(): + def verify(shape, blocksize=2): + x = np.random.uniform(size=shape).astype("float32") + ref_res = mx.nd.space_to_depth(mx.nd.array(x), blocksize) + mx_sym = mx.sym.space_to_depth(mx.sym.var("x"), blocksize) + shape_dict = {"x": x.shape, } + mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict) + for target, ctx in ctx_list(): + for kind in ["graph", "debug"]: + intrp = relay.create_executor(kind, mod=mod, ctx=ctx, target=target) + op_res = intrp.evaluate()(x) + tvm.testing.assert_allclose(op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-3, atol=1e-5) + + verify((1, 1, 9, 9), 3) + + if __name__ == '__main__': test_forward_mlp() test_forward_vgg() @@ -1047,6 +1079,8 @@ def _verify_swap_axis(in_shape, out_shape, dim1, dim2): test_forward_instance_norm() test_forward_layer_norm() test_forward_one_hot() + test_forward_depth_to_space() + test_forward_space_to_depth() test_forward_convolution() test_forward_deconvolution() test_forward_cond()