Skip to content

Commit

Permalink
Fix llm memory leak and replace mutex_lock with unique_lock (#1815)
Browse files Browse the repository at this point in the history
- Fix memory leak of llm module.
- Use unique_lock to manager mutex locks in llm module.

Fixes #1813 
Fixes #1814

Signed-off-by: vegetableysm <[email protected]>
  • Loading branch information
vegetableysm authored Mar 13, 2024
1 parent 2839088 commit 82fc618
Show file tree
Hide file tree
Showing 6 changed files with 89 additions and 41 deletions.
20 changes: 20 additions & 0 deletions modules/llm-cache/ds/kv_state_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,9 @@ void KVStateCacheBuilder::Delete(std::shared_ptr<NodeData> evictedNodeData) {
// delete (DataWrapper*) evictedNodeData->nodeData;
if (evictedNodeData->cleanTreeData) {
this->rootTree->ClearSubtreeData(treeData);
std::shared_ptr<Object> blockObject =
kvStateCacheBlockBuilder->_Seal(client);
client.DelData(blockObject->id());
delete kvStateCacheBlockBuilder;
}
evictedNodeData->RecycleSource();
Expand Down Expand Up @@ -326,6 +329,7 @@ Status KVStateCacheBuilder::Merge(std::shared_ptr<KVStateCache> kvStateCache) {
}

this->version = globalCacheBuilder->GetVersion();
globalCacheBuilder->Close();
return Status::OK();
}

Expand Down Expand Up @@ -403,4 +407,20 @@ KVStateCacheBuilder::~KVStateCacheBuilder() {
}
}

void KVStateCacheBuilder::Close() {
std::set<void*> subTreeDataSet = rootTree->GetSubTreeDataSet();
for (auto iter = subTreeDataSet.begin(); iter != subTreeDataSet.end();
++iter) {
TreeData* treeData = reinterpret_cast<TreeData*>(*iter);
if (treeData->isPtr == true &&
treeData->kvStateCacheBlockBuilder != nullptr) {
std::shared_ptr<Object> object =
reinterpret_cast<KVStateCacheBlockBuilder*>(
treeData->kvStateCacheBlockBuilder)
->_Seal(client);
client.DelData(object->id());
}
}
}

} // namespace vineyard
2 changes: 2 additions & 0 deletions modules/llm-cache/ds/kv_state_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ class KVStateCacheBuilder : public vineyard::ObjectBuilder {

int GetLayer() { return this->layer; }

void Close();

~KVStateCacheBuilder();
};

Expand Down
94 changes: 55 additions & 39 deletions modules/llm-cache/ds/kv_state_cache_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,93 +98,93 @@ Status KVStateCacheManager::QueryInternal(
Status KVStateCacheManager::Update(
const std::vector<int>& tokenList, int nextToken,
const std::map<int, std::pair<LLMKV, LLMKV>>& kvState) {
Status result =
Status::Invalid("Query cache failed: can not gain the cache lock.");

if (!syncMutex.try_lock()) {
return result;
std::unique_lock<std::mutex> lock(cacheAccessMutex, std::defer_lock);
if (!lock.try_lock()) {
// If failed to gain the lock, return OK and wait for next time
return Status::OK();
}

result = UpdateInternal(tokenList, nextToken, kvState);
if (isClosed) {
return Status::Invalid("The cache manager is closed.");
}

syncMutex.unlock();
return result;
return UpdateInternal(tokenList, nextToken, kvState);
}

Status KVStateCacheManager::Update(
const std::vector<int>& tokenList,
const std::vector<std::map<int, std::pair<LLMKV, LLMKV>>>& kvState) {
Status result =
Status::Invalid("Update cache failed: can not gain the cache lock.");
if (!syncMutex.try_lock()) {
return result;
std::unique_lock<std::mutex> lock(cacheAccessMutex, std::defer_lock);
if (!lock.try_lock()) {
return Status::OK();
}
if (isClosed) {
return Status::Invalid("The cache manager is closed.");
}

std::vector<int> tokenListCopy;
for (size_t i = 0; i < tokenList.size(); i++) {
result = UpdateInternal(tokenListCopy, tokenList[i], kvState[i]);
Status result = UpdateInternal(tokenListCopy, tokenList[i], kvState[i]);
if (!result.ok()) {
break;
}
tokenListCopy.push_back(tokenList[i]);
}

syncMutex.unlock();
return result;
return Status::OK();
}

Status KVStateCacheManager::Query(
const std::vector<int>& tokenList, int token,
std::map<int, std::pair<LLMKV, LLMKV>>& kvState) {
Status result =
Status::Invalid("Query cache failed: can not gain the cache lock.");

if (!syncMutex.try_lock()) {
return result;
std::unique_lock<std::mutex> lock(cacheAccessMutex, std::defer_lock);
if (!lock.try_lock()) {
// If failed to gain the lock, return OK and wait for next time
return Status::OK();
}
if (isClosed) {
return Status::Invalid("The cache manager is closed.");
}

result = QueryInternal(tokenList, token, kvState);
syncMutex.unlock();

return result;
return QueryInternal(tokenList, token, kvState);
}

Status KVStateCacheManager::Query(
const std::vector<int>& tokenList,
std::vector<std::map<int, std::pair<LLMKV, LLMKV>>>& listKVState) {
Status result =
Status::Invalid("Query cache failed: can not gain the cache lock.");
if (!syncMutex.try_lock()) {
return result;
std::unique_lock<std::mutex> lock(cacheAccessMutex, std::defer_lock);
if (!lock.try_lock()) {
return Status::Invalid("Query cache failed: can not gain the cache lock.");
}
if (isClosed) {
return Status::Invalid("The cache manager is closed.");
}

// support partial match of the token list
// copy the token list and query the cache one token by one token
std::vector<int> tokenListCopy;
std::map<int, std::pair<LLMKV, LLMKV>> kvState;
for (size_t i = 0; i < tokenList.size(); i++) {
result = QueryInternal(tokenListCopy, tokenList[i], kvState);
Status result = QueryInternal(tokenListCopy, tokenList[i], kvState);
if (!result.ok()) {
syncMutex.unlock();
return Status::OK();
}
tokenListCopy.push_back(tokenList[i]);
listKVState.push_back(kvState);
kvState.clear();
}

syncMutex.unlock();
return result;
return Status::OK();
}

KVStateCacheManager::~KVStateCacheManager() {
LOG(INFO) << "Wait for sync thread to exit.";
{
std::lock_guard<std::mutex> lock(exitMutex);
std::lock_guard<std::mutex> lock(exitMutex);
if (!exitFlag) {
exitFlag = true;
exitMutex.unlock();
cv.notify_one();
syncThread.join();
}
cv.notify_one();
syncThread.join();
LOG(INFO) << "KVStateCacheManager exit.";
}

Expand Down Expand Up @@ -279,7 +279,7 @@ void KVStateCacheManager::SyncThreadFunc(KVStateCacheManager* manager) {
if (manager->exitFlag) {
break;
}
manager->syncMutex.lock();
std::lock_guard<std::mutex> lock(manager->cacheAccessMutex);
std::string actualKey;

AcquireServerLock(manager->client, manager->llmCacheSyncLock, actualKey);
Expand All @@ -292,7 +292,6 @@ void KVStateCacheManager::SyncThreadFunc(KVStateCacheManager* manager) {
}

ReleaseServerLock(manager->client, actualKey);
manager->syncMutex.unlock();

last_time = std::chrono::duration_cast<std::chrono::seconds>(
std::chrono::system_clock::now().time_since_epoch())
Expand Down Expand Up @@ -354,4 +353,21 @@ void KVStateCacheManager::ReleaseServerLock(Client& client,
}
}

void KVStateCacheManager::Close() {
// recycle blob
LOG(INFO) << "Wait for sync thread to exit.";
std::lock_guard<std::mutex> exitLock(exitMutex);
if (!exitFlag) {
exitFlag = true;
exitMutex.unlock();
cv.notify_one();
syncThread.join();
}

LOG(INFO) << "Recycle blob.";
std::lock_guard<std::mutex> cacheLock(cacheAccessMutex);
this->kvStateCacheBuilder->Close();
this->isClosed = true;
}

} // namespace vineyard
5 changes: 4 additions & 1 deletion modules/llm-cache/ds/kv_state_cache_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@ class KVStateCacheManager {
std::string llmCacheSyncLock;
std::string llmCacheObjectName;
std::thread syncThread;
std::mutex syncMutex;
std::mutex cacheAccessMutex;
int syncInterval;
bool exitFlag = false;
std::condition_variable cv;
std::mutex exitMutex;
bool isClosed = false;

public:
KVStateCacheManager(Client& client,
Expand Down Expand Up @@ -69,6 +70,8 @@ class KVStateCacheManager {
const std::vector<int>& tokenList,
std::vector<std::map<int, std::pair<LLMKV, LLMKV>>>& listKVState);

void Close();

~KVStateCacheManager();

private:
Expand Down
2 changes: 2 additions & 0 deletions modules/llm-cache/tests/kv_state_cache_benchmark_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ int main(int argc, char** argv) {
}

benchmark_inference(all_token_lists);
sleep(5);
inference_done = true;
manager->Close();
});

memory_monitor.join();
Expand Down
7 changes: 6 additions & 1 deletion modules/llm-cache/tests/kv_state_cache_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ void inference(std::shared_ptr<KVStateCacheManager>& kv_state_cache_manager,
kv_state.clear();
Status result =
kv_state_cache_manager->Query(inference_tokens, tokens[i], kv_state);
if (!result.ok()) {
if (!result.ok() || kv_state.empty()) {
LOG(INFO) << "Can not find the kv_state from cache:";
print_current_tokens(inference_tokens, tokens[i]);
LOG(INFO) << "Generate the kv_state and update the cache.";
Expand Down Expand Up @@ -184,6 +184,11 @@ void threadFunc(std::string socket) {
}

LOG(INFO) << "inference end";

manager->Close();
std::shared_ptr<InstanceStatus> status;
VINEYARD_CHECK_OK(client.InstanceStatus(status));
LOG(INFO) << "memory usage:" << status->memory_usage;
client.Disconnect();
}

Expand Down

0 comments on commit 82fc618

Please sign in to comment.