Skip to content

Commit

Permalink
Support the flux-dev model too. (huggingface#2395)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored Aug 4, 2024
1 parent c0a559d commit 89eae41
Showing 1 changed file with 37 additions and 9 deletions.
46 changes: 37 additions & 9 deletions candle-examples/examples/flux/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,15 @@ struct Args {

#[arg(long)]
decode_only: Option<String>,

#[arg(long, value_enum, default_value = "schnell")]
model: Model,
}

#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
enum Model {
Schnell,
Dev,
}

fn run(args: Args) -> Result<()> {
Expand All @@ -50,6 +59,7 @@ fn run(args: Args) -> Result<()> {
width,
tracing,
decode_only,
model,
} = args;
let width = width.unwrap_or(1360);
let height = height.unwrap_or(768);
Expand All @@ -63,9 +73,13 @@ fn run(args: Args) -> Result<()> {
};

let api = hf_hub::api::sync::Api::new()?;
let bf_repo = api.repo(hf_hub::Repo::model(
"black-forest-labs/FLUX.1-schnell".to_string(),
));
let bf_repo = {
let name = match model {
Model::Dev => "black-forest-labs/FLUX.1-dev",
Model::Schnell => "black-forest-labs/FLUX.1-schnell",
};
api.repo(hf_hub::Repo::model(name.to_string()))
};
let device = candle_examples::device(cpu)?;
let dtype = device.bf16_default_to_f32();
let img = match decode_only {
Expand Down Expand Up @@ -132,16 +146,27 @@ fn run(args: Args) -> Result<()> {
};
println!("CLIP\n{clip_emb}");
let img = {
let model_file = bf_repo.get("flux1-schnell.sft")?;
let model_file = match model {
Model::Schnell => bf_repo.get("flux1-schnell.sft")?,
Model::Dev => bf_repo.get("flux1-dev.sft")?,
};
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let cfg = flux::model::Config::schnell();
let model = flux::model::Flux::new(&cfg, vb)?;

let cfg = match model {
Model::Dev => flux::model::Config::dev(),
Model::Schnell => flux::model::Config::schnell(),
};
let img = flux::sampling::get_noise(1, height, width, &device)?.to_dtype(dtype)?;
let state = flux::sampling::State::new(&t5_emb, &clip_emb, &img)?;
let timesteps = match model {
Model::Dev => {
flux::sampling::get_schedule(50, Some((state.img.dim(1)?, 0.5, 1.15)))
}
Model::Schnell => flux::sampling::get_schedule(4, None),
};
let model = flux::model::Flux::new(&cfg, vb)?;

println!("{state:?}");
let timesteps = flux::sampling::get_schedule(4, None); // no shift for flux-schnell
println!("{timesteps:?}");
flux::sampling::denoise(
&model,
Expand All @@ -166,7 +191,10 @@ fn run(args: Args) -> Result<()> {
let img = {
let model_file = bf_repo.get("ae.sft")?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], dtype, &device)? };
let cfg = flux::autoencoder::Config::schnell();
let cfg = match model {
Model::Dev => flux::autoencoder::Config::dev(),
Model::Schnell => flux::autoencoder::Config::schnell(),
};
let model = flux::autoencoder::AutoEncoder::new(&cfg, vb)?;
model.decode(&img)?
};
Expand Down

0 comments on commit 89eae41

Please sign in to comment.