Skip to content

Commit

Permalink
fix: small lock improvement (#19)
Browse files Browse the repository at this point in the history
Co-authored-by: vansangpfiev <[email protected]>
  • Loading branch information
vansangpfiev and sangjanai authored May 24, 2024
1 parent 7647191 commit c637824
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 21 deletions.
55 changes: 36 additions & 19 deletions src/llama_server_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ json LlamaServerContext::GetModelProps() {

int LlamaServerContext::RequestCompletion(json data, bool infill,
bool embedding, int multitask_id) {
std::unique_lock<std::mutex> lock(mutex_tasks);
TaskServer task;
task.id = id_gen++;
task.target_id = 0;
Expand All @@ -263,12 +262,14 @@ int LlamaServerContext::RequestCompletion(json data, bool infill,
// when a completion task's prompt array is not a singleton, we split it
// into multiple requests
if (task.data.at("prompt").size() > 1) {
lock.unlock(); // entering new func scope
return SplitMultipromptTask(task);
}

// otherwise, it's a single-prompt task, we actually queue it
queue_tasks.push_back(task);
{
std::lock_guard<std::mutex> lock(mutex_tasks);
queue_tasks.push_back(task);
}
condition_tasks.notify_one();
return task.id;
}
Expand Down Expand Up @@ -303,12 +304,14 @@ TaskResult LlamaServerContext::NextResult(int task_id) {
}

void LlamaServerContext::RequestCancel(int task_id) {
std::unique_lock<std::mutex> lock(mutex_tasks);
TaskServer task;
task.id = id_gen++;
task.type = TaskType::kCancelTask;
task.target_id = task_id;
queue_tasks.push_back(task);
{
std::lock_guard<std::mutex> lock(mutex_tasks);
queue_tasks.push_back(task);
}
condition_tasks.notify_one();
}

Expand Down Expand Up @@ -820,13 +823,15 @@ void LlamaServerContext::SendError(int id_task, int id_multi,
}

void LlamaServerContext::AddMultiTask(int id, std::vector<int>& sub_ids) {
std::lock_guard<std::mutex> lock(mutex_tasks);
TaskMulti multi;
multi.id = id;
std::copy(
sub_ids.begin(), sub_ids.end(),
std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end()));
queue_multitasks.push_back(multi);
{
std::lock_guard<std::mutex> lock(mutex_tasks);
queue_multitasks.push_back(multi);
}
condition_tasks.notify_one();
}

Expand Down Expand Up @@ -880,7 +885,6 @@ json LlamaServerContext::GetFormatedGeneration(LlamaClientSlot& slot) {

void LlamaServerContext::SendPartialResponse(LlamaClientSlot& slot,
CompletionTokenOutput tkn) {
std::unique_lock<std::mutex> lock(mutex_results);
TaskResult res;
res.id = slot.task_id;
res.multitask_id = slot.multitask_id;
Expand Down Expand Up @@ -916,12 +920,14 @@ void LlamaServerContext::SendPartialResponse(LlamaClientSlot& slot,
res.result_json["model"] = slot.oaicompat_model;
}

queue_results.push_back(res);
{
std::lock_guard<std::mutex> lock(mutex_results);
queue_results.push_back(res);
}
condition_results.notify_all();
}

void LlamaServerContext::SendFinalResponse(LlamaClientSlot& slot) {
std::unique_lock<std::mutex> lock(mutex_results);
TaskResult res;
res.id = slot.task_id;
res.multitask_id = slot.multitask_id;
Expand Down Expand Up @@ -972,12 +978,14 @@ void LlamaServerContext::SendFinalResponse(LlamaClientSlot& slot) {
UpdateMultiTask(slot.multitask_id, slot.task_id, res);
}

queue_results.push_back(res);
{
std::lock_guard<std::mutex> lock(mutex_results);
queue_results.push_back(res);
}
condition_results.notify_all();
}

void LlamaServerContext::SendEmbedding(LlamaClientSlot& slot) {
std::unique_lock<std::mutex> lock(mutex_results);
TaskResult res;
res.id = slot.task_id;
res.multitask_id = slot.multitask_id;
Expand Down Expand Up @@ -1015,7 +1023,10 @@ void LlamaServerContext::SendEmbedding(LlamaClientSlot& slot) {
{"embedding", embd_res},
};

queue_results.push_back(res);
{
std::lock_guard<std::mutex> lock(mutex_results);
queue_results.push_back(res);
}
condition_results.notify_all();
}

Expand Down Expand Up @@ -1111,10 +1122,15 @@ int LlamaServerContext::SplitMultipromptTask(TaskServer& multiprompt_task) {
}

void LlamaServerContext::ProcessTasks() {
std::unique_lock<std::mutex> lock(mutex_tasks);
while (!queue_tasks.empty()) {
while (true) {
std::unique_lock<std::mutex> l(mutex_tasks);
if (queue_tasks.empty()) {
l.unlock();
break;
}
TaskServer task = queue_tasks.front();
queue_tasks.erase(queue_tasks.begin());
l.unlock();
switch (task.type) {
case TaskType::kCompletionTask: {
LlamaClientSlot* slot = GetSlot(json_value(task.data, "slot_id", -1));
Expand Down Expand Up @@ -1155,6 +1171,7 @@ void LlamaServerContext::ProcessTasks() {

// remove finished multitasks from the queue of multitasks, and add the
// corresponding result to the result queue
std::lock_guard<std::mutex> l(mutex_tasks);
auto queue_iterator = queue_multitasks.begin();
while (queue_iterator != queue_multitasks.end()) {
if (queue_iterator->subtasks_remaining.empty()) {
Expand All @@ -1172,8 +1189,10 @@ void LlamaServerContext::ProcessTasks() {
}
aggregate_result.result_json = json{"results", result_jsons};

std::lock_guard<std::mutex> lock(mutex_results);
queue_results.push_back(aggregate_result);
{
std::lock_guard<std::mutex> lock(mutex_results);
queue_results.push_back(aggregate_result);
}
condition_results.notify_all();

queue_iterator = queue_multitasks.erase(queue_iterator);
Expand Down Expand Up @@ -1211,8 +1230,6 @@ bool LlamaServerContext::UpdateSlots() {
"cache";
KvCacheClear();
}
// std::this_thread::sleep_for(std::chrono::milliseconds(5));
// TODO: Need to implement queueing using CV for better performance
std::unique_lock<std::mutex> lock(mutex_tasks);
condition_tasks.wait(lock, [&] {
return (!queue_tasks.empty() && model_loaded_external) ||
Expand Down
4 changes: 2 additions & 2 deletions src/llama_server_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ struct LlamaServerContext {
bool all_slots_are_idle = false;
bool add_bos_token = true;

int32_t id_gen;
std::atomic<int32_t> id_gen;
int32_t n_ctx; // total context for all clients / slots

// Internal
Expand All @@ -138,7 +138,7 @@ struct LlamaServerContext {
std::vector<TaskServer> queue_tasks;
std::vector<TaskResult> queue_results;
std::vector<TaskMulti> queue_multitasks;
std::mutex mutex_tasks; // also guards id_gen, and queue_multitasks
std::mutex mutex_tasks; // also guards queue_multitasks
std::condition_variable condition_tasks;
std::mutex mutex_results;
std::condition_variable condition_results;
Expand Down

0 comments on commit c637824

Please sign in to comment.