From 73fa77db21b35c8d2c88e11fd56cc450006d95af Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 8 Oct 2024 21:28:27 +0300 Subject: [PATCH] llama.vim : accept/cancel suggestions --- examples/llama.vim | 164 +++++++++++++++++++++++++++++++++++------ src/llama-sampling.cpp | 26 ++++--- 2 files changed, 155 insertions(+), 35 deletions(-) diff --git a/examples/llama.vim b/examples/llama.vim index d4d809489f8aa..30a717181062b 100644 --- a/examples/llama.vim +++ b/examples/llama.vim @@ -6,31 +6,36 @@ " "augroup llama_cpp " autocmd! -" autocmd InsertEnter * inoremap :call llama#fim() +" autocmd InsertEnter * inoremap :call llama#fim()a "augroup END " +" color of the suggested text +highlight llama_hint guifg=#ff772f + let s:default_config = { - \ 'endpoint': 'http://127.0.0.1:8012/infill', - \ 'prefix_lines': 32, - \ 'suffix_lines': 32, - \ 'n_predict': 64, - \ 'n_probs': 3, - \ 'temperature': 0.1, - \ 'stop': ["\n"] + \ 'endpoint': 'http://127.0.0.1:8012/infill', + \ 'n_prefix': 32, + \ 'n_suffix': 32, + \ 'n_predict': 64, + \ 'n_probs': 3, + \ 'temperature': 0.1, + \ 'stop': ["\n"] \ } let g:llama_config = get(g:, 'llama_config', s:default_config) function! llama#fim() abort - let l:lines_prefix = getline(max([1, line('.') - g:llama_config.suffix_lines]), line('.') - 1) - let l:lines_suffix = getline(line('.') + 1, min([line('$'), line('.') + g:llama_config.prefix_lines])) + let l:pos_x = col('.') + let l:pos_y = line('.') + let l:max_y = line('$') - let l:cursor_col = col('.') + let l:lines_prefix = getline(max([1, l:pos_y - g:llama_config.n_prefix]), l:pos_y - 1) + let l:lines_suffix = getline(l:pos_y + 1, min([l:max_y, l:pos_y + g:llama_config.n_suffix])) let l:line_cur = getline('.') - let l:line_cur_prefix = strpart(l:line_cur, 0, l:cursor_col) - let l:line_cur_suffix = strpart(l:line_cur, l:cursor_col) + let l:line_cur_prefix = strpart(l:line_cur, 0, l:pos_x) + let l:line_cur_suffix = strpart(l:line_cur, l:pos_x) let l:prefix = "" \ . join(l:lines_prefix, "\n") @@ -40,6 +45,7 @@ function! llama#fim() abort let l:suffix = "" \ . l:line_cur_suffix \ . join(l:lines_suffix, "\n") + \ . "\n" let l:request = json_encode({ \ 'prompt': "", @@ -63,21 +69,131 @@ function! llama#fim() abort \ g:llama_config.endpoint, shellescape(l:request) \ ) - let l:response = json_decode(system(l:curl_command)) + let l:can_accept = v:true + let s:content = [] + + let l:raw = system(l:curl_command) + if l:can_accept && v:shell_error + call add(s:content, "<| curl error: is the server on? |>") + let l:can_accept = v:false + endif + + if l:can_accept && l:raw == "" + call add(s:content, "<| empty response: is the server on? |>") + let l:can_accept = v:false + endif + + " get the generated suggestion + if l:can_accept + let l:response = json_decode(l:raw) + + for l:part in split(get(l:response, 'content', ''), "\n", 1) + call add(s:content, l:part) + endfor + + " remove trailing new lines + while len(s:content) > 0 && s:content[-1] == "" + call remove(s:content, -1) + endwhile + endif + + if len(s:content) == 0 + call add(s:content, "<| nothing to suggest |>") + let l:can_accept = v:false + endif + + let s:pos_dx = len(s:content[-1]) + let s:content[-1] .= l:line_cur_suffix + + " display virtual text with the suggestion + let l:bufnr = bufnr('%') + let s:ns_id = nvim_create_namespace('llama_virtual_text') + + call nvim_buf_set_extmark(l:bufnr, s:ns_id, l:pos_y - 1, l:pos_x - 1, { + \ 'virt_text': [[s:content[0], 'llama_hint']], + \ 'virt_text_win_col': virtcol('.') + \ }) + + call nvim_buf_set_extmark(l:bufnr, s:ns_id, l:pos_y - 1, 0, { + \ 'virt_lines': map(s:content[1:], {idx, val -> [[val, 'llama_hint']]}), + \ 'virt_text_win_col': virtcol('.') + \ }) - echom l:response + " accept suggestion with Tab and reject it with any other key + if l:can_accept + inoremap :call llama#accept_virtual_text() + else + inoremap :call llama#cancel_virtual_text() + endif - let l:content = [] - for l:part in split(get(l:response, 'content', ''), "\n", 1) - call add(l:content, l:part) + for l:key in range(33, 127) + [8, 27] + if l:key != 0x7C + if l:key == 8 + execute 'inoremap :call llama#cancel_virtual_text()' + elseif l:key == 27 + execute 'inoremap :call llama#cancel_virtual_text()' + elseif l:key == 127 + execute 'inoremap :call llama#cancel_virtual_text()' + else + execute 'inoremap ' . nr2char(l:key) . ' :call llama#cancel_virtual_text()' . nr2char(l:key) + endif + endif endfor - echom l:content + inoremap :call llama#cancel_virtual_text() + inoremap :call llama#cancel_virtual_text() + inoremap :call llama#cancel_virtual_text() + inoremap :call llama#cancel_virtual_text() +endfunction + +function! llama#accept_virtual_text() + let l:pos_x = col('.') + let l:pos_y = line('.') + + let l:line_cur = getline('.') + + let l:pos0 = l:pos_x - 2 + + if l:pos_x == len(l:line_cur) + let l:pos0 = l:pos_x - 1 + endif + + " insert the suggestion at the cursor location + call setline(l:pos_y, l:line_cur[:l:pos0] . s:content[0]) + if len(s:content) > 1 + call append(l:pos_y, s:content[1:-1]) + endif - " insert the 'content' at the current cursor location - let l:content[0] = l:line_cur_prefix . l:content[0] - let l:content[-1] .= l:line_cur_suffix + " move the cursor to the end of the accepted text + call cursor(l:pos_y + len(s:content) - 1, l:pos_x + s:pos_dx) + + call llama#cancel_virtual_text() +endfunction + +function! llama#cancel_virtual_text() + " clear the virtual text + let l:bufnr = bufnr('%') + call nvim_buf_clear_namespace(l:bufnr, s:ns_id, 0, -1) + + " remove the mappings + iunmap + + for l:key in range(33, 127) + [8, 27] + if l:key != 0x7C + if l:key == 8 + execute 'iunmap ' + elseif l:key == 27 + execute 'iunmap ' + elseif l:key == 127 + execute 'iunmap ' + else + execute 'iunmap ' . nr2char(l:key) + endif + endif + endfor - call setline('.', l:content[0]) - call append (line('.'), l:content[1:-1]) + iunmap + iunmap + iunmap + iunmap endfunction diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index a61444018c00f..fbb3997e9c7e6 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1724,24 +1724,28 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_ } } + // determine the token with max logit + float l_max = -INFINITY; + int i_max = -1; + for (size_t i = 0; i < cur_p->size; ++i) { + if (cur_p->data[i].logit > l_max) { + l_max = cur_p->data[i].logit; + i_max = i; + } + } + // if all probs are -INFINITY -> reduce cur_p to single EOG token - if (std::all_of(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & td) { return td.logit == -INFINITY; })) { + if (i_max == -1) { cur_p->size = 1; cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab); cur_p->data[0].logit = 1.0f; - } - // resize - const auto size_org = cur_p->size; - - cur_p->size = 0; - - for (size_t i = 0; i < size_org; ++i) { - if (cur_p->data[i].logit != -INFINITY) { - cur_p->data[cur_p->size++] = cur_p->data[i]; - } + return; } + cur_p->size = 1; + cur_p->data[0] = cur_p->data[i_max]; + for (size_t i = 0; i < cur_p->size; ++i) { LLAMA_LOG_DEBUG("after : cur_p[%zu] = { id: %d, p: %f, logit: %f }\n", i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit); }