diff --git a/src/hbf.rs b/src/hbf.rs index 76f7f0d..757837f 100644 --- a/src/hbf.rs +++ b/src/hbf.rs @@ -1,3 +1,10 @@ +use core::{ + iter::Sum, + ops::{Add, Mul}, +}; + +use num_traits::Zero; + /// Filter input items into output items. pub trait Filter { /// Input/output item type. @@ -69,36 +76,41 @@ pub trait Filter { /// overhead) for blocks of 32 high-rate items or more, depending very much on architecture. #[derive(Clone, Debug, Copy)] -pub struct SymFir<'a, const M: usize, const N: usize> { - x: [f32; N], - taps: &'a [f32; M], +pub struct SymFir<'a, T, const M: usize, const N: usize> { + x: [T; N], + taps: &'a [T; M], } -impl<'a, const M: usize, const N: usize> SymFir<'a, M, N> { +impl<'a, T: Copy + Zero + Add + Mul + Sum, const M: usize, const N: usize> + SymFir<'a, T, M, N> +{ /// Create a new `SymFir`. /// /// # Args /// * `taps`: one-sided FIR coefficients, expluding center tap, oldest to one-before-center - pub fn new(taps: &'a [f32; M]) -> Self { + pub fn new(taps: &'a [T; M]) -> Self { debug_assert!(N >= M * 2); - Self { x: [0.0; N], taps } + Self { + x: [T::zero(); N], + taps, + } } /// Obtain a mutable reference to the input items buffer space. #[inline] - pub fn buf_mut(&mut self) -> &mut [f32] { + pub fn buf_mut(&mut self) -> &mut [T] { &mut self.x[2 * M - 1..] } /// Perform the FIR convolution and yield results iteratively. #[inline] - pub fn get(&self) -> impl Iterator + '_ { + pub fn get(&self) -> impl Iterator + '_ { self.x.windows(2 * M).map(|x| { let (old, new) = x.split_at(M); old.iter() .zip(new.iter().rev()) .zip(self.taps.iter()) - .map(|((xo, xn), tap)| (xo + xn) * tap) + .map(|((xo, xn), tap)| (*xo + *xn) * *tap) .sum() }) } @@ -122,27 +134,57 @@ impl<'a, const M: usize, const N: usize> SymFir<'a, M, N> { /// M: number of taps /// N: state size: N = 2*M - 1 + output.len() #[derive(Clone, Debug, Copy)] -pub struct HbfDec<'a, const M: usize, const N: usize> { - even: [f32; N], // This is an upper bound to N - M (unstable const expr) - odd: SymFir<'a, M, N>, +pub struct HbfDec<'a, T, const M: usize, const N: usize> { + even: [T; N], // This is an upper bound to N - M (unstable const expr) + odd: SymFir<'a, T, M, N>, } -impl<'a, const M: usize, const N: usize> HbfDec<'a, M, N> { +impl<'a, T: Zero + Copy + Add + Mul + Sum, const M: usize, const N: usize> + HbfDec<'a, T, M, N> +{ /// Create a new `HbfDec`. /// /// # Args /// * `taps`: The FIR filter coefficients. Only the non-zero (odd) taps /// from oldest to one-before-center. Normalized such that center tap is 1. - pub fn new(taps: &'a [f32; M]) -> Self { + pub fn new(taps: &'a [T; M]) -> Self { Self { - even: [0.0; N], + even: [T::zero(); N], odd: SymFir::new(taps), } } } -impl<'a, const M: usize, const N: usize> Filter for HbfDec<'a, M, N> { - type Item = f32; +trait Half { + fn half(self) -> Self; +} + +macro_rules! impl_half_f { + ($($t:ty)+) => {$( + impl Half for $t { + fn half(self) -> Self { + 0.5 * self + } + } + )+} +} +impl_half_f!(f32 f64); + +macro_rules! impl_half_i { + ($($t:ty)+) => {$( + impl Half for $t { + fn half(self) -> Self { + self >> 1 + } + } + )+} +} +impl_half_i!(i8 i16 i32 i64 i128); + +impl<'a, T: Copy + Zero + Add + Mul + Sum + Half, const M: usize, const N: usize> Filter + for HbfDec<'a, T, M, N> +{ + type Item = T; #[inline] fn block_size(&self) -> (usize, usize) { @@ -176,7 +218,7 @@ impl<'a, const M: usize, const N: usize> Filter for HbfDec<'a, M, N> { .iter_mut() .zip(self.even[..k].iter().zip(self.odd.get())) { - *yi = 0.5 * (even + odd); + *yi = (*even + odd).half(); } // keep state self.even.copy_within(k..k + M - 1, 0); @@ -192,27 +234,31 @@ impl<'a, const M: usize, const N: usize> Filter for HbfDec<'a, M, N> { /// M: number of taps /// N: state size: N = 2*M - 1 + input.len() #[derive(Clone, Debug, Copy)] -pub struct HbfInt<'a, const M: usize, const N: usize> { - fir: SymFir<'a, M, N>, +pub struct HbfInt<'a, T, const M: usize, const N: usize> { + fir: SymFir<'a, T, M, N>, } -impl<'a, const M: usize, const N: usize> HbfInt<'a, M, N> { +impl<'a, T: Copy + Zero + Add + Mul + Sum, const M: usize, const N: usize> + HbfInt<'a, T, M, N> +{ /// Non-zero (odd) taps from oldest to one-before-center. /// Normalized such that center tap is 1. - pub fn new(taps: &'a [f32; M]) -> Self { + pub fn new(taps: &'a [T; M]) -> Self { Self { fir: SymFir::new(taps), } } /// Obtain a mutable reference to the input items buffer space - pub fn buf_mut(&mut self) -> &mut [f32] { + pub fn buf_mut(&mut self) -> &mut [T] { self.fir.buf_mut() } } -impl<'a, const M: usize, const N: usize> Filter for HbfInt<'a, M, N> { - type Item = f32; +impl<'a, T: Copy + Zero + Add + Mul + Sum, const M: usize, const N: usize> Filter + for HbfInt<'a, T, M, N> +{ + type Item = T; #[inline] fn block_size(&self) -> (usize, usize) { @@ -369,10 +415,30 @@ pub const HBF_CASCADE_BLOCK: usize = 1 << 6; pub struct HbfDecCascade { depth: usize, stages: ( - HbfDec<'static, { HBF_TAPS.0.len() }, { 2 * HBF_TAPS.0.len() - 1 + HBF_CASCADE_BLOCK }>, - HbfDec<'static, { HBF_TAPS.1.len() }, { 2 * HBF_TAPS.1.len() - 1 + HBF_CASCADE_BLOCK * 2 }>, - HbfDec<'static, { HBF_TAPS.2.len() }, { 2 * HBF_TAPS.2.len() - 1 + HBF_CASCADE_BLOCK * 4 }>, - HbfDec<'static, { HBF_TAPS.3.len() }, { 2 * HBF_TAPS.3.len() - 1 + HBF_CASCADE_BLOCK * 8 }>, + HbfDec< + 'static, + f32, + { HBF_TAPS.0.len() }, + { 2 * HBF_TAPS.0.len() - 1 + HBF_CASCADE_BLOCK }, + >, + HbfDec< + 'static, + f32, + { HBF_TAPS.1.len() }, + { 2 * HBF_TAPS.1.len() - 1 + HBF_CASCADE_BLOCK * 2 }, + >, + HbfDec< + 'static, + f32, + { HBF_TAPS.2.len() }, + { 2 * HBF_TAPS.2.len() - 1 + HBF_CASCADE_BLOCK * 4 }, + >, + HbfDec< + 'static, + f32, + { HBF_TAPS.3.len() }, + { 2 * HBF_TAPS.3.len() - 1 + HBF_CASCADE_BLOCK * 8 }, + >, ), } @@ -478,10 +544,30 @@ impl Filter for HbfDecCascade { pub struct HbfIntCascade { depth: usize, pub stages: ( - HbfInt<'static, { HBF_TAPS.0.len() }, { 2 * HBF_TAPS.0.len() - 1 + HBF_CASCADE_BLOCK }>, - HbfInt<'static, { HBF_TAPS.1.len() }, { 2 * HBF_TAPS.1.len() - 1 + HBF_CASCADE_BLOCK * 2 }>, - HbfInt<'static, { HBF_TAPS.2.len() }, { 2 * HBF_TAPS.2.len() - 1 + HBF_CASCADE_BLOCK * 4 }>, - HbfInt<'static, { HBF_TAPS.3.len() }, { 2 * HBF_TAPS.3.len() - 1 + HBF_CASCADE_BLOCK * 8 }>, + HbfInt< + 'static, + f32, + { HBF_TAPS.0.len() }, + { 2 * HBF_TAPS.0.len() - 1 + HBF_CASCADE_BLOCK }, + >, + HbfInt< + 'static, + f32, + { HBF_TAPS.1.len() }, + { 2 * HBF_TAPS.1.len() - 1 + HBF_CASCADE_BLOCK * 2 }, + >, + HbfInt< + 'static, + f32, + { HBF_TAPS.2.len() }, + { 2 * HBF_TAPS.2.len() - 1 + HBF_CASCADE_BLOCK * 4 }, + >, + HbfInt< + 'static, + f32, + { HBF_TAPS.3.len() }, + { 2 * HBF_TAPS.3.len() - 1 + HBF_CASCADE_BLOCK * 8 }, + >, ), } @@ -580,7 +666,7 @@ mod test { #[test] fn test() { - let mut h = HbfDec::<1, 5>::new(&[0.5]); + let mut h = HbfDec::<_, 1, 5>::new(&[0.5]); assert_eq!(h.process_block(None, &mut []), &[]); let mut x = [1.0; 8]; @@ -588,7 +674,7 @@ mod test { let x = h.process_block(None, &mut x); assert_eq!(x, [0.75, 1.0, 1.0, 1.0]); - let mut h = HbfDec::<{ HBF_TAPS.3.len() }, 11>::new(&HBF_TAPS.3); + let mut h = HbfDec::<_, { HBF_TAPS.3.len() }, 11>::new(&HBF_TAPS.3); let mut x: Vec<_> = (0..8).map(|i| i as f32).collect(); assert_eq!((2, x.len()), h.block_size()); let x = h.process_block(None, &mut x); @@ -666,23 +752,24 @@ mod test { #[test] #[ignore] fn insn_dec() { - const N: usize = HBF_TAPS.3.len(); - let mut h = HbfDec::::new(&HBF_TAPS.3); + const N: usize = HBF_TAPS.4.len(); + assert_eq!(N, 3); + let mut h = HbfDec::<_, N, { 2 * N - 1 + (1 << 4) }>::new(&HBF_TAPS.4); let mut x = [9.0; 1 << 5]; for _ in 0..1 << 25 { h.process_block(None, &mut x); } } - /// 1k block size, single stage, 15 mul (59 tap) decimator + /// 1k block size, single stage, 23 mul (91 tap) decimator /// 4.9 insn: > 1 GS/s #[test] #[ignore] fn insn_dec2() { const N: usize = HBF_TAPS.0.len(); - assert_eq!(N, 15); + assert_eq!(N, 23); const M: usize = 1 << 10; - let mut h = HbfDec::::new(&HBF_TAPS.0); + let mut h = HbfDec::<_, N, { 2 * N - 1 + M }>::new(&HBF_TAPS.0); let mut x = [9.0; M]; for _ in 0..1 << 20 { h.process_block(None, &mut x);