diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index 2563607ac7..06f6cd3715 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -718,6 +718,7 @@ impl BackendStorage for MetalStorage { } let name = match (self.dtype, t.dtype()) { (DType::U8, DType::F32) => "where_u8_f32", + (DType::U32, DType::F32) => "where_u32_f32", (DType::U8, DType::BF16) => "where_u8_bf16", (DType::U8, DType::F16) => "where_u8_f16", (DType::U8, DType::I64) => "where_u8_i64", diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 960ae1dfea..8c38e74ac6 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1023,6 +1023,27 @@ fn where_cond() { ); assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]); } +#[test] +fn where_cond_u32_f32() { + let shape = vec![6]; + let cond = vec![0u32, 1, 0, 0, 1, 1]; + let cond_l = (vec![1], 0); + let left_true = vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]; + let left_l = (vec![1], 0); + let right_false = vec![-1.0f32, -2.0, -3.0, -4.0, -5.0, -6.0]; + let right_l = (vec![1], 0); + let results = run_where_cond( + &shape, + &cond, + cond_l, + &left_true, + left_l, + &right_false, + right_l, + "where_u32_f32", + ); + assert_eq!(approx(results, 4), vec![-1.0f32, 2.0, -3.0, -4.0, 5.0, 6.0]); +} fn run_gemm( (b, m, n, k): (usize, usize, usize, usize), diff --git a/candle-nn/benches/benchmarks/layer_norm.rs b/candle-nn/benches/benchmarks/layer_norm.rs index 0be5c450ca..4a5fe667be 100644 --- a/candle-nn/benches/benchmarks/layer_norm.rs +++ b/candle-nn/benches/benchmarks/layer_norm.rs @@ -5,7 +5,7 @@ use criterion::{black_box, criterion_group, Criterion}; use std::time::Instant; fn run(input: &Tensor, weight: &Tensor, bias: &Tensor) { - let _ = LayerNorm::new(weight.clone(), bias.clone(), 1e-5).forward(&input); + let _ = LayerNorm::new(weight.clone(), bias.clone(), 1e-5).forward(input); } const B: usize = 1;