From 82fc618147537069e1d0eb13d0dc6817685939d7 Mon Sep 17 00:00:00 2001 From: vegetableysm <108774481+vegetableysm@users.noreply.github.com> Date: Wed, 13 Mar 2024 17:54:34 +0800 Subject: [PATCH] Fix llm memory leak and replace mutex_lock with unique_lock (#1815) - Fix memory leak of llm module. - Use unique_lock to manager mutex locks in llm module. Fixes #1813 Fixes #1814 Signed-off-by: vegetableysm --- modules/llm-cache/ds/kv_state_cache.cc | 20 ++++ modules/llm-cache/ds/kv_state_cache.h | 2 + .../llm-cache/ds/kv_state_cache_manager.cc | 94 +++++++++++-------- modules/llm-cache/ds/kv_state_cache_manager.h | 5 +- .../tests/kv_state_cache_benchmark_test.cc | 2 + .../llm-cache/tests/kv_state_cache_test.cc | 7 +- 6 files changed, 89 insertions(+), 41 deletions(-) diff --git a/modules/llm-cache/ds/kv_state_cache.cc b/modules/llm-cache/ds/kv_state_cache.cc index 38cdf38a..c7370b7d 100644 --- a/modules/llm-cache/ds/kv_state_cache.cc +++ b/modules/llm-cache/ds/kv_state_cache.cc @@ -259,6 +259,9 @@ void KVStateCacheBuilder::Delete(std::shared_ptr evictedNodeData) { // delete (DataWrapper*) evictedNodeData->nodeData; if (evictedNodeData->cleanTreeData) { this->rootTree->ClearSubtreeData(treeData); + std::shared_ptr blockObject = + kvStateCacheBlockBuilder->_Seal(client); + client.DelData(blockObject->id()); delete kvStateCacheBlockBuilder; } evictedNodeData->RecycleSource(); @@ -326,6 +329,7 @@ Status KVStateCacheBuilder::Merge(std::shared_ptr kvStateCache) { } this->version = globalCacheBuilder->GetVersion(); + globalCacheBuilder->Close(); return Status::OK(); } @@ -403,4 +407,20 @@ KVStateCacheBuilder::~KVStateCacheBuilder() { } } +void KVStateCacheBuilder::Close() { + std::set subTreeDataSet = rootTree->GetSubTreeDataSet(); + for (auto iter = subTreeDataSet.begin(); iter != subTreeDataSet.end(); + ++iter) { + TreeData* treeData = reinterpret_cast(*iter); + if (treeData->isPtr == true && + treeData->kvStateCacheBlockBuilder != nullptr) { + std::shared_ptr object = + reinterpret_cast( + treeData->kvStateCacheBlockBuilder) + ->_Seal(client); + client.DelData(object->id()); + } + } +} + } // namespace vineyard diff --git a/modules/llm-cache/ds/kv_state_cache.h b/modules/llm-cache/ds/kv_state_cache.h index f8ddd343..0a01912c 100644 --- a/modules/llm-cache/ds/kv_state_cache.h +++ b/modules/llm-cache/ds/kv_state_cache.h @@ -128,6 +128,8 @@ class KVStateCacheBuilder : public vineyard::ObjectBuilder { int GetLayer() { return this->layer; } + void Close(); + ~KVStateCacheBuilder(); }; diff --git a/modules/llm-cache/ds/kv_state_cache_manager.cc b/modules/llm-cache/ds/kv_state_cache_manager.cc index 3f1ebc64..f2c7a4bc 100644 --- a/modules/llm-cache/ds/kv_state_cache_manager.cc +++ b/modules/llm-cache/ds/kv_state_cache_manager.cc @@ -98,64 +98,65 @@ Status KVStateCacheManager::QueryInternal( Status KVStateCacheManager::Update( const std::vector& tokenList, int nextToken, const std::map>& kvState) { - Status result = - Status::Invalid("Query cache failed: can not gain the cache lock."); - - if (!syncMutex.try_lock()) { - return result; + std::unique_lock lock(cacheAccessMutex, std::defer_lock); + if (!lock.try_lock()) { + // If failed to gain the lock, return OK and wait for next time + return Status::OK(); } - result = UpdateInternal(tokenList, nextToken, kvState); + if (isClosed) { + return Status::Invalid("The cache manager is closed."); + } - syncMutex.unlock(); - return result; + return UpdateInternal(tokenList, nextToken, kvState); } Status KVStateCacheManager::Update( const std::vector& tokenList, const std::vector>>& kvState) { - Status result = - Status::Invalid("Update cache failed: can not gain the cache lock."); - if (!syncMutex.try_lock()) { - return result; + std::unique_lock lock(cacheAccessMutex, std::defer_lock); + if (!lock.try_lock()) { + return Status::OK(); + } + if (isClosed) { + return Status::Invalid("The cache manager is closed."); } - std::vector tokenListCopy; for (size_t i = 0; i < tokenList.size(); i++) { - result = UpdateInternal(tokenListCopy, tokenList[i], kvState[i]); + Status result = UpdateInternal(tokenListCopy, tokenList[i], kvState[i]); if (!result.ok()) { break; } tokenListCopy.push_back(tokenList[i]); } - syncMutex.unlock(); - return result; + return Status::OK(); } Status KVStateCacheManager::Query( const std::vector& tokenList, int token, std::map>& kvState) { - Status result = - Status::Invalid("Query cache failed: can not gain the cache lock."); - - if (!syncMutex.try_lock()) { - return result; + std::unique_lock lock(cacheAccessMutex, std::defer_lock); + if (!lock.try_lock()) { + // If failed to gain the lock, return OK and wait for next time + return Status::OK(); + } + if (isClosed) { + return Status::Invalid("The cache manager is closed."); } - result = QueryInternal(tokenList, token, kvState); - syncMutex.unlock(); - - return result; + return QueryInternal(tokenList, token, kvState); } Status KVStateCacheManager::Query( const std::vector& tokenList, std::vector>>& listKVState) { - Status result = - Status::Invalid("Query cache failed: can not gain the cache lock."); - if (!syncMutex.try_lock()) { - return result; + std::unique_lock lock(cacheAccessMutex, std::defer_lock); + if (!lock.try_lock()) { + return Status::Invalid("Query cache failed: can not gain the cache lock."); + } + if (isClosed) { + return Status::Invalid("The cache manager is closed."); } // support partial match of the token list @@ -163,9 +164,8 @@ Status KVStateCacheManager::Query( std::vector tokenListCopy; std::map> kvState; for (size_t i = 0; i < tokenList.size(); i++) { - result = QueryInternal(tokenListCopy, tokenList[i], kvState); + Status result = QueryInternal(tokenListCopy, tokenList[i], kvState); if (!result.ok()) { - syncMutex.unlock(); return Status::OK(); } tokenListCopy.push_back(tokenList[i]); @@ -173,18 +173,18 @@ Status KVStateCacheManager::Query( kvState.clear(); } - syncMutex.unlock(); - return result; + return Status::OK(); } KVStateCacheManager::~KVStateCacheManager() { LOG(INFO) << "Wait for sync thread to exit."; - { - std::lock_guard lock(exitMutex); + std::lock_guard lock(exitMutex); + if (!exitFlag) { exitFlag = true; + exitMutex.unlock(); + cv.notify_one(); + syncThread.join(); } - cv.notify_one(); - syncThread.join(); LOG(INFO) << "KVStateCacheManager exit."; } @@ -279,7 +279,7 @@ void KVStateCacheManager::SyncThreadFunc(KVStateCacheManager* manager) { if (manager->exitFlag) { break; } - manager->syncMutex.lock(); + std::lock_guard lock(manager->cacheAccessMutex); std::string actualKey; AcquireServerLock(manager->client, manager->llmCacheSyncLock, actualKey); @@ -292,7 +292,6 @@ void KVStateCacheManager::SyncThreadFunc(KVStateCacheManager* manager) { } ReleaseServerLock(manager->client, actualKey); - manager->syncMutex.unlock(); last_time = std::chrono::duration_cast( std::chrono::system_clock::now().time_since_epoch()) @@ -354,4 +353,21 @@ void KVStateCacheManager::ReleaseServerLock(Client& client, } } +void KVStateCacheManager::Close() { + // recycle blob + LOG(INFO) << "Wait for sync thread to exit."; + std::lock_guard exitLock(exitMutex); + if (!exitFlag) { + exitFlag = true; + exitMutex.unlock(); + cv.notify_one(); + syncThread.join(); + } + + LOG(INFO) << "Recycle blob."; + std::lock_guard cacheLock(cacheAccessMutex); + this->kvStateCacheBuilder->Close(); + this->isClosed = true; +} + } // namespace vineyard diff --git a/modules/llm-cache/ds/kv_state_cache_manager.h b/modules/llm-cache/ds/kv_state_cache_manager.h index 21d2e0d0..be932cf1 100644 --- a/modules/llm-cache/ds/kv_state_cache_manager.h +++ b/modules/llm-cache/ds/kv_state_cache_manager.h @@ -36,11 +36,12 @@ class KVStateCacheManager { std::string llmCacheSyncLock; std::string llmCacheObjectName; std::thread syncThread; - std::mutex syncMutex; + std::mutex cacheAccessMutex; int syncInterval; bool exitFlag = false; std::condition_variable cv; std::mutex exitMutex; + bool isClosed = false; public: KVStateCacheManager(Client& client, @@ -69,6 +70,8 @@ class KVStateCacheManager { const std::vector& tokenList, std::vector>>& listKVState); + void Close(); + ~KVStateCacheManager(); private: diff --git a/modules/llm-cache/tests/kv_state_cache_benchmark_test.cc b/modules/llm-cache/tests/kv_state_cache_benchmark_test.cc index 93c8cd95..1fd563e9 100644 --- a/modules/llm-cache/tests/kv_state_cache_benchmark_test.cc +++ b/modules/llm-cache/tests/kv_state_cache_benchmark_test.cc @@ -156,7 +156,9 @@ int main(int argc, char** argv) { } benchmark_inference(all_token_lists); + sleep(5); inference_done = true; + manager->Close(); }); memory_monitor.join(); diff --git a/modules/llm-cache/tests/kv_state_cache_test.cc b/modules/llm-cache/tests/kv_state_cache_test.cc index ffa1fa20..85fce56e 100644 --- a/modules/llm-cache/tests/kv_state_cache_test.cc +++ b/modules/llm-cache/tests/kv_state_cache_test.cc @@ -145,7 +145,7 @@ void inference(std::shared_ptr& kv_state_cache_manager, kv_state.clear(); Status result = kv_state_cache_manager->Query(inference_tokens, tokens[i], kv_state); - if (!result.ok()) { + if (!result.ok() || kv_state.empty()) { LOG(INFO) << "Can not find the kv_state from cache:"; print_current_tokens(inference_tokens, tokens[i]); LOG(INFO) << "Generate the kv_state and update the cache."; @@ -184,6 +184,11 @@ void threadFunc(std::string socket) { } LOG(INFO) << "inference end"; + + manager->Close(); + std::shared_ptr status; + VINEYARD_CHECK_OK(client.InstanceStatus(status)); + LOG(INFO) << "memory usage:" << status->memory_usage; client.Disconnect(); }