From 2649df21da2de79cd5eebd366bf8e49c8d81d737 Mon Sep 17 00:00:00 2001 From: skewballfox Date: Sun, 1 Jan 2023 20:31:46 -0600 Subject: [PATCH 1/3] migrating to new api --- py-speechsauce/speechsauce/__init__.py | 58 ++++++++++++ py-speechsauce/src/lib.rs | 78 ++++++++++------ speechsauce/Cargo.toml | 2 + speechsauce/src/config.rs | 59 ++++++------ speechsauce/src/feature.rs | 124 ++++--------------------- speechsauce/src/lib.rs | 59 ++++-------- 6 files changed, 180 insertions(+), 200 deletions(-) create mode 100644 py-speechsauce/speechsauce/__init__.py diff --git a/py-speechsauce/speechsauce/__init__.py b/py-speechsauce/speechsauce/__init__.py new file mode 100644 index 0000000..e1b5d15 --- /dev/null +++ b/py-speechsauce/speechsauce/__init__.py @@ -0,0 +1,58 @@ +from functools import lru_cache + + +@lru_cache(maxsize=32) +def _get_speech_config( + sampling_frequency, + frame_length=0.020, + frame_stride=0.01, + num_cepstral=13, + num_filters=40, + fft_length=512, + low_frequency=0, + high_frequency=None, + dc_elimination=True, +): + """pay no attention to the man behind the curtain + + this function returns a config object to be used by the rust code, avoids recomputing elements where possible + """ + pass + + +def mfcc( + signal, + sampling_frequency, + frame_length=0.020, + frame_stride=0.01, + num_cepstral=13, + num_filters=40, + fft_length=512, + low_frequency=0, + high_frequency=None, + dc_elimination=True, +): + """Compute MFCC features from an audio signal. + Args: + signal (array): the audio signal from which to compute features. + Should be an N x 1 array + sampling_frequency (int): the sampling frequency of the signal + we are working with. + frame_length (float): the length of each frame in seconds. + Default is 0.020s + frame_stride (float): the step between successive frames in seconds. + Default is 0.02s (means no overlap) + num_filters (int): the number of filters in the filterbank, + default 40. + fft_length (int): number of FFT points. Default is 512. + low_frequency (float): lowest band edge of mel filters. + In Hz, default is 0. + high_frequency (float): highest band edge of mel filters. + In Hz, default is samplerate/2 + num_cepstral (int): Number of cepstral coefficients. + dc_elimination (bool): hIf the first dc component should + be eliminated or not. + Returns: + array: A numpy array of size (num_frames x num_cepstral) containing mfcc features. + """ + pass diff --git a/py-speechsauce/src/lib.rs b/py-speechsauce/src/lib.rs index f91e826..011abd8 100644 --- a/py-speechsauce/src/lib.rs +++ b/py-speechsauce/src/lib.rs @@ -1,10 +1,15 @@ - +use numpy::{IntoPyArray, PyArray1, PyArray2, PyReadonlyArray1, PyReadonlyArray2}; use pyo3::prelude::*; -use numpy::{IntoPyArray, PyReadonlyArray1, PyReadonlyArray2, PyArray2, PyArray1}; -use speechsauce::{feature,processing}; +use speechsauce::{config::SpeechConfig, feature, processing}; +#[pyclass] +#[repr(transparent)] +#[derive(Clone)] +pub struct PySpeechConfig { + pub speech_config: SpeechConfig, +} #[pymodule] -fn speechsauce(_py: Python<'_>, m: &PyModule) -> PyResult<()>{ +fn speechsauce(_py: Python<'_>, m: &PyModule) -> PyResult<()> { /// Compute MFCC features from an audio signal. /// Args: /// signal : the audio signal from which to compute features. @@ -29,37 +34,58 @@ fn speechsauce(_py: Python<'_>, m: &PyModule) -> PyResult<()>{ /// array: A numpy array of size (num_frames x num_cepstral) containing mfcc features. #[pyfn(m)] fn mfcc<'py>( - py: Python<'py>, + py: Python<'py>, signal: PyReadonlyArray1, - sampling_frequency: usize, - frame_length: f64, // =0.020, - frame_stride: f64, // =0.01, - num_cepstral: usize, // =13, - num_filters: usize, // =40, - fft_length: usize, // =512, - low_frequency: f64, // =0, - high_frequency: Option, // =None, - dc_elimination: bool, //True - ) -> &'py PyArray2{ - feature::mfcc(signal.as_array(), sampling_frequency, frame_length, frame_stride, num_cepstral, num_filters, fft_length, low_frequency, high_frequency, dc_elimination).into_pyarray(py) + config: PySpeechConfig, + ) -> &'py PyArray2 { + feature::mfcc(signal.as_array(), config.speech_config) } - + //TODO: #14 make signal a mutable borrow (PyReadWriteArray) once the next version of numpy-rust is released #[pyfn(m)] fn preemphasis<'py>( - py: Python<'py>, - signal: PyReadonlyArray1, - shift: isize, - cof: f64 - ) -> &'py PyArray1{ + py: Python<'py>, + signal: PyReadonlyArray1, + shift: isize, + cof: f64, + ) -> &'py PyArray1 { processing::preemphasis(signal.as_array().to_owned(), shift, cof).into_pyarray(py) } #[pyfn(m)] - fn cmvn<'py>(py: Python<'py>, vec: PyReadonlyArray2, variance_normalization: bool)-> &'py PyArray2 - { + fn cmvn<'py>( + py: Python<'py>, + vec: PyReadonlyArray2, + variance_normalization: bool, + ) -> &'py PyArray2 { processing::cmvn(vec.as_array(), variance_normalization).into_pyarray(py) } - + + fn _speech_config<'py>( + py: Python<'py>, + sampling_frequency: usize, + frame_length: f64, // =0.020, + frame_stride: f64, // =0.01, + num_cepstral: usize, // =13, + num_filters: usize, // =40, + fft_length: usize, // =512, + low_frequency: f64, // =0, + high_frequency: Option, // =None, + dc_elimination: bool, //True + ) -> &'py PySpeechConfig { + PySpeechConfig { + speech_config: SpeechConfig::new( + sampling_frequency, + fft_length, + frame_length, + frame_stride, + num_cepstral, + num_filters, + low_frequency, + high_frequency, + dc_elimination, + ), + } + } Ok(()) -} \ No newline at end of file +} diff --git a/speechsauce/Cargo.toml b/speechsauce/Cargo.toml index e66bd69..5ee4511 100644 --- a/speechsauce/Cargo.toml +++ b/speechsauce/Cargo.toml @@ -14,3 +14,5 @@ num-traits = "0.2.15" ndarray={version="^0.15",features=["approx"]} ndarray-rand = "0.14.0" + + diff --git a/speechsauce/src/config.rs b/speechsauce/src/config.rs index 303f4f9..81ee332 100644 --- a/speechsauce/src/config.rs +++ b/speechsauce/src/config.rs @@ -4,7 +4,7 @@ use ndrustfft::{DctHandler, R2cFftHandler}; use crate::feature::filterbanks; #[derive(Default)] -pub struct MfccConfigBuilder { +pub struct SpeechConfigBuilder { ///sampling frequency of the signal sample_rate: usize, /// number of FFT points. @@ -25,9 +25,9 @@ pub struct MfccConfigBuilder { dc_elimination: bool, } -impl MfccConfigBuilder { - fn new(sample_rate: usize) -> MfccConfigBuilder { - MfccConfigBuilder { +impl SpeechConfigBuilder { + pub fn new(sample_rate: usize) -> SpeechConfigBuilder { + SpeechConfigBuilder { sample_rate, fft_points: 512, frame_length: 0.02, @@ -40,43 +40,43 @@ impl MfccConfigBuilder { } } - pub fn high_freq(mut self, high_frequency: f64) -> MfccConfigBuilder { + pub fn high_freq(mut self, high_frequency: f64) -> SpeechConfigBuilder { self.high_frequency = high_frequency; self } - pub fn dc_elimination(mut self, dc_elimination: bool) -> MfccConfigBuilder { + pub fn dc_elimination(mut self, dc_elimination: bool) -> SpeechConfigBuilder { self.dc_elimination = dc_elimination; self } - pub fn low_freq(mut self, low_frequency: f64) -> MfccConfigBuilder { + pub fn low_freq(mut self, low_frequency: f64) -> SpeechConfigBuilder { self.low_frequency = low_frequency; self } - pub fn num_cepstral(mut self, num_cepstral: usize) -> MfccConfigBuilder { + pub fn num_cepstral(mut self, num_cepstral: usize) -> SpeechConfigBuilder { self.num_cepstral = num_cepstral; self } - pub fn frame_stride(mut self, frame_stride: f64) -> MfccConfigBuilder { + pub fn frame_stride(mut self, frame_stride: f64) -> SpeechConfigBuilder { self.frame_stride = frame_stride; self } - pub fn frame_length(mut self, frame_length: f64) -> MfccConfigBuilder { + pub fn frame_length(mut self, frame_length: f64) -> SpeechConfigBuilder { self.frame_length = frame_length; self } - pub fn fft_points(mut self, fft_points: usize) -> MfccConfigBuilder { + pub fn fft_points(mut self, fft_points: usize) -> SpeechConfigBuilder { self.fft_points = fft_points; self } - pub fn build(self) -> MfccConfig { - MfccConfig::new( + pub fn build(self) -> SpeechConfig { + SpeechConfig::new( self.sample_rate, self.fft_points, self.frame_length, @@ -90,33 +90,34 @@ impl MfccConfigBuilder { } } -pub struct MfccConfig { +#[derive(Clone)] +pub struct SpeechConfig { ///sampling frequency of the signal - sample_rate: usize, + pub sample_rate: usize, /// number of FFT points. - fft_points: usize, + pub fft_points: usize, /// the length of each frame in seconds. - frame_length: f64, // =0.020, + pub frame_length: f64, // =0.020, /// the step between successive frames in seconds. - frame_stride: f64, // =0.01, + pub frame_stride: f64, // =0.01, /// Number of cepstral coefficients. - num_cepstral: usize, // =13, + pub num_cepstral: usize, // =13, /// the number of filters in the filterbank - num_filters: usize, // =40, + pub num_filters: usize, // =40, ///lowest band edge of mel filters in Hz - low_frequency: f64, + pub low_frequency: f64, ///highest band edge of mel filters in Hz. - high_frequency: f64, + pub high_frequency: f64, /// If the first dc component should be eliminated or not - dc_elimination: bool, + pub dc_elimination: bool, ///for - dct_handler: DctHandler, - fft_handler: R2cFftHandler, + pub dct_handler: DctHandler, + pub fft_handler: R2cFftHandler, /// Mel-filterbanks - filter_banks: Array2, + pub filter_banks: Array2, } -impl MfccConfig { +impl SpeechConfig { pub fn new( sample_rate: usize, fft_points: usize, @@ -150,7 +151,7 @@ impl MfccConfig { } } - pub fn builder() -> MfccConfigBuilder { - MfccConfigBuilder::default() + pub fn builder() -> SpeechConfigBuilder { + SpeechConfigBuilder::default() } } diff --git a/speechsauce/src/feature.rs b/speechsauce/src/feature.rs index 146efb8..eaa0d80 100644 --- a/speechsauce/src/feature.rs +++ b/speechsauce/src/feature.rs @@ -1,3 +1,4 @@ +use crate::config::SpeechConfig; /// This module provides functions for calculating the main speech /// features that the package is aimed to extract as well as the required elements. use crate::functions::{frequency_to_mel, mel_arr_to_frequency, triangle, zero_handling}; @@ -5,8 +6,8 @@ use crate::processing::stack_frames; use crate::util::ArrayLog; use ndarray::{ - concatenate, s, Array, Array1, Array2, Array3, ArrayView1, ArrayViewMut1, Axis, Dimension, - NewAxis, Slice, + concatenate, s, Array, Array1, Array2, Array3, ArrayBase, ArrayView1, ArrayViewMut1, Axis, Dim, + Dimension, Ix2, NewAxis, Slice, }; use ndrustfft::{nddct2, DctHandler}; @@ -80,49 +81,11 @@ pub(crate) fn filterbanks( /// Args: /// signal : the audio signal from which to compute features. /// Should be an N x 1 array -/// sampling_frequency : the sampling frequency of the signal -/// we are working with. -/// frame_length : -/// Default is 0.020s -/// frame_stride : the step between successive frames in seconds. -/// Default is 0.02s (means no overlap) -/// num_filters : the number of filters in the filterbank, -/// default 40. -/// fft_length : number of FFT points. Default is 512. -/// low_frequency : lowest band edge of mel filters. -/// In Hz, default is 0. -/// high_frequency (float): highest band edge of mel filters. -/// In Hz, default is samplerate/2 -/// num_cepstral : Number of cepstral coefficients. -/// dc_elimination : If the first dc component should -/// be eliminated or not. -/// Returns: -/// array: An array of size (num_frames x num_cepstral) containing mfcc features. -pub fn mfcc( - signal: ArrayView1, - sampling_frequency: usize, - frame_length: f64, // =0.020, - frame_stride: f64, // =0.02, - num_cepstral: usize, // =13, - num_filters: usize, // =40, - fft_length: usize, // =512, - low_frequency: f64, // =0, - high_frequency: Option, // =None, - dc_elimination: bool, //True -) -> Array2 { - let (mut feature, energy) = mfe( - signal, - sampling_frequency, - frame_length, - frame_stride, - num_filters, - fft_length, - low_frequency, - high_frequency, - ); +pub fn mfcc(signal: ArrayView1, speech_config: &SpeechConfig) -> Array2 { + let (mut feature, energy) = mfe(signal, &speech_config); if feature.is_empty() { - return Array::::zeros((0_usize, num_cepstral)); + return Array::::zeros((0_usize, speech_config.num_cepstral)); } feature = feature.log(); //feature second axis equal to num_filters @@ -152,11 +115,11 @@ pub fn mfcc( .slice_axis_mut(Axis(1), Slice::new(1, None, 1)) .mapv_inplace(|x| x * (1. / (2. * n).sqrt())); - transformed_feature = transformed_feature.slice_move(s![.., ..num_cepstral]); + transformed_feature = transformed_feature.slice_move(s![.., ..speech_config.num_cepstral]); // replace first cepstral coefficient with log of frame energy for DC // elimination. - if dc_elimination { + if speech_config.dc_elimination { //>>>x = np.array([[1,2,3,4],[5,6,7,8]]) //>>>x[:,0] //array([1, 5]) @@ -194,50 +157,29 @@ fn _f_it(x: usize) -> Array2 { /// Returns: /// array: features - the energy of fiterbank of size num_frames x num_filters. /// The energy of each frame: num_frames x 1 -pub fn mfe( - signal: ArrayView1, - sampling_frequency: usize, - frame_length: f64, /*=0.020*/ - frame_stride: f64, /*=0.01*/ - num_filters: usize, /*=40*/ - fft_length: usize, /*=512*/ - low_frequency: f64, /*=0*/ - high_frequency: Option, /*None*/ -) -> (Array2, Array1) { +pub fn mfe(signal: ArrayView1, speech_config: &SpeechConfig) -> (Array2, Array1) { // // Stack frames let frames = stack_frames( signal, - sampling_frequency, - frame_length, - frame_stride, + speech_config.sample_rate, + speech_config.frame_length, + speech_config.frame_stride, None, false, ); - // getting the high frequency - let high_frequency = high_frequency.unwrap_or(sampling_frequency as f64 / 2.); - // calculation of the power spectrum - let power_spectrum = crate::processing::power_spectrum(frames, fft_length); - let coefficients = power_spectrum.shape()[1]; + let power_spectrum = crate::processing::power_spectrum(frames, speech_config.fft_points); + // this stores the total energy in each frame let frame_energies = power_spectrum.sum_axis(Axis(1)); // Handling zero energies. let frame_energies = zero_handling(frame_energies); - // Extracting the filterbank - let filter_banks = filterbanks( - num_filters, - coefficients, - sampling_frequency as f64, - Some(low_frequency), - Some(high_frequency), - ); - // Filterbank energies - let features = power_spectrum.dot(&filter_banks.reversed_axes()); + let features = power_spectrum.dot(&speech_config.filter_banks.view().reversed_axes()); let features = crate::functions::zero_handling(features); (features, frame_energies) @@ -247,41 +189,11 @@ pub fn mfe( /// Args: /// signal : the audio signal from which to compute features. /// Should be an N x 1 array -/// sampling_frequency : the sampling frequency of the signal -/// we are working with. -/// frame_length : the length of each frame in seconds. -/// Default is 0.020s -/// frame_stride : the step between successive frames in seconds. -/// Default is 0.02s (means no overlap) -/// num_filters : the number of filters in the filterbank, -/// default 40. -/// fft_length : number of FFT points. Default is 512. -/// low_frequency : lowest band edge of mel filters. -/// In Hz, default is 0. -/// high_frequency : highest band edge of mel filters. -/// In Hz, default is samplerate/2 +/// speech_config: the configuration for the speech processing functions /// Returns: /// array: Features - The log energy of fiterbank of size num_frames x num_filters frame_log_energies. The log energy of each frame num_frames x 1 -fn lmfe( - signal: ArrayView1, - sampling_frequency: usize, - frame_length: f64, /*=0.020*/ - frame_stride: f64, /*=0.01*/ - num_filters: usize, /*=40*/ - fft_length: usize, /*=512*/ - low_frequency: f64, /*=0*/ - high_frequency: Option, /*None*/ -) -> Array2 { - let (feature, _frame_energies) = mfe( - signal, - sampling_frequency, - frame_length, - frame_stride, - num_filters, - fft_length, - low_frequency, - high_frequency, - ); +fn lmfe(signal: ArrayView1, speech_config: &SpeechConfig) -> Array2 { + let (feature, _) = mfe(signal, speech_config); feature.log() } diff --git a/speechsauce/src/lib.rs b/speechsauce/src/lib.rs index e086061..a315959 100644 --- a/speechsauce/src/lib.rs +++ b/speechsauce/src/lib.rs @@ -7,6 +7,7 @@ pub mod util; #[cfg(test)] mod tests { + use crate::config::{SpeechConfig, SpeechConfigBuilder}; use crate::feature::{mfcc, mfe}; use ndarray::Array; use ndarray_rand::rand_distr::{Normal, Uniform}; @@ -24,6 +25,9 @@ mod tests { 16000 } + fn default_config(sample_rate: usize) -> SpeechConfig { + SpeechConfigBuilder::new(sample_rate).build() + } fn get_num_frames( signal_len: usize, sample_rate: usize, @@ -87,32 +91,19 @@ mod tests { let num_cepstral: usize = 13; let sampling_frequency = 16000; - let frame_length = 0.02; - let frame_stride = 0.01; - let num_filters = 40; - let fft_length = 512; - let low_frequency = 0.; - let hi_frequency = None; - let dc_elimination = true; let signal = create_signal(); - let mfcc = mfcc( - signal.view(), - sampling_frequency, - frame_length, - frame_stride, - num_cepstral, - num_filters, - fft_length, - low_frequency, - hi_frequency, - dc_elimination, - ); + let speech_config = default_config(sampling_frequency); + let mfcc = mfcc(signal.view(), &speech_config); for &val in mfcc.iter() { assert!(!val.is_nan()); } - let num_frames = - get_num_frames(signal.len(), sampling_frequency, frame_length, frame_stride); + let num_frames = get_num_frames( + signal.len(), + sampling_frequency, + speech_config.frame_length, + speech_config.frame_stride, + ); assert_eq!(mfcc.shape()[0], num_frames); assert_eq!(mfcc.shape()[1], num_cepstral); @@ -120,31 +111,21 @@ mod tests { #[test] fn test_mfe() { let sampling_frequency = 16000; - let frame_length = 0.02; - let frame_stride = 0.01; - let num_filters = 40; - let fft_length = 512; - let low_frequency = 0.; - let hi_frequency = None; let signal = create_signal(); - let (features, frame_energies) = mfe( - signal.view(), + let speech_config = default_config(sampling_frequency); + let (features, frame_energies) = mfe(signal.view(), &speech_config); + //supposed number of frames (I think) + let num_frames = get_num_frames( + signal.len(), sampling_frequency, - frame_length, - frame_stride, - num_filters, - fft_length, - low_frequency, - hi_frequency, + speech_config.frame_length, + speech_config.frame_stride, ); - //supposed number of frames (I think) - let num_frames = - get_num_frames(signal.len(), sampling_frequency, frame_length, frame_stride); //test shape of outputs assert!(features.shape()[0] == num_frames); - assert!(features.shape()[1] == num_filters); + assert!(features.shape()[1] == speech_config.num_filters); assert!(frame_energies.shape()[0] == num_frames); } } From a0005d41c424dc4e345a53bf7d031b56542449ea Mon Sep 17 00:00:00 2001 From: skewballfox Date: Sun, 1 Jan 2023 20:54:46 -0600 Subject: [PATCH 2/3] working to implement lru_cache on python call arguments --- py-speechsauce/src/lib.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/py-speechsauce/src/lib.rs b/py-speechsauce/src/lib.rs index 011abd8..299c269 100644 --- a/py-speechsauce/src/lib.rs +++ b/py-speechsauce/src/lib.rs @@ -1,4 +1,4 @@ -use numpy::{IntoPyArray, PyArray1, PyArray2, PyReadonlyArray1, PyReadonlyArray2}; +use numpy::{IntoPyArray, PyArray1, PyArray2, PyReadonlyArray1, PyReadonlyArray2, ToPyArray}; use pyo3::prelude::*; use speechsauce::{config::SpeechConfig, feature, processing}; #[pyclass] @@ -38,7 +38,8 @@ fn speechsauce(_py: Python<'_>, m: &PyModule) -> PyResult<()> { signal: PyReadonlyArray1, config: PySpeechConfig, ) -> &'py PyArray2 { - feature::mfcc(signal.as_array(), config.speech_config) + let PySpeechConfig { speech_config } = config; + feature::mfcc(signal.as_array(), &speech_config).to_pyarray(py) } //TODO: #14 make signal a mutable borrow (PyReadWriteArray) once the next version of numpy-rust is released @@ -73,7 +74,7 @@ fn speechsauce(_py: Python<'_>, m: &PyModule) -> PyResult<()> { high_frequency: Option, // =None, dc_elimination: bool, //True ) -> &'py PySpeechConfig { - PySpeechConfig { + &'py PySpeechConfig { speech_config: SpeechConfig::new( sampling_frequency, fft_length, @@ -82,7 +83,7 @@ fn speechsauce(_py: Python<'_>, m: &PyModule) -> PyResult<()> { num_cepstral, num_filters, low_frequency, - high_frequency, + high_frequency.unwrap_or(sampling_frequency as f64 / 2.0), dc_elimination, ), } From 7ff3ee55b168ded3adad9fa82c0d257ddf91645b Mon Sep 17 00:00:00 2001 From: skewballfox Date: Thu, 5 Jan 2023 12:23:33 -0600 Subject: [PATCH 3/3] I think this works, will test later tonight --- .gitignore | 3 ++ py-speechsauce/Cargo.toml | 5 +++- py-speechsauce/speechsauce/__init__.py | 28 +++++++++++++++++-- py-speechsauce/src/lib.rs | 38 +++++++++++++++++++------- 4 files changed, 61 insertions(+), 13 deletions(-) diff --git a/.gitignore b/.gitignore index c382288..5706b6e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ /target Cargo.lock /py-speechsauce/.venv +py-speechsauce/speechsauce/*.abi3.so +py-speechsauce/src/lib.rs +py-speechsauce/**/__pycache__/* diff --git a/py-speechsauce/Cargo.toml b/py-speechsauce/Cargo.toml index 2e9fc38..e784515 100644 --- a/py-speechsauce/Cargo.toml +++ b/py-speechsauce/Cargo.toml @@ -5,9 +5,12 @@ edition = "2018" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [lib] -name = "speechsauce" +name = "speechsauce_python" crate-type = ["cdylib"] +[package.metadata.maturin] +name = "speechsauce._internal" + [dependencies] ndarray="^0.15" pyo3 = {version= "0.16.5", features=["extension-module","abi3-py37"]} diff --git a/py-speechsauce/speechsauce/__init__.py b/py-speechsauce/speechsauce/__init__.py index e1b5d15..8fb3c6d 100644 --- a/py-speechsauce/speechsauce/__init__.py +++ b/py-speechsauce/speechsauce/__init__.py @@ -1,4 +1,7 @@ from functools import lru_cache +from ._internal import mfcc as internal_mfcc, _speech_config, cmvn, preemphasis + +__all__ = ["mfcc", "preemphasis", "cmvn"] @lru_cache(maxsize=32) @@ -17,7 +20,17 @@ def _get_speech_config( this function returns a config object to be used by the rust code, avoids recomputing elements where possible """ - pass + return _speech_config( + sampling_frequency, + frame_length, + frame_stride, + num_cepstral, + num_filters, + fft_length, + low_frequency, + high_frequency, + dc_elimination, + ) def mfcc( @@ -55,4 +68,15 @@ def mfcc( Returns: array: A numpy array of size (num_frames x num_cepstral) containing mfcc features. """ - pass + config = _get_speech_config( + sampling_frequency, + frame_length, + frame_stride, + num_cepstral, + num_filters, + fft_length, + low_frequency, + high_frequency, + dc_elimination, + ) + return internal_mfcc(signal, config) diff --git a/py-speechsauce/src/lib.rs b/py-speechsauce/src/lib.rs index 299c269..b712894 100644 --- a/py-speechsauce/src/lib.rs +++ b/py-speechsauce/src/lib.rs @@ -1,15 +1,28 @@ +use std::sync::Arc; + use numpy::{IntoPyArray, PyArray1, PyArray2, PyReadonlyArray1, PyReadonlyArray2, ToPyArray}; -use pyo3::prelude::*; +use pyo3::{callback::IntoPyCallbackOutput, prelude::*}; use speechsauce::{config::SpeechConfig, feature, processing}; #[pyclass] #[repr(transparent)] #[derive(Clone)] -pub struct PySpeechConfig { - pub speech_config: SpeechConfig, +pub struct PySpeechConfig(SpeechConfig); + +impl IntoPyCallbackOutput for PySpeechConfig { + fn convert(self, py: Python<'_>) -> PyResult { + Ok(self) + } +} + +impl IntoPy for PySpeechConfig { + fn into_py(self, py: Python<'_>) -> SpeechConfig { + self.0 + } } #[pymodule] fn speechsauce(_py: Python<'_>, m: &PyModule) -> PyResult<()> { + m.add_class::()?; /// Compute MFCC features from an audio signal. /// Args: /// signal : the audio signal from which to compute features. @@ -36,9 +49,11 @@ fn speechsauce(_py: Python<'_>, m: &PyModule) -> PyResult<()> { fn mfcc<'py>( py: Python<'py>, signal: PyReadonlyArray1, - config: PySpeechConfig, + config: Py, ) -> &'py PyArray2 { - let PySpeechConfig { speech_config } = config; + let cell = config.as_ref(py); + let obj_ref = cell.borrow(); + let speech_config = &obj_ref.0; feature::mfcc(signal.as_array(), &speech_config).to_pyarray(py) } @@ -62,6 +77,7 @@ fn speechsauce(_py: Python<'_>, m: &PyModule) -> PyResult<()> { processing::cmvn(vec.as_array(), variance_normalization).into_pyarray(py) } + #[pyfn(m)] fn _speech_config<'py>( py: Python<'py>, sampling_frequency: usize, @@ -73,9 +89,10 @@ fn speechsauce(_py: Python<'_>, m: &PyModule) -> PyResult<()> { low_frequency: f64, // =0, high_frequency: Option, // =None, dc_elimination: bool, //True - ) -> &'py PySpeechConfig { - &'py PySpeechConfig { - speech_config: SpeechConfig::new( + ) -> Py { + Py::new( + py, + PySpeechConfig(SpeechConfig::new( sampling_frequency, fft_length, frame_length, @@ -85,8 +102,9 @@ fn speechsauce(_py: Python<'_>, m: &PyModule) -> PyResult<()> { low_frequency, high_frequency.unwrap_or(sampling_frequency as f64 / 2.0), dc_elimination, - ), - } + )), + ) + .unwrap() } Ok(()) }