Skip to content

Commit

Permalink
whisper : add API for applying custom logits filters during decoding
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Feb 19, 2023
1 parent f254e78 commit 0d22916
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
12 changes: 9 additions & 3 deletions whisper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con
MEM_REQ_SCRATCH3.at (model.type) +
scale*MEM_REQ_MODEL.at (model.type) +
scale*MEM_REQ_KV_CROSS.at(model.type) +
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));
scale*std::max(MEM_REQ_ENCODE.at(model.type), MEM_REQ_DECODE.at(model.type));

// this is the memory required by one decoder
const size_t mem_required_decoder =
Expand Down Expand Up @@ -2962,6 +2962,9 @@ struct whisper_full_params whisper_full_default_params(enum whisper_sampling_str

/*.encoder_begin_callback =*/ nullptr,
/*.encoder_begin_callback_user_data =*/ nullptr,

/*.logits_filter_callback =*/ nullptr,
/*.logits_filter_callback_user_data =*/ nullptr,
};

switch (strategy) {
Expand Down Expand Up @@ -3089,7 +3092,7 @@ static const std::vector<std::string> non_speech_tokens = {
// - applies logit filters
// - computes logprobs and probs
static void whisper_process_logits(
const struct whisper_context & ctx,
struct whisper_context & ctx,
const struct whisper_full_params params,
struct whisper_decoder & decoder,
float temperature) {
Expand Down Expand Up @@ -3145,6 +3148,9 @@ static void whisper_process_logits(
logits[vocab.token_translate] = -INFINITY;
logits[vocab.token_transcribe] = -INFINITY;

if (params.logits_filter_callback) {
params.logits_filter_callback(&ctx, tokens_cur.data(), tokens_cur.size(), logits.data(), params.logits_filter_callback_user_data);
}

// suppress non-speech tokens
// ref: https://github.com/openai/whisper/blob/7858aa9c08d98f75575035ecd6481f462d66ca27/whisper/tokenizer.py#L224-L253
Expand Down Expand Up @@ -3848,7 +3854,7 @@ int whisper_full(
return a.sequence.sum_logprobs_all > b.sequence.sum_logprobs_all;
});

unsigned int cur_c = 0;
uint32_t cur_c = 0;

for (int j = 0; j < n_decoders_cur; ++j) {
auto & decoder = ctx->decoders[j];
Expand Down
14 changes: 14 additions & 0 deletions whisper.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,16 @@ extern "C" {
// If it returns false, the computation is aborted
typedef bool (*whisper_encoder_begin_callback)(struct whisper_context * ctx, void * user_data);

// Logits filter callback
// Can be used to modify the logits before sampling
// If not NULL, called after applying temperature to logits
typedef void (*whisper_logits_filter_callback)(
struct whisper_context * ctx,
const whisper_token_data * tokens,
int n_tokens,
float * logits,
void * user_data);

// Parameters for the whisper_full() function
// If you chnage the order or add new parameters, make sure to update the default values in whisper.cpp:
// whisper_full_default_params()
Expand Down Expand Up @@ -315,6 +325,10 @@ extern "C" {
// called each time before the encoder starts
whisper_encoder_begin_callback encoder_begin_callback;
void * encoder_begin_callback_user_data;

// called by each decoder to filter obtained logits
whisper_logits_filter_callback logits_filter_callback;
void * logits_filter_callback_user_data;
};

WHISPER_API struct whisper_full_params whisper_full_default_params(enum whisper_sampling_strategy strategy);
Expand Down

0 comments on commit 0d22916

Please sign in to comment.