Skip to content

Commit

Permalink
fix: use download event type to listen ws on client side (#1601)
Browse files Browse the repository at this point in the history
* fix: use download event type to listen ws on client side

* fix: format

* fix: remove unused

---------

Co-authored-by: vansangpfiev <[email protected]>
  • Loading branch information
vansangpfiev and sangjanai authored Nov 1, 2024
1 parent 601437d commit 166cdb5
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 13 deletions.
7 changes: 4 additions & 3 deletions engine/cli/commands/engine_install_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ bool EngineInstallCmd::Exec(const std::string& engine,
DownloadProgress dp;
dp.Connect(host_, port_);
// engine can be small, so need to start ws first
auto dp_res = std::async(std::launch::deferred,
[&dp, &engine] { return dp.Handle(engine); });
auto dp_res = std::async(std::launch::deferred, [&dp] {
return dp.Handle(DownloadType::Engine);
});
CLI_LOG("Validating download items, please wait..")

httplib::Client cli(host_ + ":" + std::to_string(port_));
Expand Down Expand Up @@ -68,7 +69,7 @@ bool EngineInstallCmd::Exec(const std::string& engine,

bool check_cuda_download = !system_info_utils::GetCudaVersion().empty();
if (check_cuda_download) {
if (!dp.Handle("cuda"))
if (!dp.Handle(DownloadType::CudaToolkit))
return false;
}

Expand Down
2 changes: 1 addition & 1 deletion engine/cli/commands/model_pull_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ std::optional<std::string> ModelPullCmd::Exec(const std::string& host, int port,
reinterpret_cast<PHANDLER_ROUTINE>(console_ctrl_handler), true);
#endif
dp.Connect(host, port);
if (!dp.Handle(model_id))
if (!dp.Handle(DownloadType::Model))
return std::nullopt;
if (force_stop)
return std::nullopt;
Expand Down
29 changes: 21 additions & 8 deletions engine/cli/utils/download_progress.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,23 @@
#include "common/event.h"
#include "indicators/dynamic_progress.hpp"
#include "indicators/progress_bar.hpp"
#include "utils/engine_constants.h"
#include "utils/format_utils.h"
#include "utils/json_helper.h"
#include "utils/logging_utils.h"

namespace {
std::string Repo2Engine(const std::string& r) {
if (r == kLlamaRepo) {
return kLlamaEngine;
} else if (r == kOnnxRepo) {
return kOnnxEngine;
} else if (r == kTrtLlmRepo) {
return kTrtLlmEngine;
}
return r;
};
} // namespace
bool DownloadProgress::Connect(const std::string& host, int port) {
if (ws_) {
CTL_INF("Already connected!");
Expand All @@ -21,7 +34,7 @@ bool DownloadProgress::Connect(const std::string& host, int port) {
return true;
}

bool DownloadProgress::Handle(const std::string& id) {
bool DownloadProgress::Handle(const DownloadType& event_type) {
assert(!!ws_);
std::unordered_map<std::string, uint64_t> totals;
status_ = DownloadStatus::DownloadStarted;
Expand All @@ -30,7 +43,7 @@ bool DownloadProgress::Handle(const std::string& id) {
std::vector<std::unique_ptr<indicators::ProgressBar>> items;
indicators::show_console_cursor(false);
auto handle_message = [this, &bars, &items, &totals,
id](const std::string& message) {
event_type](const std::string& message) {
CTL_INF(message);

auto pad_string = [](const std::string& str,
Expand All @@ -50,8 +63,8 @@ bool DownloadProgress::Handle(const std::string& id) {

auto ev = cortex::event::GetDownloadEventFromJson(
json_helper::ParseJsonString(message));
// Ignore other task ids
if (ev.download_task_.id != id) {
// Ignore other task type
if (ev.download_task_.type != event_type) {
return;
}

Expand All @@ -63,7 +76,7 @@ bool DownloadProgress::Handle(const std::string& id) {
indicators::option::BarWidth{50}, indicators::option::Start{"["},
indicators::option::Fill{"="}, indicators::option::Lead{">"},
indicators::option::End{"]"},
indicators::option::PrefixText{pad_string(i.id)},
indicators::option::PrefixText{pad_string(Repo2Engine(i.id))},
indicators::option::ForegroundColor{indicators::Color::white},
indicators::option::ShowRemainingTime{true}));
bars->push_back(*(items.back()));
Expand All @@ -80,7 +93,7 @@ bool DownloadProgress::Handle(const std::string& id) {
if (ev.type_ == DownloadStatus::DownloadStarted ||
ev.type_ == DownloadStatus::DownloadUpdated) {
(*bars)[i].set_option(indicators::option::PrefixText{
pad_string(it.id) +
pad_string(Repo2Engine(it.id)) +
std::to_string(
int(static_cast<double>(downloaded) / totals[it.id] * 100)) +
'%'});
Expand All @@ -94,8 +107,8 @@ bool DownloadProgress::Handle(const std::string& id) {
auto total_str = format_utils::BytesToHumanReadable(totals[it.id]);
(*bars)[i].set_option(
indicators::option::PostfixText{total_str + "/" + total_str});
(*bars)[i].set_option(
indicators::option::PrefixText{pad_string(it.id) + "100%"});
(*bars)[i].set_option(indicators::option::PrefixText{
pad_string(Repo2Engine(it.id)) + "100%"});
(*bars)[i].set_progress(100);

CTL_INF("Download success");
Expand Down
2 changes: 1 addition & 1 deletion engine/cli/utils/download_progress.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class DownloadProgress {
public:
bool Connect(const std::string& host, int port);

bool Handle(const std::string& id);
bool Handle(const DownloadType& event_type);

void ForceStop() { force_stop_ = true; }

Expand Down

0 comments on commit 166cdb5

Please sign in to comment.