diff --git a/Cargo.toml b/Cargo.toml index d71cc4bc2e..313c68f9c3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ members = [ "candle-transformers", "candle-wasm-examples/*", "candle-wasm-tests", + "tensor-tools", ] exclude = [ "candle-flash-attn", diff --git a/candle-examples/examples/quantized-t5/README.md b/candle-examples/examples/quantized-t5/README.md index 8b8179ebd6..c86e746d90 100644 --- a/candle-examples/examples/quantized-t5/README.md +++ b/candle-examples/examples/quantized-t5/README.md @@ -17,7 +17,7 @@ generate quantized weight files from the original safetensors file by using the `tensor-tools` command line utility via: ```bash -$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf +$ cargo run --bin tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf ``` ## Using custom models diff --git a/tensor-tools/Cargo.toml b/tensor-tools/Cargo.toml new file mode 100644 index 0000000000..eecd7e4353 --- /dev/null +++ b/tensor-tools/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "tensor-tools" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true + +[dependencies] +anyhow = { workspace = true } +candle = { workspace = true } +clap = { workspace = true } +rayon = { workspace = true } +safetensors = { workspace = true } diff --git a/candle-core/examples/tensor-tools.rs b/tensor-tools/src/main.rs similarity index 90% rename from candle-core/examples/tensor-tools.rs rename to tensor-tools/src/main.rs index 5dc49cd891..ad351171f5 100644 --- a/candle-core/examples/tensor-tools.rs +++ b/tensor-tools/src/main.rs @@ -1,5 +1,5 @@ -use candle_core::quantized::{gguf_file, GgmlDType, QTensor}; -use candle_core::{Device, Result}; +use candle::quantized::{gguf_file, GgmlDType, QTensor}; +use candle::{Device, Result}; use clap::{Parser, Subcommand, ValueEnum}; use rayon::prelude::*; @@ -177,10 +177,10 @@ fn run_print( device: &Device, ) -> Result<()> { if full { - candle_core::display::set_print_options_full(); + candle::display::set_print_options_full(); } if let Some(line_width) = line_width { - candle_core::display::set_line_width(line_width) + candle::display::set_line_width(line_width) } let format = match format { Some(format) => format, @@ -196,7 +196,7 @@ fn run_print( }; match format { Format::Npz => { - let tensors = candle_core::npy::NpzTensors::new(file)?; + let tensors = candle::npy::NpzTensors::new(file)?; for name in names.iter() { println!("==== {name} ===="); match tensors.get(name)? { @@ -206,8 +206,8 @@ fn run_print( } } Format::Safetensors => { - use candle_core::safetensors::Load; - let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? }; + use candle::safetensors::Load; + let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? }; let tensors: std::collections::HashMap<_, _> = tensors.tensors().into_iter().collect(); for name in names.iter() { println!("==== {name} ===="); @@ -221,7 +221,7 @@ fn run_print( } } Format::Pth => { - let pth_file = candle_core::pickle::PthTensors::new(file, None)?; + let pth_file = candle::pickle::PthTensors::new(file, None)?; for name in names.iter() { println!("==== {name} ===="); match pth_file.get(name)? { @@ -233,11 +233,11 @@ fn run_print( } } Format::Pickle => { - candle_core::bail!("pickle format is not supported for print") + candle::bail!("pickle format is not supported for print") } Format::Ggml => { let mut file = std::fs::File::open(file)?; - let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?; + let content = candle::quantized::ggml_file::Content::read(&mut file, device)?; for name in names.iter() { println!("==== {name} ===="); match content.tensors.get(name) { @@ -287,7 +287,7 @@ fn run_ls( }; match format { Format::Npz => { - let tensors = candle_core::npy::NpzTensors::new(file)?; + let tensors = candle::npy::NpzTensors::new(file)?; let mut names = tensors.names(); names.sort(); for name in names { @@ -299,12 +299,12 @@ fn run_ls( } } Format::Safetensors => { - let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? }; + let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? }; let mut tensors = tensors.tensors(); tensors.sort_by(|a, b| a.0.cmp(&b.0)); for (name, view) in tensors.iter() { let dtype = view.dtype(); - let dtype = match candle_core::DType::try_from(dtype) { + let dtype = match candle::DType::try_from(dtype) { Ok(dtype) => format!("{dtype:?}"), Err(_) => format!("{dtype:?}"), }; @@ -313,7 +313,7 @@ fn run_ls( } } Format::Pth => { - let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose, None)?; + let mut tensors = candle::pickle::read_pth_tensor_info(file, verbose, None)?; tensors.sort_by(|a, b| a.name.cmp(&b.name)); for tensor_info in tensors.iter() { println!( @@ -330,7 +330,7 @@ fn run_ls( Format::Pickle => { let file = std::fs::File::open(file)?; let mut reader = std::io::BufReader::new(file); - let mut stack = candle_core::pickle::Stack::empty(); + let mut stack = candle::pickle::Stack::empty(); stack.read_loop(&mut reader)?; for (i, obj) in stack.stack().iter().enumerate() { println!("{i} {obj:?}"); @@ -338,7 +338,7 @@ fn run_ls( } Format::Ggml => { let mut file = std::fs::File::open(file)?; - let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?; + let content = candle::quantized::ggml_file::Content::read(&mut file, device)?; let mut tensors = content.tensors.into_iter().collect::>(); tensors.sort_by(|a, b| a.0.cmp(&b.0)); for (name, qtensor) in tensors.iter() { @@ -374,7 +374,7 @@ fn run_quantize_safetensors( let mut out_file = std::fs::File::create(out_file)?; let mut tensors = std::collections::HashMap::new(); for in_file in in_files.iter() { - let in_tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?; + let in_tensors = candle::safetensors::load(in_file, &Device::Cpu)?; tensors.extend(in_tensors) } println!("tensors: {}", tensors.len()); @@ -416,7 +416,7 @@ fn run_dequantize( let tensor = tensor.dequantize(device)?; tensors.insert(tensor_name.to_string(), tensor); } - candle_core::safetensors::save(&tensors, out_file)?; + candle::safetensors::save(&tensors, out_file)?; Ok(()) } @@ -428,11 +428,11 @@ fn run_quantize( device: &Device, ) -> Result<()> { if in_files.is_empty() { - candle_core::bail!("no specified input files") + candle::bail!("no specified input files") } if let Some(extension) = out_file.extension() { if extension == "safetensors" { - candle_core::bail!("the generated file cannot use the safetensors extension") + candle::bail!("the generated file cannot use the safetensors extension") } } if let Some(extension) = in_files[0].extension() { @@ -442,7 +442,7 @@ fn run_quantize( } if in_files.len() != 1 { - candle_core::bail!("only a single in-file can be used when quantizing gguf files") + candle::bail!("only a single in-file can be used when quantizing gguf files") } // Open the out file early so as to fail directly on missing directories etc.