Skip to content

Commit

Permalink
Merge pull request apache#34 from heliqi/paddle
Browse files Browse the repository at this point in the history
add elu groupnorm hardtanh hardshrink instance_norm op
  • Loading branch information
jiangjiajun committed Sep 15, 2021
2 parents 9c536e6 + d3f29eb commit 2ae2826
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 17 deletions.
79 changes: 76 additions & 3 deletions python/tvm/relay/frontend/paddlepaddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,17 @@ def convert_dropout(g, op, block):
"""Operator converter for dropout."""

x = g.get_node(op.input("X")[0])
out = _op.copy(x)
g.add_node(op.output("Out")[0], x)


def convert_elu(g, op, block):
"""Operator converter for elu."""

x = g.get_node(op.input("X")[0])
dtype = infer_type(x).checked_type.dtype
alpha = op.attr("alpha")
alpha = _expr.const(alpha, dtype=dtype)
out = alpha * _op.nn.relu(_expr.const(1, dtype=dtype) - _op.exp(x)) + _op.nn.relu(x)
g.add_node(op.output("Out")[0], out)


Expand Down Expand Up @@ -800,12 +810,46 @@ def convert_gelu(g, op, block):
g.add_node(op.output("Out")[0], out)


def convert_group_norm(g, op, block):
"""Operator converter for group_norm."""

x = g.get_node(op.input("X")[0])
num_groups = op.attr("groups")
epsilon = op.attr("epsilon")
gamma = g.get_node(op.input("Scale")[0])
beta = g.get_node(op.input("Bias")[0])
out = _op.nn.group_norm(
x,
gamma=gamma,
beta=beta,
num_groups=num_groups,
axis=1,
epsilon=epsilon,
center=True,
scale=True,
)
g.add_node(op.output("Y")[0], out)


def convert_hard_shrink(g, op, block):
"""Operator converter for hard_shrink."""

x = g.get_node(op.input("X")[0])
dtype = infer_type(x).checked_type.dtype
threshold = op.attr("threshold")
threshold = _op.const(threshold, dtype)
out = _op.logical_or(x < _op.const(-1.0, dtype) * threshold, x > threshold)
out = _op.cast(out, dtype) * x
g.add_node(op.output("Out")[0], out)


def convert_hard_sigmoid(g, op, block):
"""Operator converter for hard_sigmoid."""

slope = op.attr("slope")
x = g.get_node(op.input("X")[0])
out = x * _expr.const(slope) + _expr.const(0.5)
dtype = infer_type(x).checked_type.dtype
out = x * _expr.const(slope, dtype) + _expr.const(0.5, dtype)
out = _op.clip(out, 0, 1)
g.add_node(op.output("Out")[0], out)

Expand All @@ -820,12 +864,23 @@ def convert_hard_swish(g, op, block):
assert np.isclose(scale, 6.0), "Only support scale==6.0 for PaddlePaddle's hard_swish"
assert np.isclose(threshold, 6.0), "Only support threshold==6.0 for PaddlePaddle's hard_swish"
x = g.get_node(op.input("X")[0])
dtype = infer_type(x).checked_type.dtype
out = _op.clip(x, -1 * offset, offset)
out = out / _expr.const(threshold) + _expr.const(0.5)
out = out / _expr.const(threshold, dtype) + _expr.const(0.5, dtype)
out = x * out
g.add_node(op.output("Out")[0], out)


def convert_hard_tanh(g, op, block):
"""Operator converter for hard_tanh."""

x = g.get_node(op.input("X")[0])
t_max = op.attr("t_max")
t_min = op.attr("t_min")
out = _op.tensor.clip(x, t_min, t_max)
g.add_node(op.output("Out")[0], out)


def convert_index_select(g, op, block):
"""Operator converter for index_select."""

Expand All @@ -837,6 +892,19 @@ def convert_index_select(g, op, block):
g.add_node(op.output("Out")[0], out)


def convert_instance_norm(g, op, block):
"""Operator converter for instance_norm."""

x = g.get_node(op.input("X")[0])
gamma = g.get_node(op.input("Scale")[0])
beta = g.get_node(op.input("Bias")[0])
epsilon = op.attr("epsilon")

scale = center = True
out = _op.nn.instance_norm(x, gamma, beta, axis=1, epsilon=epsilon, center=center, scale=scale)
g.add_node(op.output("Y")[0], out)


def convert_layer_norm(g, op, block):
"""Operator converter for layer_norm."""

Expand Down Expand Up @@ -1751,6 +1819,7 @@ def convert_where(g, op, block):
"bicubic_interp_v2": convert_interpolate,
"bilinear_interp_v2": convert_interpolate,
"bmm": convert_bmm,
"brelu": convert_hard_tanh,
"cast": convert_cast,
"ceil": convert_unary_op,
"clip": convert_clip,
Expand All @@ -1774,6 +1843,7 @@ def convert_where(g, op, block):
"elementwise_min": convert_elementwise_op,
"elementwise_pow": convert_elementwise_op,
"elementwise_floordiv": convert_elementwise_op,
"elu": convert_elu,
"equal": convert_elementwise_op,
"erf": convert_unary_op,
"exp": convert_unary_op,
Expand All @@ -1790,9 +1860,12 @@ def convert_where(g, op, block):
"gather_nd": convert_gather_nd,
"gelu": convert_gelu,
"greater_than": convert_elementwise_op,
"group_norm": convert_group_norm,
"hard_shrink": convert_hard_shrink,
"hard_sigmoid": convert_hard_sigmoid,
"hard_swish": convert_hard_swish,
"index_select": convert_index_select,
"instance_norm": convert_instance_norm,
"isinf": convert_unary_op,
"isinf_v2": convert_unary_op,
"layer_norm": convert_layer_norm,
Expand Down
61 changes: 47 additions & 14 deletions tests/python/frontend/paddlepaddle/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,25 +838,42 @@ def gelu(inputs):


@tvm.testing.uses_gpu
def test_forward_hard_sigmoid():
@paddle.jit.to_static
def hard_sigmoid(inputs):
return nn.functional.hardsigmoid(inputs)
def test_forward_group_norm():
class GroupNorm(nn.Layer):
def __init__(self, channels, groups):
super(GroupNorm).__init__()
self.group_norm = paddle.nn.GroupNorm(num_channels=channels, num_groups=groups)

input_shape = [1, 3, 10, 10]
input_data = paddle.rand(input_shape, dtype="float32")
verify_model(hard_sigmoid, input_data=input_data)
def forward(self, inputs):
self.group_norm(inputs)

x_data = np.random.random(size=(2, 6, 2, 2)).astype("float32")
x = paddle.to_tensor(x_data)
verify_model(GroupNorm(6, 6), x)


@tvm.testing.uses_gpu
def test_forward_hard_swish():
@paddle.jit.to_static
def hard_swish(inputs):
return nn.functional.hardswish(inputs)
def test_forward_hard_activation():
class Activation(nn.Layer):
def __init__(self, op_name):
super(Activation, self).__init__()
self.op_name_ = op_name
for candidate in (paddle.nn.functional, paddle):
self.func = getattr(candidate, op_name, None)
if self.func:
break

@paddle.jit.to_static
def forward(self, inputs):
return self.func(inputs)

input_shape = [1, 3, 10, 10]
input_data = paddle.rand(input_shape, dtype="float32")
verify_model(hard_swish, input_data=input_data)
input_data_2 = paddle.randint(1, 100, input_shape, dtype="int32")
op_list = ["elu", "hardshrink", "hardsigmoid", "hardswish", "hardtanh"]
for op_name in op_list:
verify_model(Activation(op_name), input_data=input_data)
verify_model(Activation(op_name), input_data=input_data_2)


@tvm.testing.uses_gpu
Expand All @@ -876,6 +893,21 @@ def index_select2(x, index):
verify_model(index_select2, input_data=[input_data, index])


@tvm.testing.uses_gpu
def test_forward_instance_norm():
class InstanceNorm(nn.Layer):
def __init__(self):
super(InstanceNorm, self).__init__()
self.instance_norm = paddle.nn.InstanceNorm2D(2)

def forward(self, inputs):
return self.instance_norm(inputs)

input_shape = [2, 2, 2, 3]
input_data = paddle.rand(input_shape, dtype="float32")
verify_model(InstanceNorm(), input_data)


@tvm.testing.uses_gpu
def test_forward_isinf():
@paddle.jit.to_static
Expand Down Expand Up @@ -1622,9 +1654,10 @@ def forward(self, c, x, y):
test_forward_gather_assign_value()
test_forward_gather_nd()
test_forward_gelu()
test_forward_hard_sigmoid()
test_forward_hard_swish()
test_forward_group_norm()
test_forward_hard_activation()
test_forward_index_select()
test_forward_instance_norm()
test_forward_interpolate()
test_forward_isinf()
test_forward_layer_norm()
Expand Down

0 comments on commit 2ae2826

Please sign in to comment.