Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reset kv cache after each query and infinite inference features #2560

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions examples/talk-llama/talk-llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ struct whisper_params {
bool verbose_prompt = false;
bool use_gpu = true;
bool flash_attn = false;

bool reset_cache = false;
bool infinite_inference = false;
std::string person = "Georgi";
std::string bot_name = "LLaMA";
std::string wake_cmd = "";
Expand Down Expand Up @@ -115,6 +116,8 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; }
else if (arg == "-s" || arg == "--speak") { params.speak = argv[++i]; }
else if (arg == "-sf" || arg == "--speak-file") { params.speak_file = argv[++i]; }
else if (arg == "-inf" || arg == "--infinite_inference") { params.infinite_inference = True; }
else if (arg == "-reset" || arg == "--reset_cache") { params.reset_cache = True; }
else if (arg == "--prompt-file") {
std::ifstream file(argv[++i]);
std::copy(std::istreambuf_iterator<char>(file), std::istreambuf_iterator<char>(), back_inserter(params.prompt));
Expand Down Expand Up @@ -165,6 +168,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para
fprintf(stderr, " --prompt-file FNAME [%-7s] file with custom prompt to start dialog\n", "");
fprintf(stderr, " --session FNAME file to cache model state in (may be large!) (default: none)\n");
fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str());
fprintf(stderr, " -inf, --infinite_inference [%-7s] infinite inference\n", params.infinite_inference ? "true" : "false");
fprintf(stderr, " -reset, --reset_cache [%-7s] reset cache after each question\n", params.reset_cache ? "true" : "false");
fprintf(stderr, "\n");
}

Expand Down Expand Up @@ -194,6 +199,7 @@ static std::string transcribe(
wparams.translate = params.translate;
wparams.no_context = true;
wparams.single_segment = true;
wparams.infinite_inference = true;
wparams.max_tokens = params.max_tokens;
wparams.language = params.language.c_str();
wparams.n_threads = params.n_threads;
Expand Down Expand Up @@ -643,16 +649,35 @@ int main(int argc, char ** argv) {
// text inference
bool done = false;
std::string text_to_speak;
if (params.reset_cache) {
int n_discard = lcparams.n_ctx - n_keep;
//std::cout << "Number of tokens to discard: " << n_discard << "\n";

llama_kv_cache_seq_rm(ctx_llama, 0, n_keep, n_keep + n_discard); //the number of tokens beyond the on
n_past = n_keep;
}
while (true) {
// predict
if (embd.size() > 0) {
if (n_past + (int) embd.size() > n_ctx) {
n_past = n_keep;
if (params.infinite_inference && (n_past + (int) embd.size() > n_ctx) && n_past != n_keep) {
std::cout<<"\nInfinite context enabled\n";
const int n_left = n_past - n_keep; //the number of tokens beyond the ones we want to keep
const int n_discard = n_left/2; //we decide to discard half of the tokens beyond the ones we want to keep

llama_kv_cache_seq_rm (ctx_llama, 0, n_keep , n_keep + n_discard); //the number of tokens beyond the on
llama_kv_cache_seq_add(ctx_llama, 0, n_keep + n_discard, n_past, -n_discard); // this function is adjusting the cache by

n_past -= n_discard;
// insert n_left/2 tokens at the start of embd from last_n_tokens
//embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end());
// stop saving session if we run out of context
path_session.clear();
continue;

// insert n_left/2 tokens at the start of embd from last_n_tokens
embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end());
//embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end());
// stop saving session if we run out of context
path_session = "";
//path_session = "";
//printf("\n---\n");
//printf("resetting: '");
//for (int i = 0; i < (int) embd.size(); i++) {
Expand Down