Skip to content

Commit

Permalink
llama : add infill sampler
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Oct 8, 2024
1 parent 25f3b4d commit 474d0e6
Show file tree
Hide file tree
Showing 11 changed files with 294 additions and 64 deletions.
14 changes: 14 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -947,6 +947,20 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
params.sparams.tfs_z = std::stof(value);
}
).set_sparam());
add_opt(llama_arg(
{"--infill-p"}, "N",
string_format("infill p threshold (default: %.1f)", (double)params.sparams.infill_p),
[](gpt_params & params, const std::string & value) {
params.sparams.infill_p = std::stof(value);
}
).set_sparam());
add_opt(llama_arg(
{"--infill-p-eog"}, "N",
string_format("infill p_eog threshold (default: %.1f)", (double)params.sparams.infill_p_eog),
[](gpt_params & params, const std::string & value) {
params.sparams.infill_p_eog = std::stof(value);
}
).set_sparam());
add_opt(llama_arg(
{"--typical"}, "N",
string_format("locally typical sampling, parameter p (default: %.1f, 1.0 = disabled)", (double)params.sparams.typ_p),
Expand Down
5 changes: 4 additions & 1 deletion common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ enum gpt_sampler_type {
GPT_SAMPLER_TYPE_TFS_Z = 4,
GPT_SAMPLER_TYPE_TYPICAL_P = 5,
GPT_SAMPLER_TYPE_TEMPERATURE = 6,
GPT_SAMPLER_TYPE_INFILL = 7,
};

// dimensionality reduction methods, used by cvector-generator
Expand All @@ -113,6 +114,8 @@ struct gpt_sampler_params {
float temp = 0.80f; // <= 0.0 to sample greedily, 0.0 to not output probabilities
float dynatemp_range = 0.00f; // 0.0 = disabled
float dynatemp_exponent = 1.00f; // controls how entropy maps to temperature in dynamic temperature sampler
float infill_p = 0.80f;
float infill_p_eog = 0.01f;
int32_t penalty_last_n = 64; // last n tokens to penalize (0 = disable penalty, -1 = context size)
float penalty_repeat = 1.00f; // 1.0 = disabled
float penalty_freq = 0.00f; // 0.0 = disabled
Expand All @@ -130,7 +133,7 @@ struct gpt_sampler_params {
GPT_SAMPLER_TYPE_TYPICAL_P,
GPT_SAMPLER_TYPE_TOP_P,
GPT_SAMPLER_TYPE_MIN_P,
GPT_SAMPLER_TYPE_TEMPERATURE
GPT_SAMPLER_TYPE_TEMPERATURE,
};

std::string grammar; // optional BNF-like grammar to constrain sampling
Expand Down
9 changes: 8 additions & 1 deletion common/sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,9 @@ struct gpt_sampler * gpt_sampler_init(const struct llama_model * model, const st
case GPT_SAMPLER_TYPE_TEMPERATURE:
llama_sampler_chain_add(result->chain, llama_sampler_init_temp_ext (params.temp, params.dynatemp_range, params.dynatemp_exponent));
break;
case GPT_SAMPLER_TYPE_INFILL:
llama_sampler_chain_add(result->chain, llama_sampler_init_infill (model, params.infill_p, params.infill_p_eog));
break;
default:
GGML_ASSERT(false && "unknown sampler type");
}
Expand Down Expand Up @@ -372,6 +375,7 @@ char gpt_sampler_type_to_chr(enum gpt_sampler_type cnstr) {
case GPT_SAMPLER_TYPE_TOP_P: return 'p';
case GPT_SAMPLER_TYPE_MIN_P: return 'm';
case GPT_SAMPLER_TYPE_TEMPERATURE: return 't';
case GPT_SAMPLER_TYPE_INFILL: return 'i';
default : return '?';
}
}
Expand All @@ -384,6 +388,7 @@ std::string gpt_sampler_type_to_str(enum gpt_sampler_type cnstr) {
case GPT_SAMPLER_TYPE_TOP_P: return "top_p";
case GPT_SAMPLER_TYPE_MIN_P: return "min_p";
case GPT_SAMPLER_TYPE_TEMPERATURE: return "temperature";
case GPT_SAMPLER_TYPE_INFILL: return "infill";
default : return "";
}
}
Expand All @@ -396,6 +401,7 @@ std::vector<gpt_sampler_type> gpt_sampler_types_from_names(const std::vector<std
{ "min_p", GPT_SAMPLER_TYPE_MIN_P },
{ "tfs_z", GPT_SAMPLER_TYPE_TFS_Z },
{ "temperature", GPT_SAMPLER_TYPE_TEMPERATURE },
{ "infill", GPT_SAMPLER_TYPE_INFILL }
};

// since samplers names are written multiple ways
Expand Down Expand Up @@ -441,7 +447,8 @@ std::vector<gpt_sampler_type> gpt_sampler_types_from_chars(const std::string & c
{ gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TYPICAL_P), GPT_SAMPLER_TYPE_TYPICAL_P },
{ gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TOP_P), GPT_SAMPLER_TYPE_TOP_P },
{ gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_MIN_P), GPT_SAMPLER_TYPE_MIN_P },
{ gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TEMPERATURE), GPT_SAMPLER_TYPE_TEMPERATURE }
{ gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_TEMPERATURE), GPT_SAMPLER_TYPE_TEMPERATURE },
{ gpt_sampler_type_to_chr(GPT_SAMPLER_TYPE_INFILL), GPT_SAMPLER_TYPE_INFILL },
};

std::vector<gpt_sampler_type> samplers;
Expand Down
22 changes: 12 additions & 10 deletions examples/llama.vim
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
"

let s:default_config = {
\ 'endpoint': 'http://127.0.0.1:8012/infill',
\ 'prefix_lines': 32,
\ 'suffix_lines': 32,
\ 'endpoint': 'http://127.0.0.1:8012/infill',
\ 'stop': ["\n"],
\ 'n_predict': 64,
\ 'n_probs': 3,
\ 'temperature': 0.1
\}
\ 'n_predict': 64,
\ 'n_probs': 3,
\ 'temperature': 0.1,
\ 'stop': ["\n"]
\ }

let g:llama_config = get(g:, 'llama_config', s:default_config)

Expand All @@ -45,14 +45,16 @@ function! llama#fim() abort
\ 'prompt': "",
\ 'input_prefix': l:prefix,
\ 'input_suffix': l:suffix,
"\ 'stop': g:llama_config.stop,
"\ 'stop': g:llama_config.stop,
\ 'n_predict': g:llama_config.n_predict,
"\ 'n_probs': g:llama_config.n_probs,
"\ 'n_probs': g:llama_config.n_probs,
\ 'penalty_last_n': 0,
\ 'temperature': g:llama_config.temperature,
\ 'top_k': 10,
\ 'top_k': 5,
\ 'infill_p': 0.20,
\ 'infill_p_eog': 0.001,
\ 'stream': v:false,
\ 'samplers': ["top_k"]
\ 'samplers': ["top_k", "infill"]
\ })

" request completion from the server
Expand Down
104 changes: 55 additions & 49 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,8 @@ struct server_context {
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
slot.sparams.infill_p = json_value(data, "infill_p", default_sparams.infill_p);
slot.sparams.infill_p_eog = json_value(data, "infill_p_eog", default_sparams.infill_p_eog);
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
slot.sparams.penalty_last_n = json_value(data, "repeat_last_n", default_sparams.penalty_last_n);
Expand Down Expand Up @@ -1236,6 +1238,8 @@ struct server_context {
{"min_p", slot.sparams.min_p},
{"tfs_z", slot.sparams.tfs_z},
{"typical_p", slot.sparams.typ_p},
{"infill_p", slot.sparams.infill_p},
{"infill_p_eog", slot.sparams.infill_p_eog},
{"repeat_last_n", slot.sparams.penalty_last_n},
{"repeat_penalty", slot.sparams.penalty_repeat},
{"presence_penalty", slot.sparams.penalty_present},
Expand Down Expand Up @@ -1964,55 +1968,57 @@ struct server_context {
slot.t_start_process_prompt = ggml_time_us();
slot.t_start_generation = 0;

if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_INFILL) {
const bool add_bos = llama_add_bos_token(model);

auto prefix_tokens = tokenize(slot.params.input_prefix, false, false);
auto suffix_tokens = tokenize(slot.params.input_suffix, false, false);

prefix_tokens.insert(prefix_tokens.begin(), llama_token_fim_pre(model));
suffix_tokens.insert(suffix_tokens.begin(), llama_token_fim_suf(model));

auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;

if (add_bos) {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
}

embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());

const llama_token middle_token = llama_token_fim_mid(model);
if (middle_token >= 0) {
embd_inp.push_back(middle_token);
}

prompt_tokens = embd_inp;
} else if (slot.cmpl_type == SERVER_TASK_CMPL_TYPE_RERANK) {
// require slot.prompt to be array of 2 strings
if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
slot.release();
send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
continue;
}

// prompt: [BOS]query[EOS][SEP]doc[EOS]
prompt_tokens.clear();
prompt_tokens.push_back(llama_token_bos(model));
{
const auto part = tokenize(slot.prompt[0], false, false);
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
}
prompt_tokens.push_back(llama_token_eos(model));
prompt_tokens.push_back(llama_token_sep(model));
{
const auto part = tokenize(slot.prompt[1], false, false);
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
}
prompt_tokens.push_back(llama_token_eos(model));
} else {
prompt_tokens = tokenize(slot.prompt, system_prompt.empty(), true); // add BOS if there isn't system prompt
switch (slot.cmpl_type) {
case SERVER_TASK_CMPL_TYPE_NORMAL:
case SERVER_TASK_CMPL_TYPE_EMBEDDING:
{
prompt_tokens = tokenize(slot.prompt, system_prompt.empty(), true); // add BOS if there isn't system prompt
} break;
case SERVER_TASK_CMPL_TYPE_RERANK:
{
// require slot.prompt to be array of 2 strings
if (!slot.prompt.is_array() || slot.prompt.size() != 2) {
SLT_ERR(slot, "%s", "invalid prompt for rerank task\n");
slot.release();
send_error(slot, "invalid prompt for rerank task", ERROR_TYPE_INVALID_REQUEST);
continue;
}

// prompt: [BOS]query[EOS][SEP]doc[EOS]
prompt_tokens.clear();
prompt_tokens.push_back(llama_token_bos(model));
{
const auto part = tokenize(slot.prompt[0], false, false);
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
}
prompt_tokens.push_back(llama_token_eos(model));
prompt_tokens.push_back(llama_token_sep(model));
{
const auto part = tokenize(slot.prompt[1], false, false);
prompt_tokens.insert(prompt_tokens.end(), part.begin(), part.end());
}
prompt_tokens.push_back(llama_token_eos(model));
} break;
case SERVER_TASK_CMPL_TYPE_INFILL:
{
auto prefix_tokens = tokenize(slot.params.input_prefix, false, false);
auto suffix_tokens = tokenize(slot.params.input_suffix, false, false);

prefix_tokens.insert(prefix_tokens.begin(), llama_token_fim_pre(model));
suffix_tokens.insert(suffix_tokens.begin(), llama_token_fim_suf(model));

auto embd_inp = params.spm_infill ? suffix_tokens : prefix_tokens;
auto embd_end = params.spm_infill ? prefix_tokens : suffix_tokens;

if (llama_add_bos_token(model)) {
embd_inp.insert(embd_inp.begin(), llama_token_bos(model));
}

embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end());
embd_inp.push_back(llama_token_fim_mid(model));

prompt_tokens = std::move(embd_inp);
} break;
}

slot.n_past = 0;
Expand Down
26 changes: 26 additions & 0 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,12 @@ extern "C" {
int32_t lstrip,
bool special);

// check if token0 is contained as a prefix in token1
LLAMA_API bool llama_token_is_prefix(
const struct llama_model * model,
llama_token token0,
llama_token token1);

/// @details Convert the provided tokens into text (inverse of llama_tokenize()).
/// @param text The char pointer must be large enough to hold the resulting text.
/// @return Returns the number of chars/bytes on success, no more than text_len_max.
Expand Down Expand Up @@ -1144,6 +1150,26 @@ extern "C" {
int32_t n_logit_bias,
const llama_logit_bias * logit_bias);

// 1. if there is a high-prob token (>= 0.9f) - pick it
// 2. if sum of EOG probs is larger than p_eog -> mask non-EOG tokens away
// 3. combine probs of tokens that have the same prefix
//
// example:
//
// - before:
// "hel": 0.5
// "hell": 0.2
// "hello": 0.1
// "dummy": 0.1
//
// - after:
// "hel": 0.8
// "dummy": 0.1
//
LLAMA_API struct llama_sampler * llama_sampler_init_infill(
const struct llama_model * model,
float p,
float p_eog);

// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
Expand Down
Loading

0 comments on commit 474d0e6

Please sign in to comment.