Skip to content

Commit

Permalink
hbf: generic over sample type
Browse files Browse the repository at this point in the history
  • Loading branch information
jordens committed Oct 30, 2023
1 parent f2382c8 commit 8fdeb15
Showing 1 changed file with 127 additions and 40 deletions.
167 changes: 127 additions & 40 deletions src/hbf.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
use core::{
iter::Sum,
ops::{Add, Mul},
};

use num_traits::Zero;

/// Filter input items into output items.
pub trait Filter {
/// Input/output item type.
Expand Down Expand Up @@ -69,36 +76,41 @@ pub trait Filter {
/// overhead) for blocks of 32 high-rate items or more, depending very much on architecture.
#[derive(Clone, Debug, Copy)]
pub struct SymFir<'a, const M: usize, const N: usize> {
x: [f32; N],
taps: &'a [f32; M],
pub struct SymFir<'a, T, const M: usize, const N: usize> {
x: [T; N],
taps: &'a [T; M],
}

impl<'a, const M: usize, const N: usize> SymFir<'a, M, N> {
impl<'a, T: Copy + Zero + Add + Mul<Output = T> + Sum, const M: usize, const N: usize>
SymFir<'a, T, M, N>
{
/// Create a new `SymFir`.
///
/// # Args
/// * `taps`: one-sided FIR coefficients, expluding center tap, oldest to one-before-center
pub fn new(taps: &'a [f32; M]) -> Self {
pub fn new(taps: &'a [T; M]) -> Self {
debug_assert!(N >= M * 2);
Self { x: [0.0; N], taps }
Self {
x: [T::zero(); N],
taps,
}
}

/// Obtain a mutable reference to the input items buffer space.
#[inline]
pub fn buf_mut(&mut self) -> &mut [f32] {
pub fn buf_mut(&mut self) -> &mut [T] {
&mut self.x[2 * M - 1..]
}

/// Perform the FIR convolution and yield results iteratively.
#[inline]
pub fn get(&self) -> impl Iterator<Item = f32> + '_ {
pub fn get(&self) -> impl Iterator<Item = T> + '_ {
self.x.windows(2 * M).map(|x| {
let (old, new) = x.split_at(M);
old.iter()
.zip(new.iter().rev())
.zip(self.taps.iter())
.map(|((xo, xn), tap)| (xo + xn) * tap)
.map(|((xo, xn), tap)| (*xo + *xn) * *tap)
.sum()
})
}
Expand All @@ -122,27 +134,57 @@ impl<'a, const M: usize, const N: usize> SymFir<'a, M, N> {
/// M: number of taps
/// N: state size: N = 2*M - 1 + output.len()
#[derive(Clone, Debug, Copy)]
pub struct HbfDec<'a, const M: usize, const N: usize> {
even: [f32; N], // This is an upper bound to N - M (unstable const expr)
odd: SymFir<'a, M, N>,
pub struct HbfDec<'a, T, const M: usize, const N: usize> {
even: [T; N], // This is an upper bound to N - M (unstable const expr)
odd: SymFir<'a, T, M, N>,
}

impl<'a, const M: usize, const N: usize> HbfDec<'a, M, N> {
impl<'a, T: Zero + Copy + Add + Mul<Output = T> + Sum, const M: usize, const N: usize>
HbfDec<'a, T, M, N>
{
/// Create a new `HbfDec`.
///
/// # Args
/// * `taps`: The FIR filter coefficients. Only the non-zero (odd) taps
/// from oldest to one-before-center. Normalized such that center tap is 1.
pub fn new(taps: &'a [f32; M]) -> Self {
pub fn new(taps: &'a [T; M]) -> Self {
Self {
even: [0.0; N],
even: [T::zero(); N],
odd: SymFir::new(taps),
}
}
}

impl<'a, const M: usize, const N: usize> Filter for HbfDec<'a, M, N> {
type Item = f32;
trait Half {
fn half(self) -> Self;
}

macro_rules! impl_half_f {
($($t:ty)+) => {$(
impl Half for $t {
fn half(self) -> Self {
0.5 * self
}
}
)+}
}
impl_half_f!(f32 f64);

macro_rules! impl_half_i {
($($t:ty)+) => {$(
impl Half for $t {
fn half(self) -> Self {
self >> 1
}
}
)+}
}
impl_half_i!(i8 i16 i32 i64 i128);

impl<'a, T: Copy + Zero + Add + Mul<Output = T> + Sum + Half, const M: usize, const N: usize> Filter
for HbfDec<'a, T, M, N>
{
type Item = T;

#[inline]
fn block_size(&self) -> (usize, usize) {
Expand Down Expand Up @@ -176,7 +218,7 @@ impl<'a, const M: usize, const N: usize> Filter for HbfDec<'a, M, N> {
.iter_mut()
.zip(self.even[..k].iter().zip(self.odd.get()))
{
*yi = 0.5 * (even + odd);
*yi = (*even + odd).half();
}
// keep state
self.even.copy_within(k..k + M - 1, 0);
Expand All @@ -192,27 +234,31 @@ impl<'a, const M: usize, const N: usize> Filter for HbfDec<'a, M, N> {
/// M: number of taps
/// N: state size: N = 2*M - 1 + input.len()
#[derive(Clone, Debug, Copy)]
pub struct HbfInt<'a, const M: usize, const N: usize> {
fir: SymFir<'a, M, N>,
pub struct HbfInt<'a, T, const M: usize, const N: usize> {
fir: SymFir<'a, T, M, N>,
}

impl<'a, const M: usize, const N: usize> HbfInt<'a, M, N> {
impl<'a, T: Copy + Zero + Add + Mul<Output = T> + Sum, const M: usize, const N: usize>
HbfInt<'a, T, M, N>
{
/// Non-zero (odd) taps from oldest to one-before-center.
/// Normalized such that center tap is 1.
pub fn new(taps: &'a [f32; M]) -> Self {
pub fn new(taps: &'a [T; M]) -> Self {
Self {
fir: SymFir::new(taps),
}
}

/// Obtain a mutable reference to the input items buffer space
pub fn buf_mut(&mut self) -> &mut [f32] {
pub fn buf_mut(&mut self) -> &mut [T] {
self.fir.buf_mut()
}
}

impl<'a, const M: usize, const N: usize> Filter for HbfInt<'a, M, N> {
type Item = f32;
impl<'a, T: Copy + Zero + Add + Mul<Output = T> + Sum, const M: usize, const N: usize> Filter
for HbfInt<'a, T, M, N>
{
type Item = T;

#[inline]
fn block_size(&self) -> (usize, usize) {
Expand Down Expand Up @@ -369,10 +415,30 @@ pub const HBF_CASCADE_BLOCK: usize = 1 << 6;
pub struct HbfDecCascade {
depth: usize,
stages: (
HbfDec<'static, { HBF_TAPS.0.len() }, { 2 * HBF_TAPS.0.len() - 1 + HBF_CASCADE_BLOCK }>,
HbfDec<'static, { HBF_TAPS.1.len() }, { 2 * HBF_TAPS.1.len() - 1 + HBF_CASCADE_BLOCK * 2 }>,
HbfDec<'static, { HBF_TAPS.2.len() }, { 2 * HBF_TAPS.2.len() - 1 + HBF_CASCADE_BLOCK * 4 }>,
HbfDec<'static, { HBF_TAPS.3.len() }, { 2 * HBF_TAPS.3.len() - 1 + HBF_CASCADE_BLOCK * 8 }>,
HbfDec<
'static,
f32,
{ HBF_TAPS.0.len() },
{ 2 * HBF_TAPS.0.len() - 1 + HBF_CASCADE_BLOCK },
>,
HbfDec<
'static,
f32,
{ HBF_TAPS.1.len() },
{ 2 * HBF_TAPS.1.len() - 1 + HBF_CASCADE_BLOCK * 2 },
>,
HbfDec<
'static,
f32,
{ HBF_TAPS.2.len() },
{ 2 * HBF_TAPS.2.len() - 1 + HBF_CASCADE_BLOCK * 4 },
>,
HbfDec<
'static,
f32,
{ HBF_TAPS.3.len() },
{ 2 * HBF_TAPS.3.len() - 1 + HBF_CASCADE_BLOCK * 8 },
>,
),
}

Expand Down Expand Up @@ -478,10 +544,30 @@ impl Filter for HbfDecCascade {
pub struct HbfIntCascade {
depth: usize,
pub stages: (
HbfInt<'static, { HBF_TAPS.0.len() }, { 2 * HBF_TAPS.0.len() - 1 + HBF_CASCADE_BLOCK }>,
HbfInt<'static, { HBF_TAPS.1.len() }, { 2 * HBF_TAPS.1.len() - 1 + HBF_CASCADE_BLOCK * 2 }>,
HbfInt<'static, { HBF_TAPS.2.len() }, { 2 * HBF_TAPS.2.len() - 1 + HBF_CASCADE_BLOCK * 4 }>,
HbfInt<'static, { HBF_TAPS.3.len() }, { 2 * HBF_TAPS.3.len() - 1 + HBF_CASCADE_BLOCK * 8 }>,
HbfInt<
'static,
f32,
{ HBF_TAPS.0.len() },
{ 2 * HBF_TAPS.0.len() - 1 + HBF_CASCADE_BLOCK },
>,
HbfInt<
'static,
f32,
{ HBF_TAPS.1.len() },
{ 2 * HBF_TAPS.1.len() - 1 + HBF_CASCADE_BLOCK * 2 },
>,
HbfInt<
'static,
f32,
{ HBF_TAPS.2.len() },
{ 2 * HBF_TAPS.2.len() - 1 + HBF_CASCADE_BLOCK * 4 },
>,
HbfInt<
'static,
f32,
{ HBF_TAPS.3.len() },
{ 2 * HBF_TAPS.3.len() - 1 + HBF_CASCADE_BLOCK * 8 },
>,
),
}

Expand Down Expand Up @@ -580,15 +666,15 @@ mod test {

#[test]
fn test() {
let mut h = HbfDec::<1, 5>::new(&[0.5]);
let mut h = HbfDec::<_, 1, 5>::new(&[0.5]);
assert_eq!(h.process_block(None, &mut []), &[]);

let mut x = [1.0; 8];
assert_eq!((2, x.len()), h.block_size());
let x = h.process_block(None, &mut x);
assert_eq!(x, [0.75, 1.0, 1.0, 1.0]);

let mut h = HbfDec::<{ HBF_TAPS.3.len() }, 11>::new(&HBF_TAPS.3);
let mut h = HbfDec::<_, { HBF_TAPS.3.len() }, 11>::new(&HBF_TAPS.3);
let mut x: Vec<_> = (0..8).map(|i| i as f32).collect();
assert_eq!((2, x.len()), h.block_size());
let x = h.process_block(None, &mut x);
Expand Down Expand Up @@ -666,23 +752,24 @@ mod test {
#[test]
#[ignore]
fn insn_dec() {
const N: usize = HBF_TAPS.3.len();
let mut h = HbfDec::<N, { 2 * N - 1 + (1 << 4) }>::new(&HBF_TAPS.3);
const N: usize = HBF_TAPS.4.len();
assert_eq!(N, 3);
let mut h = HbfDec::<_, N, { 2 * N - 1 + (1 << 4) }>::new(&HBF_TAPS.4);
let mut x = [9.0; 1 << 5];
for _ in 0..1 << 25 {
h.process_block(None, &mut x);
}
}

/// 1k block size, single stage, 15 mul (59 tap) decimator
/// 1k block size, single stage, 23 mul (91 tap) decimator
/// 4.9 insn: > 1 GS/s
#[test]
#[ignore]
fn insn_dec2() {
const N: usize = HBF_TAPS.0.len();
assert_eq!(N, 15);
assert_eq!(N, 23);
const M: usize = 1 << 10;
let mut h = HbfDec::<N, { 2 * N - 1 + M }>::new(&HBF_TAPS.0);
let mut h = HbfDec::<_, N, { 2 * N - 1 + M }>::new(&HBF_TAPS.0);
let mut x = [9.0; M];
for _ in 0..1 << 20 {
h.process_block(None, &mut x);
Expand Down

0 comments on commit 8fdeb15

Please sign in to comment.