From 861a12eaca3c2824be3611300dc883356789dfa8 Mon Sep 17 00:00:00 2001 From: tikikun Date: Mon, 13 Nov 2023 08:25:58 +0700 Subject: [PATCH] feat:unload model --- controllers/llamaCPP.cc | 34 +++++++++++++++++++++++++++++++++- controllers/llamaCPP.h | 8 +++++++- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/controllers/llamaCPP.cc b/controllers/llamaCPP.cc index 868b7aa06..b7963c6c5 100644 --- a/controllers/llamaCPP.cc +++ b/controllers/llamaCPP.cc @@ -196,6 +196,25 @@ void llamaCPP::embedding( return; } +void llamaCPP::unloadModel( + const HttpRequestPtr &req, + std::function &&callback) { + Json::Value jsonResp; + jsonResp["message"] = "No model loaded"; + if (model_loaded) { + stopBackgroundTask(); + + llama_free(llama.ctx); + llama_free_model(llama.model); + llama.ctx = nullptr; + llama.model = nullptr; + jsonResp["message"] = "Model unloaded successfully"; + } + auto resp = nitro_utils::nitroHttpJsonResponse(jsonResp); + callback(resp); + return; +} + void llamaCPP::loadModel( const HttpRequestPtr &req, std::function &&callback) { @@ -274,7 +293,20 @@ void llamaCPP::loadModel( void llamaCPP::backgroundTask() { while (model_loaded) { - model_loaded = llama.update_slots(); + // model_loaded = + llama.update_slots(); + LOG_INFO << "Background state refresh!"; } + LOG_INFO << "Background task stopped!"; return; } + +void llamaCPP::stopBackgroundTask() { + if (model_loaded) { + model_loaded = false; + LOG_INFO << "changed to false"; + if (backgroundThread.joinable()) { + backgroundThread.join(); + } + } +} diff --git a/controllers/llamaCPP.h b/controllers/llamaCPP.h index 4e755d524..03999b18a 100644 --- a/controllers/llamaCPP.h +++ b/controllers/llamaCPP.h @@ -2124,6 +2124,8 @@ class llamaCPP : public drogon::HttpController { METHOD_ADD(llamaCPP::chatCompletion, "chat_completion", Post); METHOD_ADD(llamaCPP::embedding, "embedding", Post); METHOD_ADD(llamaCPP::loadModel, "loadmodel", Post); + METHOD_ADD(llamaCPP::unloadModel, "unloadmodel", Get); + // PATH_ADD("/llama/chat_completion", Post); METHOD_LIST_END void chatCompletion(const HttpRequestPtr &req, @@ -2132,13 +2134,17 @@ class llamaCPP : public drogon::HttpController { std::function &&callback); void loadModel(const HttpRequestPtr &req, std::function &&callback); + void unloadModel(const HttpRequestPtr &req, + std::function &&callback); void warmupModel(); void backgroundTask(); + void stopBackgroundTask(); + private: llama_server_context llama; - bool model_loaded = false; + std::atomic model_loaded = false; size_t sent_count = 0; size_t sent_token_probs_index = 0; std::thread backgroundThread;