diff --git a/Cargo.toml b/Cargo.toml index 43d94c889..bb6550b3f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "blas" -version = "0.22.0" +version = "0.22.1" license = "Apache-2.0/MIT" authors = [ "Andrew Straw ", @@ -19,11 +19,15 @@ keywords = ["linear-algebra"] [dependencies] libc = "0.2" +half = {version="2.3.1", optional=true} [dependencies.num-complex] version = "0.4" default-features = false [dependencies.blas-sys] -version = "0.7" +version = "0.7.2" default-features = false + +[features] +half = ["blas-sys/half", "dep:half"] diff --git a/src/lib.rs b/src/lib.rs index 2a2bba011..36fae2a55 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2822,3 +2822,37 @@ pub unsafe fn ztrsm( &ldb, ) } + +#[cfg(feature = "half")] +#[inline] +pub unsafe fn hgemm( + transa: u8, + transb: u8, + m: i32, + n: i32, + k: i32, + alpha: half::f16, + a: &[half::f16], + lda: i32, + b: &[half::f16], + ldb: i32, + beta: half::f16, + c: &mut [half::f16], + ldc: i32, +) { + ffi::hgemm_( + &(transa as c_char), + &(transb as c_char), + &m, + &n, + &k, + &alpha, + a.as_ptr(), + &lda, + b.as_ptr(), + &ldb, + &beta, + c.as_mut_ptr(), + &ldc, + ) +}