From 6eea45a761fc1636b5e8012d02bdaa93321652ca Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sun, 15 Sep 2024 21:27:46 +0100 Subject: [PATCH] Add a couple cast metal kernels. (#2479) --- candle-core/src/metal_backend/mod.rs | 39 ++++++++++++++++++++++------ 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/candle-core/src/metal_backend/mod.rs b/candle-core/src/metal_backend/mod.rs index a60c475bbc..9c980db804 100644 --- a/candle-core/src/metal_backend/mod.rs +++ b/candle-core/src/metal_backend/mod.rs @@ -412,19 +412,42 @@ impl BackendStorage for MetalStorage { .map_err(MetalError::from)?; } else { let kernel_name = match (self.dtype, dtype) { - (DType::U32, DType::F16) => "cast_u32_f16_strided", + (DType::BF16, DType::F16) => "cast_bf16_f16_strided", + (DType::BF16, DType::F32) => "cast_bf16_f32_strided", + (DType::BF16, DType::I64) => "cast_bf16_i64_strided", + (DType::BF16, DType::U32) => "cast_bf16_u32_strided", + (DType::BF16, DType::U8) => "cast_bf16_u8_strided", + + (DType::F16, DType::BF16) => "cast_f16_bf16_strided", + (DType::F16, DType::F32) => "cast_f16_f32_strided", + (DType::F16, DType::I64) => "cast_f16_i64_strided", + (DType::F16, DType::U32) => "cast_f16_u32_strided", + (DType::F16, DType::U8) => "cast_f16_u8_strided", + + (DType::F32, DType::BF16) => "cast_f32_bf16_strided", + (DType::F32, DType::F16) => "cast_f32_f16_strided", + (DType::F32, DType::I64) => "cast_f32_i64_strided", + (DType::F32, DType::U32) => "cast_f32_u32_strided", + (DType::F32, DType::U8) => "cast_f32_u8_strided", + + (DType::I64, DType::F32) => "cast_i64_f32_strided", + (DType::I64, DType::BF16) => "cast_i64_bf16_strided", + (DType::I64, DType::F16) => "cast_i64_f16_strided", + (DType::I64, DType::U32) => "cast_i64_u32_strided", + (DType::I64, DType::U8) => "cast_i64_u8_strided", + (DType::U32, DType::BF16) => "cast_u32_bf16_strided", + (DType::U32, DType::F16) => "cast_u32_f16_strided", (DType::U32, DType::F32) => "cast_u32_f32_strided", - (DType::U32, DType::U8) => "cast_u32_u8_strided", (DType::U32, DType::I64) => "cast_u32_i64_strided", - (DType::U8, DType::U32) => "cast_u8_u32_strided", + (DType::U32, DType::U8) => "cast_u32_u8_strided", + + (DType::U8, DType::BF16) => "cast_u8_bf16_strided", + (DType::U8, DType::F16) => "cast_u8_f16_strided", (DType::U8, DType::F32) => "cast_u8_f32_strided", (DType::U8, DType::I64) => "cast_u8_i64_strided", - (DType::F32, DType::F16) => "cast_f32_f16_strided", - (DType::F16, DType::F32) => "cast_f16_f32_strided", - (DType::I64, DType::F32) => "cast_i64_f32_strided", - (DType::F32, DType::BF16) => "cast_f32_bf16_strided", - (DType::BF16, DType::F32) => "cast_bf16_f32_strided", + (DType::U8, DType::U32) => "cast_u8_u32_strided", + (left, right) => { crate::bail!("Metal strided to_dtype {left:?} {right:?} not implemented") }