diff --git a/examples/talk-llama/talk-llama.cpp b/examples/talk-llama/talk-llama.cpp index 1b9de94d724..3ca7f32b13f 100644 --- a/examples/talk-llama/talk-llama.cpp +++ b/examples/talk-llama/talk-llama.cpp @@ -66,7 +66,8 @@ struct whisper_params { bool verbose_prompt = false; bool use_gpu = true; bool flash_attn = false; - + bool reset_cache = false; + bool infinite_inference = false; std::string person = "Georgi"; std::string bot_name = "LLaMA"; std::string wake_cmd = ""; @@ -115,6 +116,8 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params else if (arg == "-ml" || arg == "--model-llama") { params.model_llama = argv[++i]; } else if (arg == "-s" || arg == "--speak") { params.speak = argv[++i]; } else if (arg == "-sf" || arg == "--speak-file") { params.speak_file = argv[++i]; } + else if (arg == "-inf" || arg == "--infinite_inference") { params.infinite_inference = True; } + else if (arg == "-reset" || arg == "--reset_cache") { params.reset_cache = True; } else if (arg == "--prompt-file") { std::ifstream file(argv[++i]); std::copy(std::istreambuf_iterator(file), std::istreambuf_iterator(), back_inserter(params.prompt)); @@ -165,6 +168,8 @@ void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params & para fprintf(stderr, " --prompt-file FNAME [%-7s] file with custom prompt to start dialog\n", ""); fprintf(stderr, " --session FNAME file to cache model state in (may be large!) (default: none)\n"); fprintf(stderr, " -f FNAME, --file FNAME [%-7s] text output file name\n", params.fname_out.c_str()); + fprintf(stderr, " -inf, --infinite_inference [%-7s] infinite inference\n", params.infinite_inference ? "true" : "false"); + fprintf(stderr, " -reset, --reset_cache [%-7s] reset cache after each question\n", params.reset_cache ? "true" : "false"); fprintf(stderr, "\n"); } @@ -194,6 +199,7 @@ static std::string transcribe( wparams.translate = params.translate; wparams.no_context = true; wparams.single_segment = true; + wparams.infinite_inference = true; wparams.max_tokens = params.max_tokens; wparams.language = params.language.c_str(); wparams.n_threads = params.n_threads; @@ -643,16 +649,35 @@ int main(int argc, char ** argv) { // text inference bool done = false; std::string text_to_speak; + if (params.reset_cache) { + int n_discard = lcparams.n_ctx - n_keep; + //std::cout << "Number of tokens to discard: " << n_discard << "\n"; + + llama_kv_cache_seq_rm(ctx_llama, 0, n_keep, n_keep + n_discard); //the number of tokens beyond the on + n_past = n_keep; + } while (true) { // predict if (embd.size() > 0) { - if (n_past + (int) embd.size() > n_ctx) { - n_past = n_keep; + if (params.infinite_inference && (n_past + (int) embd.size() > n_ctx) && n_past != n_keep) { + std::cout<<"\nInfinite context enabled\n"; + const int n_left = n_past - n_keep; //the number of tokens beyond the ones we want to keep + const int n_discard = n_left/2; //we decide to discard half of the tokens beyond the ones we want to keep + + llama_kv_cache_seq_rm (ctx_llama, 0, n_keep , n_keep + n_discard); //the number of tokens beyond the on + llama_kv_cache_seq_add(ctx_llama, 0, n_keep + n_discard, n_past, -n_discard); // this function is adjusting the cache by + + n_past -= n_discard; + // insert n_left/2 tokens at the start of embd from last_n_tokens + //embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end()); + // stop saving session if we run out of context + path_session.clear(); + continue; // insert n_left/2 tokens at the start of embd from last_n_tokens - embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end()); + //embd.insert(embd.begin(), embd_inp.begin() + embd_inp.size() - n_prev, embd_inp.end()); // stop saving session if we run out of context - path_session = ""; + //path_session = ""; //printf("\n---\n"); //printf("resetting: '"); //for (int i = 0; i < (int) embd.size(); i++) {