forked from huggingface/candle
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add some preliminary ONNX support (huggingface#1260)
* Add the onnx protos. * Move the reading bits. * Install protoc on the CI. * Install protoc on the cuda CI too. * Use clap for the onnx tool. * Tweak the CI protoc install. * Add some simple evalution function. * Add some binary operator support.
- Loading branch information
1 parent
bfe9511
commit 8cbb9d0
Showing
10 changed files
with
1,033 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
[package] | ||
name = "candle-onnx" | ||
version.workspace = true | ||
edition.workspace = true | ||
description.workspace = true | ||
repository.workspace = true | ||
keywords.workspace = true | ||
categories.workspace = true | ||
license.workspace = true | ||
|
||
[dependencies] | ||
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" } | ||
candle-nn = { path = "../candle-nn", version = "0.3.0" } | ||
prost = "0.12.1" | ||
|
||
[build-dependencies] | ||
prost-build = "0.12.1" | ||
|
||
[dev-dependencies] | ||
anyhow = { workspace = true } | ||
clap = { workspace = true } | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
use std::io::Result; | ||
|
||
fn main() -> Result<()> { | ||
prost_build::compile_protos(&["src/onnx.proto3"], &["src/"])?; | ||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
use anyhow::Result; | ||
use candle::{Device, Tensor}; | ||
|
||
use clap::{Parser, Subcommand}; | ||
|
||
#[derive(Subcommand, Debug, Clone)] | ||
enum Command { | ||
Print { | ||
#[arg(long)] | ||
file: String, | ||
}, | ||
SimpleEval { | ||
#[arg(long)] | ||
file: String, | ||
}, | ||
} | ||
|
||
#[derive(Parser, Debug)] | ||
#[command(author, version, about, long_about = None)] | ||
pub struct Args { | ||
#[command(subcommand)] | ||
command: Command, | ||
} | ||
|
||
pub fn main() -> Result<()> { | ||
let args = Args::parse(); | ||
match args.command { | ||
Command::Print { file } => { | ||
let model = candle_onnx::read_file(file)?; | ||
println!("{model:?}"); | ||
let graph = model.graph.unwrap(); | ||
for node in graph.node.iter() { | ||
println!("{node:?}"); | ||
} | ||
} | ||
Command::SimpleEval { file } => { | ||
let model = candle_onnx::read_file(file)?; | ||
let inputs = model | ||
.graph | ||
.as_ref() | ||
.unwrap() | ||
.input | ||
.iter() | ||
.map(|name| { | ||
let value = Tensor::new(&[-3.2, 2.7], &Device::Cpu)?; | ||
Ok((name.name.clone(), value)) | ||
}) | ||
.collect::<Result<_>>()?; | ||
let outputs = candle_onnx::simple_eval(&model, inputs)?; | ||
for (name, value) in outputs.iter() { | ||
println!("{name}: {value:?}") | ||
} | ||
} | ||
} | ||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
use crate::onnx; | ||
use candle::{Result, Tensor}; | ||
use std::collections::HashMap; | ||
|
||
pub type Value = Tensor; | ||
|
||
// This function provides a direct evaluation of the proto. | ||
// Longer-term, we should first convert the proto to an intermediate representation of the compute | ||
// graph so as to make multiple evaluations more efficient. | ||
// An example upside of this would be to remove intermediary values when they are not needed | ||
// anymore. | ||
pub fn simple_eval( | ||
model: &onnx::ModelProto, | ||
inputs: HashMap<String, Value>, | ||
) -> Result<HashMap<String, Value>> { | ||
let graph = match &model.graph { | ||
None => candle::bail!("no graph defined in proto"), | ||
Some(graph) => graph, | ||
}; | ||
// TODO: validate the inputs. | ||
let mut values = inputs; | ||
// The nodes are topologically sorted so we can just process them in order. | ||
for node in graph.node.iter() { | ||
let get = |input_name: &str| match values.get(input_name) { | ||
Some(value) => Ok(value), | ||
None => candle::bail!("cannot find {input_name} for op {}", node.name), | ||
}; | ||
// TODO: Validate node.input for each operator. | ||
match node.op_type.as_str() { | ||
"Add" => { | ||
let input0 = get(&node.input[0])?; | ||
let input1 = get(&node.input[0])?; | ||
let output = input0.broadcast_add(input1)?; | ||
values.insert(node.output[0].clone(), output); | ||
} | ||
"Sub" => { | ||
let input0 = get(&node.input[0])?; | ||
let input1 = get(&node.input[0])?; | ||
let output = input0.broadcast_sub(input1)?; | ||
values.insert(node.output[0].clone(), output); | ||
} | ||
"Mul" => { | ||
let input0 = get(&node.input[0])?; | ||
let input1 = get(&node.input[0])?; | ||
let output = input0.broadcast_mul(input1)?; | ||
values.insert(node.output[0].clone(), output); | ||
} | ||
"Div" => { | ||
let input0 = get(&node.input[0])?; | ||
let input1 = get(&node.input[0])?; | ||
let output = input0.broadcast_div(input1)?; | ||
values.insert(node.output[0].clone(), output); | ||
} | ||
"MatMul" => { | ||
let input0 = get(&node.input[0])?; | ||
let input1 = get(&node.input[0])?; | ||
let output = input0.broadcast_matmul(input1)?; | ||
values.insert(node.output[0].clone(), output); | ||
} | ||
"Gelu" => { | ||
let input = get(&node.input[0])?; | ||
let output = input.gelu_erf()?; | ||
values.insert(node.output[0].clone(), output); | ||
} | ||
"Relu" => { | ||
let input = get(&node.input[0])?; | ||
let output = input.relu()?; | ||
values.insert(node.output[0].clone(), output); | ||
} | ||
op_type => candle::bail!("unsupported op_type {op_type} for op {}", node.name), | ||
} | ||
} | ||
graph | ||
.output | ||
.iter() | ||
.map(|output| match values.remove(&output.name) { | ||
None => candle::bail!("cannot find output {}", output.name), | ||
Some(value) => Ok((output.name.clone(), value)), | ||
}) | ||
.collect() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
use candle::Result; | ||
use prost::Message; | ||
|
||
pub mod onnx { | ||
include!(concat!(env!("OUT_DIR"), "/onnx.rs")); | ||
} | ||
|
||
mod eval; | ||
pub use eval::simple_eval; | ||
|
||
pub fn read_file<P: AsRef<std::path::Path>>(p: P) -> Result<onnx::ModelProto> { | ||
let buf = std::fs::read(p)?; | ||
onnx::ModelProto::decode(buf.as_slice()).map_err(candle::Error::wrap) | ||
} |
Oops, something went wrong.