Skip to content

Commit

Permalink
Merge pull request apache#45 from wjj19950828/paddle_frontend
Browse files Browse the repository at this point in the history
Add Simplernn op
  • Loading branch information
jiangjiajun committed Sep 22, 2021
2 parents 4dd04fb + 385917c commit 9ea0e53
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 0 deletions.
57 changes: 57 additions & 0 deletions python/tvm/relay/frontend/paddlepaddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -1711,6 +1711,38 @@ def generate_gru(

return output, hidden_state

def generate_simplernn(
input_seqs, hidden_state, w_inp, w_hid, b_inp, b_hid, n_act, backwards=False
):
"""Implementation of SimpleRNN cell for paddlepaddle of TVM"""

h_list = []
seq_length = len(input_seqs)
for i in range(seq_length):
step = input_seqs[i] if not backwards else input_seqs[seq_length - (i + 1)]
step = _op.squeeze(step, axis=[0])
xwt = _op.nn.dense(step, w_inp)
hwt = _op.nn.dense(hidden_state, w_hid)
if b_inp is not None:
xwt += b_inp
if b_hid is not None:
hwt += b_hid

n_gate = n_act(xwt + hwt)

hidden_state = n_gate
h_list.append(_op.expand_dims(hidden_state, axis=0))

if backwards:
h_list = h_list[::-1]

# Concatenate outputs and add back in direction axis.
concatenated = _op.concatenate(h_list, 0)
output = _op.expand_dims(concatenated, axis=1)
hidden_state = _op.expand_dims(hidden_state, axis=0)

return output, hidden_state

def make_param_inputs(g, node, layer, hidden_size, num_layers):
"""Param for weight and bias."""

Expand Down Expand Up @@ -1840,6 +1872,31 @@ def make_init_param_inputs(g, node, layer):
result_H.append(H)
output = _op.concatenate(result_output, axis=1)
H = _op.concatenate(result_H, axis=0)
elif mode == "RNN_TANH":
init_h = make_init_param_inputs(g, op, layer)
init_hs = _op.split(init_h, num_directions)
result_output = []
result_H = []
for i in range(num_directions):
H_t = _op.squeeze(init_hs[i], axis=[0])
W = g.get_node(input_weights[i])
R = g.get_node(hidden_weights[i])
WB = g.get_node(input_bias[i])
RB = g.get_node(hidden_bias[i])
output, H = generate_simplernn(
input_seqs=x_steps,
hidden_state=H_t,
w_inp=W,
w_hid=R,
b_inp=WB,
b_hid=RB,
n_act=_op.tanh,
backwards=i == 1,
)
result_output.append(output)
result_H.append(H)
output = _op.concatenate(result_output, axis=1)
H = _op.concatenate(result_H, axis=0)

output = _op.transpose(output, axes=[0, 2, 1, 3])
output = _op.reshape(output, newshape=(0, 0, -1))
Expand Down
33 changes: 33 additions & 0 deletions tests/python/frontend/paddlepaddle/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,38 @@ def forward(self, inputs, prev_h):
verify_model(GRU2(), input_data=[gru_input_data, prev_h])


@tvm.testing.uses_gpu
def test_forward_simplernn():
class SimpleRNN1(nn.Layer):
def __init__(self):
super(SimpleRNN1, self).__init__()
self.simplernn = nn.SimpleRNN(288, 48, 2, direction="bidirect", time_major=True)

@paddle.jit.to_static
def forward(self, inputs, prev_h):
y, h = self.simplernn(inputs, prev_h)
return y

class SimpleRNN2(nn.Layer):
def __init__(self):
super(SimpleRNN2, self).__init__()
self.simplernn = nn.SimpleRNNCell(16, 32)

@paddle.jit.to_static
def forward(self, inputs, prev_h):
y, h = self.simplernn(inputs, prev_h)
return y

gru_input_shape = [25, 1, 288]
gru_input_data = paddle.rand(gru_input_shape, dtype="float32")
prev_h = paddle.rand([4, 1, 48], dtype="float32")
verify_model(SimpleRNN1(), input_data=[gru_input_data, prev_h])
gru_input_shape = [4, 16]
gru_input_data = paddle.rand(gru_input_shape, dtype="float32")
prev_h = paddle.rand([4, 32], dtype="float32")
verify_model(SimpleRNN2(), input_data=[gru_input_data, prev_h])


@tvm.testing.uses_gpu
def test_forward_multiply():
@paddle.jit.to_static
Expand Down Expand Up @@ -2149,6 +2181,7 @@ def forward(self, x):
test_forward_look_up()
test_forward_lstm()
test_forward_gru()
test_forward_simplernn()
test_forward_masked_select()
test_forward_matmul()
test_forward_meshgrid()
Expand Down

0 comments on commit 9ea0e53

Please sign in to comment.