Skip to content

Commit

Permalink
llama : clean-up
Browse files Browse the repository at this point in the history
ggml-ci
  • Loading branch information
ggerganov committed Oct 11, 2024
1 parent 6384002 commit cefd4ac
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 35 deletions.
111 changes: 84 additions & 27 deletions examples/llama.vim
Original file line number Diff line number Diff line change
@@ -1,31 +1,72 @@
" LLM-based text completion using llama.cpp
"
" requires:
"
" - neovim
" - curl
" - llama.cpp server instance
" - FIM-compatible model
"
" sample config:
"
" - Ctrl+F - trigger FIM completion manually
" - Tab - accept the current suggestion
" - Shift+Tab - accept just the first line of the segguestion
" - Ctrl+F - trigger FIM completion manually
"
" make symlink or copy this file to ~/.config/nvim/autoload/llama.vim
"
" start the llama.cpp server with a FIM-compatible model. for example:
"
" $ llama-server -m {model.gguf} --port 8012 -ngl 99 -fa --ubatch-size 1024 --batch-size 2048
"
" --batch-size [512, model max context]
"
" adjust the batch size to control how much of the provided context will be used during the inference
" lower values will use smaller part of the context around the cursor, which will result in faster processing
"
" run this once to initialise the plugin:
" --ubatch-size [64, 2048]
"
" :call llama#init()
" chunks the batch into smaller chunks for faster processing
" depends on the specific hardware. use llama-bench to profile and determine the best size
"
" run this once to initialise llama.vim:
"
" :call llama#init()
"

" color of the suggested text
highlight llama_hl_hint guifg=#ff772f
highlight llama_hl_info guifg=#77ff2f

" endpoint: llama.cpp server endpoint
" n_prefix: number of lines to include in the prefix
" n_suffix: number of lines to include in the suffix
" n_predict: max number of tokens to predict
" t_max_prompt_ms: max alloted time for the text generation
" show_info: show extra info about the inference
" auto_fim: trigger FIM completion automatically on cursor movement
let s:default_config = {
\ 'endpoint': 'http://127.0.0.1:8012/infill',
\ 'n_prefix': 128,
\ 'n_suffix': 128,
\ 'n_prefix': 256,
\ 'n_suffix': 256,
\ 'n_predict': 64,
\ 't_max_prompt_ms': 300,
\ 't_max_prompt_ms': 500,
\ 't_max_predict_ms': 200,
\ 'show_info': v:true,
\ 'auto_fim': v:true,
\ 'stop': ["\n"]
\ }

let g:llama_config = get(g:, 'llama_config', s:default_config)

function! llama#init()
let s:pos_x = 0
if !executable('curl')
echohl WarningMsg
echo 'llama.vim requires the "curl" command to be available'
echohl None
return
endif

let s:pos_x = 0 " cursor position upon start of completion
let s:pos_y = 0
let s:pos_x0 = 0 " pos_x corrected for end-of-line edge case

Expand All @@ -46,8 +87,8 @@ function! llama#init()

augroup llama
autocmd!
autocmd InsertEnter * inoremap <buffer> <silent> <C-F> <C-O>:call llama#fim(v:false)<CR>
autocmd InsertLeave * call llama#fim_cancel()
autocmd InsertEnter * inoremap <buffer> <silent> <C-F> <C-O>:call llama#fim(v:false)<CR>
autocmd InsertLeavePre * call llama#fim_cancel()

autocmd CursorMoved * call llama#fim_cancel()
augroup END
Expand Down Expand Up @@ -90,7 +131,6 @@ function! llama#fim(is_auto) abort
\ 'prompt': "",
\ 'input_prefix': l:prefix,
\ 'input_suffix': l:suffix,
"\ 'stop': g:llama_config.stop,
\ 'n_predict': g:llama_config.n_predict,
\ 'penalty_last_n': 0,
\ 'top_k': 100,
Expand Down Expand Up @@ -126,16 +166,23 @@ function! llama#fim(is_auto) abort
endif
endfunction

function! llama#fim_accept()
" if first_line == v:true accept only the first line of the response
function! llama#fim_accept(first_line)
" insert the suggestion at the cursor location
if s:can_accept && len(s:content) > 0
call setline(s:pos_y, s:line_cur[:(s:pos_x0 - 1)] . s:content[0])
if len(s:content) > 1
call append(s:pos_y, s:content[1:-1])
if !a:first_line
call append(s:pos_y, s:content[1:-1])
endif
endif

" move the cursor to the end of the accepted text
call cursor(s:pos_y + len(s:content) - 1, s:pos_x + s:pos_dx)
if !a:first_line
call cursor(s:pos_y + len(s:content) - 1, s:pos_x + s:pos_dx)
else
call cursor(s:pos_y, s:pos_x + len(s:content[0]) - 1)
endif
endif

call llama#fim_cancel()
Expand All @@ -146,6 +193,11 @@ function! llama#fim_cancel()
call jobstop(s:current_job)
endif

if s:timer_fim != -1
call timer_stop(s:timer_fim)
let s:timer_fim = -1
endif

" clear the virtual text
let l:bufnr = bufnr('%')

Expand All @@ -155,7 +207,9 @@ function! llama#fim_cancel()
call nvim_buf_clear_namespace(l:bufnr, l:id_vt_fim, 0, -1)
call nvim_buf_clear_namespace(l:bufnr, l:id_vt_info, 0, -1)

" remove the mappings
silent! iunmap <buffer> <Tab>
silent! iunmap <buffer> <S-Tab>
silent! iunmap <buffer> <Esc>

augroup llama_insert
Expand All @@ -173,6 +227,8 @@ function! s:fim_auto_enable()
augroup END
endfunction

" auto-start a fim job a short time after the cursor has moved
" if there is already a job queued - cancel it
function! s:fim_auto()
if s:current_job != v:null
call jobstop(s:current_job)
Expand All @@ -189,7 +245,7 @@ function! s:fim_auto()
let s:timer_fim = timer_start(500, {-> llama#fim(v:true)})
endfunction


" callback that processes the result from the server
function! s:fim_on_stdout(job_id, data, event) dict
let l:raw = join(a:data, "\n")
if len(l:raw) == 0
Expand All @@ -199,6 +255,13 @@ function! s:fim_on_stdout(job_id, data, event) dict
let s:can_accept = v:true
let l:has_info = v:false

if s:can_accept && v:shell_error
if !self.is_auto
call add(s:content, "<| curl error: is the server on? |>")
endif
let s:can_accept = v:false
endif

let l:n_prompt = 0
let l:t_prompt_ms = 1.0
let l:s_prompt = 0
Expand All @@ -207,13 +270,6 @@ function! s:fim_on_stdout(job_id, data, event) dict
let l:t_predict_ms = 1.0
let l:s_predict = 0

if s:can_accept && v:shell_error
if !self.is_auto
call add(s:content, "<| curl error: is the server on? |>")
endif
let s:can_accept = v:false
endif

" get the generated suggestion
if s:can_accept
let l:response = json_decode(l:raw)
Expand All @@ -227,7 +283,7 @@ function! s:fim_on_stdout(job_id, data, event) dict
call remove(s:content, -1)
endwhile

" if response.timings
" if response.timings is available
if len(get(l:response, 'timings', {})) > 0
let l:has_info = v:true
let l:timings = get(l:response, 'timings', {})
Expand Down Expand Up @@ -264,8 +320,8 @@ function! s:fim_on_stdout(job_id, data, event) dict
let l:id_vt_fim = nvim_create_namespace('vt_fim')
let l:id_vt_info = nvim_create_namespace('vt_info')

" construct the info message:
if l:has_info
" construct the info message and display it to the right of the current line
if g:llama_config.show_info && l:has_info
" prefix the info string with whitespace in order to offset it to the right of the fim overlay
let l:prefix = repeat(' ', len(s:content[0]) - len(s:line_cur_suffix) + 3)

Expand All @@ -282,6 +338,7 @@ function! s:fim_on_stdout(job_id, data, event) dict
\ })
endif

" display the suggestion
call nvim_buf_set_extmark(l:bufnr, l:id_vt_fim, s:pos_y - 1, s:pos_x - 1, {
\ 'virt_text': [[s:content[0], 'llama_hl_hint']],
\ 'virt_text_win_col': virtcol('.') - 1
Expand All @@ -293,8 +350,8 @@ function! s:fim_on_stdout(job_id, data, event) dict
\ })

" setup accept/cancel events
inoremap <buffer> <Tab> <C-O>:call llama#fim_accept()<CR>
inoremap <buffer> <Esc> <C-O>:call llama#fim_cancel()<CR><Esc>
inoremap <buffer> <Tab> <C-O>:call llama#fim_accept(v:false)<CR>
inoremap <buffer> <S-Tab> <C-O>:call llama#fim_accept(v:true)<CR>
augroup llama_insert
autocmd!
Expand Down
22 changes: 15 additions & 7 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ struct slot_params {
int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half
int32_t n_predict = -1; // new tokens to predict

int64_t t_max_prompt_ms = -1;
int64_t t_max_predict_ms = -1;
int64_t t_max_prompt_ms = -1; // TODO: not implemented
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit

std::vector<std::string> antiprompt;

Expand Down Expand Up @@ -2028,8 +2028,8 @@ struct server_context {
auto prefix_tokens = tokenize(slot.params.input_prefix, false, false);
auto suffix_tokens = tokenize(slot.params.input_suffix, false, false);

// for now pick context to fit in a single batch
const int n_suffix_take = std::min<int>(suffix_tokens.size(), n_batch/2);
// for now pick context to fit in a single batch (ratio prefix:suffix = 3:1, TODO: configurable?)
const int n_suffix_take = std::min<int>(suffix_tokens.size(), n_batch/4);
const int n_prefix_take = std::min<int>(prefix_tokens.size(), (n_batch - 3) - n_suffix_take);

prefix_tokens.erase(prefix_tokens.begin(), prefix_tokens.begin() + prefix_tokens.size() - n_prefix_take);
Expand Down Expand Up @@ -2057,9 +2057,17 @@ struct server_context {

SLT_INF(slot, "prompt tokenized, n_ctx_slot = %d, n_keep = %d, n_prompt_tokens = %d\n", slot.n_ctx, slot.params.n_keep, slot.n_prompt_tokens);

// print prompt tokens:
for (int i = 0; i < (int) prompt_tokens.size(); i++) {
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
// print prompt tokens (for debugging)
if (1) {
// first 16 tokens (avoid flooding logs)
for (int i = 0; i < std::min<int>(16, prompt_tokens.size()); i++) {
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
}
} else {
// all
for (int i = 0; i < (int) prompt_tokens.size(); i++) {
SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str());
}
}

// empty prompt passed -> release the slot and send empty response
Expand Down
6 changes: 5 additions & 1 deletion src/llama-sampling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1646,6 +1646,8 @@ struct llama_sampler * llama_sampler_init_logit_bias(

// infill

//#define GGML_DEBUG_SAMPLER_INFILL

struct llama_sampler_infill {
const struct llama_vocab * vocab;
};
Expand All @@ -1659,10 +1661,11 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_

llama_sampler_softmax_impl(cur_p);

// print cur_p:
#if defined(GGML_DEBUG_SAMPLER_INFILL)
for (size_t i = 0; i < cur_p->size; ++i) {
LLAMA_LOG_DEBUG("infill: cur_p[%zu] = { id: %d, p: %f, logit: %f }\n", i, cur_p->data[i].id, cur_p->data[i].p, cur_p->data[i].logit);
}
#endif

float p_max = 0.0f;
float p_txt_sum = 0.0f;
Expand Down Expand Up @@ -1746,6 +1749,7 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_
return;
}

// pick the best token
cur_p->size = 1;
cur_p->data[0] = cur_p->data[i_max];

Expand Down

0 comments on commit cefd4ac

Please sign in to comment.