diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index ffbc0173a1c4..1ac365f414f8 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -1406,6 +1406,36 @@ def convert_pool2d(g, op, block): g.add_node(op.output("Out")[0], out) +def convert_max_pool2d_with_index(g, op, block): + """Operator converter for max_pool2d_with_index.""" + + adaptive = op.attr("adaptive") + global_pooling = op.attr("global_pooling") + ksize = op.attr("ksize") + paddings = op.attr("paddings") + if global_pooling: + adaptive = True + ksize = [1, 1] + + input_x = g.get_node(op.input("X")[0]) + + strides = op.attr("strides") + if isinstance(strides, int): + strides = [strides, strides] + if isinstance(ksize, int): + ksize = [ksize, ksize] + if isinstance(paddings, int): + paddings = [paddings] * 2 + + if not adaptive: + out = getattr(_op.nn, "max_pool2d")( + input_x, pool_size=ksize, strides=strides, padding=paddings + ) + else: + out = getattr(_op.nn, "adaptive_max_pool2d")(input_x, output_size=ksize) + g.add_node(op.output("Out")[0], out) + + def convert_padding(g, op, block): """Operator converter for padding.""" @@ -2277,6 +2307,7 @@ def convert_where(g, op, block): "nearest_interp_v2": convert_interpolate, "not_equal": convert_elementwise_op, "pool2d": convert_pool2d, + "max_pool2d_with_index": convert_max_pool2d_with_index, "pad1d": convert_padding, "pad2d": convert_padding, "pad3d": convert_padding, diff --git a/tests/python/frontend/paddlepaddle/test_forward.py b/tests/python/frontend/paddlepaddle/test_forward.py index 2e9a81f30b3b..62fcfe9c09fd 100644 --- a/tests/python/frontend/paddlepaddle/test_forward.py +++ b/tests/python/frontend/paddlepaddle/test_forward.py @@ -1375,15 +1375,16 @@ def pool2d2(inputs): @paddle.jit.to_static def pool2d3(inputs): - return nn.functional.max_pool2d( + output, max_indices = nn.functional.max_pool2d( inputs, kernel_size=2, stride=2, padding=0, return_mask=True ) + return output input_data = paddle.uniform(shape=[1, 2, 32, 32], dtype="float32", min=-1, max=1) verify_model(pool2d1, input_data=input_data) verify_model(pool2d2, input_data=input_data) # need op max_pool2d_with_index - # verify_model(pool2d3, input_data=input_data) + verify_model(pool2d3, input_data=input_data) @tvm.testing.uses_gpu