Skip to content

Commit

Permalink
Add some preliminary ONNX support (huggingface#1260)
Browse files Browse the repository at this point in the history
* Add the onnx protos.

* Move the reading bits.

* Install protoc on the CI.

* Install protoc on the cuda CI too.

* Use clap for the onnx tool.

* Tweak the CI protoc install.

* Add some simple evalution function.

* Add some binary operator support.
  • Loading branch information
LaurentMazare authored Nov 4, 2023
1 parent bfe9511 commit 8cbb9d0
Show file tree
Hide file tree
Showing 10 changed files with 1,033 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/ci_cuda.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jobs:
- name: Install Rust Stable
run: curl https://sh.rustup.rs -sSf | sh -s -- -y
- uses: Swatinem/rust-cache@v2
- run: apt-get update -y && apt-get install libssl-dev -y
- run: apt-get update -y && apt-get install libssl-dev protobuf-compiler -y
- name: Test (cuda)
run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda
stop-runner:
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/rust-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ jobs:
rust: [stable]
steps:
- uses: actions/checkout@v2
- uses: arduino/setup-protoc@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
Expand All @@ -35,6 +36,7 @@ jobs:
rust: [stable]
steps:
- uses: actions/checkout@v2
- uses: arduino/setup-protoc@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
Expand All @@ -50,6 +52,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: arduino/setup-protoc@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
Expand All @@ -66,6 +69,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: arduino/setup-protoc@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
Expand Down
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ members = [
"candle-examples",
"candle-book",
"candle-nn",
"candle-onnx",
"candle-pyo3",
"candle-transformers",
"candle-wasm-examples/*",
Expand Down
22 changes: 22 additions & 0 deletions candle-onnx/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
[package]
name = "candle-onnx"
version.workspace = true
edition.workspace = true
description.workspace = true
repository.workspace = true
keywords.workspace = true
categories.workspace = true
license.workspace = true

[dependencies]
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-nn = { path = "../candle-nn", version = "0.3.0" }
prost = "0.12.1"

[build-dependencies]
prost-build = "0.12.1"

[dev-dependencies]
anyhow = { workspace = true }
clap = { workspace = true }

6 changes: 6 additions & 0 deletions candle-onnx/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
use std::io::Result;

fn main() -> Result<()> {
prost_build::compile_protos(&["src/onnx.proto3"], &["src/"])?;
Ok(())
}
56 changes: 56 additions & 0 deletions candle-onnx/examples/onnx_basics.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
use anyhow::Result;
use candle::{Device, Tensor};

use clap::{Parser, Subcommand};

#[derive(Subcommand, Debug, Clone)]
enum Command {
Print {
#[arg(long)]
file: String,
},
SimpleEval {
#[arg(long)]
file: String,
},
}

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
pub struct Args {
#[command(subcommand)]
command: Command,
}

pub fn main() -> Result<()> {
let args = Args::parse();
match args.command {
Command::Print { file } => {
let model = candle_onnx::read_file(file)?;
println!("{model:?}");
let graph = model.graph.unwrap();
for node in graph.node.iter() {
println!("{node:?}");
}
}
Command::SimpleEval { file } => {
let model = candle_onnx::read_file(file)?;
let inputs = model
.graph
.as_ref()
.unwrap()
.input
.iter()
.map(|name| {
let value = Tensor::new(&[-3.2, 2.7], &Device::Cpu)?;
Ok((name.name.clone(), value))
})
.collect::<Result<_>>()?;
let outputs = candle_onnx::simple_eval(&model, inputs)?;
for (name, value) in outputs.iter() {
println!("{name}: {value:?}")
}
}
}
Ok(())
}
81 changes: 81 additions & 0 deletions candle-onnx/src/eval.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
use crate::onnx;
use candle::{Result, Tensor};
use std::collections::HashMap;

pub type Value = Tensor;

// This function provides a direct evaluation of the proto.
// Longer-term, we should first convert the proto to an intermediate representation of the compute
// graph so as to make multiple evaluations more efficient.
// An example upside of this would be to remove intermediary values when they are not needed
// anymore.
pub fn simple_eval(
model: &onnx::ModelProto,
inputs: HashMap<String, Value>,
) -> Result<HashMap<String, Value>> {
let graph = match &model.graph {
None => candle::bail!("no graph defined in proto"),
Some(graph) => graph,
};
// TODO: validate the inputs.
let mut values = inputs;
// The nodes are topologically sorted so we can just process them in order.
for node in graph.node.iter() {
let get = |input_name: &str| match values.get(input_name) {
Some(value) => Ok(value),
None => candle::bail!("cannot find {input_name} for op {}", node.name),
};
// TODO: Validate node.input for each operator.
match node.op_type.as_str() {
"Add" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[0])?;
let output = input0.broadcast_add(input1)?;
values.insert(node.output[0].clone(), output);
}
"Sub" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[0])?;
let output = input0.broadcast_sub(input1)?;
values.insert(node.output[0].clone(), output);
}
"Mul" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[0])?;
let output = input0.broadcast_mul(input1)?;
values.insert(node.output[0].clone(), output);
}
"Div" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[0])?;
let output = input0.broadcast_div(input1)?;
values.insert(node.output[0].clone(), output);
}
"MatMul" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[0])?;
let output = input0.broadcast_matmul(input1)?;
values.insert(node.output[0].clone(), output);
}
"Gelu" => {
let input = get(&node.input[0])?;
let output = input.gelu_erf()?;
values.insert(node.output[0].clone(), output);
}
"Relu" => {
let input = get(&node.input[0])?;
let output = input.relu()?;
values.insert(node.output[0].clone(), output);
}
op_type => candle::bail!("unsupported op_type {op_type} for op {}", node.name),
}
}
graph
.output
.iter()
.map(|output| match values.remove(&output.name) {
None => candle::bail!("cannot find output {}", output.name),
Some(value) => Ok((output.name.clone(), value)),
})
.collect()
}
14 changes: 14 additions & 0 deletions candle-onnx/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
use candle::Result;
use prost::Message;

pub mod onnx {
include!(concat!(env!("OUT_DIR"), "/onnx.rs"));
}

mod eval;
pub use eval::simple_eval;

pub fn read_file<P: AsRef<std::path::Path>>(p: P) -> Result<onnx::ModelProto> {
let buf = std::fs::read(p)?;
onnx::ModelProto::decode(buf.as_slice()).map_err(candle::Error::wrap)
}
Loading

0 comments on commit 8cbb9d0

Please sign in to comment.