Skip to content

Commit

Permalink
Improvements (#2)
Browse files Browse the repository at this point in the history
* Vendor Eigen, use Eigen Matrix/Tensor for waveform/spectrogram classes
  • Loading branch information
sevagh authored Jul 28, 2023
1 parent 5af9ab9 commit 690ec94
Show file tree
Hide file tree
Showing 12 changed files with 487 additions and 440 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@
[submodule "vendor/libnyquist"]
path = vendor/libnyquist
url = https://github.com/ddiakopoulos/libnyquist
[submodule "vendor/eigen"]
path = vendor/eigen
url = https://gitlab.com/libeigen/eigen
7 changes: 3 additions & 4 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ set(LIBNYQUIST_BUILD_EXAMPLE OFF CACHE BOOL "Disable libnyquist example")
add_subdirectory(vendor/libnyquist)

# add library Eigen3
find_package(Eigen3 REQUIRED)
include_directories(${EIGEN3_INCLUDE_DIR})
include_directories(vendor/eigen)

# add OpenBLAS for blas + lapack
find_package(BLAS REQUIRED)
Expand All @@ -57,9 +56,9 @@ include_directories(src)
# include src/*.cpp and src/*.c as source files
file(GLOB SOURCES "src/*.cpp")

# compile library, link against libnyquist and Eigen3
# compile library, link against libnyquist
add_library(umx.cpp.lib SHARED ${SOURCES})
target_link_libraries(umx.cpp.lib libnyquist Eigen3::Eigen ${BLAS_LIBRARIES} ${LAPACK_LIBRARIES})
target_link_libraries(umx.cpp.lib libnyquist ${BLAS_LIBRARIES} ${LAPACK_LIBRARIES} lapacke)
if(OPENMP_FOUND)
target_link_libraries(umx.cpp.lib ${OpenMP_CXX_LIBRARIES})
endif()
Expand Down
203 changes: 100 additions & 103 deletions src/dsp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ using namespace nqr;

static constexpr float PI = 3.14159265359F;

umxcpp::StereoWaveform umxcpp::load_audio(std::string filename)
Eigen::MatrixXf umxcpp::load_audio(std::string filename)
{
// load a wav file with libnyquist
std::shared_ptr<AudioData> fileData = std::make_shared<AudioData>();
Expand Down Expand Up @@ -47,34 +47,32 @@ umxcpp::StereoWaveform umxcpp::load_audio(std::string filename)
size_t N = fileData->samples.size() / fileData->channelCount;

// create a struct to hold two float vectors for left and right channels
umxcpp::StereoWaveform ret;
ret.left.resize(N);
ret.right.resize(N);
Eigen::MatrixXf ret(2, N);

if (fileData->channelCount == 1)
{
// Mono case
for (size_t i = 0; i < N; ++i)
{
ret.left[i] = fileData->samples[i]; // left channel
ret.right[i] = fileData->samples[i]; // right channel
ret(0, i) = fileData->samples[i]; // left channel
ret(1, i) = fileData->samples[i]; // right channel
}
}
else
{
// Stereo case
for (size_t i = 0; i < N; ++i)
{
ret.left[i] = fileData->samples[2 * i]; // left channel
ret.right[i] = fileData->samples[2 * i + 1]; // right channel
ret(0, i) = fileData->samples[2 * i]; // left channel
ret(1, i) = fileData->samples[2 * i + 1]; // right channel
}
}

return ret;
}

// write a function to write a StereoWaveform to a wav file
void umxcpp::write_audio_file(const umxcpp::StereoWaveform &waveform,
void umxcpp::write_audio_file(const Eigen::MatrixXf &waveform,
std::string filename)
{
// create a struct to hold the audio data
Expand All @@ -87,18 +85,13 @@ void umxcpp::write_audio_file(const umxcpp::StereoWaveform &waveform,
fileData->channelCount = 2;

// set the number of samples
fileData->samples.resize(waveform.left.size() * 2);
fileData->samples.resize(waveform.cols() * 2);

// write the left channel
for (size_t i = 0; i < waveform.left.size(); ++i)
for (size_t i = 0; i < waveform.cols(); ++i)
{
fileData->samples[2 * i] = waveform.left[i];
}

// write the right channel
for (size_t i = 0; i < waveform.right.size(); ++i)
{
fileData->samples[2 * i + 1] = waveform.right[i];
fileData->samples[2 * i] = waveform(0, i);
fileData->samples[2 * i + 1] = waveform(1, i);
}

int encoderStatus =
Expand Down Expand Up @@ -135,134 +128,138 @@ std::vector<float> hann_window(int window_size)
}

// reflect padding
std::vector<float> pad_signal(const std::vector<float> &signal, int n_fft)
void pad_signal(std::vector<float> &signal, int n_fft)
{
int pad = n_fft / 2;
std::vector<float> pad_start(signal.begin(), signal.begin() + pad);
std::vector<float> pad_end(signal.end() - pad, signal.end());
std::reverse(pad_start.begin(), pad_start.end());
std::reverse(pad_end.begin(), pad_end.end());
std::vector<float> padded_signal = signal;
padded_signal.insert(padded_signal.begin(), pad_start.begin(),
pad_start.end());
padded_signal.insert(padded_signal.end(), pad_end.begin(), pad_end.end());
return padded_signal;
signal.insert(signal.begin(), pad_start.begin(), pad_start.end());
signal.insert(signal.end(), pad_end.begin(), pad_end.end());
}

// reflect unpadding
std::vector<float> unpad_signal(const std::vector<float> &signal, int n_fft)
void unpad_signal(std::vector<float> &signal, int n_fft)
{
int pad = n_fft / 2;
std::vector<float> unpadded_signal = signal;
unpadded_signal.erase(unpadded_signal.begin(),
unpadded_signal.begin() +
pad); // remove 'pad' elements from the start
unpadded_signal.erase(
unpadded_signal.end() - pad,
unpadded_signal.end()); // remove 'pad' elements from the end
return unpadded_signal;
signal.erase(signal.begin(),
signal.begin() + pad); // remove 'pad' elements from the start

auto it = signal.end() - pad;
signal.erase(it, signal.end()); // remove 'pad' elements from the end
}

umxcpp::StereoSpectrogramR umxcpp::magnitude(const StereoSpectrogramC &spec)
Eigen::Tensor3dXcf umxcpp::polar_to_complex(const Eigen::Tensor3dXf &magnitude,
const Eigen::Tensor3dXf &phase)
{
// compute the magnitude of a complex spectrogram
StereoSpectrogramR ret;
ret.left.resize(spec.left.size());
ret.right.resize(spec.right.size());
// Assert dimensions are the same
assert(magnitude.dimensions() == phase.dimensions());

for (std::size_t i = 0; i < spec.left.size(); ++i)
{
ret.left[i].resize(spec.left[i].size());
ret.right[i].resize(spec.right[i].size());
// Get dimensions for convenience
int dim1 = magnitude.dimension(0);
int dim2 = magnitude.dimension(1);
int dim3 = magnitude.dimension(2);

for (std::size_t j = 0; j < spec.left[i].size(); ++j)
// Initialize complex spectrogram tensor
Eigen::Tensor3dXcf complex_spectrogram(dim1, dim2, dim3);

// Iterate over all indices and apply the transformation
for (int i = 0; i < dim1; ++i)
{
for (int j = 0; j < dim2; ++j)
{
// compute the magnitude on the std::complex<float>
ret.left[i][j] = std::abs(spec.left[i][j]);
ret.right[i][j] = std::abs(spec.right[i][j]);
for (int k = 0; k < dim3; ++k)
{
float mag = magnitude(i, j, k);
float ph = phase(i, j, k);
complex_spectrogram(i, j, k) = std::polar(mag, ph);
}
}
}

return ret;
return complex_spectrogram;
}

// repeat the above magnitude function but adapt for phase
umxcpp::StereoSpectrogramR umxcpp::phase(const StereoSpectrogramC &spec)
Eigen::Tensor3dXcf umxcpp::stft(const Eigen::MatrixXf &audio)
{
// compute the phase of a complex spectrogram
StereoSpectrogramR ret;
ret.left.resize(spec.left.size());
ret.right.resize(spec.right.size());
auto window = hann_window(FFT_WINDOW_SIZE);

for (std::size_t i = 0; i < spec.left.size(); ++i)
{
ret.left[i].resize(spec.left[i].size());
ret.right[i].resize(spec.right[i].size());
// apply padding equivalent to center padding with center=True
// in torch.stft:
// https://pytorch.org/docs/stable/generated/torch.stft.html

for (std::size_t j = 0; j < spec.left[i].size(); ++j)
{
// compute phase using std::complex<float>
ret.left[i][j] = std::arg(spec.left[i][j]);
ret.right[i][j] = std::arg(spec.right[i][j]);
}
}
std::vector<float> audio_left(audio.row(0).size());
Eigen::VectorXf row_vec = audio.row(0);
std::copy_n(row_vec.data(), row_vec.size(), audio_left.begin());

return ret;
}
std::vector<float> audio_right(audio.row(1).size());
row_vec = audio.row(1);
std::copy_n(row_vec.data(), row_vec.size(), audio_right.begin());

umxcpp::StereoSpectrogramC umxcpp::combine(const StereoSpectrogramR &mag,
const StereoSpectrogramR &phase)
{
// combine magnitude and phase into a complex spectrogram
StereoSpectrogramC ret;
ret.left.resize(mag.left.size());
ret.right.resize(mag.right.size());
pad_signal(audio_left, FFT_WINDOW_SIZE);
pad_signal(audio_right, FFT_WINDOW_SIZE);

for (std::size_t i = 0; i < mag.left.size(); ++i)
{
ret.left[i].resize(mag.left[i].size());
ret.right[i].resize(mag.right[i].size());
auto stft_left =
stft_inner(audio_left, window, FFT_WINDOW_SIZE, FFT_HOP_SIZE);
auto stft_right =
stft_inner(audio_right, window, FFT_WINDOW_SIZE, FFT_HOP_SIZE);

// get the size of rows and cols
int rows = stft_left.size();
int cols = stft_left[0].size();

Eigen::Tensor3dXcf spec(2, rows, cols);

for (std::size_t j = 0; j < mag.left[i].size(); ++j)
for (int i = 0; i < rows; ++i)
{
for (int j = 0; j < cols; ++j)
{
// compute the complex number from the polar form
ret.left[i][j] = std::polar(mag.left[i][j], phase.left[i][j]);
ret.right[i][j] = std::polar(mag.right[i][j], phase.right[i][j]);
spec(0, i, j) = stft_left[i][j];
spec(1, i, j) = stft_right[i][j];
}
}

return ret;
return spec;
}

umxcpp::StereoSpectrogramC umxcpp::stft(const StereoWaveform &audio)
Eigen::MatrixXf umxcpp::istft(const Eigen::Tensor3dXcf &spec)
{
StereoSpectrogramC spec;
auto window = hann_window(FFT_WINDOW_SIZE);

// apply padding equivalent to center padding with center=True
// in torch.stft:
// https://pytorch.org/docs/stable/generated/torch.stft.html
auto chn_left = pad_signal(audio.left, FFT_WINDOW_SIZE);
auto chn_right = pad_signal(audio.right, FFT_WINDOW_SIZE);
int rows = spec.dimension(1);
int cols = spec.dimension(2);

spec.left = stft_inner(chn_left, window, FFT_WINDOW_SIZE, FFT_HOP_SIZE);
spec.right = stft_inner(chn_right, window, FFT_WINDOW_SIZE, FFT_HOP_SIZE);
// Create the nested vectors
std::vector<std::vector<std::complex<float>>> stft_left(
rows, std::vector<std::complex<float>>(cols));
std::vector<std::vector<std::complex<float>>> stft_right(
rows, std::vector<std::complex<float>>(cols));

return spec;
}
// Populate the nested vectors
for (int i = 0; i < rows; ++i)
{
for (int j = 0; j < cols; ++j)
{
stft_left[i][j] = spec(0, i, j);
stft_right[i][j] = spec(1, i, j);
}
}

umxcpp::StereoWaveform umxcpp::istft(const StereoSpectrogramC &spec)
{
StereoWaveform audio;
auto window = hann_window(FFT_WINDOW_SIZE);
std::vector<float> chn_left =
istft_inner(stft_left, window, FFT_WINDOW_SIZE, FFT_HOP_SIZE);
std::vector<float> chn_right =
istft_inner(stft_right, window, FFT_WINDOW_SIZE, FFT_HOP_SIZE);

unpad_signal(chn_left, FFT_WINDOW_SIZE);
unpad_signal(chn_right, FFT_WINDOW_SIZE);

auto chn_left =
istft_inner(spec.left, window, FFT_WINDOW_SIZE, FFT_HOP_SIZE);
auto chn_right =
istft_inner(spec.right, window, FFT_WINDOW_SIZE, FFT_HOP_SIZE);
Eigen::MatrixXf audio(2, chn_left.size());

audio.left = unpad_signal(chn_left, FFT_WINDOW_SIZE);
audio.right = unpad_signal(chn_right, FFT_WINDOW_SIZE);
audio.row(0) =
Eigen::Map<Eigen::MatrixXf>(chn_left.data(), 1, chn_left.size());
audio.row(1) =
Eigen::Map<Eigen::MatrixXf>(chn_right.data(), 1, chn_right.size());

return audio;
}
Expand Down
Loading

0 comments on commit 690ec94

Please sign in to comment.