Skip to content

Commit

Permalink
Move the tensor-tools binary in a separate crate. (huggingface#1969)
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare authored Mar 30, 2024
1 parent b190fd8 commit 3144150
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 22 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ members = [
"candle-transformers",
"candle-wasm-examples/*",
"candle-wasm-tests",
"tensor-tools",
]
exclude = [
"candle-flash-attn",
Expand Down
2 changes: 1 addition & 1 deletion candle-examples/examples/quantized-t5/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ generate quantized weight files from the original safetensors file by using the
`tensor-tools` command line utility via:

```bash
$ cargo run --example tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
$ cargo run --bin tensor-tools --release -- quantize --quantization q6k PATH/TO/T5/model.safetensors /tmp/model.gguf
```

## Using custom models
Expand Down
16 changes: 16 additions & 0 deletions tensor-tools/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
[package]
name = "tensor-tools"
version.workspace = true
edition.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true

[dependencies]
anyhow = { workspace = true }
candle = { workspace = true }
clap = { workspace = true }
rayon = { workspace = true }
safetensors = { workspace = true }
42 changes: 21 additions & 21 deletions candle-core/examples/tensor-tools.rs → tensor-tools/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use candle_core::quantized::{gguf_file, GgmlDType, QTensor};
use candle_core::{Device, Result};
use candle::quantized::{gguf_file, GgmlDType, QTensor};
use candle::{Device, Result};
use clap::{Parser, Subcommand, ValueEnum};
use rayon::prelude::*;

Expand Down Expand Up @@ -177,10 +177,10 @@ fn run_print(
device: &Device,
) -> Result<()> {
if full {
candle_core::display::set_print_options_full();
candle::display::set_print_options_full();
}
if let Some(line_width) = line_width {
candle_core::display::set_line_width(line_width)
candle::display::set_line_width(line_width)
}
let format = match format {
Some(format) => format,
Expand All @@ -196,7 +196,7 @@ fn run_print(
};
match format {
Format::Npz => {
let tensors = candle_core::npy::NpzTensors::new(file)?;
let tensors = candle::npy::NpzTensors::new(file)?;
for name in names.iter() {
println!("==== {name} ====");
match tensors.get(name)? {
Expand All @@ -206,8 +206,8 @@ fn run_print(
}
}
Format::Safetensors => {
use candle_core::safetensors::Load;
let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? };
use candle::safetensors::Load;
let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? };
let tensors: std::collections::HashMap<_, _> = tensors.tensors().into_iter().collect();
for name in names.iter() {
println!("==== {name} ====");
Expand All @@ -221,7 +221,7 @@ fn run_print(
}
}
Format::Pth => {
let pth_file = candle_core::pickle::PthTensors::new(file, None)?;
let pth_file = candle::pickle::PthTensors::new(file, None)?;
for name in names.iter() {
println!("==== {name} ====");
match pth_file.get(name)? {
Expand All @@ -233,11 +233,11 @@ fn run_print(
}
}
Format::Pickle => {
candle_core::bail!("pickle format is not supported for print")
candle::bail!("pickle format is not supported for print")
}
Format::Ggml => {
let mut file = std::fs::File::open(file)?;
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?;
let content = candle::quantized::ggml_file::Content::read(&mut file, device)?;
for name in names.iter() {
println!("==== {name} ====");
match content.tensors.get(name) {
Expand Down Expand Up @@ -287,7 +287,7 @@ fn run_ls(
};
match format {
Format::Npz => {
let tensors = candle_core::npy::NpzTensors::new(file)?;
let tensors = candle::npy::NpzTensors::new(file)?;
let mut names = tensors.names();
names.sort();
for name in names {
Expand All @@ -299,12 +299,12 @@ fn run_ls(
}
}
Format::Safetensors => {
let tensors = unsafe { candle_core::safetensors::MmapedSafetensors::new(file)? };
let tensors = unsafe { candle::safetensors::MmapedSafetensors::new(file)? };
let mut tensors = tensors.tensors();
tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, view) in tensors.iter() {
let dtype = view.dtype();
let dtype = match candle_core::DType::try_from(dtype) {
let dtype = match candle::DType::try_from(dtype) {
Ok(dtype) => format!("{dtype:?}"),
Err(_) => format!("{dtype:?}"),
};
Expand All @@ -313,7 +313,7 @@ fn run_ls(
}
}
Format::Pth => {
let mut tensors = candle_core::pickle::read_pth_tensor_info(file, verbose, None)?;
let mut tensors = candle::pickle::read_pth_tensor_info(file, verbose, None)?;
tensors.sort_by(|a, b| a.name.cmp(&b.name));
for tensor_info in tensors.iter() {
println!(
Expand All @@ -330,15 +330,15 @@ fn run_ls(
Format::Pickle => {
let file = std::fs::File::open(file)?;
let mut reader = std::io::BufReader::new(file);
let mut stack = candle_core::pickle::Stack::empty();
let mut stack = candle::pickle::Stack::empty();
stack.read_loop(&mut reader)?;
for (i, obj) in stack.stack().iter().enumerate() {
println!("{i} {obj:?}");
}
}
Format::Ggml => {
let mut file = std::fs::File::open(file)?;
let content = candle_core::quantized::ggml_file::Content::read(&mut file, device)?;
let content = candle::quantized::ggml_file::Content::read(&mut file, device)?;
let mut tensors = content.tensors.into_iter().collect::<Vec<_>>();
tensors.sort_by(|a, b| a.0.cmp(&b.0));
for (name, qtensor) in tensors.iter() {
Expand Down Expand Up @@ -374,7 +374,7 @@ fn run_quantize_safetensors(
let mut out_file = std::fs::File::create(out_file)?;
let mut tensors = std::collections::HashMap::new();
for in_file in in_files.iter() {
let in_tensors = candle_core::safetensors::load(in_file, &Device::Cpu)?;
let in_tensors = candle::safetensors::load(in_file, &Device::Cpu)?;
tensors.extend(in_tensors)
}
println!("tensors: {}", tensors.len());
Expand Down Expand Up @@ -416,7 +416,7 @@ fn run_dequantize(
let tensor = tensor.dequantize(device)?;
tensors.insert(tensor_name.to_string(), tensor);
}
candle_core::safetensors::save(&tensors, out_file)?;
candle::safetensors::save(&tensors, out_file)?;
Ok(())
}

Expand All @@ -428,11 +428,11 @@ fn run_quantize(
device: &Device,
) -> Result<()> {
if in_files.is_empty() {
candle_core::bail!("no specified input files")
candle::bail!("no specified input files")
}
if let Some(extension) = out_file.extension() {
if extension == "safetensors" {
candle_core::bail!("the generated file cannot use the safetensors extension")
candle::bail!("the generated file cannot use the safetensors extension")
}
}
if let Some(extension) = in_files[0].extension() {
Expand All @@ -442,7 +442,7 @@ fn run_quantize(
}

if in_files.len() != 1 {
candle_core::bail!("only a single in-file can be used when quantizing gguf files")
candle::bail!("only a single in-file can be used when quantizing gguf files")
}

// Open the out file early so as to fail directly on missing directories etc.
Expand Down

0 comments on commit 3144150

Please sign in to comment.