Skip to content

Commit

Permalink
Basic qmatmul parallelization (huggingface#492)
Browse files Browse the repository at this point in the history
* Basic `par_iter` parallelization

* Pass errors up

* Disable `avx` for x86 macs
  • Loading branch information
LLukas22 authored Aug 18, 2023
1 parent c78ce76 commit 109e95b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
8 changes: 4 additions & 4 deletions .cargo/config.toml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
[target.x86_64-unknown-linux-gnu]
rustflags = ["-C", "target-cpu=native"]

[target.aarch64-apple-darwin]
[build]
rustflags = ["-C", "target-cpu=native"]

[target.wasm32-unknown-unknown]
rustflags = ["-C", "target-feature=+simd128"]

[target.x86_64-apple-darwin]
rustflags = ["-C", "target-feature=-avx,-avx2"]
20 changes: 15 additions & 5 deletions candle-core/src/quantized/k_quants.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use super::GgmlDType;
use crate::Result;
use half::f16;
use rayon::prelude::*;

// Default to QK_K 256 rather than 64.
pub const QK_K: usize = 256;
Expand All @@ -13,7 +14,7 @@ pub const QK5_1: usize = 32;
pub const QK8_0: usize = 32;
pub const QK8_1: usize = 32;

pub trait GgmlType: Sized + Clone {
pub trait GgmlType: Sized + Clone + Send + Sync {
const DTYPE: GgmlDType;
const BLCK_SIZE: usize;
type VecDotType: GgmlType;
Expand Down Expand Up @@ -1030,10 +1031,19 @@ pub fn matmul<T: GgmlType>(
for row_idx in 0..m {
let lhs_row = &lhs_b[row_idx * k_in_lhs_blocks..(row_idx + 1) * k_in_lhs_blocks];
let dst_row = &mut dst[row_idx * n..(row_idx + 1) * n];
for (col_idx, dst) in dst_row.iter_mut().enumerate() {
let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks];
*dst = T::vec_dot(k, rhs_col, lhs_row)?;
}

let result: Result<Vec<_>> = dst_row
.into_par_iter()
.enumerate()
.with_min_len(128)
.with_max_len(512)
.map(|(col_idx, dst)| {
let rhs_col = &rhs_t[col_idx * k_in_rhs_blocks..(col_idx + 1) * k_in_rhs_blocks];
T::vec_dot(k, rhs_col, lhs_row).map(|value| *dst = value)
})
.collect();

result?;
}
Ok(())
}
Expand Down

0 comments on commit 109e95b

Please sign in to comment.