Skip to content

Commit

Permalink
metal : fix build and some more comments (#10229)
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov authored Nov 9, 2024
1 parent bb38cdd commit 39a334a
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
2 changes: 2 additions & 0 deletions ggml/src/ggml-metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -3041,6 +3041,8 @@ static void ggml_metal_encode_node(

bool use_vec_kernel = false;

// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
// for now avoiding mainly to keep the number of templates/kernels a bit lower
if (ne01 >= 4 || (ne00%128 != 0)) {
switch (src1->type) {
case GGML_TYPE_F16:
Expand Down
8 changes: 4 additions & 4 deletions ggml/src/ggml-metal.metal
Original file line number Diff line number Diff line change
Expand Up @@ -3356,8 +3356,8 @@ kernel void kernel_flash_attn_ext_vec(
const short D4 = D/4;
const short D16 = D/16;
const short NW = N_SIMDWIDTH;
const short NL = NW/4;
const short SH = 2*C; // shared memory per simdgroup
const short NL = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0
const short SH = 2*C; // shared memory per simdgroup

const short T = D + nsg*SH; // shared memory size per query in (half)

Expand Down Expand Up @@ -3448,7 +3448,7 @@ kernel void kernel_flash_attn_ext_vec(

// Q*K^T
{
// each simdgroup processes 1 query and 4 keys
// each simdgroup processes 1 query and 4 (NW/NL) keys
for (short cc = 0; cc < C/4; ++cc) {
qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };

Expand Down Expand Up @@ -3646,7 +3646,7 @@ kernel void kernel_flash_attn_ext_vec(
half, half4, half4x4, \
half4x4

typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>) flash_attn_ext_vec_t;

template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>;
#if defined(GGML_METAL_USE_BF16)
Expand Down

0 comments on commit 39a334a

Please sign in to comment.