Skip to content

Commit

Permalink
whisper : avoid some memory allocations
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Dec 30, 2022
1 parent 35f6168 commit 4a250c9
Showing 1 changed file with 21 additions and 6 deletions.
27 changes: 21 additions & 6 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ struct whisper_vocab {
std::map<token, id> token_to_id;
std::map<id, token> id_to_token;

// used to avoid memory allocations during sampling
// TODO: move to whisper_context in the future
std::vector<std::pair<double, whisper_vocab::id>> probs_id;

id token_eot = 50256;
id token_sot = 50257;
id token_prev = 50360;
Expand Down Expand Up @@ -551,6 +555,9 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx

std::string word;
std::vector<char> tmp;

tmp.reserve(128);

for (int i = 0; i < n_vocab; i++) {
uint32_t len;
read_safe(fin, len);
Expand Down Expand Up @@ -603,6 +610,11 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx
vocab.id_to_token[i] = word;
}
}

wctx.logits.reserve(vocab.n_vocab*model.hparams.n_text_ctx);
wctx.probs.reserve(vocab.n_vocab*model.hparams.n_text_ctx);

vocab.probs_id.reserve(n_vocab);
}

{
Expand Down Expand Up @@ -1021,7 +1033,7 @@ static bool whisper_model_load(const std::string & fname, whisper_context & wctx

std::string name;
std::vector<char> tmp(length); // create a buffer
fin.read( &tmp[0], tmp.size() ); // read to buffer
fin.read(&tmp[0], tmp.size()); // read to buffer
name.assign(&tmp[0], tmp.size());

if (model.tensors.find(name) == model.tensors.end()) {
Expand Down Expand Up @@ -1849,19 +1861,19 @@ static bool whisper_decode(

// the most basic sampling scheme - select the top token
static whisper_token_data whisper_sample_best(
const whisper_vocab & vocab,
whisper_vocab & vocab,
const float * probs,
bool force_timestamp,
bool is_initial) {
whisper_token_data result = {
0, 0, 0.0f, 0.0f, 0.0f, -1, -1, 0.0f,
};

int n_logits = vocab.id_to_token.size();
const int n_logits = vocab.n_vocab;

std::vector<std::pair<double, whisper_vocab::id>> probs_id;
probs_id.reserve(n_logits);
auto & probs_id = vocab.probs_id;

probs_id.clear();
for (int i = 0; i < n_logits; i++) {
probs_id.emplace_back(probs[i], i);
}
Expand Down Expand Up @@ -2001,6 +2013,9 @@ static void fft(const std::vector<float> & in, std::vector<float> & out) {
std::vector<float> even;
std::vector<float> odd;

even.reserve(N/2);
odd.reserve(N/2);

for (int i = 0; i < N; i++) {
if (i % 2 == 0) {
even.push_back(in[i]);
Expand Down Expand Up @@ -2434,7 +2449,7 @@ int whisper_lang_auto_detect(
std::vector<std::pair<float, int>> probs_id;
for (const auto & kv : g_lang) {
const auto token_lang = whisper_token_lang(ctx, kv.second.first);
probs_id.emplace_back( ctx->probs[token_lang], kv.second.first );
probs_id.emplace_back(ctx->probs[token_lang], kv.second.first);
}

// sort descending
Expand Down

0 comments on commit 4a250c9

Please sign in to comment.