diff --git a/common/arg.cpp b/common/arg.cpp index dba0bd14472f8..bd8b1e3a98ad9 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -947,6 +947,20 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, params.sparams.tfs_z = std::stof(value); } ).set_sparam()); + add_opt(llama_arg( + {"--infill-p"}, "N", + string_format("infill p threshold (default: %.1f)", (double)params.sparams.infill_p), + [](gpt_params & params, const std::string & value) { + params.sparams.infill_p = std::stof(value); + } + ).set_sparam()); + add_opt(llama_arg( + {"--infill-p-eog"}, "N", + string_format("infill p_eog threshold (default: %.1f)", (double)params.sparams.infill_p_eog), + [](gpt_params & params, const std::string & value) { + params.sparams.infill_p_eog = std::stof(value); + } + ).set_sparam()); add_opt(llama_arg( {"--typical"}, "N", string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sparams.typ_p), diff --git a/common/common.h b/common/common.h index 6dd1f9d27559b..6bbd5989b585b 100644 --- a/common/common.h +++ b/common/common.h @@ -90,6 +90,7 @@ enum gpt_sampler_type { GPT_SAMPLER_TYPE_TFS_Z = 4, GPT_SAMPLER_TYPE_TYPICAL_P = 5, GPT_SAMPLER_TYPE_TEMPERATURE = 6, + GPT_SAMPLER_TYPE_INFILL = 7, }; // dimensionality reduction methods, used by cvector-generator @@ -113,6 +114,8 @@ struct gpt_sampler_params { float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities float dynatemp_range = 0.00f; // 0.0 = disabled float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler + float infill_p = 0.80f; + float infill_p_eog = 0.01f; int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size) float penalty_repeat = 1.00f; // 1.0 = disabled float penalty_freq = 0.00f; // 0.0 = disabled @@ -130,7 +133,7 @@ struct gpt_sampler_params { GPT_SAMPLER_TYPE_TYPICAL_P, GPT_SAMPLER_TYPE_TOP_P, GPT_SAMPLER_TYPE_MIN_P, - GPT_SAMPLER_TYPE_TEMPERATURE + GPT_SAMPLER_TYPE_TEMPERATURE, }; std::string grammar; // optional BNF-like grammar to constrain sampling diff --git a/common/sampling.cpp b/common/sampling.cpp index 3dc7f112094e6..7163ded0bef84 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -193,6 +193,9 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st case GPT_SAMPLER_TYPE_TEMPERATURE: llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent)); break; + case GPT_SAMPLER_TYPE_INFILL: + llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model, params.infill_p, params.infill_p_eog)); + break; default: GGML_ASSERT(false && "unknown sampler type"); } @@ -372,6 +375,7 @@ char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr) { case GPT_SAMPLER_TYPE_TOP_P: return 'p'; case GPT_SAMPLER_TYPE_MIN_P: return 'm'; case GPT_SAMPLER_TYPE_TEMPERATURE: return 't'; + case GPT_SAMPLER_TYPE_INFILL: return 'i'; default : return '?'; } } @@ -384,6 +388,7 @@ std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr) { case GPT_SAMPLER_TYPE_TOP_P: return "top_p"; case GPT_SAMPLER_TYPE_MIN_P: return "min_p"; case GPT_SAMPLER_TYPE_TEMPERATURE: return "temperature"; + case GPT_SAMPLER_TYPE_INFILL: return "infill"; default : return ""; } } @@ -396,6 +401,7 @@ std::vector gpt_sampler_types_from_names(const std::vector gpt_sampler_types_from_chars(const std::string & c { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TYPICAL_P), GPT_SAMPLER_TYPE_TYPICAL_P }, { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TOP_P), GPT_SAMPLER_TYPE_TOP_P }, { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_MIN_P), GPT_SAMPLER_TYPE_MIN_P }, - { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TEMPERATURE), GPT_SAMPLER_TYPE_TEMPERATURE } + { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TEMPERATURE), GPT_SAMPLER_TYPE_TEMPERATURE }, + { gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_INFILL), GPT_SAMPLER_TYPE_INFILL }, }; std::vector samplers; diff --git a/examples/llama.vim b/examples/llama.vim index 3929797039002..d4d809489f8aa 100644 --- a/examples/llama.vim +++ b/examples/llama.vim @@ -11,14 +11,14 @@ " let s:default_config = { + \ 'endpoint': 'http://127.0.0.1:8012/infill', \ 'prefix_lines': 32, \ 'suffix_lines': 32, - \ 'endpoint': 'http://127.0.0.1:8012/infill', - \ 'stop': ["\n"], - \ 'n_predict': 64, - \ 'n_probs': 3, - \ 'temperature': 0.1 - \} + \ 'n_predict': 64, + \ 'n_probs': 3, + \ 'temperature': 0.1, + \ 'stop': ["\n"] + \ } let g:llama_config = get(g:, 'llama_config', s:default_config) @@ -45,14 +45,16 @@ function! llama#fim() abort \ 'prompt': "", \ 'input_prefix': l:prefix, \ 'input_suffix': l:suffix, - "\ 'stop': g:llama_config.stop, + "\ 'stop': g:llama_config.stop, \ 'n_predict': g:llama_config.n_predict, - "\ 'n_probs': g:llama_config.n_probs, + "\ 'n_probs': g:llama_config.n_probs, \ 'penalty_last_n': 0, \ 'temperature': g:llama_config.temperature, - \ 'top_k': 10, + \ 'top_k': 5, + \ 'infill_p': 0.20, + \ 'infill_p_eog': 0.001, \ 'stream': v:false, - \ 'samplers': ["top_k"] + \ 'samplers': ["top_k", "infill"] \ }) " request completion from the server diff --git a/examples/server/server.cpp b/examples/server/server.cpp index b51cc68bd6e6b..3c5f7e51bda99 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -889,6 +889,8 @@ struct server_context { slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z); slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p); slot.sparams.temp = json_value(data, "temperature", default_sparams.temp); + slot.sparams.infill_p = json_value(data, "infill_p", default_sparams.infill_p); + slot.sparams.infill_p_eog = json_value(data, "infill_p_eog", default_sparams.infill_p_eog); slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range); slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent); slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n); @@ -1236,6 +1238,8 @@ struct server_context { {"min_p", slot.sparams.min_p}, {"tfs_z", slot.sparams.tfs_z}, {"typical_p", slot.sparams.typ_p}, + {"infill_p", slot.sparams.infill_p}, + {"infill_p_eog", slot.sparams.infill_p_eog}, {"repeat_last_n", slot.sparams.penalty_last_n}, {"repeat_penalty", slot.sparams.penalty_repeat}, {"presence_penalty", slot.sparams.penalty_present}, @@ -1964,55 +1968,57 @@ struct server_context { slot.t_start_process_prompt = ggml_time_us(); slot.t_start_generation = 0; - if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_INFILL) { - const bool add_bos = llama_add_bos_token(model); - - auto prefix_tokens = tokenize(slot.params.input_prefix, false, false); - auto suffix_tokens = tokenize(slot.params.input_suffix, false, false); - - prefix_tokens.insert(prefix_tokens.begin(), llama_token_fim_pre(model)); - suffix_tokens.insert(suffix_tokens.begin(), llama_token_fim_suf(model)); - - auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens; - auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens; - - if (add_bos) { - embd_inp.insert(embd_inp.begin(), llama_token_bos(model)); - } - - embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); - - const llama_token middle_token = llama_token_fim_mid(model); - if (middle_token >= 0) { - embd_inp.push_back(middle_token); - } - - prompt_tokens = embd_inp; - } else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) { - // require slot.prompt to be array of 2 strings - if (!slot.prompt.is_array() || slot.prompt.size() != 2) { - SLT_ERR(slot, "%s", "invalid prompt for rerank task\n"); - slot.release(); - send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST); - continue; - } - - // prompt: [BOS]query[EOS][SEP]doc[EOS] - prompt_tokens.clear(); - prompt_tokens.push_back(llama_token_bos(model)); - { - const auto part = tokenize(slot.prompt[0], false, false); - prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end()); - } - prompt_tokens.push_back(llama_token_eos(model)); - prompt_tokens.push_back(llama_token_sep(model)); - { - const auto part = tokenize(slot.prompt[1], false, false); - prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end()); - } - prompt_tokens.push_back(llama_token_eos(model)); - } else { - prompt_tokens = tokenize(slot.prompt, system_prompt.empty(), true); // add BOS if there isn't system prompt + switch (slot.cmpl_type) { + case SERVER_TASK_CMPL_TYPE_NORMAL: + case SERVER_TASK_CMPL_TYPE_EMBEDDING: + { + prompt_tokens = tokenize(slot.prompt, system_prompt.empty(), true); // add BOS if there isn't system prompt + } break; + case SERVER_TASK_CMPL_TYPE_RERANK: + { + // require slot.prompt to be array of 2 strings + if (!slot.prompt.is_array() || slot.prompt.size() != 2) { + SLT_ERR(slot, "%s", "invalid prompt for rerank task\n"); + slot.release(); + send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST); + continue; + } + + // prompt: [BOS]query[EOS][SEP]doc[EOS] + prompt_tokens.clear(); + prompt_tokens.push_back(llama_token_bos(model)); + { + const auto part = tokenize(slot.prompt[0], false, false); + prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end()); + } + prompt_tokens.push_back(llama_token_eos(model)); + prompt_tokens.push_back(llama_token_sep(model)); + { + const auto part = tokenize(slot.prompt[1], false, false); + prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end()); + } + prompt_tokens.push_back(llama_token_eos(model)); + } break; + case SERVER_TASK_CMPL_TYPE_INFILL: + { + auto prefix_tokens = tokenize(slot.params.input_prefix, false, false); + auto suffix_tokens = tokenize(slot.params.input_suffix, false, false); + + prefix_tokens.insert(prefix_tokens.begin(), llama_token_fim_pre(model)); + suffix_tokens.insert(suffix_tokens.begin(), llama_token_fim_suf(model)); + + auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens; + auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens; + + if (llama_add_bos_token(model)) { + embd_inp.insert(embd_inp.begin(), llama_token_bos(model)); + } + + embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); + embd_inp.push_back(llama_token_fim_mid(model)); + + prompt_tokens = std::move(embd_inp); + } break; } slot.n_past = 0; diff --git a/include/llama.h b/include/llama.h index 65d50b41977df..457acc30a2c74 100644 --- a/include/llama.h +++ b/include/llama.h @@ -952,6 +952,12 @@ extern "C" { int32_t lstrip, bool special); + // check if token0 is contained as a prefix in token1 + LLAMA_API bool llama_token_is_prefix( + const struct llama_model * model, + llama_token token0, + llama_token token1); + /// @details Convert the provided tokens into text (inverse of llama_tokenize()). /// @param text The char pointer must be large enough to hold the resulting text. /// @return Returns the number of chars/bytes on success, no more than text_len_max. @@ -1144,6 +1150,26 @@ extern "C" { int32_t n_logit_bias, const llama_logit_bias * logit_bias); + // 1. if there is a high-prob token (>= 0.9f) - pick it + // 2. if sum of EOG probs is larger than p_eog -> mask non-EOG tokens away + // 3. combine probs of tokens that have the same prefix + // + // example: + // + // - before: + // "hel": 0.5 + // "hell": 0.2 + // "hello": 0.1 + // "dummy": 0.1 + // + // - after: + // "hel": 0.8 + // "dummy": 0.1 + // + LLAMA_API struct llama_sampler * llama_sampler_init_infill( + const struct llama_model * model, + float p, + float p_eog); // Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl); diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index e255a8fc4fd54..a61444018c00f 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1644,6 +1644,141 @@ struct llama_sampler * llama_sampler_init_logit_bias( }; } +// infill + +struct llama_sampler_infill { + const struct llama_vocab * vocab; + + const float p; + const float p_eog; +}; + +static const char * llama_sampler_infill_name(const struct llama_sampler * /*smpl*/) { + return "infill"; +} + +static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + auto * ctx = (llama_sampler_infill *) smpl->ctx; + + llama_sampler_softmax_impl(cur_p); + + // print cur_p: + for (size_t i = 0; i < cur_p->size; ++i) { + LLAMA_LOG_DEBUG("infill: cur_p[%zu] = { id: %d, p: %f, logit: %f }\n", i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); + } + + float p_max = 0.0f; + float p_eog_sum = 0.0f; + + for (size_t i = 0; i < cur_p->size; ++i) { + p_max = fmaxf(p_max, cur_p->data[i].p); + if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) { + p_eog_sum += cur_p->data[i].p; + } + } + + if (p_max < 0.90f && p_eog_sum > ctx->p_eog) { + LLAMA_LOG_DEBUG("infill: all EOG tokens are more likely than p_eog (%f), keeping only EOG tokens\n", ctx->p_eog); + + // keep just the EOG tokens + const auto size_org = cur_p->size; + + cur_p->size = 0; + + for (size_t i = 0; i < size_org; ++i) { + if (llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) { + cur_p->data[cur_p->size++] = cur_p->data[i]; + } + } + + return; + } + + // combine tokens with common prefix + for (size_t i = 0; i < cur_p->size; ++i) { + for (size_t j = 0; j < cur_p->size; ++j) { + if (cur_p->data[i].logit == -INFINITY) { + break; + } + + if (i == j || cur_p->data[j].logit == -INFINITY) { + continue; + } + + if (llama_token_is_prefix_impl(*ctx->vocab, cur_p->data[i].id, cur_p->data[j].id)) { + if (cur_p->data[i].p > cur_p->data[j].p) { + cur_p->data[i].p += cur_p->data[j].p; + cur_p->data[j].logit = -INFINITY; + } else { + cur_p->data[j].p += cur_p->data[i].p; + cur_p->data[i].logit = -INFINITY; + } + } + } + } + + // mask non-EOG tokens with prob < ctx->p + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].p < ctx->p && !llama_token_is_eog_impl(*ctx->vocab, cur_p->data[i].id)) { + cur_p->data[i].logit = -INFINITY; + } + } + + // if all probs are -INFINITY -> reduce cur_p to single EOG token + if (std::all_of(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & td) { return td.logit == -INFINITY; })) { + cur_p->size = 1; + cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab); + cur_p->data[0].logit = 1.0f; + } + + // resize + const auto size_org = cur_p->size; + + cur_p->size = 0; + + for (size_t i = 0; i < size_org; ++i) { + if (cur_p->data[i].logit != -INFINITY) { + cur_p->data[cur_p->size++] = cur_p->data[i]; + } + } + + for (size_t i = 0; i < cur_p->size; ++i) { + LLAMA_LOG_DEBUG("after : cur_p[%zu] = { id: %d, p: %f, logit: %f }\n", i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); + } +} + +static struct llama_sampler * llama_sampler_infill_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_infill *) smpl->ctx; + return llama_sampler_init_infill_impl(*ctx->vocab, ctx->p, ctx->p_eog); +} + +static void llama_sampler_infill_free(struct llama_sampler * smpl) { + delete (llama_sampler_infill *) smpl->ctx; +} + +static struct llama_sampler_i llama_sampler_infill_i = { + /* .name = */ llama_sampler_infill_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_infill_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_infill_clone, + /* .free = */ llama_sampler_infill_free, +}; + +struct llama_sampler * llama_sampler_init_infill_impl( + const struct llama_vocab & vocab, + float p, + float p_eog) { + return new llama_sampler { + /* .iface = */ &llama_sampler_infill_i, + /* .ctx = */ new llama_sampler_infill { + /* .vocab = */ &vocab, + /* .p = */ p, + /* .p_eog = */ p_eog, + }, + }; +} + // utils uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) { diff --git a/src/llama-sampling.h b/src/llama-sampling.h index d90b147130e4b..0d78d624db450 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -4,8 +4,6 @@ #include "llama-grammar.h" -#include - struct llama_vocab; struct llama_grammar; @@ -27,3 +25,8 @@ struct llama_sampler * llama_sampler_init_grammar_impl( const struct llama_vocab & vocab, const char * grammar_str, const char * grammar_root); + +struct llama_sampler * llama_sampler_init_infill_impl( + const struct llama_vocab & vocab, + float p, + float p_eog); diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index a27394a377231..367b31bac541d 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1858,6 +1858,23 @@ int32_t llama_token_to_piece_impl(const struct llama_vocab & vocab, llama_token return 0; } +bool llama_token_is_prefix_impl( + const struct llama_vocab & vocab, + llama_token token0, + llama_token token1) { + char text_buf_0[128]; + char text_buf_1[128]; + + const int32_t len0 = llama_token_to_piece_impl(vocab, token0, text_buf_0, 128, 0, false); + const int32_t len1 = llama_token_to_piece_impl(vocab, token1, text_buf_1, 128, 0, false); + + if (len0 <= 0 || len1 <= 0) { + return false; + } + + return len0 < len1 && memcmp(text_buf_0, text_buf_1, len0) == 0; +} + int32_t llama_detokenize_impl( const struct llama_vocab & vocab, const llama_token * tokens, diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 17e14488a4d52..d958d0073be95 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -48,7 +48,7 @@ struct llama_vocab { id special_cls_id = LLAMA_TOKEN_NULL; id special_mask_id = LLAMA_TOKEN_NULL; - id linefeed_id = 13; + id linefeed_id = 13; // fim tokens id special_fim_pre_id = LLAMA_TOKEN_NULL; @@ -149,6 +149,12 @@ int32_t llama_token_to_piece_impl( int32_t lstrip, bool special); +// check if token0 is contained as a prefix in token1 +bool llama_token_is_prefix_impl( + const struct llama_vocab & vocab, + llama_token token0, + llama_token token1); + int32_t llama_detokenize_impl( const struct llama_vocab & vocab, const llama_token * tokens, diff --git a/src/llama.cpp b/src/llama.cpp index 1fad760778b4f..d1ee77b0c7891 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -21478,6 +21478,13 @@ int32_t llama_token_to_piece( return llama_token_to_piece_impl(model->vocab, token, buf, length, lstrip, special); } +bool llama_token_is_prefix( + const struct llama_model * model, + llama_token token0, + llama_token token1) { + return llama_token_is_prefix_impl(model->vocab, token0, token1); +} + int32_t llama_detokenize( const struct llama_model * model, const llama_token * tokens, @@ -21808,6 +21815,10 @@ struct llama_sampler * llama_sampler_init_grammar(const struct llama_model * mod return llama_sampler_init_grammar_impl(model->vocab, grammar_str, grammar_root); } +struct llama_sampler * llama_sampler_init_infill(const struct llama_model * model, float p, float p_eog) { + return llama_sampler_init_infill_impl(model->vocab, p, p_eog); +} + // // model split //