Skip to content

Commit

Permalink
Merge pull request #7 from JacobLinCool/gradio-5
Browse files Browse the repository at this point in the history
feat: support gradio 5
  • Loading branch information
JacobLinCool authored Oct 12, 2024
2 parents 8228f83 + 725ccb6 commit 8b6a788
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 5 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "gradio"
version = "0.3.0"
version = "0.3.1"
edition = "2021"
authors = ["Jacob Lin <[email protected]>"]
description = "Gradio Client in Rust."
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Gradio Client in Rust.
- [x] Command-line interface
- [x] Synchronous and asynchronous API

> Supposed to work with Gradio 4, other versions are not tested.
> Supposed to work with Gradio 5 & 4, other versions are not tested.
## Documentation

Expand Down
32 changes: 32 additions & 0 deletions examples/whisper-turbo.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
use gradio::{Client, ClientOptions, PredictionInput};

#[tokio::main]
async fn main() {
if std::env::args().len() < 2 {
println!("Please provide an audio file path as an argument");
std::process::exit(1);
}
let args: Vec<String> = std::env::args().collect();
let file_path = &args[1];
println!("File: {}", file_path);

// Gradio v5
let client = Client::new("hf-audio/whisper-large-v3-turbo", ClientOptions::default())
.await
.unwrap();

let output = client
.predict(
"/predict",
vec![
PredictionInput::from_file(file_path),
PredictionInput::from_value("transcribe"),
],
)
.await
.unwrap();
println!(
"Output: {}",
output[0].clone().as_value().unwrap().as_str().unwrap()
);
}
10 changes: 7 additions & 3 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ impl Client {

let http_client = Client::build_http_client(&options.hf_token)?;

let (api_root, space_id) =
let (mut api_root, space_id) =
Client::resolve_app_reference(&http_client, app_reference).await?;

if let Some((username, password)) = &options.auth {
Expand All @@ -95,6 +95,10 @@ impl Client {
}

let config = Client::fetch_config(&http_client, &api_root).await?;
if let Some(ref api_prefix) = config.api_prefix {
api_root.push_str(api_prefix);
}

let api_info = Client::fetch_api_info(&http_client, &api_root).await?;

Ok(Self {
Expand Down Expand Up @@ -229,9 +233,9 @@ impl Client {
let json = res.json::<serde_json::Value>().await?;
let config: AppConfigVersionOnly = serde_json::from_value(json.clone())?;

if !config.version.starts_with("4.") {
if !config.version.starts_with("5.") && !config.version.starts_with("4.") {
eprintln!(
"Warning: This client is supposed to work with Gradio 4. The current version of the app is {}, which may cause issues.",
"Warning: This client is supposed to work with Gradio 5 & 4. The current version of the app is {}, which may cause issues.",
config.version
);
}
Expand Down
1 change: 1 addition & 0 deletions src/structs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub struct AppConfig {
pub theme_hash: Option<StringOrI64>,
pub username: Option<String>,
pub max_file_size: Option<i64>,
pub api_prefix: Option<String>,
#[serde(default)]
pub auth_required: Option<bool>,
#[serde(default)]
Expand Down

0 comments on commit 8b6a788

Please sign in to comment.