Skip to content

Commit

Permalink
llama.vim : accept/cancel suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Oct 9, 2024
1 parent 474d0e6 commit 73fa77d
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 35 deletions.
164 changes: 140 additions & 24 deletions examples/llama.vim
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,36 @@
"
"augroup llama_cpp
" autocmd!
" autocmd InsertEnter * inoremap <buffer> <silent> <C-F> <Esc>:call llama#fim()<CR>
" autocmd InsertEnter * inoremap <buffer> <silent> <C-F> <Esc>:call llama#fim()<CR>a
"augroup END
"

" color of the suggested text
highlight llama_hint guifg=#ff772f

let s:default_config = {
\ 'endpoint': 'http://127.0.0.1:8012/infill',
\ 'prefix_lines': 32,
\ 'suffix_lines': 32,
\ 'n_predict': 64,
\ 'n_probs': 3,
\ 'temperature': 0.1,
\ 'stop': ["\n"]
\ 'endpoint': 'http://127.0.0.1:8012/infill',
\ 'n_prefix': 32,
\ 'n_suffix': 32,
\ 'n_predict': 64,
\ 'n_probs': 3,
\ 'temperature': 0.1,
\ 'stop': ["\n"]
\ }

let g:llama_config = get(g:, 'llama_config', s:default_config)

function! llama#fim() abort
let l:lines_prefix = getline(max([1, line('.') - g:llama_config.suffix_lines]), line('.') - 1)
let l:lines_suffix = getline(line('.') + 1, min([line('$'), line('.') + g:llama_config.prefix_lines]))
let l:pos_x = col('.')
let l:pos_y = line('.')
let l:max_y = line('$')

let l:cursor_col = col('.')
let l:lines_prefix = getline(max([1, l:pos_y - g:llama_config.n_prefix]), l:pos_y - 1)
let l:lines_suffix = getline(l:pos_y + 1, min([l:max_y, l:pos_y + g:llama_config.n_suffix]))

let l:line_cur = getline('.')
let l:line_cur_prefix = strpart(l:line_cur, 0, l:cursor_col)
let l:line_cur_suffix = strpart(l:line_cur, l:cursor_col)
let l:line_cur_prefix = strpart(l:line_cur, 0, l:pos_x)
let l:line_cur_suffix = strpart(l:line_cur, l:pos_x)

let l:prefix = ""
\ . join(l:lines_prefix, "\n")
Expand All @@ -40,6 +45,7 @@ function! llama#fim() abort
let l:suffix = ""
\ . l:line_cur_suffix
\ . join(l:lines_suffix, "\n")
\ . "\n"

let l:request = json_encode({
\ 'prompt': "",
Expand All @@ -63,21 +69,131 @@ function! llama#fim() abort
\ g:llama_config.endpoint, shellescape(l:request)
\ )

let l:response = json_decode(system(l:curl_command))
let l:can_accept = v:true
let s:content = []

let l:raw = system(l:curl_command)
if l:can_accept && v:shell_error
call add(s:content, "<| curl error: is the server on? |>")
let l:can_accept = v:false
endif

if l:can_accept && l:raw == ""
call add(s:content, "<| empty response: is the server on? |>")
let l:can_accept = v:false
endif

" get the generated suggestion
if l:can_accept
let l:response = json_decode(l:raw)

for l:part in split(get(l:response, 'content', ''), "\n", 1)
call add(s:content, l:part)
endfor

" remove trailing new lines
while len(s:content) > 0 && s:content[-1] == ""
call remove(s:content, -1)
endwhile
endif

if len(s:content) == 0
call add(s:content, "<| nothing to suggest |>")
let l:can_accept = v:false
endif

let s:pos_dx = len(s:content[-1])
let s:content[-1] .= l:line_cur_suffix

" display virtual text with the suggestion
let l:bufnr = bufnr('%')
let s:ns_id = nvim_create_namespace('llama_virtual_text')

call nvim_buf_set_extmark(l:bufnr, s:ns_id, l:pos_y - 1, l:pos_x - 1, {
\ 'virt_text': [[s:content[0], 'llama_hint']],
\ 'virt_text_win_col': virtcol('.')
\ })

call nvim_buf_set_extmark(l:bufnr, s:ns_id, l:pos_y - 1, 0, {
\ 'virt_lines': map(s:content[1:], {idx, val -> [[val, 'llama_hint']]}),
\ 'virt_text_win_col': virtcol('.')
\ })

echom l:response
" accept suggestion with Tab and reject it with any other key
if l:can_accept
inoremap <buffer> <Tab> <C-O>:call llama#accept_virtual_text()<CR>
else
inoremap <buffer> <Tab> <C-O>:call llama#cancel_virtual_text()<CR>
endif

let l:content = []
for l:part in split(get(l:response, 'content', ''), "\n", 1)
call add(l:content, l:part)
for l:key in range(33, 127) + [8, 27]
if l:key != 0x7C
if l:key == 8
execute 'inoremap <buffer> <Bs> <C-O>:call llama#cancel_virtual_text()<CR><Bs>'
elseif l:key == 27
execute 'inoremap <buffer> <Esc> <C-O>:call llama#cancel_virtual_text()<CR><Esc>'
elseif l:key == 127
execute 'inoremap <buffer> <Del> <C-O>:call llama#cancel_virtual_text()<CR><Del>'
else
execute 'inoremap <buffer> ' . nr2char(l:key) . ' <C-O>:call llama#cancel_virtual_text()<CR>' . nr2char(l:key)
endif
endif
endfor

echom l:content
inoremap <buffer> <Up> <C-O>:call llama#cancel_virtual_text()<CR><Up>
inoremap <buffer> <Down> <C-O>:call llama#cancel_virtual_text()<CR><Down>
inoremap <buffer> <Left> <C-O>:call llama#cancel_virtual_text()<CR><Left>
inoremap <buffer> <Right> <C-O>:call llama#cancel_virtual_text()<CR><Right>
endfunction

function! llama#accept_virtual_text()
let l:pos_x = col('.')
let l:pos_y = line('.')

let l:line_cur = getline('.')

let l:pos0 = l:pos_x - 2

if l:pos_x == len(l:line_cur)
let l:pos0 = l:pos_x - 1
endif

" insert the suggestion at the cursor location
call setline(l:pos_y, l:line_cur[:l:pos0] . s:content[0])
if len(s:content) > 1
call append(l:pos_y, s:content[1:-1])
endif

" insert the 'content' at the current cursor location
let l:content[0] = l:line_cur_prefix . l:content[0]
let l:content[-1] .= l:line_cur_suffix
" move the cursor to the end of the accepted text
call cursor(l:pos_y + len(s:content) - 1, l:pos_x + s:pos_dx)

call llama#cancel_virtual_text()
endfunction

function! llama#cancel_virtual_text()
" clear the virtual text
let l:bufnr = bufnr('%')
call nvim_buf_clear_namespace(l:bufnr, s:ns_id, 0, -1)

" remove the mappings
iunmap <buffer> <Tab>

for l:key in range(33, 127) + [8, 27]
if l:key != 0x7C
if l:key == 8
execute 'iunmap <buffer> <Bs>'
elseif l:key == 27
execute 'iunmap <buffer> <Esc>'
elseif l:key == 127
execute 'iunmap <buffer> <Del>'
else
execute 'iunmap <buffer> ' . nr2char(l:key)
endif
endif
endfor

call setline('.', l:content[0])
call append (line('.'), l:content[1:-1])
iunmap <buffer> <Up>
iunmap <buffer> <Down>
iunmap <buffer> <Left>
iunmap <buffer> <Right>
endfunction
26 changes: 15 additions & 11 deletions src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1724,24 +1724,28 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
}
}

// determine the token with max logit
float l_max = -INFINITY;
int i_max = -1;
for (size_t i = 0; i < cur_p->size; ++i) {
if (cur_p->data[i].logit > l_max) {
l_max = cur_p->data[i].logit;
i_max = i;
}
}

// if all probs are -INFINITY -> reduce cur_p to single EOG token
if (std::all_of(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & td) { return td.logit == -INFINITY; })) {
if (i_max == -1) {
cur_p->size = 1;
cur_p->data[0].id = llama_token_eot_impl(*ctx->vocab);
cur_p->data[0].logit = 1.0f;
}

// resize
const auto size_org = cur_p->size;

cur_p->size = 0;

for (size_t i = 0; i < size_org; ++i) {
if (cur_p->data[i].logit != -INFINITY) {
cur_p->data[cur_p->size++] = cur_p->data[i];
}
return;
}

cur_p->size = 1;
cur_p->data[0] = cur_p->data[i_max];

for (size_t i = 0; i < cur_p->size; ++i) {
LLAMA_LOG_DEBUG("after : cur_p[%zu] = { id: %d, p: %f, logit: %f }\n", i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
}
Expand Down

0 comments on commit 73fa77d

Please sign in to comment.