Skip to content

Commit

Permalink
Add the DAC model. (huggingface#2433)
Browse files Browse the repository at this point in the history
* Add the DAC model.

* More quantization support.

* Handle DAC decoding.

* Plug the DAC decoding in parler-tts.
  • Loading branch information
LaurentMazare authored Aug 19, 2024
1 parent 58197e1 commit 236b29f
Show file tree
Hide file tree
Showing 7 changed files with 404 additions and 8 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,4 @@ candle-wasm-examples/**/config*.json
.idea/*
__pycache__
out.safetensors
out.wav
1 change: 1 addition & 0 deletions candle-examples/examples/parler-tts/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

tensors = load_file("out.safetensors")
dac_model = DACModel.from_pretrained("parler-tts/dac_44khZ_8kbps")
print(dac_model.model)
output_ids = tensors["codes"][None, None]
print(output_ids, "\n", output_ids.shape)
batch_size = 1
Expand Down
26 changes: 19 additions & 7 deletions candle-examples/examples/parler-tts/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ extern crate accelerate_src;
use anyhow::Error as E;
use clap::Parser;

use candle::{DType, Tensor};
use candle::{DType, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::parler_tts::{Config, Model};
use tokenizers::Tokenizer;
Expand Down Expand Up @@ -36,7 +36,7 @@ struct Args {
description: String,

/// The temperature used to generate samples.
#[arg(long, default_value_t = 1.0)]
#[arg(long, default_value_t = 0.0)]
temperature: f64,

/// Nucleus sampling probability cutoff.
Expand Down Expand Up @@ -82,6 +82,10 @@ struct Args {

#[arg(long, default_value_t = 512)]
max_steps: usize,

/// The output wav file.
#[arg(long, default_value = "out.wav")]
out_file: String,
}

fn main() -> anyhow::Result<()> {
Expand Down Expand Up @@ -152,24 +156,32 @@ fn main() -> anyhow::Result<()> {
.get_ids()
.to_vec();
let description_tokens = Tensor::new(description_tokens, &device)?.unsqueeze(0)?;
println!("{description_tokens}");

let prompt_tokens = tokenizer
.encode(args.prompt, true)
.map_err(E::msg)?
.get_ids()
.to_vec();
let prompt_tokens = Tensor::new(prompt_tokens, &device)?.unsqueeze(0)?;
println!("{prompt_tokens}");

let lp = candle_transformers::generation::LogitsProcessor::new(
args.seed,
Some(args.temperature),
args.top_p,
);
println!("starting generation...");
let codes = model.generate(&prompt_tokens, &description_tokens, lp, args.max_steps)?;
println!("{codes}");
println!("generated codes\n{codes}");
let codes = codes.to_dtype(DType::I64)?;
codes.save_safetensors("codes", "out.safetensors")?;
let codes = codes.unsqueeze(0)?;
let pcm = model
.audio_encoder
.decode_codes(&codes.to_device(&device)?)?;
println!("{pcm}");
let pcm = pcm.i((0, 0))?;
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
let pcm = pcm.to_vec1::<f32>()?;
let mut output = std::fs::File::create(&args.out_file)?;
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, config.audio_encoder.sampling_rate)?;

Ok(())
}
Loading

0 comments on commit 236b29f

Please sign in to comment.