Skip to content

Commit

Permalink
llama : add infill sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Oct 9, 2024
1 parent 61a66f2 commit 5a2bf08
Show file tree
Hide file tree
Showing 10 changed files with 235 additions and 5 deletions.
14 changes: 14 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,20 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
params.sparams.tfs_z = std::stof(value);
}
).set_sparam());
add_opt(llama_arg(
{"--infill-p"}, "N",
string_format("infill p threshold (default: %.1f)", (double)params.sparams.infill_p),
[](gpt_params & params, const std::string & value) {
params.sparams.infill_p = std::stof(value);
}
).set_sparam());
add_opt(llama_arg(
{"--infill-p-eog"}, "N",
string_format("infill p_eog threshold (default: %.1f)", (double)params.sparams.infill_p_eog),
[](gpt_params & params, const std::string & value) {
params.sparams.infill_p_eog = std::stof(value);
}
).set_sparam());
add_opt(llama_arg(
{"--typical"}, "N",
string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sparams.typ_p),
Expand Down
5 changes: 4 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ enum gpt_sampler_type {
GPT_SAMPLER_TYPE_TFS_Z = 4,
GPT_SAMPLER_TYPE_TYPICAL_P = 5,
GPT_SAMPLER_TYPE_TEMPERATURE = 6,
GPT_SAMPLER_TYPE_INFILL = 7,
};

// dimensionality reduction methods, used by cvector-generator
Expand All @@ -113,6 +114,8 @@ struct gpt_sampler_params {
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
float infill_p = 0.80f;
float infill_p_eog = 0.01f;
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
Expand All @@ -130,7 +133,7 @@ struct gpt_sampler_params {
GPT_SAMPLER_TYPE_TYPICAL_P,
GPT_SAMPLER_TYPE_TOP_P,
GPT_SAMPLER_TYPE_MIN_P,
GPT_SAMPLER_TYPE_TEMPERATURE
GPT_SAMPLER_TYPE_TEMPERATURE,
};

std::string grammar; // optional BNF-like grammar to constrain sampling
Expand Down
9 changes: 8 additions & 1 deletion common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
case GPT_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break;
case GPT_SAMPLER_TYPE_INFILL:
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model, params.infill_p, params.infill_p_eog));
break;
default:
GGML_ASSERT(false && "unknown sampler type");
}
Expand Down Expand Up @@ -372,6 +375,7 @@ char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr) {
case GPT_SAMPLER_TYPE_TOP_P: return 'p';
case GPT_SAMPLER_TYPE_MIN_P: return 'm';
case GPT_SAMPLER_TYPE_TEMPERATURE: return 't';
case GPT_SAMPLER_TYPE_INFILL: return 'i';
default : return '?';
}
}
Expand All @@ -384,6 +388,7 @@ std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr) {
case GPT_SAMPLER_TYPE_TOP_P: return "top_p";
case GPT_SAMPLER_TYPE_MIN_P: return "min_p";
case GPT_SAMPLER_TYPE_TEMPERATURE: return "temperature";
case GPT_SAMPLER_TYPE_INFILL: return "infill";
default : return "";
}
}
Expand All @@ -396,6 +401,7 @@ std::vector<gpt_sampler_type> gpt_sampler_types_from_names(const std::vector<std
{ "min_p", GPT_SAMPLER_TYPE_MIN_P },
{ "tfs_z", GPT_SAMPLER_TYPE_TFS_Z },
{ "temperature", GPT_SAMPLER_TYPE_TEMPERATURE },
{ "infill", GPT_SAMPLER_TYPE_INFILL }
};

// since samplers names are written multiple ways
Expand Down Expand Up @@ -441,7 +447,8 @@ std::vector<gpt_sampler_type> gpt_sampler_types_from_chars(const std::string & c
{ gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TYPICAL_P), GPT_SAMPLER_TYPE_TYPICAL_P },
{ gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TOP_P), GPT_SAMPLER_TYPE_TOP_P },
{ gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_MIN_P), GPT_SAMPLER_TYPE_MIN_P },
{ gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TEMPERATURE), GPT_SAMPLER_TYPE_TEMPERATURE }
{ gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TEMPERATURE), GPT_SAMPLER_TYPE_TEMPERATURE },
{ gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_INFILL), GPT_SAMPLER_TYPE_INFILL },
};

std::vector<gpt_sampler_type> samplers;
Expand Down
4 changes: 4 additions & 0 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,8 @@ struct server_context {
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
slot.sparams.infill_p = json_value(data, "infill_p", default_sparams.infill_p);
slot.sparams.infill_p_eog = json_value(data, "infill_p_eog", default_sparams.infill_p_eog);
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
Expand Down Expand Up @@ -1236,6 +1238,8 @@ struct server_context {
{"min_p", slot.sparams.min_p},
{"tfs_z", slot.sparams.tfs_z},
{"typical_p", slot.sparams.typ_p},
{"infill_p", slot.sparams.infill_p},
{"infill_p_eog", slot.sparams.infill_p_eog},
{"repeat_last_n", slot.sparams.penalty_last_n},
{"repeat_penalty", slot.sparams.penalty_repeat},
{"presence_penalty", slot.sparams.penalty_present},
Expand Down
26 changes: 26 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,12 @@ extern "C" {
int32_t lstrip,
bool special);

// check if token0 is contained as a prefix in token1
LLAMA_API bool llama_token_is_prefix(
const struct llama_model * model,
llama_token token0,
llama_token token1);

/// @details Convert the provided tokens into text (inverse of llama_tokenize()).
/// @param text The char pointer must be large enough to hold the resulting text.
/// @return Returns the number of chars/bytes on success, no more than text_len_max.
Expand Down Expand Up @@ -1144,6 +1150,26 @@ extern "C" {
int32_t n_logit_bias,
const llama_logit_bias * logit_bias);

// 1. if there is a high-prob token (>= 0.9f) - pick it
// 2. if sum of EOG probs is larger than p_eog -> mask non-EOG tokens away
// 3. combine probs of tokens that have the same prefix
//
// example:
//
// - before:
// "hel": 0.5
// "hell": 0.2
// "hello": 0.1
// "dummy": 0.1
//
// - after:
// "hel": 0.8
// "dummy": 0.1
//
LLAMA_API struct llama_sampler * llama_sampler_init_infill(
const struct llama_model * model,
float p,
float p_eog);

// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
Expand Down
139 changes: 139 additions & 0 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1644,6 +1644,145 @@ struct llama_sampler * llama_sampler_init_logit_bias(
};
}

// infill

struct llama_sampler_infill {
const struct llama_vocab * vocab;

const float p;
const float p_eog;
};

static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) {
return "infill";
}

static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
auto * ctx = (llama_sampler_infill *) smpl->ctx;

llama_sampler_softmax_impl(cur_p);

// print cur_p:
for (size_t i = 0; i < cur_p->size; ++i) {
LLAMA_LOG_DEBUG("infill: cur_p[%zu] = { id: %d, p: %f, logit: %f }\n", i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
}

float p_max = 0.0f;
float p_eog_sum = 0.0f;

for (size_t i = 0; i < cur_p->size; ++i) {
p_max = fmaxf(p_max, cur_p->data[i].p);
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
p_eog_sum += cur_p->data[i].p;
}
}

if (p_max < 0.90f && p_eog_sum > ctx->p_eog) {
LLAMA_LOG_DEBUG("infill: all EOG tokens are more likely than p_eog (%f), keeping only EOG tokens\n", ctx->p_eog);

// keep just the EOG tokens
const auto size_org = cur_p->size;

cur_p->size = 0;

for (size_t i = 0; i < size_org; ++i) {
if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
cur_p->data[cur_p->size++] = cur_p->data[i];
}
}

return;
}

// combine tokens with common prefix
for (size_t i = 0; i < cur_p->size; ++i) {
for (size_t j = 0; j < cur_p->size; ++j) {
if (cur_p->data[i].logit == -INFINITY) {
break;
}

if (i == j || cur_p->data[j].logit == -INFINITY) {
continue;
}

if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) {
if (cur_p->data[i].p > cur_p->data[j].p) {
cur_p->data[i].p += cur_p->data[j].p;
cur_p->data[j].logit = -INFINITY;
} else {
cur_p->data[j].p += cur_p->data[i].p;
cur_p->data[i].logit = -INFINITY;
}
}
}
}

// mask non-EOG tokens with prob < ctx->p
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].p < ctx->p && !llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) {
cur_p->data[i].logit = -INFINITY;
}
}

// determine the token with max logit
float l_max = -INFINITY;
int i_max = -1;
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].logit > l_max) {
l_max = cur_p->data[i].logit;
i_max = i;
}
}

// if all probs are -INFINITY -> reduce cur_p to single EOG token
if (i_max == -1) {
cur_p->size = 1;
cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
cur_p->data[0].logit = 1.0f;

return;
}

cur_p->size = 1;
cur_p->data[0] = cur_p->data[i_max];

for (size_t i = 0; i < cur_p->size; ++i) {
LLAMA_LOG_DEBUG("after : cur_p[%zu] = { id: %d, p: %f, logit: %f }\n", i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
}
}

static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) {
const auto * ctx = (const llama_sampler_infill *) smpl->ctx;
return llama_sampler_init_infill_impl(*ctx->vocab, ctx->p, ctx->p_eog);
}

static void llama_sampler_infill_free(struct llama_sampler * smpl) {
delete (llama_sampler_infill *) smpl->ctx;
}

static struct llama_sampler_i llama_sampler_infill_i = {
/* .name = */ llama_sampler_infill_name,
/* .accept = */ nullptr,
/* .apply = */ llama_sampler_infill_apply,
/* .reset = */ nullptr,
/* .clone = */ llama_sampler_infill_clone,
/* .free = */ llama_sampler_infill_free,
};

struct llama_sampler * llama_sampler_init_infill_impl(
const struct llama_vocab & vocab,
float p,
float p_eog) {
return new llama_sampler {
/* .iface = */ &llama_sampler_infill_i,
/* .ctx = */ new llama_sampler_infill {
/* .vocab = */ &vocab,
/* .p = */ p,
/* .p_eog = */ p_eog,
},
};
}

// utils

uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
Expand Down
7 changes: 5 additions & 2 deletions src/llama-sampling.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

#include "llama-grammar.h"

#include <unordered_map>

struct llama_vocab;
struct llama_grammar;

Expand All @@ -27,3 +25,8 @@ struct llama_sampler * llama_sampler_init_grammar_impl(
const struct llama_vocab & vocab,
const char * grammar_str,
const char * grammar_root);

struct llama_sampler * llama_sampler_init_infill_impl(
const struct llama_vocab & vocab,
float p,
float p_eog);
17 changes: 17 additions & 0 deletions src/llama-vocab.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1858,6 +1858,23 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token
return 0;
}

bool llama_token_is_prefix_impl(
const struct llama_vocab & vocab,
llama_token token0,
llama_token token1) {
char text_buf_0[128];
char text_buf_1[128];

const int32_t len0 = llama_token_to_piece_impl(vocab, token0, text_buf_0, 128, 0, false);
const int32_t len1 = llama_token_to_piece_impl(vocab, token1, text_buf_1, 128, 0, false);

if (len0 <= 0 || len1 <= 0) {
return false;
}

return len0 < len1 && memcmp(text_buf_0, text_buf_1, len0) == 0;
}

int32_t llama_detokenize_impl(
const struct llama_vocab & vocab,
const llama_token * tokens,
Expand Down
8 changes: 7 additions & 1 deletion src/llama-vocab.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ struct llama_vocab {
id special_cls_id = LLAMA_TOKEN_NULL;
id special_mask_id = LLAMA_TOKEN_NULL;

id linefeed_id = 13;
id linefeed_id = 13;

// fim tokens
id special_fim_pre_id = LLAMA_TOKEN_NULL;
Expand Down Expand Up @@ -149,6 +149,12 @@ int32_t llama_token_to_piece_impl(
int32_t lstrip,
bool special);

// check if token0 is contained as a prefix in token1
bool llama_token_is_prefix_impl(
const struct llama_vocab & vocab,
llama_token token0,
llama_token token1);

int32_t llama_detokenize_impl(
const struct llama_vocab & vocab,
const llama_token * tokens,
Expand Down
11 changes: 11 additions & 0 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21466,6 +21466,13 @@ int32_t llama_token_to_piece(
return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special);
}

bool llama_token_is_prefix(
const struct llama_model * model,
llama_token token0,
llama_token token1) {
return llama_token_is_prefix_impl(model->vocab, token0, token1);
}

int32_t llama_detokenize(
const struct llama_model * model,
const llama_token * tokens,
Expand Down Expand Up @@ -21796,6 +21803,10 @@ struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * mod
return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root);
}

struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model, float p, float p_eog) {
return llama_sampler_init_infill_impl(model->vocab, p, p_eog);
}

//
// model split
//
Expand Down

0 comments on commit 5a2bf08

Please sign in to comment.