Skip to content

Commit

Permalink
fix: stop inflight chat completion in stream mode (#315)
Browse files Browse the repository at this point in the history
* fix: stop inflight chat completion in stream mode

* Merge branch 'main' into 'fix/stop-inflight-chat-completion'

* chore: change runner from macos-12 to macos-selfhosted-12

---------

Co-authored-by: vansangpfiev <[email protected]>
  • Loading branch information
vansangpfiev and sangjanai authored Dec 5, 2024
1 parent 08ef284 commit 749872e
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 13 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ jobs:
ccache-dir: "/home/runner/.ccache"
- os: "mac"
name: "amd64"
runs-on: "macos-12"
runs-on: "macos-selfhosted-12"
cmake-flags: "-DCORTEXLLAMA_VERSION=${{needs.create-draft-release.outputs.version}} -DBUILD_SHARED_LIBS=OFF -DLLAMA_BUILD_COMMON=ON -DGGML_METAL=OFF"
run-e2e: true
vulkan: false
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/nightly-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ jobs:
ccache-dir: "/home/runner/.ccache"
- os: "mac"
name: "amd64"
runs-on: "macos-12"
runs-on: "macos-selfhosted-12"
cmake-flags: "-DCORTEXLLAMA_VERSION=${{needs.create-draft-release.outputs.version}} -DBUILD_SHARED_LIBS=OFF -DLLAMA_BUILD_COMMON=ON -DGGML_METAL=OFF"
run-e2e: true
vulkan: false
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/template-e2e-weekend-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ jobs:

- os: "mac"
name: "amd64"
runs-on: "macos-12"
runs-on: "macos-selfhosted-12"
cmake-flags: "-DCORTEXLLAMA_VERSION=${{github.event.pull_request.head.sha}} -DBUILD_SHARED_LIBS=OFF -DLLAMA_BUILD_COMMON=ON -DGGML_METAL=OFF"
run-e2e: true
vulkan: false
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/template-quality-gate-pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ jobs:
ccache-dir: "/home/runner/.ccache"
- os: "mac"
name: "amd64"
runs-on: "macos-12"
runs-on: "macos-selfhosted-12"
cmake-flags: "-DCORTEXLLAMA_VERSION=${{github.event.pull_request.head.sha}} -DBUILD_SHARED_LIBS=OFF -DLLAMA_BUILD_COMMON=ON -DGGML_METAL=OFF"
run-e2e: true
vulkan: false
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/template-quality-gate-submodule.yml
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ jobs:
ccache-dir: "/home/runner/.ccache"
- os: "mac"
name: "amd64"
runs-on: "macos-12"
runs-on: "macos-selfhosted-12"
cmake-flags: "-DCORTEXLLAMA_VERSION=${{github.event.pull_request.head.sha}} -DBUILD_SHARED_LIBS=OFF -DLLAMA_BUILD_COMMON=ON -DGGML_METAL=OFF"
run-e2e: true
vulkan: false
Expand Down
5 changes: 4 additions & 1 deletion base/cortex-common/enginei.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ class EngineI {
virtual bool IsSupported(const std::string& f) {
if (f == "HandleChatCompletion" || f == "HandleEmbedding" ||
f == "LoadModel" || f == "UnloadModel" || f == "GetModelStatus" ||
f == "GetModels" || f == "SetFileLogger" || f == "SetLogLevel") {
f == "GetModels" || f == "SetFileLogger" || f == "SetLogLevel" ||
f == "StopInferencing") {
return true;
}
return false;
Expand All @@ -77,4 +78,6 @@ class EngineI {
virtual void SetFileLogger(int max_log_lines,
const std::string& log_path) = 0;
virtual void SetLogLevel(trantor::Logger::LogLevel log_level) = 0;

virtual void StopInferencing(const std::string& model_id) = 0;
};
19 changes: 15 additions & 4 deletions examples/server/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ class Server {
}
}

void ForceStopInferencing(const std::string& model_id) {
if (engine_) {
engine_->StopInferencing(model_id);
} else {
LOG_WARN << "Engine is null";
}
}

public:
std::unique_ptr<dylib> dylib_;
EngineI* engine_;
Expand Down Expand Up @@ -122,17 +130,19 @@ int main(int argc, char** argv) {
};

auto process_stream_res = [&server](httplib::Response& resp,
std::shared_ptr<SyncQueue> q) {
std::shared_ptr<SyncQueue> q,
const std::string& model_id) {
const auto chunked_content_provider =
[&server, q](size_t size, httplib::DataSink& sink) {
[&server, q, model_id](size_t size, httplib::DataSink& sink) {
while (true) {
auto [status, res] = q->wait_and_pop();
auto str = res["data"].asString();
LOG_TRACE << "data: " << str;

if (!sink.write(str.c_str(), str.size())) {
LOG_WARN << "Failed to write";
// return false;
server.ForceStopInferencing(model_id);
return false;
}
if (status["has_error"].asBool() || status["is_done"].asBool()) {
LOG_INFO << "Done";
Expand Down Expand Up @@ -183,14 +193,15 @@ int main(int argc, char** argv) {
auto req_body = std::make_shared<Json::Value>();
r.Parse(req.body, *req_body);
bool is_stream = (*req_body).get("stream", false).asBool();
std::string model_id = (*req_body).get("model", "invalid_model").asString();
// This is an async call, need to use queue
auto q = std::make_shared<SyncQueue>();
server.engine_->HandleChatCompletion(
req_body, [&server, q](Json::Value status, Json::Value res) {
q->push(std::make_pair(status, res));
});
if (is_stream) {
process_stream_res(resp, q);
process_stream_res(resp, q, model_id);
} else {
process_non_stream_res(resp, *q);
}
Expand Down
38 changes: 35 additions & 3 deletions src/llama_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include "llama_utils.h"
#include "trantor/utils/Logger.h"


#if defined(_WIN32)
#include <windows.h>
#include <codecvt>
Expand Down Expand Up @@ -537,6 +536,10 @@ void LlamaEngine::SetLogLevel(trantor::Logger::LogLevel log_level) {
trantor::Logger::setLogLevel(log_level);
}

void LlamaEngine::StopInferencing(const std::string& model_id) {
AddForceStopInferenceModel(model_id);
}

void LlamaEngine::SetFileLogger(int max_log_lines,
const std::string& log_path) {
if (!async_file_logger_) {
Expand Down Expand Up @@ -959,12 +962,19 @@ void LlamaEngine::HandleInferenceImpl(
LOG_INFO << "Request " << request_id << ": "
<< "Streamed, waiting for respone";
auto state = CreateInferenceState(si.ctx);
auto model_id = completion.model_id;

// Queued task
si.q->runTaskInQueue([cb = std::move(callback), state, data, request_id,
n_probs, include_usage]() {
si.q->runTaskInQueue([this, cb = std::move(callback), state, data,
request_id, n_probs, include_usage, model_id]() {
state->task_id = state->llama.RequestCompletion(data, false, false, -1);
while (state->llama.model_loaded_external) {
if (HasForceStopInferenceModel(model_id)) {
LOG_INFO << "Force stop inferencing for model: " << model_id;
state->llama.RequestCancel(state->task_id);
RemoveForceStopInferenceModel(model_id);
break;
}
TaskResult result = state->llama.NextResult(state->task_id);
if (!result.error) {
std::string to_send;
Expand Down Expand Up @@ -1287,6 +1297,28 @@ bool LlamaEngine::ShouldInitBackend() const {
return true;
}

void LlamaEngine::AddForceStopInferenceModel(const std::string& id) {
std::lock_guard l(fsi_mtx_);
if (force_stop_inference_models_.find(id) ==
force_stop_inference_models_.end()) {
LOG_INFO << "Added force stop inferencing model: " << id;
force_stop_inference_models_.insert(id);
}
}
void LlamaEngine::RemoveForceStopInferenceModel(const std::string& id) {
std::lock_guard l(fsi_mtx_);
if (force_stop_inference_models_.find(id) !=
force_stop_inference_models_.end()) {
force_stop_inference_models_.erase(id);
}
}

bool LlamaEngine::HasForceStopInferenceModel(const std::string& id) const {
std::lock_guard l(fsi_mtx_);
return force_stop_inference_models_.find(id) !=
force_stop_inference_models_.end();
}

extern "C" {
EngineI* get_engine() {
return new LlamaEngine();
Expand Down
9 changes: 9 additions & 0 deletions src/llama_engine.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <trantor/utils/AsyncFileLogger.h>
#include <unordered_set>
#include "chat_completion_request.h"
#include "cortex-common/enginei.h"
#include "file_logger.h"
Expand Down Expand Up @@ -44,6 +45,7 @@ class LlamaEngine : public EngineI {
void SetFileLogger(int max_log_lines, const std::string& log_path) final;
void SetLogLevel(trantor::Logger::LogLevel log_level =
trantor::Logger::LogLevel::kInfo) final;
void StopInferencing(const std::string& model_id) final;

private:
bool LoadModelImpl(std::shared_ptr<Json::Value> jsonBody);
Expand All @@ -59,6 +61,10 @@ class LlamaEngine : public EngineI {
void WarmUpModel(const std::string& model_id);
bool ShouldInitBackend() const;

void AddForceStopInferenceModel(const std::string& id);
void RemoveForceStopInferenceModel(const std::string& id);
bool HasForceStopInferenceModel(const std::string& id) const;

private:
struct ServerInfo {
LlamaServerContext ctx;
Expand All @@ -78,6 +84,9 @@ class LlamaEngine : public EngineI {

// key: model_id, value: ServerInfo
std::unordered_map<std::string, ServerInfo> server_map_;
// lock the force_stop_inference_models_
mutable std::mutex fsi_mtx_;
std::unordered_set<std::string> force_stop_inference_models_;

std::atomic<int> no_of_requests_ = 0;
std::atomic<int> no_of_chats_ = 0;
Expand Down

0 comments on commit 749872e

Please sign in to comment.