Skip to content

Commit

Permalink
ONNX: Add Floor and Ceil (huggingface#2235)
Browse files Browse the repository at this point in the history
  • Loading branch information
mokulus authored Jun 2, 2024
1 parent 1ec3b2c commit 03344d3
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 0 deletions.
10 changes: 10 additions & 0 deletions candle-onnx/src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -877,6 +877,16 @@ pub fn simple_eval(
let output = input.relu()?;
values.insert(node.output[0].clone(), output);
}
"Ceil" => {
let input = get(&node.input[0])?;
let output = input.ceil()?;
values.insert(node.output[0].clone(), output);
}
"Floor" => {
let input = get(&node.input[0])?;
let output = input.floor()?;
values.insert(node.output[0].clone(), output);
}
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#Constant
"Constant" => {
let value = match node.attribute.iter().find(|attr| attr.name == "value") {
Expand Down
152 changes: 152 additions & 0 deletions candle-onnx/tests/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2556,3 +2556,155 @@ fn test_where() -> Result<()> {

Ok(())
}

#[test]
fn test_floor() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Floor".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
// some values taken from https://numpy.org/doc/stable/reference/generated/numpy.floor.html
vec![
f64::NAN,
f64::INFINITY,
f64::NEG_INFINITY,
-1.7,
-1.5,
-0.2,
0.2,
1.5,
1.7,
2.0,
],
&[10],
&Device::Cpu,
)?;

let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);

let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);

let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");

let results = z.to_vec1::<f64>()?;

assert!(results[0].is_nan());
assert_eq!(
results[1..],
vec![
f64::INFINITY,
f64::NEG_INFINITY,
-2.,
-2.,
-1.,
0.,
1.,
1.,
2.
]
);

Ok(())
}

#[test]
fn test_ceil() -> Result<()> {
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
node: vec![NodeProto {
op_type: "Ceil".to_string(),
domain: "".to_string(),
attribute: vec![],
input: vec![INPUT_X.to_string()],
output: vec![OUTPUT_Z.to_string()],
name: "".to_string(),
doc_string: "".to_string(),
}],
name: "".to_string(),
initializer: vec![],
input: vec![ValueInfoProto {
name: INPUT_X.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
output: vec![ValueInfoProto {
name: OUTPUT_Z.to_string(),
doc_string: "".to_string(),
r#type: None,
}],
value_info: vec![],
doc_string: "".to_string(),
sparse_initializer: vec![],
quantization_annotation: vec![],
}));
let x = Tensor::from_vec(
// some values taken from https://numpy.org/doc/stable/reference/generated/numpy.ceil.html
vec![
f64::NAN,
f64::INFINITY,
f64::NEG_INFINITY,
-1.7,
-1.5,
-0.2,
0.2,
1.5,
1.7,
2.0,
],
&[10],
&Device::Cpu,
)?;

let mut inputs: HashMap<String, Tensor> = HashMap::new();
inputs.insert(INPUT_X.to_string(), x);

let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
assert_eq!(eval.len(), 1);

let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");

let results = z.to_vec1::<f64>()?;

assert!(results[0].is_nan());
assert_eq!(
results[1..],
vec![
f64::INFINITY,
f64::NEG_INFINITY,
-1.,
-1.,
-0.,
1.,
2.,
2.,
2.
]
);

Ok(())
}

0 comments on commit 03344d3

Please sign in to comment.