Skip to content

Commit

Permalink
Embed the mel filters in the whisper binary. (huggingface#373)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored Aug 9, 2023
1 parent 5b79b38 commit a3b1699
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 13 deletions.
17 changes: 4 additions & 13 deletions candle-examples/examples/whisper/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
extern crate intel_mkl_src;

use anyhow::{Error as E, Result};
use candle::{safetensors::Load, DType, Device, Tensor};
use candle::{DType, Device, Tensor};
use candle_nn::{ops::softmax, VarBuilder};
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
Expand Down Expand Up @@ -243,13 +243,6 @@ struct Args {
/// The seed to use when generating random samples.
#[arg(long, default_value_t = 299792458)]
seed: u64,

/// The mel filters in safetensors format.
#[arg(
long,
default_value = "candle-examples/examples/whisper/mel_filters.safetensors"
)]
filters: String,
}

fn main() -> Result<()> {
Expand Down Expand Up @@ -301,11 +294,9 @@ fn main() -> Result<()> {
};
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;

let mel_filters = unsafe { candle::safetensors::MmapedFile::new(args.filters)? };
let mel_filters = mel_filters.deserialize()?;
let mel_filters = mel_filters.tensor("mel_80")?.load(&device)?;
println!("loaded mel filters {:?}", mel_filters.shape());
let mel_filters = mel_filters.flatten_all()?.to_vec1::<f32>()?;
let mel_bytes = include_bytes!("melfilters.bytes");
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);

let mut input = std::fs::File::open(input)?;
let (header, data) = wav::read(&mut input)?;
Expand Down
Binary file not shown.

0 comments on commit a3b1699

Please sign in to comment.