From 844d45cde4c5068fc45aaebd5858c8370a93ca20 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 21 Sep 2024 15:03:19 +0200 Subject: [PATCH] Bugfix for the metal elu kernel. (#2490) * Bugfix for the metal elu kernel. * Add a test. --- candle-core/tests/tensor_tests.rs | 13 +++++++++++++ candle-metal-kernels/src/affine.metal | 2 +- 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index 567b49f1db..4a76035c17 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -193,6 +193,19 @@ fn unary_op(device: &Device) -> Result<()> { tensor.sign()?.to_vec1::()?, [-1., -1., -1., 0., 0., 1., 1., 1., 1.] ); + let tensor = Tensor::new(&[-1.0f32, 0., -2., 3.], device)?; + let y = tensor.elu(2.)?; + assert_eq!( + test_utils::to_vec1_round(&y, 4)?, + [-1.2642, 0.0000, -1.7293, 3.0000] + ); + // This test failed on metal prior to the following PR: + // https://github.com/huggingface/candle/pull/2490 + let y = tensor.reshape((2, 2))?.t()?.elu(2.)?.flatten_all()?; + assert_eq!( + test_utils::to_vec1_round(&y, 4)?, + [-1.2642, -1.7293, 0.0000, 3.0000] + ); Ok(()) } diff --git a/candle-metal-kernels/src/affine.metal b/candle-metal-kernels/src/affine.metal index cbbb03e2a2..e5229f55ee 100644 --- a/candle-metal-kernels/src/affine.metal +++ b/candle-metal-kernels/src/affine.metal @@ -105,7 +105,7 @@ kernel void FN_NAME##_strided( \ return; \ } \ const TYPENAME x = input[get_strided_index(id, num_dims, dims, strides)]; \ - output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \ + output[id] = TYPENAME((x > 0)?x: mul * (exp(x) - 1)); \ } \