From 17e0a47b38b4638c426842ea9b82824b557e2eb6 Mon Sep 17 00:00:00 2001 From: Wataru Ishida Date: Thu, 14 Mar 2024 00:56:28 +0000 Subject: [PATCH] feat(fp16): use SIMD instruction when available Signed-off-by: Wataru Ishida --- reduction_server/Cargo.toml | 1 + reduction_server/src/main.rs | 2 + reduction_server/src/reduce.rs | 165 +++++++++++++++++ reduction_server/src/reduce/aarch64.rs | 233 +++++++++++++++++++++++++ reduction_server/src/ring.rs | 1 + reduction_server/src/server.rs | 1 + reduction_server/src/utils.rs | 83 --------- 7 files changed, 403 insertions(+), 83 deletions(-) create mode 100644 reduction_server/src/reduce.rs create mode 100644 reduction_server/src/reduce/aarch64.rs diff --git a/reduction_server/Cargo.toml b/reduction_server/Cargo.toml index 0db13c6..034c200 100644 --- a/reduction_server/Cargo.toml +++ b/reduction_server/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" [dependencies] aligned_box = "0.2.1" +cfg-if = "1.0.0" clap = { version = "4.4.14", features = ["derive"] } env_logger = "0.10.1" half = { version = "2.3.1", features = ["num-traits"] } diff --git a/reduction_server/src/main.rs b/reduction_server/src/main.rs index 1e86f41..6731597 100644 --- a/reduction_server/src/main.rs +++ b/reduction_server/src/main.rs @@ -7,6 +7,7 @@ #![feature(c_variadic)] #![feature(portable_simd)] #![feature(min_specialization)] +#![feature(test)] use clap::Parser; @@ -16,6 +17,7 @@ mod partitioned_vec; mod client; mod server; mod ring; +mod reduce; use utils::Args; use server::server; diff --git a/reduction_server/src/reduce.rs b/reduction_server/src/reduce.rs new file mode 100644 index 0000000..491293f --- /dev/null +++ b/reduction_server/src/reduce.rs @@ -0,0 +1,165 @@ +/* + * Copyright (c) 2024, the Optcast Authors. All rights reserved. + * + * See LICENSE for license information + */ + +use aligned_box::AlignedBox; +use half::f16; + +use crate::utils::{alignment, Float}; + +#[cfg(all(target_arch = "aarch64", target_feature = "fp16"))] +mod aarch64; + +#[cfg(not(all(target_arch = "aarch64", target_feature = "fp16")))] +use half::slice::HalfFloatSliceExt; + +#[allow(dead_code)] +pub(crate) struct WorkingMemory { + recv_bufs: Vec>, + send_buf: AlignedBox<[f32]>, +} + +#[allow(dead_code)] +impl WorkingMemory { + pub(crate) fn new(count: usize, num_recv: usize) -> Self { + let recv_bufs = (0..num_recv) + .map(|_| AlignedBox::<[f32]>::slice_from_default(alignment(count), count).unwrap()) + .collect::>(); + let send_buf = AlignedBox::<[f32]>::slice_from_default(alignment(count), count).unwrap(); + Self { + recv_bufs, + send_buf, + } + } +} + +pub(crate) trait Reduce { + fn reduce( + &mut self, + recv_bufs: &Vec<&[T]>, + work_mem: Option<&mut WorkingMemory>, + ) -> Result<(), ()>; +} + +impl Reduce for [T] { + default fn reduce(&mut self, _: &Vec<&[T]>, _: Option<&mut WorkingMemory>) -> Result<(), ()> { + Err(()) + } +} + +impl Reduce for [f16] { + #[allow(unused_variables)] + fn reduce( + &mut self, + recv_bufs: &Vec<&[f16]>, + work_mem: Option<&mut WorkingMemory>, + ) -> Result<(), ()> { + cfg_if::cfg_if! { + if #[cfg(all( + target_arch = "aarch64", + target_feature = "fp16" + ))] { + for (i, recv) in recv_bufs.iter().enumerate() { + if i == 0 { + self.copy_from_slice(recv); + } else { + unsafe {aarch64::add_assign_f16_aligned_slice(self, recv);} + } + } + } + else { + let work_mem = work_mem.unwrap(); + for (i, recv) in recv_bufs.iter().enumerate() { + recv.convert_to_f32_slice(&mut work_mem.recv_bufs[i].as_mut()); + } + work_mem.send_buf.reduce( + &work_mem + .recv_bufs + .iter() + .map(|v| { + let slice_ref: &[f32] = &**v; + slice_ref + }) + .collect(), + None, + )?; + self.as_mut() + .convert_from_f32_slice(&work_mem.send_buf.as_ref()); + } + } + Ok(()) + } +} + +// impl Reduce for AlignedBox<[T]> can't compile +// error: cannot specialize on trait `SimdElement` +// --> src/main.rs:139:17 +// | +// 139 | impl Reduce for AlignedBox<[T]> { +impl Reduce for [f32] { + fn reduce(&mut self, recv_bufs: &Vec<&[f32]>, _: Option<&mut WorkingMemory>) -> Result<(), ()> { + let (_, send, _) = self.as_simd_mut::<4>(); + for (i, recv) in recv_bufs.iter().enumerate() { + let (_, recv, _) = recv.as_ref().as_simd::<4>(); + if i == 0 { + send.copy_from_slice(&recv.as_ref()); + } else { + for j in 0..send.len() { + send[j] += recv[j]; + } + } + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + extern crate test; + + fn bench_reduce(b: &mut test::Bencher) + where + T: Float + + std::fmt::Debug + + std::ops::AddAssign + + std::default::Default + + std::clone::Clone, + { + let count = 1024; + let num_recv = 4; + let mut work_mem = WorkingMemory::new(count, num_recv); + let mut recv_bufs = vec![]; + for _ in 0..num_recv { + recv_bufs.push( + AlignedBox::<[T]>::slice_from_value(alignment(count), count, T::default()).unwrap(), + ); + } + let mut send_buf = + AlignedBox::<[T]>::slice_from_value(alignment(count), count, T::default()).unwrap(); + b.iter(|| { + send_buf.reduce( + &recv_bufs + .iter() + .map(|v| { + let slice_ref: &[T] = &**v; + slice_ref + }) + .collect(), + Some(&mut work_mem), + ) + }); + } + + #[bench] + fn bench_f16_reduce(b: &mut test::Bencher) { + bench_reduce::(b); + } + + #[bench] + fn bench_f32_reduce(b: &mut test::Bencher) { + bench_reduce::(b); + } +} diff --git a/reduction_server/src/reduce/aarch64.rs b/reduction_server/src/reduce/aarch64.rs new file mode 100644 index 0000000..de2cec6 --- /dev/null +++ b/reduction_server/src/reduce/aarch64.rs @@ -0,0 +1,233 @@ +#![allow(dead_code)] + +use core::{ + arch::{aarch64::uint16x8_t, asm}, + mem::MaybeUninit, + ptr, +}; + +use half::f16; + +#[target_feature(enable = "fp16")] +#[inline] +unsafe fn fadd_f16x8(a: &uint16x8_t, b: &uint16x8_t) -> uint16x8_t { + let result: uint16x8_t; + asm!( + "fadd {result:v}.8h, {a:v}.8h, {b:v}.8h", + a = in(vreg) *a, + b = in(vreg) *b, + result = out(vreg) result, + options(pure, nomem, nostack)); + result +} + +#[target_feature(enable = "fp16")] +#[inline] +unsafe fn fadd_assign_f16x8(a: &mut uint16x8_t, b: &uint16x8_t) { + asm!( + "fadd {a:v}.8h, {a:v}.8h, {b:v}.8h", + a = inlateout(vreg) *a, + b = in(vreg) *b, + options(pure, nomem, nostack)); +} + +#[target_feature(enable = "fp16")] +// SAFETY: a and b must be aligned to 128 bits +unsafe fn add_f16x8_aligned(a: &[f16; 8], b: &[f16; 8]) -> [f16; 8] { + let a = a.as_ptr() as *const uint16x8_t; + let b = b.as_ptr() as *const uint16x8_t; + let result = unsafe { fadd_f16x8(&*a, &*b) }; + *(&result as *const uint16x8_t).cast() +} + +// SAFETY: a and b must be aligned to 128 bits +pub(super) unsafe fn add_f16_aligned_slice(a: &[f16], b: &[f16], result: &mut [f16]) { + assert_eq!(a.len(), b.len()); + assert_eq!(a.len(), result.len()); + assert_eq!(a.len() % 8, 0); + for i in (0..a.len()).step_by(8) { + let a = unsafe { &*(a.as_ptr().add(i) as *const [f16; 8]) }; + let b = unsafe { &*(b.as_ptr().add(i) as *const [f16; 8]) }; + let result = unsafe { &mut *(result.as_mut_ptr().add(i) as *mut [f16; 8]) }; + *result = unsafe { add_f16x8_aligned(a, b) }; + } +} + +#[target_feature(enable = "fp16")] +// SAFETY: a and b must be aligned to 128 bits +unsafe fn add_assign_f16x8_aligned(a: &mut [f16; 8], b: &[f16; 8]) { + let a = a.as_mut_ptr() as *mut uint16x8_t; + let b = b.as_ptr() as *const uint16x8_t; + unsafe { fadd_assign_f16x8(&mut *a, &*b) }; +} + +// SAFETY: a and b must be aligned to 128 bits +pub(super) unsafe fn add_assign_f16_aligned_slice(a: &mut [f16], b: &[f16]) { + assert_eq!(a.len(), b.len()); + assert_eq!(a.len() % 8, 0); + for i in (0..a.len()).step_by(8) { + let a = unsafe { &mut *(a.as_mut_ptr().add(i) as *mut [f16; 8]) }; + let b = unsafe { &*(b.as_ptr().add(i) as *const [f16; 8]) }; + unsafe { add_assign_f16x8_aligned(a, b) }; + } +} + +#[target_feature(enable = "fp16")] +unsafe fn add_f16x8(a: &[f16; 8], b: &[f16; 8]) -> [f16; 8] { + let mut aa = MaybeUninit::::uninit(); + ptr::copy_nonoverlapping(a.as_ptr(), aa.as_mut_ptr().cast(), 8); + let mut bb = MaybeUninit::::uninit(); + ptr::copy_nonoverlapping(b.as_ptr(), bb.as_mut_ptr().cast(), 8); + let result = unsafe { fadd_f16x8(&aa.assume_init(), &bb.assume_init()) }; + *(&result as *const uint16x8_t).cast() +} + +#[target_feature(enable = "fp16")] +unsafe fn add_assign_f16x8(a: &mut [f16; 8], b: &[f16; 8]) { + let mut aa = MaybeUninit::::uninit(); + ptr::copy_nonoverlapping(a.as_ptr(), aa.as_mut_ptr().cast(), 8); + let mut aa = aa.assume_init(); + let mut bb = MaybeUninit::::uninit(); + ptr::copy_nonoverlapping(b.as_ptr(), bb.as_mut_ptr().cast(), 8); + let bb = bb.assume_init(); + asm!( + "fadd {aa:v}.8h, {aa:v}.8h, {bb:v}.8h", + aa = inlateout(vreg) aa, + bb = in(vreg) bb, + options(pure, nomem, nostack)); + ptr::copy_nonoverlapping(&aa as *const uint16x8_t as *const f16, a.as_mut_ptr(), 8); +} + +// test +#[cfg(test)] +mod tests { + use super::*; + use crate::utils::alignment; + use aligned_box::AlignedBox; + + #[test] + fn test_add_f16x8() { + // create two arrays of f16 + let a: [f16; 8] = [ + f16::from_f32(1.0), + f16::from_f32(2.0), + f16::from_f32(3.0), + f16::from_f32(4.0), + f16::from_f32(1.0), + f16::from_f32(2.0), + f16::from_f32(3.0), + f16::from_f32(4.0), + ]; + let b: [f16; 8] = [ + f16::from_f32(5.0), + f16::from_f32(6.0), + f16::from_f32(7.0), + f16::from_f32(8.0), + f16::from_f32(5.0), + f16::from_f32(6.0), + f16::from_f32(7.0), + f16::from_f32(8.0), + ]; + // call the function + let result = unsafe { add_f16x8(&a, &b) }; + + assert_eq!( + result, + [ + f16::from_f32(6.0), + f16::from_f32(8.0), + f16::from_f32(10.0), + f16::from_f32(12.0), + f16::from_f32(6.0), + f16::from_f32(8.0), + f16::from_f32(10.0), + f16::from_f32(12.0) + ] + ); + } + + #[test] + fn test_add_f16x8_aligned() { + let a = + AlignedBox::<[f16]>::slice_from_value(alignment(128), 128, f16::from_f32(1.0)).unwrap(); + let b = + AlignedBox::<[f16]>::slice_from_value(alignment(128), 128, f16::from_f32(2.0)).unwrap(); + + // forcibly convert &a as &[f16; 8] + let a = unsafe { &*(a.as_ptr() as *const [f16; 8]) }; + let b = unsafe { &*(b.as_ptr() as *const [f16; 8]) }; + + let result = unsafe { add_f16x8_aligned(&a, &b) }; + + assert_eq!( + result, + [ + f16::from_f32(3.0), + f16::from_f32(3.0), + f16::from_f32(3.0), + f16::from_f32(3.0), + f16::from_f32(3.0), + f16::from_f32(3.0), + f16::from_f32(3.0), + f16::from_f32(3.0), + ] + ); + } + + #[test] + fn test_add_f16_aligned_slice() { + let a = + AlignedBox::<[f16]>::slice_from_value(alignment(128), 128, f16::from_f32(1.0)).unwrap(); + let b = + AlignedBox::<[f16]>::slice_from_value(alignment(128), 128, f16::from_f32(2.0)).unwrap(); + let mut result = AlignedBox::<[f16]>::slice_from_default(alignment(128), 128).unwrap(); + + unsafe { add_f16_aligned_slice(&a, &b, &mut result) }; + + // check all elements are 3.0 + for i in 0..result.len() { + assert_eq!(result[i], f16::from_f32(3.0)); + } + } + + #[test] + fn test_add_assign_f16x8() { + // create two arrays of f16 + let mut a: [f16; 8] = [ + f16::from_f32(1.0), + f16::from_f32(2.0), + f16::from_f32(3.0), + f16::from_f32(4.0), + f16::from_f32(1.0), + f16::from_f32(2.0), + f16::from_f32(3.0), + f16::from_f32(4.0), + ]; + let b: [f16; 8] = [ + f16::from_f32(5.0), + f16::from_f32(6.0), + f16::from_f32(7.0), + f16::from_f32(8.0), + f16::from_f32(5.0), + f16::from_f32(6.0), + f16::from_f32(7.0), + f16::from_f32(8.0), + ]; + // call the function + unsafe { add_assign_f16x8(&mut a, &b) }; + + assert_eq!( + a, + [ + f16::from_f32(6.0), + f16::from_f32(8.0), + f16::from_f32(10.0), + f16::from_f32(12.0), + f16::from_f32(6.0), + f16::from_f32(8.0), + f16::from_f32(10.0), + f16::from_f32(12.0) + ] + ); + } +} diff --git a/reduction_server/src/ring.rs b/reduction_server/src/ring.rs index a839c5b..fbf8d02 100644 --- a/reduction_server/src/ring.rs +++ b/reduction_server/src/ring.rs @@ -15,6 +15,7 @@ use half::f16; use log::{info, trace}; use crate::utils::*; +use crate::reduce::{Reduce, WorkingMemory}; use crate::nccl_net; use crate::nccl_net::Comm; diff --git a/reduction_server/src/server.rs b/reduction_server/src/server.rs index c246368..aae13ee 100644 --- a/reduction_server/src/server.rs +++ b/reduction_server/src/server.rs @@ -15,6 +15,7 @@ use log::{info, trace, warn}; use half::f16; use crate::utils::*; +use crate::reduce::{Reduce, WorkingMemory}; use crate::nccl_net; use crate::nccl_net::Comm; diff --git a/reduction_server/src/utils.rs b/reduction_server/src/utils.rs index ea6b335..5f36c8c 100644 --- a/reduction_server/src/utils.rs +++ b/reduction_server/src/utils.rs @@ -7,10 +7,8 @@ use std::fmt::Debug; use std::time::Duration; -use aligned_box::AlignedBox; use clap::{Parser, ValueEnum}; use half::f16; -use half::slice::HalfFloatSliceExt; use log::info; use num_traits::FromPrimitive; @@ -131,87 +129,6 @@ pub(crate) fn print_stat(args: &Args, elapsed: &Duration) { ); } -pub(crate) struct WorkingMemory { - recv_bufs: Vec>, - send_buf: AlignedBox<[f32]>, -} - -impl WorkingMemory { - pub(crate) fn new(count: usize, num_recv: usize) -> Self { - let recv_bufs = (0..num_recv) - .map(|_| AlignedBox::<[f32]>::slice_from_default(alignment(count), count).unwrap()) - .collect::>(); - let send_buf = AlignedBox::<[f32]>::slice_from_default(alignment(count), count).unwrap(); - Self { - recv_bufs, - send_buf, - } - } -} - -pub(crate) trait Reduce { - fn reduce( - &mut self, - recv_bufs: &Vec<&[T]>, - work_mem: Option<&mut WorkingMemory>, - ) -> Result<(), ()>; -} - -impl Reduce for [T] { - default fn reduce(&mut self, _: &Vec<&[T]>, _: Option<&mut WorkingMemory>) -> Result<(), ()> { - Err(()) - } -} - -impl Reduce for [f16] { - fn reduce( - &mut self, - recv_bufs: &Vec<&[f16]>, - work_mem: Option<&mut WorkingMemory>, - ) -> Result<(), ()> { - let work_mem = work_mem.unwrap(); - for (i, recv) in recv_bufs.iter().enumerate() { - recv.convert_to_f32_slice(&mut work_mem.recv_bufs[i].as_mut()); - } - work_mem.send_buf.reduce( - &work_mem - .recv_bufs - .iter() - .map(|v| { - let slice_ref: &[f32] = &**v; - slice_ref - }) - .collect(), - None, - )?; - self.as_mut() - .convert_from_f32_slice(&work_mem.send_buf.as_ref()); - Ok(()) - } -} - -// impl Reduce for AlignedBox<[T]> can't compile -// error: cannot specialize on trait `SimdElement` -// --> src/main.rs:139:17 -// | -// 139 | impl Reduce for AlignedBox<[T]> { -impl Reduce for [f32] { - fn reduce(&mut self, recv_bufs: &Vec<&[f32]>, _: Option<&mut WorkingMemory>) -> Result<(), ()> { - let (_, send, _) = self.as_simd_mut::<4>(); - for (i, recv) in recv_bufs.iter().enumerate() { - let (_, recv, _) = recv.as_ref().as_simd::<4>(); - if i == 0 { - send.copy_from_slice(&recv.as_ref()); - } else { - for j in 0..send.len() { - send[j] += recv[j]; - } - } - } - Ok(()) - } -} - pub(crate) fn vec_of_none(n: usize) -> Vec> { std::iter::repeat_with(|| None).take(n).collect() }