From e6a2d0da3bc09b307ff8b753ee9b9e833d8ae6aa Mon Sep 17 00:00:00 2001 From: HieDean <34408026+HieDean@users.noreply.github.com> Date: Mon, 20 Nov 2023 09:20:50 +0800 Subject: [PATCH] Replace Clone() with View() (#432) Co-authored-by: hiedean --- .../offline-transducer-modified-beam-search-decoder.cc | 2 +- sherpa-onnx/csrc/online-rnn-lm.cc | 8 ++++---- .../csrc/online-transducer-greedy-search-decoder.cc | 8 +++++--- .../online-transducer-modified-beam-search-decoder.cc | 2 +- sherpa-onnx/csrc/online-wenet-ctc-model.cc | 6 +++--- 5 files changed, 14 insertions(+), 12 deletions(-) diff --git a/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc index e845b3138..142acb4ac 100644 --- a/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/offline-transducer-modified-beam-search-decoder.cc @@ -94,7 +94,7 @@ OfflineTransducerModifiedBeamSearchDecoder::Decode( // now cur_encoder_out is of shape (num_hyps, joiner_dim) Ort::Value logit = model_->RunJoiner( - std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); + std::move(cur_encoder_out), View(&decoder_out)); float *p_logit = logit.GetTensorMutableData(); LogSoftmax(p_logit, vocab_size, num_hyps); diff --git a/sherpa-onnx/csrc/online-rnn-lm.cc b/sherpa-onnx/csrc/online-rnn-lm.cc index 29b150e45..ff493c930 100644 --- a/sherpa-onnx/csrc/online-rnn-lm.cc +++ b/sherpa-onnx/csrc/online-rnn-lm.cc @@ -67,13 +67,13 @@ class OnlineRnnLM::Impl { return {std::move(out[0]), std::move(next_states)}; } - std::pair> GetInitStates() const { + std::pair> GetInitStates() { std::vector ans; ans.reserve(init_states_.size()); - for (const auto &s : init_states_) { - ans.emplace_back(Clone(allocator_, &s)); + for (auto &s : init_states_) { + ans.emplace_back(View(&s)); } - return {std::move(Clone(allocator_, &init_scores_.value)), std::move(ans)}; + return {View(&init_scores_.value), std::move(ans)}; } private: diff --git a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc index c2fc1103d..132aa87d2 100644 --- a/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-greedy-search-decoder.cc @@ -99,9 +99,11 @@ void OnlineTransducerGreedySearchDecoder::Decode( } if (is_batch_decoder_out_cached) { auto &r = result->front(); - std::vector decoder_out_shape = r.decoder_out.GetTensorTypeAndShapeInfo().GetShape(); + std::vector decoder_out_shape = + r.decoder_out.GetTensorTypeAndShapeInfo().GetShape(); decoder_out_shape[0] = batch_size; - decoder_out = Ort::Value::CreateTensor(model_->Allocator(), decoder_out_shape.data(), decoder_out_shape.size()); + decoder_out = Ort::Value::CreateTensor(model_->Allocator(), + decoder_out_shape.data(), decoder_out_shape.size()); UseCachedDecoderOut(*result, &decoder_out); } else { Ort::Value decoder_input = model_->BuildDecoderInput(*result); @@ -112,7 +114,7 @@ void OnlineTransducerGreedySearchDecoder::Decode( Ort::Value cur_encoder_out = GetEncoderOutFrame(model_->Allocator(), &encoder_out, t); Ort::Value logit = model_->RunJoiner( - std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); + std::move(cur_encoder_out), View(&decoder_out)); const float *p_logit = logit.GetTensorData(); diff --git a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc index a98f19dad..a02e34503 100644 --- a/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc +++ b/sherpa-onnx/csrc/online-transducer-modified-beam-search-decoder.cc @@ -120,7 +120,7 @@ void OnlineTransducerModifiedBeamSearchDecoder::Decode( cur_encoder_out = Repeat(model_->Allocator(), &cur_encoder_out, hyps_row_splits); Ort::Value logit = model_->RunJoiner( - std::move(cur_encoder_out), Clone(model_->Allocator(), &decoder_out)); + std::move(cur_encoder_out), View(&decoder_out)); float *p_logit = logit.GetTensorMutableData(); LogSoftmax(p_logit, vocab_size, num_hyps); diff --git a/sherpa-onnx/csrc/online-wenet-ctc-model.cc b/sherpa-onnx/csrc/online-wenet-ctc-model.cc index 5d7e90964..eac1a21cb 100644 --- a/sherpa-onnx/csrc/online-wenet-ctc-model.cc +++ b/sherpa-onnx/csrc/online-wenet-ctc-model.cc @@ -105,11 +105,11 @@ class OnlineWenetCtcModel::Impl { // - attn_cache // - conv_cache // - offset - std::vector GetInitStates() const { + std::vector GetInitStates() { std::vector ans; ans.reserve(3); - ans.push_back(Clone(Allocator(), &attn_cache_)); - ans.push_back(Clone(Allocator(), &conv_cache_)); + ans.push_back(View(&attn_cache_)); + ans.push_back(View(&conv_cache_)); int64_t offset_shape = 1;