Skip to content

Commit

Permalink
feat(cli): show progress
Browse files Browse the repository at this point in the history
  • Loading branch information
JacobLinCool committed Aug 28, 2024
1 parent 28241c5 commit 8d1491d
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 23 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,5 @@ Cargo.lock
# Added by cargo

/target

/outputs
77 changes: 54 additions & 23 deletions src/bin/gr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::path::PathBuf;

use anyhow::Result;
use clap::{arg, Command};
use gradio::{Client, ClientOptions, PredictionInput};
use gradio::{Client, ClientOptions, PredictionInput, PredictionOutput};

#[tokio::main]
async fn main() -> Result<()> {
Expand Down Expand Up @@ -111,29 +111,60 @@ async fn run_command(
}

let http_client = client.http_client.clone();
let output = client.predict(&route, data).await?;
for (i, ret) in endpoint.returns.iter().enumerate() {
let value = output.get(i).expect("Missing return value");
let name = if let Some(label) = &ret.label {
label
} else if let Some(name) = &ret.parameter_name {
name
} else {
"unnamed"
};

if value.is_file() {
let file = value.clone().as_file()?;
if let Some(outdir) = outdir {
let mut fp = PathBuf::from(outdir);
fp.push(format!("{}.{}", name, file.suggest_extension()));
file.save_to_path(&fp, Some(http_client.clone())).await?;
println!("{}: {}", name, fp.display());
} else {
println!("{}: {}", name, file.url.unwrap_or("".to_string()));
let mut prediction = client.submit(&route, data).await.unwrap();
while let Some(event) = prediction.next().await {
let event = event.unwrap();
match event {
gradio::structs::QueueDataMessage::InQueue {
rank, queue_size, ..
} => {
eprintln!("Queueing: {}/{}", rank + 1, queue_size);
}
} else {
println!("{}: {}", name, value.clone().as_value()?);
gradio::structs::QueueDataMessage::Processing { progress_data, .. } => {
if progress_data.is_none() {
continue;
}
let progress_data = progress_data.unwrap();
if !progress_data.is_empty() {
let progress_data = &progress_data[0];
eprintln!(
"Processing: {}/{} {}",
progress_data.index + 1,
progress_data.length.unwrap(),
progress_data.unit
);
}
}
gradio::structs::QueueDataMessage::Completed { output, .. } => {
let output: Vec<PredictionOutput> = output.try_into().unwrap();

for (i, ret) in endpoint.returns.iter().enumerate() {
let value = output.get(i).expect("Missing return value");
let name = if let Some(label) = &ret.label {
label
} else if let Some(name) = &ret.parameter_name {
name
} else {
"unnamed"
};

if value.is_file() {
let file = value.clone().as_file()?;
if let Some(outdir) = outdir {
let mut fp = PathBuf::from(outdir);
fp.push(format!("{}.{}", name, file.suggest_extension()));
file.save_to_path(&fp, Some(http_client.clone())).await?;
println!("{}: {}", name, fp.display());
} else {
println!("{}: {}", name, file.url.unwrap_or("".to_string()));
}
} else {
println!("{}: {}", name, value.clone().as_value()?);
}
}
break;
}
_ => {}
}
}

Expand Down

0 comments on commit 8d1491d

Please sign in to comment.