From 166cdb5c6da4cf684d5ffe78defbb576330e8b2d Mon Sep 17 00:00:00 2001 From: vansangpfiev Date: Fri, 1 Nov 2024 09:12:05 +0700 Subject: [PATCH] fix: use download event type to listen ws on client side (#1601) * fix: use download event type to listen ws on client side * fix: format * fix: remove unused --------- Co-authored-by: vansangpfiev --- engine/cli/commands/engine_install_cmd.cc | 7 +++--- engine/cli/commands/model_pull_cmd.cc | 2 +- engine/cli/utils/download_progress.cc | 29 ++++++++++++++++------- engine/cli/utils/download_progress.h | 2 +- 4 files changed, 27 insertions(+), 13 deletions(-) diff --git a/engine/cli/commands/engine_install_cmd.cc b/engine/cli/commands/engine_install_cmd.cc index a0d008c60..f046d89e1 100644 --- a/engine/cli/commands/engine_install_cmd.cc +++ b/engine/cli/commands/engine_install_cmd.cc @@ -35,8 +35,9 @@ bool EngineInstallCmd::Exec(const std::string& engine, DownloadProgress dp; dp.Connect(host_, port_); // engine can be small, so need to start ws first - auto dp_res = std::async(std::launch::deferred, - [&dp, &engine] { return dp.Handle(engine); }); + auto dp_res = std::async(std::launch::deferred, [&dp] { + return dp.Handle(DownloadType::Engine); + }); CLI_LOG("Validating download items, please wait..") httplib::Client cli(host_ + ":" + std::to_string(port_)); @@ -68,7 +69,7 @@ bool EngineInstallCmd::Exec(const std::string& engine, bool check_cuda_download = !system_info_utils::GetCudaVersion().empty(); if (check_cuda_download) { - if (!dp.Handle("cuda")) + if (!dp.Handle(DownloadType::CudaToolkit)) return false; } diff --git a/engine/cli/commands/model_pull_cmd.cc b/engine/cli/commands/model_pull_cmd.cc index 8d6757d61..ad8938146 100644 --- a/engine/cli/commands/model_pull_cmd.cc +++ b/engine/cli/commands/model_pull_cmd.cc @@ -149,7 +149,7 @@ std::optional ModelPullCmd::Exec(const std::string& host, int port, reinterpret_cast(console_ctrl_handler), true); #endif dp.Connect(host, port); - if (!dp.Handle(model_id)) + if (!dp.Handle(DownloadType::Model)) return std::nullopt; if (force_stop) return std::nullopt; diff --git a/engine/cli/utils/download_progress.cc b/engine/cli/utils/download_progress.cc index e77e43beb..9c38d4bdf 100644 --- a/engine/cli/utils/download_progress.cc +++ b/engine/cli/utils/download_progress.cc @@ -4,10 +4,23 @@ #include "common/event.h" #include "indicators/dynamic_progress.hpp" #include "indicators/progress_bar.hpp" +#include "utils/engine_constants.h" #include "utils/format_utils.h" #include "utils/json_helper.h" #include "utils/logging_utils.h" +namespace { +std::string Repo2Engine(const std::string& r) { + if (r == kLlamaRepo) { + return kLlamaEngine; + } else if (r == kOnnxRepo) { + return kOnnxEngine; + } else if (r == kTrtLlmRepo) { + return kTrtLlmEngine; + } + return r; +}; +} // namespace bool DownloadProgress::Connect(const std::string& host, int port) { if (ws_) { CTL_INF("Already connected!"); @@ -21,7 +34,7 @@ bool DownloadProgress::Connect(const std::string& host, int port) { return true; } -bool DownloadProgress::Handle(const std::string& id) { +bool DownloadProgress::Handle(const DownloadType& event_type) { assert(!!ws_); std::unordered_map totals; status_ = DownloadStatus::DownloadStarted; @@ -30,7 +43,7 @@ bool DownloadProgress::Handle(const std::string& id) { std::vector> items; indicators::show_console_cursor(false); auto handle_message = [this, &bars, &items, &totals, - id](const std::string& message) { + event_type](const std::string& message) { CTL_INF(message); auto pad_string = [](const std::string& str, @@ -50,8 +63,8 @@ bool DownloadProgress::Handle(const std::string& id) { auto ev = cortex::event::GetDownloadEventFromJson( json_helper::ParseJsonString(message)); - // Ignore other task ids - if (ev.download_task_.id != id) { + // Ignore other task type + if (ev.download_task_.type != event_type) { return; } @@ -63,7 +76,7 @@ bool DownloadProgress::Handle(const std::string& id) { indicators::option::BarWidth{50}, indicators::option::Start{"["}, indicators::option::Fill{"="}, indicators::option::Lead{">"}, indicators::option::End{"]"}, - indicators::option::PrefixText{pad_string(i.id)}, + indicators::option::PrefixText{pad_string(Repo2Engine(i.id))}, indicators::option::ForegroundColor{indicators::Color::white}, indicators::option::ShowRemainingTime{true})); bars->push_back(*(items.back())); @@ -80,7 +93,7 @@ bool DownloadProgress::Handle(const std::string& id) { if (ev.type_ == DownloadStatus::DownloadStarted || ev.type_ == DownloadStatus::DownloadUpdated) { (*bars)[i].set_option(indicators::option::PrefixText{ - pad_string(it.id) + + pad_string(Repo2Engine(it.id)) + std::to_string( int(static_cast(downloaded) / totals[it.id] * 100)) + '%'}); @@ -94,8 +107,8 @@ bool DownloadProgress::Handle(const std::string& id) { auto total_str = format_utils::BytesToHumanReadable(totals[it.id]); (*bars)[i].set_option( indicators::option::PostfixText{total_str + "/" + total_str}); - (*bars)[i].set_option( - indicators::option::PrefixText{pad_string(it.id) + "100%"}); + (*bars)[i].set_option(indicators::option::PrefixText{ + pad_string(Repo2Engine(it.id)) + "100%"}); (*bars)[i].set_progress(100); CTL_INF("Download success"); diff --git a/engine/cli/utils/download_progress.h b/engine/cli/utils/download_progress.h index 4f71e6d84..98fe85654 100644 --- a/engine/cli/utils/download_progress.h +++ b/engine/cli/utils/download_progress.h @@ -10,7 +10,7 @@ class DownloadProgress { public: bool Connect(const std::string& host, int port); - bool Handle(const std::string& id); + bool Handle(const DownloadType& event_type); void ForceStop() { force_stop_ = true; }