Skip to content

Commit

Permalink
Merge pull request apache#41 from wjj19950828/support_textcnn
Browse files Browse the repository at this point in the history
support for textcnn
  • Loading branch information
jiangjiajun committed Sep 18, 2021
2 parents 82345f2 + 52e7063 commit 69824a4
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
31 changes: 31 additions & 0 deletions python/tvm/relay/frontend/paddlepaddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1406,6 +1406,36 @@ def convert_pool2d(g, op, block):
g.add_node(op.output("Out")[0], out)


def convert_max_pool2d_with_index(g, op, block):
"""Operator converter for max_pool2d_with_index."""

adaptive = op.attr("adaptive")
global_pooling = op.attr("global_pooling")
ksize = op.attr("ksize")
paddings = op.attr("paddings")
if global_pooling:
adaptive = True
ksize = [1, 1]

input_x = g.get_node(op.input("X")[0])

strides = op.attr("strides")
if isinstance(strides, int):
strides = [strides, strides]
if isinstance(ksize, int):
ksize = [ksize, ksize]
if isinstance(paddings, int):
paddings = [paddings] * 2

if not adaptive:
out = getattr(_op.nn, "max_pool2d")(
input_x, pool_size=ksize, strides=strides, padding=paddings
)
else:
out = getattr(_op.nn, "adaptive_max_pool2d")(input_x, output_size=ksize)
g.add_node(op.output("Out")[0], out)


def convert_padding(g, op, block):
"""Operator converter for padding."""

Expand Down Expand Up @@ -2277,6 +2307,7 @@ def convert_where(g, op, block):
"nearest_interp_v2": convert_interpolate,
"not_equal": convert_elementwise_op,
"pool2d": convert_pool2d,
"max_pool2d_with_index": convert_max_pool2d_with_index,
"pad1d": convert_padding,
"pad2d": convert_padding,
"pad3d": convert_padding,
Expand Down
5 changes: 3 additions & 2 deletions tests/python/frontend/paddlepaddle/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1375,15 +1375,16 @@ def pool2d2(inputs):

@paddle.jit.to_static
def pool2d3(inputs):
return nn.functional.max_pool2d(
output, max_indices = nn.functional.max_pool2d(
inputs, kernel_size=2, stride=2, padding=0, return_mask=True
)
return output

input_data = paddle.uniform(shape=[1, 2, 32, 32], dtype="float32", min=-1, max=1)
verify_model(pool2d1, input_data=input_data)
verify_model(pool2d2, input_data=input_data)
# need op max_pool2d_with_index
# verify_model(pool2d3, input_data=input_data)
verify_model(pool2d3, input_data=input_data)


@tvm.testing.uses_gpu
Expand Down

0 comments on commit 69824a4

Please sign in to comment.