From 8cbb9d0e6ce57a8dbfc685f3121ed9d01b02726e Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 4 Nov 2023 06:36:05 +0100 Subject: [PATCH] Add some preliminary ONNX support (#1260) * Add the onnx protos. * Move the reading bits. * Install protoc on the CI. * Install protoc on the cuda CI too. * Use clap for the onnx tool. * Tweak the CI protoc install. * Add some simple evalution function. * Add some binary operator support. --- .github/workflows/ci_cuda.yaml | 2 +- .github/workflows/rust-ci.yml | 4 + Cargo.toml | 1 + candle-onnx/Cargo.toml | 22 + candle-onnx/build.rs | 6 + candle-onnx/examples/onnx_basics.rs | 56 ++ candle-onnx/src/eval.rs | 81 +++ candle-onnx/src/lib.rs | 14 + candle-onnx/src/onnx.proto3 | 836 ++++++++++++++++++++++++++++ test.onnx | 12 + 10 files changed, 1033 insertions(+), 1 deletion(-) create mode 100644 candle-onnx/Cargo.toml create mode 100644 candle-onnx/build.rs create mode 100644 candle-onnx/examples/onnx_basics.rs create mode 100644 candle-onnx/src/eval.rs create mode 100644 candle-onnx/src/lib.rs create mode 100644 candle-onnx/src/onnx.proto3 create mode 100644 test.onnx diff --git a/.github/workflows/ci_cuda.yaml b/.github/workflows/ci_cuda.yaml index 8953c44432..ec792a25a4 100644 --- a/.github/workflows/ci_cuda.yaml +++ b/.github/workflows/ci_cuda.yaml @@ -59,7 +59,7 @@ jobs: - name: Install Rust Stable run: curl https://sh.rustup.rs -sSf | sh -s -- -y - uses: Swatinem/rust-cache@v2 - - run: apt-get update -y && apt-get install libssl-dev -y + - run: apt-get update -y && apt-get install libssl-dev protobuf-compiler -y - name: Test (cuda) run: PATH=$PATH:/usr/local/cuda-11.8/bin/ /root/.cargo/bin/cargo test --features cuda stop-runner: diff --git a/.github/workflows/rust-ci.yml b/.github/workflows/rust-ci.yml index 2ca53b2348..b435bdfa1a 100644 --- a/.github/workflows/rust-ci.yml +++ b/.github/workflows/rust-ci.yml @@ -16,6 +16,7 @@ jobs: rust: [stable] steps: - uses: actions/checkout@v2 + - uses: arduino/setup-protoc@v2 - uses: actions-rs/toolchain@v1 with: profile: minimal @@ -35,6 +36,7 @@ jobs: rust: [stable] steps: - uses: actions/checkout@v2 + - uses: arduino/setup-protoc@v2 - uses: actions-rs/toolchain@v1 with: profile: minimal @@ -50,6 +52,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 + - uses: arduino/setup-protoc@v2 - uses: actions-rs/toolchain@v1 with: profile: minimal @@ -66,6 +69,7 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 + - uses: arduino/setup-protoc@v2 - uses: actions-rs/toolchain@v1 with: profile: minimal diff --git a/Cargo.toml b/Cargo.toml index 89ffe530b9..5a7ea75940 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,6 +5,7 @@ members = [ "candle-examples", "candle-book", "candle-nn", + "candle-onnx", "candle-pyo3", "candle-transformers", "candle-wasm-examples/*", diff --git a/candle-onnx/Cargo.toml b/candle-onnx/Cargo.toml new file mode 100644 index 0000000000..a4817e438a --- /dev/null +++ b/candle-onnx/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "candle-onnx" +version.workspace = true +edition.workspace = true +description.workspace = true +repository.workspace = true +keywords.workspace = true +categories.workspace = true +license.workspace = true + +[dependencies] +candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" } +candle-nn = { path = "../candle-nn", version = "0.3.0" } +prost = "0.12.1" + +[build-dependencies] +prost-build = "0.12.1" + +[dev-dependencies] +anyhow = { workspace = true } +clap = { workspace = true } + diff --git a/candle-onnx/build.rs b/candle-onnx/build.rs new file mode 100644 index 0000000000..79e7a39d7e --- /dev/null +++ b/candle-onnx/build.rs @@ -0,0 +1,6 @@ +use std::io::Result; + +fn main() -> Result<()> { + prost_build::compile_protos(&["src/onnx.proto3"], &["src/"])?; + Ok(()) +} diff --git a/candle-onnx/examples/onnx_basics.rs b/candle-onnx/examples/onnx_basics.rs new file mode 100644 index 0000000000..b91cbee6b2 --- /dev/null +++ b/candle-onnx/examples/onnx_basics.rs @@ -0,0 +1,56 @@ +use anyhow::Result; +use candle::{Device, Tensor}; + +use clap::{Parser, Subcommand}; + +#[derive(Subcommand, Debug, Clone)] +enum Command { + Print { + #[arg(long)] + file: String, + }, + SimpleEval { + #[arg(long)] + file: String, + }, +} + +#[derive(Parser, Debug)] +#[command(author, version, about, long_about = None)] +pub struct Args { + #[command(subcommand)] + command: Command, +} + +pub fn main() -> Result<()> { + let args = Args::parse(); + match args.command { + Command::Print { file } => { + let model = candle_onnx::read_file(file)?; + println!("{model:?}"); + let graph = model.graph.unwrap(); + for node in graph.node.iter() { + println!("{node:?}"); + } + } + Command::SimpleEval { file } => { + let model = candle_onnx::read_file(file)?; + let inputs = model + .graph + .as_ref() + .unwrap() + .input + .iter() + .map(|name| { + let value = Tensor::new(&[-3.2, 2.7], &Device::Cpu)?; + Ok((name.name.clone(), value)) + }) + .collect::>()?; + let outputs = candle_onnx::simple_eval(&model, inputs)?; + for (name, value) in outputs.iter() { + println!("{name}: {value:?}") + } + } + } + Ok(()) +} diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs new file mode 100644 index 0000000000..fe112fddab --- /dev/null +++ b/candle-onnx/src/eval.rs @@ -0,0 +1,81 @@ +use crate::onnx; +use candle::{Result, Tensor}; +use std::collections::HashMap; + +pub type Value = Tensor; + +// This function provides a direct evaluation of the proto. +// Longer-term, we should first convert the proto to an intermediate representation of the compute +// graph so as to make multiple evaluations more efficient. +// An example upside of this would be to remove intermediary values when they are not needed +// anymore. +pub fn simple_eval( + model: &onnx::ModelProto, + inputs: HashMap, +) -> Result> { + let graph = match &model.graph { + None => candle::bail!("no graph defined in proto"), + Some(graph) => graph, + }; + // TODO: validate the inputs. + let mut values = inputs; + // The nodes are topologically sorted so we can just process them in order. + for node in graph.node.iter() { + let get = |input_name: &str| match values.get(input_name) { + Some(value) => Ok(value), + None => candle::bail!("cannot find {input_name} for op {}", node.name), + }; + // TODO: Validate node.input for each operator. + match node.op_type.as_str() { + "Add" => { + let input0 = get(&node.input[0])?; + let input1 = get(&node.input[0])?; + let output = input0.broadcast_add(input1)?; + values.insert(node.output[0].clone(), output); + } + "Sub" => { + let input0 = get(&node.input[0])?; + let input1 = get(&node.input[0])?; + let output = input0.broadcast_sub(input1)?; + values.insert(node.output[0].clone(), output); + } + "Mul" => { + let input0 = get(&node.input[0])?; + let input1 = get(&node.input[0])?; + let output = input0.broadcast_mul(input1)?; + values.insert(node.output[0].clone(), output); + } + "Div" => { + let input0 = get(&node.input[0])?; + let input1 = get(&node.input[0])?; + let output = input0.broadcast_div(input1)?; + values.insert(node.output[0].clone(), output); + } + "MatMul" => { + let input0 = get(&node.input[0])?; + let input1 = get(&node.input[0])?; + let output = input0.broadcast_matmul(input1)?; + values.insert(node.output[0].clone(), output); + } + "Gelu" => { + let input = get(&node.input[0])?; + let output = input.gelu_erf()?; + values.insert(node.output[0].clone(), output); + } + "Relu" => { + let input = get(&node.input[0])?; + let output = input.relu()?; + values.insert(node.output[0].clone(), output); + } + op_type => candle::bail!("unsupported op_type {op_type} for op {}", node.name), + } + } + graph + .output + .iter() + .map(|output| match values.remove(&output.name) { + None => candle::bail!("cannot find output {}", output.name), + Some(value) => Ok((output.name.clone(), value)), + }) + .collect() +} diff --git a/candle-onnx/src/lib.rs b/candle-onnx/src/lib.rs new file mode 100644 index 0000000000..3b36c4cf70 --- /dev/null +++ b/candle-onnx/src/lib.rs @@ -0,0 +1,14 @@ +use candle::Result; +use prost::Message; + +pub mod onnx { + include!(concat!(env!("OUT_DIR"), "/onnx.rs")); +} + +mod eval; +pub use eval::simple_eval; + +pub fn read_file>(p: P) -> Result { + let buf = std::fs::read(p)?; + onnx::ModelProto::decode(buf.as_slice()).map_err(candle::Error::wrap) +} diff --git a/candle-onnx/src/onnx.proto3 b/candle-onnx/src/onnx.proto3 new file mode 100644 index 0000000000..f47006f8c9 --- /dev/null +++ b/candle-onnx/src/onnx.proto3 @@ -0,0 +1,836 @@ +// +// WARNING: This file is automatically generated! Please edit onnx.in.proto. +// + + +// SPDX-License-Identifier: Apache-2.0 + + +syntax = "proto3"; + +package onnx; + +// Overview +// +// ONNX is an open specification that is comprised of the following components: +// +// 1) A definition of an extensible computation graph model. +// 2) Definitions of standard data types. +// 3) Definitions of built-in operators. +// +// This document describes the syntax of models and their computation graphs, +// as well as the standard data types. Together, they are referred to as the ONNX +// Intermediate Representation, or 'IR' for short. +// +// The normative semantic specification of the ONNX IR is found in docs/IR.md. +// Definitions of the built-in neural network operators may be found in docs/Operators.md. + +// Notes +// +// Protobuf compatibility +// +// To simplify framework compatibility, ONNX is defined using the subset of protobuf +// that is compatible with both protobuf v2 and v3. This means that we do not use any +// protobuf features that are only available in one of the two versions. +// +// Here are the most notable contortions we have to carry out to work around +// these limitations: +// +// - No 'map' (added protobuf 3.0). We instead represent mappings as lists +// of key-value pairs, where order does not matter and duplicates +// are not allowed. + + +// Versioning +// +// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md +// +// To be compatible with both proto2 and proto3, we will use a version number +// that is not defined by the default value but an explicit enum number. +enum Version { + // proto3 requires the first enum value to be zero. + // We add this just to appease the compiler. + _START_VERSION = 0; + // The version field is always serialized and we will use it to store the + // version that the graph is generated from. This helps us set up version + // control. + // For the IR, we are using simple numbers starting with 0x00000001, + // which was the version we published on Oct 10, 2017. + IR_VERSION_2017_10_10 = 0x0000000000000001; + + // IR_VERSION 2 published on Oct 30, 2017 + // - Added type discriminator to AttributeProto to support proto3 users + IR_VERSION_2017_10_30 = 0x0000000000000002; + + // IR VERSION 3 published on Nov 3, 2017 + // - For operator versioning: + // - Added new message OperatorSetIdProto + // - Added opset_import in ModelProto + // - For vendor extensions, added domain in NodeProto + IR_VERSION_2017_11_3 = 0x0000000000000003; + + // IR VERSION 4 published on Jan 22, 2019 + // - Relax constraint that initializers should be a subset of graph inputs + // - Add type BFLOAT16 + IR_VERSION_2019_1_22 = 0x0000000000000004; + + // IR VERSION 5 published on March 18, 2019 + // - Add message TensorAnnotation. + // - Add quantization annotation in GraphProto to map tensor with its scale and zero point quantization parameters. + IR_VERSION_2019_3_18 = 0x0000000000000005; + + // IR VERSION 6 published on Sep 19, 2019 + // - Add support for sparse tensor constants stored in model. + // - Add message SparseTensorProto + // - Add sparse initializers + IR_VERSION_2019_9_19 = 0x0000000000000006; + + // IR VERSION 7 published on May 8, 2020 + // - Add support to allow function body graph to rely on multiple external opreator sets. + // - Add a list to promote inference graph's initializers to global and + // mutable variables. Global variables are visible in all graphs of the + // stored models. + // - Add message TrainingInfoProto to store initialization + // method and training algorithm. The execution of TrainingInfoProto + // can modify the values of mutable variables. + // - Implicitly add inference graph into each TrainingInfoProto's algorithm. + IR_VERSION_2020_5_8 = 0x0000000000000007; + + // IR VERSION 8 published on July 30, 2021 + // Introduce TypeProto.SparseTensor + // Introduce TypeProto.Optional + // Added a list of FunctionProtos local to the model + // Deprecated since_version and operator status from FunctionProto + IR_VERSION_2021_7_30 = 0x0000000000000008; + + // IR VERSION 9 published on May 5, 2023 + // Added AttributeProto to FunctionProto so that default attribute values can be set. + // Added FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ. + IR_VERSION = 0x0000000000000009; +} + +// Attributes +// +// A named attribute containing either singular float, integer, string, graph, +// and tensor values, or repeated float, integer, string, graph, and tensor values. +// An AttributeProto MUST contain the name field, and *only one* of the +// following content fields, effectively enforcing a C/C++ union equivalent. +message AttributeProto { + reserved 12, 16 to 19; + reserved "v"; + + // Note: this enum is structurally identical to the OpSchema::AttrType + // enum defined in schema.h. If you rev one, you likely need to rev the other. + enum AttributeType { + UNDEFINED = 0; + FLOAT = 1; + INT = 2; + STRING = 3; + TENSOR = 4; + GRAPH = 5; + SPARSE_TENSOR = 11; + TYPE_PROTO = 13; + + FLOATS = 6; + INTS = 7; + STRINGS = 8; + TENSORS = 9; + GRAPHS = 10; + SPARSE_TENSORS = 12; + TYPE_PROTOS = 14; + } + + // The name field MUST be present for this version of the IR. + string name = 1; // namespace Attribute + + // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. + // In this case, this AttributeProto does not contain data, and it's a reference of attribute + // in parent scope. + // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. + string ref_attr_name = 21; + + // A human-readable documentation for this attribute. Markdown is allowed. + string doc_string = 13; + + // The type field MUST be present for this version of the IR. + // For 0.0.1 versions of the IR, this field was not defined, and + // implementations needed to use has_field heuristics to determine + // which value field was in use. For IR_VERSION 0.0.2 or later, this + // field MUST be set and match the f|i|s|t|... field in use. This + // change was made to accommodate proto3 implementations. + AttributeType type = 20; // discriminator that indicates which field below is in use + + // Exactly ONE of the following fields must be present for this version of the IR + float f = 2; // float + int64 i = 3; // int + bytes s = 4; // UTF-8 string + TensorProto t = 5; // tensor value + GraphProto g = 6; // graph + SparseTensorProto sparse_tensor = 22; // sparse tensor value + // Do not use field below, it's deprecated. + // optional ValueProto v = 12; // value - subsumes everything but graph + TypeProto tp = 14; // type proto + + repeated float floats = 7; // list of floats + repeated int64 ints = 8; // list of ints + repeated bytes strings = 9; // list of UTF-8 strings + repeated TensorProto tensors = 10; // list of tensors + repeated GraphProto graphs = 11; // list of graph + repeated SparseTensorProto sparse_tensors = 23; // list of sparse tensors + repeated TypeProto type_protos = 15;// list of type protos +} + +// Defines information on value, including the name, the type, and +// the shape of the value. +message ValueInfoProto { + // This field MUST be present in this version of the IR. + string name = 1; // namespace Value + // This field MUST be present in this version of the IR for + // inputs and outputs of the top-level graph. + TypeProto type = 2; + // A human-readable documentation for this value. Markdown is allowed. + string doc_string = 3; +} + +// Nodes +// +// Computation graphs are made up of a DAG of nodes, which represent what is +// commonly called a "layer" or "pipeline stage" in machine learning frameworks. +// +// For example, it can be a node of type "Conv" that takes in an image, a filter +// tensor and a bias tensor, and produces the convolved output. +message NodeProto { + repeated string input = 1; // namespace Value + repeated string output = 2; // namespace Value + + // An optional identifier for this node in a graph. + // This field MAY be absent in ths version of the IR. + string name = 3; // namespace Node + + // The symbolic identifier of the Operator to execute. + string op_type = 4; // namespace Operator + // The domain of the OperatorSet that specifies the operator named by op_type. + string domain = 7; // namespace Domain + + // Additional named attributes. + repeated AttributeProto attribute = 5; + + // A human-readable documentation for this node. Markdown is allowed. + string doc_string = 6; +} + +// Training information +// TrainingInfoProto stores information for training a model. +// In particular, this defines two functionalities: an initialization-step +// and a training-algorithm-step. Initialization resets the model +// back to its original state as if no training has been performed. +// Training algorithm improves the model based on input data. +// +// The semantics of the initialization-step is that the initializers +// in ModelProto.graph and in TrainingInfoProto.algorithm are first +// initialized as specified by the initializers in the graph, and then +// updated by the "initialization_binding" in every instance in +// ModelProto.training_info. +// +// The field "algorithm" defines a computation graph which represents a +// training algorithm's step. After the execution of a +// TrainingInfoProto.algorithm, the initializers specified by "update_binding" +// may be immediately updated. If the targeted training algorithm contains +// consecutive update steps (such as block coordinate descent methods), +// the user needs to create a TrainingInfoProto for each step. +message TrainingInfoProto { + // This field describes a graph to compute the initial tensors + // upon starting the training process. Initialization graph has no input + // and can have multiple outputs. Usually, trainable tensors in neural + // networks are randomly initialized. To achieve that, for each tensor, + // the user can put a random number operator such as RandomNormal or + // RandomUniform in TrainingInfoProto.initialization.node and assign its + // random output to the specific tensor using "initialization_binding". + // This graph can also set the initializers in "algorithm" in the same + // TrainingInfoProto; a use case is resetting the number of training + // iteration to zero. + // + // By default, this field is an empty graph and its evaluation does not + // produce any output. Thus, no initializer would be changed by default. + GraphProto initialization = 1; + + // This field represents a training algorithm step. Given required inputs, + // it computes outputs to update initializers in its own or inference graph's + // initializer lists. In general, this field contains loss node, gradient node, + // optimizer node, increment of iteration count. + // + // An execution of the training algorithm step is performed by executing the + // graph obtained by combining the inference graph (namely "ModelProto.graph") + // and the "algorithm" graph. That is, the actual + // input/initializer/output/node/value_info/sparse_initializer list of + // the training graph is the concatenation of + // "ModelProto.graph.input/initializer/output/node/value_info/sparse_initializer" + // and "algorithm.input/initializer/output/node/value_info/sparse_initializer" + // in that order. This combined graph must satisfy the normal ONNX conditions. + // Now, let's provide a visualization of graph combination for clarity. + // Let the inference graph (i.e., "ModelProto.graph") be + // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d + // and the "algorithm" graph be + // tensor_d -> Add -> tensor_e + // The combination process results + // tensor_a, tensor_b -> MatMul -> tensor_c -> Sigmoid -> tensor_d -> Add -> tensor_e + // + // Notice that an input of a node in the "algorithm" graph may reference the + // output of a node in the inference graph (but not the other way round). Also, inference + // node cannot reference inputs of "algorithm". With these restrictions, inference graph + // can always be run independently without training information. + // + // By default, this field is an empty graph and its evaluation does not + // produce any output. Evaluating the default training step never + // update any initializers. + GraphProto algorithm = 2; + + // This field specifies the bindings from the outputs of "initialization" to + // some initializers in "ModelProto.graph.initializer" and + // the "algorithm.initializer" in the same TrainingInfoProto. + // See "update_binding" below for details. + // + // By default, this field is empty and no initializer would be changed + // by the execution of "initialization". + repeated StringStringEntryProto initialization_binding = 3; + + // Gradient-based training is usually an iterative procedure. In one gradient + // descent iteration, we apply + // + // x = x - r * g + // + // where "x" is the optimized tensor, "r" stands for learning rate, and "g" is + // gradient of "x" with respect to a chosen loss. To avoid adding assignments + // into the training graph, we split the update equation into + // + // y = x - r * g + // x = y + // + // The user needs to save "y = x - r * g" into TrainingInfoProto.algorithm. To + // tell that "y" should be assigned to "x", the field "update_binding" may + // contain a key-value pair of strings, "x" (key of StringStringEntryProto) + // and "y" (value of StringStringEntryProto). + // For a neural network with multiple trainable (mutable) tensors, there can + // be multiple key-value pairs in "update_binding". + // + // The initializers appears as keys in "update_binding" are considered + // mutable variables. This implies some behaviors + // as described below. + // + // 1. We have only unique keys in all "update_binding"s so that two + // variables may not have the same name. This ensures that one + // variable is assigned up to once. + // 2. The keys must appear in names of "ModelProto.graph.initializer" or + // "TrainingInfoProto.algorithm.initializer". + // 3. The values must be output names of "algorithm" or "ModelProto.graph.output". + // 4. Mutable variables are initialized to the value specified by the + // corresponding initializer, and then potentially updated by + // "initializer_binding"s and "update_binding"s in "TrainingInfoProto"s. + // + // This field usually contains names of trainable tensors + // (in ModelProto.graph), optimizer states such as momentums in advanced + // stochastic gradient methods (in TrainingInfoProto.graph), + // and number of training iterations (in TrainingInfoProto.graph). + // + // By default, this field is empty and no initializer would be changed + // by the execution of "algorithm". + repeated StringStringEntryProto update_binding = 4; +} + +// Models +// +// ModelProto is a top-level file/container format for bundling a ML model and +// associating its computation graph with metadata. +// +// The semantics of the model are described by the associated GraphProto's. +message ModelProto { + // The version of the IR this model targets. See Version enum above. + // This field MUST be present. + int64 ir_version = 1; + + // The OperatorSets this model relies on. + // All ModelProtos MUST have at least one entry that + // specifies which version of the ONNX OperatorSet is + // being imported. + // + // All nodes in the ModelProto's graph will bind against the operator + // with the same-domain/same-op_type operator with the HIGHEST version + // in the referenced operator sets. + repeated OperatorSetIdProto opset_import = 8; + + // The name of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + string producer_name = 2; + + // The version of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + string producer_version = 3; + + // Domain name of the model. + // We use reverse domain names as name space indicators. For example: + // `com.facebook.fair` or `com.microsoft.cognitiveservices` + // + // Together with `model_version` and GraphProto.name, this forms the unique identity of + // the graph. + string domain = 4; + + // The version of the graph encoded. See Version enum below. + int64 model_version = 5; + + // A human-readable documentation for this model. Markdown is allowed. + string doc_string = 6; + + // The parameterized graph that is evaluated to execute the model. + GraphProto graph = 7; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 14; + + // Training-specific information. Sequentially executing all stored + // `TrainingInfoProto.algorithm`s and assigning their outputs following + // the corresponding `TrainingInfoProto.update_binding`s is one training + // iteration. Similarly, to initialize the model + // (as if training hasn't happened), the user should sequentially execute + // all stored `TrainingInfoProto.initialization`s and assigns their outputs + // using `TrainingInfoProto.initialization_binding`s. + // + // If this field is empty, the training behavior of the model is undefined. + repeated TrainingInfoProto training_info = 20; + + // A list of function protos local to the model. + // + // Name of the function "FunctionProto.name" should be unique within the domain "FunctionProto.domain". + // In case of any conflicts the behavior (whether the model local functions are given higher priority, + // or standard operator sets are given higher priotity or this is treated as error) is defined by + // the runtimes. + // + // The operator sets imported by FunctionProto should be compatible with the ones + // imported by ModelProto and other model local FunctionProtos. + // Example, if same operator set say 'A' is imported by a FunctionProto and ModelProto + // or by 2 FunctionProtos then versions for the operator set may be different but, + // the operator schema returned for op_type, domain, version combination + // for both the versions should be same for every node in the function body. + // + // One FunctionProto can reference other FunctionProto in the model, however, recursive reference + // is not allowed. + repeated FunctionProto functions = 25; +}; + +// StringStringEntryProto follows the pattern for cross-proto-version maps. +// See https://developers.google.com/protocol-buffers/docs/proto3#maps +message StringStringEntryProto { + string key = 1; + string value = 2; +}; + +message TensorAnnotation { + string tensor_name = 1; + // pairs to annotate tensor specified by above. + // The keys used in the mapping below must be pre-defined in ONNX spec. + // For example, for 8-bit linear quantization case, 'SCALE_TENSOR', 'ZERO_POINT_TENSOR' will be pre-defined as + // quantization parameter keys. + repeated StringStringEntryProto quant_parameter_tensor_names = 2; +} + + + +// Graphs +// +// A graph defines the computational logic of a model and is comprised of a parameterized +// list of nodes that form a directed acyclic graph based on their inputs and outputs. +// This is the equivalent of the "network" or "graph" in many deep learning +// frameworks. +message GraphProto { + // The nodes in the graph, sorted topologically. + repeated NodeProto node = 1; + + // The name of the graph. + string name = 2; // namespace Graph + + // A list of named tensor values, used to specify constant inputs of the graph. + // Each initializer (both TensorProto as well SparseTensorProto) MUST have a name. + // The name MUST be unique across both initializer and sparse_initializer, + // but the name MAY also appear in the input list. + repeated TensorProto initializer = 5; + + // Initializers (see above) stored in sparse format. + repeated SparseTensorProto sparse_initializer = 15; + + // A human-readable documentation for this graph. Markdown is allowed. + string doc_string = 10; + + // The inputs and outputs of the graph. + repeated ValueInfoProto input = 11; + repeated ValueInfoProto output = 12; + + // Information for the values in the graph. The ValueInfoProto.name's + // must be distinct. It is optional for a value to appear in value_info list. + repeated ValueInfoProto value_info = 13; + + // This field carries information to indicate the mapping among a tensor and its + // quantization parameter tensors. For example: + // For tensor 'a', it may have {'SCALE_TENSOR', 'a_scale'} and {'ZERO_POINT_TENSOR', 'a_zero_point'} annotated, + // which means, tensor 'a_scale' and tensor 'a_zero_point' are scale and zero point of tensor 'a' in the model. + repeated TensorAnnotation quantization_annotation = 14; + + reserved 3, 4, 6 to 9; + reserved "ir_version", "producer_version", "producer_tag", "domain"; +} + +// Tensors +// +// A serialized tensor value. +message TensorProto { + enum DataType { + UNDEFINED = 0; + // Basic types. + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool + + // IEEE754 half-precision floating-point format (16 bits wide). + // This format has 1 sign bit, 5 exponent bits, and 10 mantissa bits. + FLOAT16 = 10; + + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; // complex with float32 real and imaginary components + COMPLEX128 = 15; // complex with float64 real and imaginary components + + // Non-IEEE floating-point format based on IEEE754 single-precision + // floating-point number truncated to 16 bits. + // This format has 1 sign bit, 8 exponent bits, and 7 mantissa bits. + BFLOAT16 = 16; + + // Non-IEEE floating-point format based on papers + // FP8 Formats for Deep Learning, https://arxiv.org/abs/2209.05433, + // 8-bit Numerical Formats For Deep Neural Networks, https://arxiv.org/pdf/2206.02915.pdf. + // Operators supported FP8 are Cast, CastLike, QuantizeLinear, DequantizeLinear. + // The computation usually happens inside a block quantize / dequantize + // fused by the runtime. + FLOAT8E4M3FN = 17; // float 8, mostly used for coefficients, supports nan, not inf + FLOAT8E4M3FNUZ = 18; // float 8, mostly used for coefficients, supports nan, not inf, no negative zero + FLOAT8E5M2 = 19; // follows IEEE 754, supports nan, inf, mostly used for gradients + FLOAT8E5M2FNUZ = 20; // follows IEEE 754, supports nan, inf, mostly used for gradients, no negative zero + + // Future extensions go here. + } + + // The shape of the tensor. + repeated int64 dims = 1; + + // The data type of the tensor. + // This field MUST have a valid TensorProto.DataType value + int32 data_type = 2; + + // For very large tensors, we may want to store them in chunks, in which + // case the following fields will specify the segment that is stored in + // the current TensorProto. + message Segment { + int64 begin = 1; + int64 end = 2; + } + Segment segment = 3; + + // Tensor content must be organized in row-major order. + // + // Depending on the data_type field, exactly one of the fields below with + // name ending in _data is used to store the elements of the tensor. + + // For float and complex64 values + // Complex64 tensors are encoded as a single array of floats, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component appearing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. + repeated float float_data = 4 [packed = true]; + + // For int32, uint8, int8, uint16, int16, bool, float8, and float16 values + // float16 and float8 values must be bit-wise converted to an uint16_t prior + // to writing to the buffer. + // When this field is present, the data_type field MUST be + // INT32, INT16, INT8, UINT16, UINT8, BOOL, FLOAT16, BFLOAT16, FLOAT8E4M3FN, FLOAT8E4M3FNUZ, FLOAT8E5M2, FLOAT8E5M2FNUZ + repeated int32 int32_data = 5 [packed = true]; + + // For strings. + // Each element of string_data is a UTF-8 encoded Unicode + // string. No trailing null, no leading BOM. The protobuf "string" + // scalar type is not used to match ML community conventions. + // When this field is present, the data_type field MUST be STRING + repeated bytes string_data = 6; + + // For int64. + // When this field is present, the data_type field MUST be INT64 + repeated int64 int64_data = 7 [packed = true]; + + // Optionally, a name for the tensor. + string name = 8; // namespace Value + + // A human-readable documentation for this tensor. Markdown is allowed. + string doc_string = 12; + + // Serializations can either use one of the fields above, or use this + // raw bytes field. The only exception is the string case, where one is + // required to store the content in the repeated bytes string_data field. + // + // When this raw_data field is used to store tensor value, elements MUST + // be stored in as fixed-width, little-endian order. + // Floating-point data types MUST be stored in IEEE 754 format. + // Complex64 elements must be written as two consecutive FLOAT values, real component first. + // Complex128 elements must be written as two consecutive DOUBLE values, real component first. + // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + // + // Note: the advantage of specific field rather than the raw_data field is + // that in some cases (e.g. int data), protobuf does a better packing via + // variable length storage, and may lead to smaller binary footprint. + // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED + bytes raw_data = 9; + + // Data can be stored inside the protobuf file using type-specific fields or raw_data. + // Alternatively, raw bytes data can be stored in an external file, using the external_data field. + // external_data stores key-value pairs describing data location. Recognized keys are: + // - "location" (required) - POSIX filesystem path relative to the directory where the ONNX + // protobuf model was stored + // - "offset" (optional) - position of byte at which stored data begins. Integer stored as string. + // Offset values SHOULD be multiples 4096 (page size) to enable mmap support. + // - "length" (optional) - number of bytes containing data. Integer stored as string. + // - "checksum" (optional) - SHA1 digest of file specified in under 'location' key. + repeated StringStringEntryProto external_data = 13; + + // Location of the data for this tensor. MUST be one of: + // - DEFAULT - data stored inside the protobuf message. Data is stored in raw_data (if set) otherwise in type-specified field. + // - EXTERNAL - data stored in an external location as described by external_data field. + enum DataLocation { + DEFAULT = 0; + EXTERNAL = 1; + } + + // If value not set, data is stored in raw_data (if set) otherwise in type-specified field. + DataLocation data_location = 14; + + // For double + // Complex128 tensors are encoded as a single array of doubles, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component appearing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 + repeated double double_data = 10 [packed = true]; + + // For uint64 and uint32 values + // When this field is present, the data_type field MUST be + // UINT32 or UINT64 + repeated uint64 uint64_data = 11 [packed = true]; +} + +// A serialized sparse-tensor value +message SparseTensorProto { + // The sequence of non-default values are encoded as a tensor of shape [NNZ]. + // The default-value is zero for numeric tensors, and empty-string for string tensors. + // values must have a non-empty name present which serves as a name for SparseTensorProto + // when used in sparse_initializer list. + TensorProto values = 1; + + // The indices of the non-default values, which may be stored in one of two formats. + // (a) Indices can be a tensor of shape [NNZ, rank] with the [i,j]-th value + // corresponding to the j-th index of the i-th value (in the values tensor). + // (b) Indices can be a tensor of shape [NNZ], in which case the i-th value + // must be the linearized-index of the i-th value (in the values tensor). + // The linearized-index can be converted into an index tuple (k_1,...,k_rank) + // using the shape provided below. + // The indices must appear in ascending order without duplication. + // In the first format, the ordering is lexicographic-ordering: + // e.g., index-value [1,4] must appear before [2,1] + TensorProto indices = 2; + + // The shape of the underlying dense-tensor: [dim_1, dim_2, ... dim_rank] + repeated int64 dims = 3; +} + +// Defines a tensor shape. A dimension can be either an integer value +// or a symbolic variable. A symbolic variable represents an unknown +// dimension. +message TensorShapeProto { + message Dimension { + oneof value { + int64 dim_value = 1; + string dim_param = 2; // namespace Shape + }; + // Standard denotation can optionally be used to denote tensor + // dimensions with standard semantic descriptions to ensure + // that operations are applied to the correct axis of a tensor. + // Refer to https://github.com/onnx/onnx/blob/main/docs/DimensionDenotation.md#denotation-definition + // for pre-defined dimension denotations. + string denotation = 3; + }; + repeated Dimension dim = 1; +} + +// Types +// +// The standard ONNX data types. +message TypeProto { + + message Tensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + int32 elem_type = 1; + TensorShapeProto shape = 2; + } + + // repeated T + message Sequence { + // The type and optional shape of each element of the sequence. + // This field MUST be present for this version of the IR. + TypeProto elem_type = 1; + }; + + // map + message Map { + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + // This field MUST refer to an integral type ([U]INT{8|16|32|64}) or STRING + int32 key_type = 1; + // This field MUST be present for this version of the IR. + TypeProto value_type = 2; + }; + + // wrapper for Tensor, Sequence, or Map + message Optional { + // The type and optional shape of the element wrapped. + // This field MUST be present for this version of the IR. + // Possible values correspond to OptionalProto.DataType enum + TypeProto elem_type = 1; + }; + + + message SparseTensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST have a valid TensorProto.DataType value + // This field MUST be present for this version of the IR. + int32 elem_type = 1; + TensorShapeProto shape = 2; + } + + + oneof value { + // The type of a tensor. + Tensor tensor_type = 1; + + // NOTE: DNN-only implementations of ONNX MAY elect to not support non-tensor values + // as input and output to graphs and nodes. These types are needed to naturally + // support classical ML operators. DNN operators SHOULD restrict their input + // and output types to tensors. + + // The type of a sequence. + Sequence sequence_type = 4; + + // The type of a map. + Map map_type = 5; + + // The type of an optional. + Optional optional_type = 9; + + + // Type of the sparse tensor + SparseTensor sparse_tensor_type = 8; + + } + + // An optional denotation can be used to denote the whole + // type with a standard semantic description as to what is + // stored inside. Refer to https://github.com/onnx/onnx/blob/main/docs/TypeDenotation.md#type-denotation-definition + // for pre-defined type denotations. + string denotation = 6; +} + +// Operator Sets +// +// OperatorSets are uniquely identified by a (domain, opset_version) pair. +message OperatorSetIdProto { + // The domain of the operator set being identified. + // The empty string ("") or absence of this field implies the operator + // set that is defined as part of the ONNX specification. + // This field MUST be present in this version of the IR when referring to any other operator set. + string domain = 1; + + // The version of the operator set being identified. + // This field MUST be present in this version of the IR. + int64 version = 2; +} + +// Operator/function status. +enum OperatorStatus { + EXPERIMENTAL = 0; + STABLE = 1; +} + +message FunctionProto { + // The name of the function, similar usage of op_type in OperatorProto. + // Combined with FunctionProto.domain, this forms the unique identity of + // the FunctionProto. + string name = 1; + + // Deprecated since IR Version 8 + // optional int64 since_version = 2; + reserved 2; + reserved "since_version"; + + // Deprecated since IR Version 8 + // optional OperatorStatus status = 3; + reserved 3; + reserved "status"; + + // The inputs and outputs of the function. + repeated string input = 4; + repeated string output = 5; + + // The attribute parameters of the function. + // It is for function parameters without default values. + repeated string attribute = 6; + + // The attribute protos of the function. + // It is for function attributes with default values. + // A function attribute shall be represented either as + // a string attribute or an AttributeProto, not both. + repeated AttributeProto attribute_proto = 11; + + // The nodes in the function. + repeated NodeProto node = 7; + // A human-readable documentation for this function. Markdown is allowed. + string doc_string = 8; + + // The OperatorSets this function body (graph) relies on. + // + // All nodes in the function body (graph) will bind against the operator + // with the same-domain/same-op_type operator with the HIGHEST version + // in the referenced operator sets. This means at most one version can be relied + // for one domain. + // + // The operator sets imported by FunctionProto should be compatible with the ones + // imported by ModelProto. Example, if same operator set say 'A' is imported by FunctionProto + // and ModelProto then versions for the operator set may be different but, + // the operator schema returned for op_type, domain, version combination + // for both the versions should be same. + + repeated OperatorSetIdProto opset_import = 9; + + // The domain which this function belongs to. Combined with FunctionProto.name, this forms the unique identity of + // the FunctionProto. + string domain = 10; +} + +// For using protobuf-lite +option optimize_for = LITE_RUNTIME; + diff --git a/test.onnx b/test.onnx new file mode 100644 index 0000000000..b69364d1a5 --- /dev/null +++ b/test.onnx @@ -0,0 +1,12 @@ + backend-test:J + +xytest"Relu +SingleReluZ +x +  + +b +y +  + +B \ No newline at end of file