diff --git a/common/arg.cpp b/common/arg.cpp index dba0bd14472f8..bd8b1e3a98ad9 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -947,6 +947,20 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, params.sparams.tfs_z = std::stof(value); } ).set_sparam()); + add_opt(llama_arg( + {"--infill-p"}, "N", + string_format("infill p threshold (default: %.1f)", (double)params.sparams.infill_p), + [](gpt_params & params, const std::string & value) { + params.sparams.infill_p = std::stof(value); + } + ).set_sparam()); + add_opt(llama_arg( + {"--infill-p-eog"}, "N", + string_format("infill p_eog threshold (default: %.1f)", (double)params.sparams.infill_p_eog), + [](gpt_params & params, const std::string & value) { + params.sparams.infill_p_eog = std::stof(value); + } + ).set_sparam()); add_opt(llama_arg( {"--typical"}, "N", string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sparams.typ_p), diff --git a/common/common.h b/common/common.h index 6dd1f9d27559b..6bbd5989b585b 100644 --- a/common/common.h +++ b/common/common.h @@ -90,6 +90,7 @@ enum gpt_sampler_type { GPT_SAMPLER_TYPE_TFS_Z = 4, GPT_SAMPLER_TYPE_TYPICAL_P = 5, GPT_SAMPLER_TYPE_TEMPERATURE = 6, + GPT_SAMPLER_TYPE_INFILL = 7, }; // dimensionality reduction methods, used by cvector-generator @@ -113,6 +114,8 @@ struct gpt_sampler_params { float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities float dynatemp_range = 0.00f; // 0.0 = disabled float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler + float infill_p = 0.80f; + float infill_p_eog = 0.01f; int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) float penalty_repeat = 1.00f; // 1.0 = disabled float penalty_freq = 0.00f; // 0.0 = disabled @@ -130,7 +133,7 @@ struct gpt_sampler_params { GPT_SAMPLER_TYPE_TYPICAL_P, GPT_SAMPLER_TYPE_TOP_P, GPT_SAMPLER_TYPE_MIN_P, - GPT_SAMPLER_TYPE_TEMPERATURE + GPT_SAMPLER_TYPE_TEMPERATURE, }; std::string grammar; // optional BNF-like grammar to constrain sampling diff --git a/common/sampling.cpp b/common/sampling.cpp index 3dc7f112094e6..7163ded0bef84 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -193,6 +193,9 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st case GPT_SAMPLER_TYPE_TEMPERATURE: llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); break; + case GPT_SAMPLER_TYPE_INFILL: + llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model, params.infill_p, params.infill_p_eog)); + break; default: GGML_ASSERT(false && "unknown sampler type"); } @@ -372,6 +375,7 @@ char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr) { case GPT_SAMPLER_TYPE_TOP_P: return 'p'; case GPT_SAMPLER_TYPE_MIN_P: return 'm'; case GPT_SAMPLER_TYPE_TEMPERATURE: return 't'; + case GPT_SAMPLER_TYPE_INFILL: return 'i'; default : return '?'; } } @@ -384,6 +388,7 @@ std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr) { case GPT_SAMPLER_TYPE_TOP_P: return "top_p"; case GPT_SAMPLER_TYPE_MIN_P: return "min_p"; case GPT_SAMPLER_TYPE_TEMPERATURE: return "temperature"; + case GPT_SAMPLER_TYPE_INFILL: return "infill"; default : return ""; } } @@ -396,6 +401,7 @@ std::vector gpt_sampler_types_from_names(const std::vector gpt_sampler_types_from_chars(const std::string & c { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TYPICAL_P), GPT_SAMPLER_TYPE_TYPICAL_P }, { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TOP_P), GPT_SAMPLER_TYPE_TOP_P }, { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_MIN_P), GPT_SAMPLER_TYPE_MIN_P }, - { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TEMPERATURE), GPT_SAMPLER_TYPE_TEMPERATURE } + { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TEMPERATURE), GPT_SAMPLER_TYPE_TEMPERATURE }, + { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_INFILL), GPT_SAMPLER_TYPE_INFILL }, }; std::vector samplers; diff --git a/examples/server/server.cpp b/examples/server/server.cpp index b3773d256ef43..3c5f7e51bda99 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -889,6 +889,8 @@ struct server_context { slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p); slot.sparams.temp = json_value(data, "temperature", default_sparams.temp); + slot.sparams.infill_p = json_value(data, "infill_p", default_sparams.infill_p); + slot.sparams.infill_p_eog = json_value(data, "infill_p_eog", default_sparams.infill_p_eog); slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range); slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); @@ -1236,6 +1238,8 @@ struct server_context { {"min_p", slot.sparams.min_p}, {"tfs_z", slot.sparams.tfs_z}, {"typical_p", slot.sparams.typ_p}, + {"infill_p", slot.sparams.infill_p}, + {"infill_p_eog", slot.sparams.infill_p_eog}, {"repeat_last_n", slot.sparams.penalty_last_n}, {"repeat_penalty", slot.sparams.penalty_repeat}, {"presence_penalty", slot.sparams.penalty_present}, diff --git a/include/llama.h b/include/llama.h index 65d50b41977df..457acc30a2c74 100644 --- a/include/llama.h +++ b/include/llama.h @@ -952,6 +952,12 @@ extern "C" { int32_t lstrip, bool special); + // check if token0 is contained as a prefix in token1 + LLAMA_API bool llama_token_is_prefix( + const struct llama_model * model, + llama_token token0, + llama_token token1); + /// @details Convert the provided tokens into text (inverse of llama_tokenize()). /// @param text The char pointer must be large enough to hold the resulting text. /// @return Returns the number of chars/bytes on success, no more than text_len_max. @@ -1144,6 +1150,26 @@ extern "C" { int32_t n_logit_bias, const llama_logit_bias * logit_bias); + // 1. if there is a high-prob token (>= 0.9f) - pick it + // 2. if sum of EOG probs is larger than p_eog -> mask non-EOG tokens away + // 3. combine probs of tokens that have the same prefix + // + // example: + // + // - before: + // "hel": 0.5 + // "hell": 0.2 + // "hello": 0.1 + // "dummy": 0.1 + // + // - after: + // "hel": 0.8 + // "dummy": 0.1 + // + LLAMA_API struct llama_sampler * llama_sampler_init_infill( + const struct llama_model * model, + float p, + float p_eog); // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index e255a8fc4fd54..fbb3997e9c7e6 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1644,6 +1644,145 @@ struct llama_sampler * llama_sampler_init_logit_bias( }; } +// infill + +struct llama_sampler_infill { + const struct llama_vocab * vocab; + + const float p; + const float p_eog; +}; + +static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) { + return "infill"; +} + +static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_infill *) smpl->ctx; + + llama_sampler_softmax_impl(cur_p); + + // print cur_p: + for (size_t i = 0; i < cur_p->size; ++i) { + LLAMA_LOG_DEBUG("infill: cur_p[%zu] = { id: %d, p: %f, logit: %f }\n", i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); + } + + float p_max = 0.0f; + float p_eog_sum = 0.0f; + + for (size_t i = 0; i < cur_p->size; ++i) { + p_max = fmaxf(p_max, cur_p->data[i].p); + if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) { + p_eog_sum += cur_p->data[i].p; + } + } + + if (p_max < 0.90f && p_eog_sum > ctx->p_eog) { + LLAMA_LOG_DEBUG("infill: all EOG tokens are more likely than p_eog (%f), keeping only EOG tokens\n", ctx->p_eog); + + // keep just the EOG tokens + const auto size_org = cur_p->size; + + cur_p->size = 0; + + for (size_t i = 0; i < size_org; ++i) { + if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) { + cur_p->data[cur_p->size++] = cur_p->data[i]; + } + } + + return; + } + + // combine tokens with common prefix + for (size_t i = 0; i < cur_p->size; ++i) { + for (size_t j = 0; j < cur_p->size; ++j) { + if (cur_p->data[i].logit == -INFINITY) { + break; + } + + if (i == j || cur_p->data[j].logit == -INFINITY) { + continue; + } + + if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) { + if (cur_p->data[i].p > cur_p->data[j].p) { + cur_p->data[i].p += cur_p->data[j].p; + cur_p->data[j].logit = -INFINITY; + } else { + cur_p->data[j].p += cur_p->data[i].p; + cur_p->data[i].logit = -INFINITY; + } + } + } + } + + // mask non-EOG tokens with prob < ctx->p + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].p < ctx->p && !llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) { + cur_p->data[i].logit = -INFINITY; + } + } + + // determine the token with max logit + float l_max = -INFINITY; + int i_max = -1; + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].logit > l_max) { + l_max = cur_p->data[i].logit; + i_max = i; + } + } + + // if all probs are -INFINITY -> reduce cur_p to single EOG token + if (i_max == -1) { + cur_p->size = 1; + cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab); + cur_p->data[0].logit = 1.0f; + + return; + } + + cur_p->size = 1; + cur_p->data[0] = cur_p->data[i_max]; + + for (size_t i = 0; i < cur_p->size; ++i) { + LLAMA_LOG_DEBUG("after : cur_p[%zu] = { id: %d, p: %f, logit: %f }\n", i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); + } +} + +static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_infill *) smpl->ctx; + return llama_sampler_init_infill_impl(*ctx->vocab, ctx->p, ctx->p_eog); +} + +static void llama_sampler_infill_free(struct llama_sampler * smpl) { + delete (llama_sampler_infill *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_infill_i = { + /* .name = */ llama_sampler_infill_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_infill_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_infill_clone, + /* .free = */ llama_sampler_infill_free, +}; + +struct llama_sampler * llama_sampler_init_infill_impl( + const struct llama_vocab & vocab, + float p, + float p_eog) { + return new llama_sampler { + /* .iface = */ &llama_sampler_infill_i, + /* .ctx = */ new llama_sampler_infill { + /* .vocab = */ &vocab, + /* .p = */ p, + /* .p_eog = */ p_eog, + }, + }; +} + // utils uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) { diff --git a/src/llama-sampling.h b/src/llama-sampling.h index d90b147130e4b..0d78d624db450 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -4,8 +4,6 @@ #include "llama-grammar.h" -#include - struct llama_vocab; struct llama_grammar; @@ -27,3 +25,8 @@ struct llama_sampler * llama_sampler_init_grammar_impl( const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root); + +struct llama_sampler * llama_sampler_init_infill_impl( + const struct llama_vocab & vocab, + float p, + float p_eog); diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index a27394a377231..367b31bac541d 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1858,6 +1858,23 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token return 0; } +bool llama_token_is_prefix_impl( + const struct llama_vocab & vocab, + llama_token token0, + llama_token token1) { + char text_buf_0[128]; + char text_buf_1[128]; + + const int32_t len0 = llama_token_to_piece_impl(vocab, token0, text_buf_0, 128, 0, false); + const int32_t len1 = llama_token_to_piece_impl(vocab, token1, text_buf_1, 128, 0, false); + + if (len0 <= 0 || len1 <= 0) { + return false; + } + + return len0 < len1 && memcmp(text_buf_0, text_buf_1, len0) == 0; +} + int32_t llama_detokenize_impl( const struct llama_vocab & vocab, const llama_token * tokens, diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 17e14488a4d52..d958d0073be95 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -48,7 +48,7 @@ struct llama_vocab { id special_cls_id = LLAMA_TOKEN_NULL; id special_mask_id = LLAMA_TOKEN_NULL; - id linefeed_id = 13; + id linefeed_id = 13; // fim tokens id special_fim_pre_id = LLAMA_TOKEN_NULL; @@ -149,6 +149,12 @@ int32_t llama_token_to_piece_impl( int32_t lstrip, bool special); +// check if token0 is contained as a prefix in token1 +bool llama_token_is_prefix_impl( + const struct llama_vocab & vocab, + llama_token token0, + llama_token token1); + int32_t llama_detokenize_impl( const struct llama_vocab & vocab, const llama_token * tokens, diff --git a/src/llama.cpp b/src/llama.cpp index 78f0670f5a36f..a14431ec655cb 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -21466,6 +21466,13 @@ int32_t llama_token_to_piece( return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special); } +bool llama_token_is_prefix( + const struct llama_model * model, + llama_token token0, + llama_token token1) { + return llama_token_is_prefix_impl(model->vocab, token0, token1); +} + int32_t llama_detokenize( const struct llama_model * model, const llama_token * tokens, @@ -21796,6 +21803,10 @@ struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * mod return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root); } +struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model, float p, float p_eog) { + return llama_sampler_init_infill_impl(model->vocab, p, p_eog); +} + // // model split //