From 385917c8035795d1563fe6c9f26558c9d032012a Mon Sep 17 00:00:00 2001 From: wjj19950828 Date: Wed, 22 Sep 2021 17:41:26 +0800 Subject: [PATCH] Add Simplernn op --- python/tvm/relay/frontend/paddlepaddle.py | 57 +++++++++++++++++++ .../frontend/paddlepaddle/test_forward.py | 33 +++++++++++ 2 files changed, 90 insertions(+) diff --git a/python/tvm/relay/frontend/paddlepaddle.py b/python/tvm/relay/frontend/paddlepaddle.py index 5b6bdd78721b..988a32a399bc 100644 --- a/python/tvm/relay/frontend/paddlepaddle.py +++ b/python/tvm/relay/frontend/paddlepaddle.py @@ -1711,6 +1711,38 @@ def generate_gru( return output, hidden_state + def generate_simplernn( + input_seqs, hidden_state, w_inp, w_hid, b_inp, b_hid, n_act, backwards=False + ): + """Implementation of SimpleRNN cell for paddlepaddle of TVM""" + + h_list = [] + seq_length = len(input_seqs) + for i in range(seq_length): + step = input_seqs[i] if not backwards else input_seqs[seq_length - (i + 1)] + step = _op.squeeze(step, axis=[0]) + xwt = _op.nn.dense(step, w_inp) + hwt = _op.nn.dense(hidden_state, w_hid) + if b_inp is not None: + xwt += b_inp + if b_hid is not None: + hwt += b_hid + + n_gate = n_act(xwt + hwt) + + hidden_state = n_gate + h_list.append(_op.expand_dims(hidden_state, axis=0)) + + if backwards: + h_list = h_list[::-1] + + # Concatenate outputs and add back in direction axis. + concatenated = _op.concatenate(h_list, 0) + output = _op.expand_dims(concatenated, axis=1) + hidden_state = _op.expand_dims(hidden_state, axis=0) + + return output, hidden_state + def make_param_inputs(g, node, layer, hidden_size, num_layers): """Param for weight and bias.""" @@ -1840,6 +1872,31 @@ def make_init_param_inputs(g, node, layer): result_H.append(H) output = _op.concatenate(result_output, axis=1) H = _op.concatenate(result_H, axis=0) + elif mode == "RNN_TANH": + init_h = make_init_param_inputs(g, op, layer) + init_hs = _op.split(init_h, num_directions) + result_output = [] + result_H = [] + for i in range(num_directions): + H_t = _op.squeeze(init_hs[i], axis=[0]) + W = g.get_node(input_weights[i]) + R = g.get_node(hidden_weights[i]) + WB = g.get_node(input_bias[i]) + RB = g.get_node(hidden_bias[i]) + output, H = generate_simplernn( + input_seqs=x_steps, + hidden_state=H_t, + w_inp=W, + w_hid=R, + b_inp=WB, + b_hid=RB, + n_act=_op.tanh, + backwards=i == 1, + ) + result_output.append(output) + result_H.append(H) + output = _op.concatenate(result_output, axis=1) + H = _op.concatenate(result_H, axis=0) output = _op.transpose(output, axes=[0, 2, 1, 3]) output = _op.reshape(output, newshape=(0, 0, -1)) diff --git a/tests/python/frontend/paddlepaddle/test_forward.py b/tests/python/frontend/paddlepaddle/test_forward.py index b7c850574bc0..73530623d3c5 100644 --- a/tests/python/frontend/paddlepaddle/test_forward.py +++ b/tests/python/frontend/paddlepaddle/test_forward.py @@ -1192,6 +1192,38 @@ def forward(self, inputs, prev_h): verify_model(GRU2(), input_data=[gru_input_data, prev_h]) +@tvm.testing.uses_gpu +def test_forward_simplernn(): + class SimpleRNN1(nn.Layer): + def __init__(self): + super(SimpleRNN1, self).__init__() + self.simplernn = nn.SimpleRNN(288, 48, 2, direction="bidirect", time_major=True) + + @paddle.jit.to_static + def forward(self, inputs, prev_h): + y, h = self.simplernn(inputs, prev_h) + return y + + class SimpleRNN2(nn.Layer): + def __init__(self): + super(SimpleRNN2, self).__init__() + self.simplernn = nn.SimpleRNNCell(16, 32) + + @paddle.jit.to_static + def forward(self, inputs, prev_h): + y, h = self.simplernn(inputs, prev_h) + return y + + gru_input_shape = [25, 1, 288] + gru_input_data = paddle.rand(gru_input_shape, dtype="float32") + prev_h = paddle.rand([4, 1, 48], dtype="float32") + verify_model(SimpleRNN1(), input_data=[gru_input_data, prev_h]) + gru_input_shape = [4, 16] + gru_input_data = paddle.rand(gru_input_shape, dtype="float32") + prev_h = paddle.rand([4, 32], dtype="float32") + verify_model(SimpleRNN2(), input_data=[gru_input_data, prev_h]) + + @tvm.testing.uses_gpu def test_forward_multiply(): @paddle.jit.to_static @@ -2146,6 +2178,7 @@ def forward(self, x): test_forward_look_up() test_forward_lstm() test_forward_gru() + test_forward_simplernn() test_forward_masked_select() test_forward_matmul() test_forward_meshgrid()