diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b6a751b..a18c90d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -168,7 +168,7 @@ jobs: ccache-dir: "/home/runner/.ccache" - os: "mac" name: "amd64" - runs-on: "macos-12" + runs-on: "macos-selfhosted-12" cmake-flags: "-DCORTEXLLAMA_VERSION=${{needs.create-draft-release.outputs.version}} -DBUILD_SHARED_LIBS=OFF -DLLAMA_BUILD_COMMON=ON -DGGML_METAL=OFF" run-e2e: true vulkan: false diff --git a/.github/workflows/nightly-build.yml b/.github/workflows/nightly-build.yml index 826ec91..a32bfc9 100644 --- a/.github/workflows/nightly-build.yml +++ b/.github/workflows/nightly-build.yml @@ -167,7 +167,7 @@ jobs: ccache-dir: "/home/runner/.ccache" - os: "mac" name: "amd64" - runs-on: "macos-12" + runs-on: "macos-selfhosted-12" cmake-flags: "-DCORTEXLLAMA_VERSION=${{needs.create-draft-release.outputs.version}} -DBUILD_SHARED_LIBS=OFF -DLLAMA_BUILD_COMMON=ON -DGGML_METAL=OFF" run-e2e: true vulkan: false diff --git a/.github/workflows/template-e2e-weekend-test.yml b/.github/workflows/template-e2e-weekend-test.yml index 0e57b4a..b694762 100644 --- a/.github/workflows/template-e2e-weekend-test.yml +++ b/.github/workflows/template-e2e-weekend-test.yml @@ -50,7 +50,7 @@ jobs: - os: "mac" name: "amd64" - runs-on: "macos-12" + runs-on: "macos-selfhosted-12" cmake-flags: "-DCORTEXLLAMA_VERSION=${{github.event.pull_request.head.sha}} -DBUILD_SHARED_LIBS=OFF -DLLAMA_BUILD_COMMON=ON -DGGML_METAL=OFF" run-e2e: true vulkan: false diff --git a/.github/workflows/template-quality-gate-pr.yml b/.github/workflows/template-quality-gate-pr.yml index 78f0c57..8dc3e34 100644 --- a/.github/workflows/template-quality-gate-pr.yml +++ b/.github/workflows/template-quality-gate-pr.yml @@ -134,7 +134,7 @@ jobs: ccache-dir: "/home/runner/.ccache" - os: "mac" name: "amd64" - runs-on: "macos-12" + runs-on: "macos-selfhosted-12" cmake-flags: "-DCORTEXLLAMA_VERSION=${{github.event.pull_request.head.sha}} -DBUILD_SHARED_LIBS=OFF -DLLAMA_BUILD_COMMON=ON -DGGML_METAL=OFF" run-e2e: true vulkan: false diff --git a/.github/workflows/template-quality-gate-submodule.yml b/.github/workflows/template-quality-gate-submodule.yml index 18bd7af..67dc490 100644 --- a/.github/workflows/template-quality-gate-submodule.yml +++ b/.github/workflows/template-quality-gate-submodule.yml @@ -134,7 +134,7 @@ jobs: ccache-dir: "/home/runner/.ccache" - os: "mac" name: "amd64" - runs-on: "macos-12" + runs-on: "macos-selfhosted-12" cmake-flags: "-DCORTEXLLAMA_VERSION=${{github.event.pull_request.head.sha}} -DBUILD_SHARED_LIBS=OFF -DLLAMA_BUILD_COMMON=ON -DGGML_METAL=OFF" run-e2e: true vulkan: false diff --git a/base/cortex-common/enginei.h b/base/cortex-common/enginei.h index 2c7a918..6c9f496 100644 --- a/base/cortex-common/enginei.h +++ b/base/cortex-common/enginei.h @@ -63,7 +63,8 @@ class EngineI { virtual bool IsSupported(const std::string& f) { if (f == "HandleChatCompletion" || f == "HandleEmbedding" || f == "LoadModel" || f == "UnloadModel" || f == "GetModelStatus" || - f == "GetModels" || f == "SetFileLogger" || f == "SetLogLevel") { + f == "GetModels" || f == "SetFileLogger" || f == "SetLogLevel" || + f == "StopInferencing") { return true; } return false; @@ -77,4 +78,6 @@ class EngineI { virtual void SetFileLogger(int max_log_lines, const std::string& log_path) = 0; virtual void SetLogLevel(trantor::Logger::LogLevel log_level) = 0; + + virtual void StopInferencing(const std::string& model_id) = 0; }; diff --git a/examples/server/server.cc b/examples/server/server.cc index 159d4f1..ef52d67 100644 --- a/examples/server/server.cc +++ b/examples/server/server.cc @@ -23,6 +23,14 @@ class Server { } } + void ForceStopInferencing(const std::string& model_id) { + if (engine_) { + engine_->StopInferencing(model_id); + } else { + LOG_WARN << "Engine is null"; + } + } + public: std::unique_ptr dylib_; EngineI* engine_; @@ -122,9 +130,10 @@ int main(int argc, char** argv) { }; auto process_stream_res = [&server](httplib::Response& resp, - std::shared_ptr q) { + std::shared_ptr q, + const std::string& model_id) { const auto chunked_content_provider = - [&server, q](size_t size, httplib::DataSink& sink) { + [&server, q, model_id](size_t size, httplib::DataSink& sink) { while (true) { auto [status, res] = q->wait_and_pop(); auto str = res["data"].asString(); @@ -132,7 +141,8 @@ int main(int argc, char** argv) { if (!sink.write(str.c_str(), str.size())) { LOG_WARN << "Failed to write"; - // return false; + server.ForceStopInferencing(model_id); + return false; } if (status["has_error"].asBool() || status["is_done"].asBool()) { LOG_INFO << "Done"; @@ -183,6 +193,7 @@ int main(int argc, char** argv) { auto req_body = std::make_shared(); r.Parse(req.body, *req_body); bool is_stream = (*req_body).get("stream", false).asBool(); + std::string model_id = (*req_body).get("model", "invalid_model").asString(); // This is an async call, need to use queue auto q = std::make_shared(); server.engine_->HandleChatCompletion( @@ -190,7 +201,7 @@ int main(int argc, char** argv) { q->push(std::make_pair(status, res)); }); if (is_stream) { - process_stream_res(resp, q); + process_stream_res(resp, q, model_id); } else { process_non_stream_res(resp, *q); } diff --git a/src/llama_engine.cc b/src/llama_engine.cc index 58d1889..379db88 100644 --- a/src/llama_engine.cc +++ b/src/llama_engine.cc @@ -8,7 +8,6 @@ #include "llama_utils.h" #include "trantor/utils/Logger.h" - #if defined(_WIN32) #include #include @@ -537,6 +536,10 @@ void LlamaEngine::SetLogLevel(trantor::Logger::LogLevel log_level) { trantor::Logger::setLogLevel(log_level); } +void LlamaEngine::StopInferencing(const std::string& model_id) { + AddForceStopInferenceModel(model_id); +} + void LlamaEngine::SetFileLogger(int max_log_lines, const std::string& log_path) { if (!async_file_logger_) { @@ -959,12 +962,19 @@ void LlamaEngine::HandleInferenceImpl( LOG_INFO << "Request " << request_id << ": " << "Streamed, waiting for respone"; auto state = CreateInferenceState(si.ctx); + auto model_id = completion.model_id; // Queued task - si.q->runTaskInQueue([cb = std::move(callback), state, data, request_id, - n_probs, include_usage]() { + si.q->runTaskInQueue([this, cb = std::move(callback), state, data, + request_id, n_probs, include_usage, model_id]() { state->task_id = state->llama.RequestCompletion(data, false, false, -1); while (state->llama.model_loaded_external) { + if (HasForceStopInferenceModel(model_id)) { + LOG_INFO << "Force stop inferencing for model: " << model_id; + state->llama.RequestCancel(state->task_id); + RemoveForceStopInferenceModel(model_id); + break; + } TaskResult result = state->llama.NextResult(state->task_id); if (!result.error) { std::string to_send; @@ -1287,6 +1297,28 @@ bool LlamaEngine::ShouldInitBackend() const { return true; } +void LlamaEngine::AddForceStopInferenceModel(const std::string& id) { + std::lock_guard l(fsi_mtx_); + if (force_stop_inference_models_.find(id) == + force_stop_inference_models_.end()) { + LOG_INFO << "Added force stop inferencing model: " << id; + force_stop_inference_models_.insert(id); + } +} +void LlamaEngine::RemoveForceStopInferenceModel(const std::string& id) { + std::lock_guard l(fsi_mtx_); + if (force_stop_inference_models_.find(id) != + force_stop_inference_models_.end()) { + force_stop_inference_models_.erase(id); + } +} + +bool LlamaEngine::HasForceStopInferenceModel(const std::string& id) const { + std::lock_guard l(fsi_mtx_); + return force_stop_inference_models_.find(id) != + force_stop_inference_models_.end(); +} + extern "C" { EngineI* get_engine() { return new LlamaEngine(); diff --git a/src/llama_engine.h b/src/llama_engine.h index 4e7d2f9..750e482 100644 --- a/src/llama_engine.h +++ b/src/llama_engine.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include "chat_completion_request.h" #include "cortex-common/enginei.h" #include "file_logger.h" @@ -44,6 +45,7 @@ class LlamaEngine : public EngineI { void SetFileLogger(int max_log_lines, const std::string& log_path) final; void SetLogLevel(trantor::Logger::LogLevel log_level = trantor::Logger::LogLevel::kInfo) final; + void StopInferencing(const std::string& model_id) final; private: bool LoadModelImpl(std::shared_ptr jsonBody); @@ -59,6 +61,10 @@ class LlamaEngine : public EngineI { void WarmUpModel(const std::string& model_id); bool ShouldInitBackend() const; + void AddForceStopInferenceModel(const std::string& id); + void RemoveForceStopInferenceModel(const std::string& id); + bool HasForceStopInferenceModel(const std::string& id) const; + private: struct ServerInfo { LlamaServerContext ctx; @@ -78,6 +84,9 @@ class LlamaEngine : public EngineI { // key: model_id, value: ServerInfo std::unordered_map server_map_; + // lock the force_stop_inference_models_ + mutable std::mutex fsi_mtx_; + std::unordered_set force_stop_inference_models_; std::atomic no_of_requests_ = 0; std::atomic no_of_chats_ = 0;