From 49940d898aa95d84878fbee288ba72630d894044 Mon Sep 17 00:00:00 2001 From: Sevag H Date: Sun, 10 Sep 2023 17:31:45 -0400 Subject: [PATCH] Feat/wiener em (#3) * Add wiener-em --- CMakeLists.txt | 2 +- README.md | 24 +- scripts/umx_pytorch_inference.py | 35 ++- src/wiener.cpp | 420 +++++++++++++++++++++++++++++++ src/wiener.hpp | 168 +++++++++++++ umx.cpp | 23 +- 6 files changed, 642 insertions(+), 30 deletions(-) create mode 100644 src/wiener.cpp create mode 100644 src/wiener.hpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 498d06e..321725a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,7 +1,7 @@ # cmake file to compile src/ # link against included submodules libnyquist -cmake_minimum_required(VERSION 3.0) +cmake_minimum_required(VERSION 3.5) if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE Release) diff --git a/README.md b/README.md index f9e2c13..562b126 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # umx.cpp +**:boom: :dizzy: 2023-09-10 update: Wiener-EM is now implemented for maximum performance!** + C++17 implementation of [Open-Unmix](https://github.com/sigsep/open-unmix-pytorch) (UMX), a PyTorch neural network for music demixing. It uses [libnyquist](https://github.com/ddiakopoulos/libnyquist) to load audio files, the [ggml](https://github.com/ggerganov/ggml) file format to serialize the PyTorch weights of `umxhq` and `umxl` to a binary file format, and [Eigen](https://eigen.tuxfamily.org/index.php?title=Main_Page) (+ OpenMP) to implement the inference of Open-Unmix. @@ -8,9 +10,9 @@ The float32 weights of UMX are quantized to uint16 during the conversion to the ## Performance -The demixed output wav files (and their SDR score) of the main program [`umx.cpp`](./umx.cpp) are mostly identical to the PyTorch models (with the post-processing Wiener-EM step disabled): +The demixed output wav files (and their SDR score) of the main program [`umx.cpp`](./umx.cpp) are mostly identical to the PyTorch models: ``` -# first, standard pytorch inference (no wiener-em) +# first, standard pytorch inference $ python ./scripts/umx_pytorch_inference.py \ --model=umxl \ --dest-dir=./umx-py-xl-out \ @@ -28,23 +30,23 @@ $ python ./scripts/evaluate-demixed-output.py \ ./umx-py-xl-out \ 'Punkdisco - Oral Hygiene' -vocals ==> SDR: 7.377 SIR: 16.028 ISR: 15.628 SAR: 8.376 -drums ==> SDR: 8.086 SIR: 12.205 ISR: 17.904 SAR: 9.055 -bass ==> SDR: 5.459 SIR: 8.830 ISR: 13.361 SAR: 10.543 -other ==> SDR: 1.442 SIR: 1.144 ISR: 5.199 SAR: 2.842 +vocals ==> SDR: 7.695 SIR: 17.312 ISR: 16.426 SAR: 8.322 +drums ==> SDR: 8.899 SIR: 14.054 ISR: 14.941 SAR: 9.428 +bass ==> SDR: 8.338 SIR: 14.352 ISR: 14.171 SAR: 10.971 +other ==> SDR: 2.017 SIR: 6.266 ISR: 6.821 SAR: 2.410 $ python ./scripts/evaluate-demixed-output.py \ --musdb-root="/MUSDB18-HQ" \ ./umx-cpp-xl-out \ 'Punkdisco - Oral Hygiene' -vocals ==> SDR: 7.377 SIR: 16.028 ISR: 15.628 SAR: 8.376 -drums ==> SDR: 8.086 SIR: 12.205 ISR: 17.904 SAR: 9.055 -bass ==> SDR: 5.459 SIR: 8.830 ISR: 13.361 SAR: 10.543 -other ==> SDR: 1.442 SIR: 1.144 ISR: 5.199 SAR: 2.842 +vocals ==> SDR: 7.750 SIR: 17.510 ISR: 16.195 SAR: 8.321 +drums ==> SDR: 9.010 SIR: 14.149 ISR: 14.900 SAR: 9.416 +bass ==> SDR: 8.349 SIR: 14.348 ISR: 14.160 SAR: 10.990 +other ==> SDR: 1.987 SIR: 6.282 ISR: 6.674 SAR: 2.461 ``` -In runtime, this is actually slower than the PyTorch inference (and probably much slower than a possible Torch C++ inference implementation). For a 4:23 song, PyTorch takes 13s and umx.cpp takes 22s. +In runtime, this is actually slower than the PyTorch inference (and probably much slower than a possible Torch C++ inference implementation). ## Motivation diff --git a/scripts/umx_pytorch_inference.py b/scripts/umx_pytorch_inference.py index bb9472d..ea2f92c 100644 --- a/scripts/umx_pytorch_inference.py +++ b/scripts/umx_pytorch_inference.py @@ -1,5 +1,6 @@ #!/usr/bin/env python import openunmix +from openunmix.filtering import wiener import torch import torchaudio.backend.sox_io_backend import torchaudio @@ -46,24 +47,44 @@ mag_spec = torch.abs(torch.view_as_complex(spec)) phase_spec = torch.angle(torch.view_as_complex(spec)) + out_mag_specs = [] + # UMX forward inference for target_name, target_model in model.items(): print(f"Inference for target {target_name}") out_mag_spec = target_model(mag_spec) print(type(out_mag_spec)) print(out_mag_spec.shape) + out_mag_specs.append(torch.unsqueeze(out_mag_spec, dim=-1)) + + out_mag_spec_concat = torch.cat(out_mag_specs, dim=-1) + print(f"shape, dtype: {out_mag_spec_concat.shape}, {out_mag_spec_concat.dtype}") + + # Convert back to complex tensor + #out_spec = out_mag_spec * torch.exp(1j * phase_spec) + # do wiener filtering + + wiener_mag_inp = out_mag_spec_concat[0, ...].permute(2, 1, 0, 3) + wiener_spec_inp = spec[0, ...].permute(2, 1, 0, 3) + + out_specs = wiener(wiener_mag_inp, wiener_spec_inp) - # Convert back to complex tensor - out_spec = out_mag_spec * torch.exp(1j * phase_spec) + # out_specs: torch.Size([44, 2049, 2, 2, 4]) + # nb_frames, nb_bins, nb_channels, 2, targets + # 0 1 2 3 4 + # permute: + # 4 2 1 0 3 - # get istft - out_audio = istft(torch.view_as_real(out_spec)) - print(out_audio.shape) - out_audio = torch.squeeze(out_audio, dim=0) + # samples, targets, channels, nb_bins, nb_frames, 2 + out_specs = torch.unsqueeze(out_specs.permute(4, 2, 1, 0, 3), dim=0) + out_audios = istft(out_specs)[0] + print(out_audios.shape) + # get istft + for i, target_name in enumerate(model.keys()): # write to file in directory if args.dest_dir is not None: os.makedirs(args.dest_dir, exist_ok=True) - torchaudio.save(os.path.join(args.dest_dir, f'target_{target_digit_map[target_name]}.wav'), out_audio, sample_rate=44100) + torchaudio.save(os.path.join(args.dest_dir, f'target_{target_digit_map[target_name]}.wav'), out_audios[i], sample_rate=44100) print("Goodbye!") diff --git a/src/wiener.cpp b/src/wiener.cpp new file mode 100644 index 0000000..d7d1c8d --- /dev/null +++ b/src/wiener.cpp @@ -0,0 +1,420 @@ +#include "wiener.hpp" +#include "dsp.hpp" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Function to compute the absolute maximum value from a complex 2D vector +static float find_max_abs(const Eigen::Tensor3dXcf &data, float scale_factor) { + float max_val_im = -1.0f; + for (int i = 0; i < data.dimension(0); ++i) { + for (int j = 0; j < data.dimension(1); ++j) { + for (int k = 0; k < data.dimension(2); ++k) { + max_val_im = std::max(max_val_im, std::sqrt(std::norm(data(i, j, k)))); + } + } + } + return std::max(1.0f, max_val_im/scale_factor); +} + +static void invert5D(umxcpp::Tensor5D& M) { + for (auto& frame : M.data) { + for (auto& bin : frame) { + std::complex a(bin[0][0][0], bin[0][0][1]); + std::complex b(bin[0][1][0], bin[0][1][1]); + std::complex c(bin[1][0][0], bin[1][0][1]); + std::complex d(bin[1][1][0], bin[1][1][1]); + + // Compute the determinant + std::complex det = a * d - b * c; + + // Compute the inverse determinant + // INEXPLICABLE 4.0 factor! + std::complex invDet = 4.0f*std::conj(det)/std::norm(det); + + // Compute the inverse matrix + std::complex tmp00 = invDet * d; + std::complex tmp01 = -invDet * b; + std::complex tmp10 = -invDet * c; + std::complex tmp11 = invDet * a; + + // Update the original tensor + bin[0][0][0] = tmp00.real(); + bin[0][0][1] = tmp00.imag(); + + bin[0][1][0] = tmp01.real(); + bin[0][1][1] = tmp01.imag(); + + bin[1][0][0] = tmp10.real(); + bin[1][0][1] = tmp10.imag(); + + bin[1][1][0] = tmp11.real(); + bin[1][1][1] = tmp11.imag(); + } + } +} + +// Compute the empirical covariance for a source. +// forward decl +static umxcpp::Tensor5D calculateCovariance( + const Eigen::Tensor3dXcf &y_j, + const int pos, + const int t_end +); + +static umxcpp::Tensor4D sumAlongFirstDimension(const umxcpp::Tensor5D& tensor5d) { + int nb_frames = tensor5d.data.size(); + int nb_bins = tensor5d.data[0].size(); + int nb_channels1 = tensor5d.data[0][0].size(); + int nb_channels2 = tensor5d.data[0][0][0].size(); + int nb_reim = tensor5d.data[0][0][0][0].size(); + + // Initialize a 4D tensor filled with zeros + umxcpp::Tensor4D result(nb_bins, nb_channels1, nb_channels2, nb_reim); + + for (int frame = 0; frame < nb_frames; ++frame) { + for (int bin = 0; bin < nb_bins; ++bin) { + for (int ch1 = 0; ch1 < nb_channels1; ++ch1) { + for (int ch2 = 0; ch2 < nb_channels2; ++ch2) { + for (int reim = 0; reim < nb_reim; ++reim) { + result.data[bin][ch1][ch2][reim] += tensor5d.data[frame][bin][ch1][ch2][reim]; + } + } + } + } + } + + return result; +} + +// Wiener filter function +std::array +umxcpp::wiener_filter(Eigen::Tensor3dXcf &mix_stft, + const std::array &targets_mag_spectrograms) +{ + // first just do naive mix-phase + std::array y; + + Eigen::Tensor3dXf mix_phase = mix_stft.unaryExpr( + [](const std::complex &c) { return std::arg(c); }); + + std::cout << "Wiener-EM: Getting first estimates from naive mix-phase" << std::endl; + + for (int target = 0; target < 4; ++target) { + y[target] = umxcpp::polar_to_complex(targets_mag_spectrograms[target], mix_phase); + } + + std::cout << "Wiener-EM: Scaling down by max_abs" << std::endl; + + // we need to refine the estimates. Scales down the estimates for + // numerical stability + float max_abs = find_max_abs(mix_stft, WIENER_SCALE_FACTOR); + + // Dividing mix_stft by max_abs + for (int i = 0; i < mix_stft.dimension(1); ++i) { + for (int j = 0; j < mix_stft.dimension(2); ++j) { + mix_stft(0, i, j) = std::complex{ + mix_stft(0, i, j).real()/max_abs, + mix_stft(0, i, j).imag()/max_abs}; + mix_stft(1, i, j) = std::complex{ + mix_stft(1, i, j).real()/max_abs, + mix_stft(1, i, j).imag()/max_abs}; + } + } + + // Dividing y by max_abs + for (int source = 0; source < 4; ++source) { + for (int i = 0; i < mix_stft.dimension(1); ++i) { + for (int j = 0; j < mix_stft.dimension(2); ++j) { + y[source](0, i, j) = std::complex{ + y[source](0, i, j).real()/max_abs, + y[source](0, i, j).imag()/max_abs}; + y[source](1, i, j) = std::complex{ + y[source](1, i, j).real()/max_abs, + y[source](1, i, j).imag()/max_abs}; + } + } + } + + // call expectation maximization + // y = expectation_maximization(y, mix_stft, iterations, eps=eps)[0] + + const int nb_channels = 2; + const int nb_frames = mix_stft.dimension(1); + const int nb_bins = mix_stft.dimension(2); + const int nb_sources = 4; + const float eps = WIENER_EPS; + + std::cout << "Wiener-EM: Initialize tensors" << std::endl; + + // Create and initialize the 5D tensor + umxcpp::Tensor3D regularization(nb_channels, nb_channels, 2); // The 3D tensor + // Fill the diagonal with sqrt(eps) for all 3D slices in dimensions 0 and 1 + regularization.fill_diagonal(std::sqrt(eps)); + + std::vector R; // A vector to hold each source's covariance matrix + for (int j = 0; j < nb_sources; ++j) { + R.emplace_back(Tensor4D(nb_bins, nb_channels, nb_channels, 2)); + } + + Tensor1D weight(nb_bins); // A 1D tensor (vector) of zeros + Tensor3D v(nb_frames, nb_bins, nb_sources); // A 3D tensor of zeros + + for (int it = 0; it < WIENER_ITERATIONS; ++it) { + std::cout << "Wiener-EM: iteration: " << it << std::endl; + // update the PSD as the average spectrogram over channels + // PSD container is v + std::cout << "\tUpdate PSD `v`" << std::endl; + for (int frame = 0; frame < nb_frames; ++frame) { + for (int bin = 0; bin < nb_bins; ++bin) { + for (int source = 0; source < nb_sources; ++source) { + float sumSquare = 0.0f; + for (int channel = 0; channel < nb_channels; ++channel) { + float realPart = 0.0f; + float imagPart = 0.0f; + + realPart += y[source](channel, frame, bin).real(); + imagPart += y[source](channel, frame, bin).imag(); + + sumSquare += (realPart * realPart) + (imagPart * imagPart); + } + // Divide by the number of channels to get the average + v.data[frame][bin][source] = sumSquare / nb_channels; + } + } + } + + for (int source = 0; source < nb_sources; ++source) { + R[source].setZero(); // Assume Tensor4d has a method to set all its elements to zero + weight.fill(WIENER_EPS); // Initialize with small epsilon (assume Tensor1d has a fill method) + + int pos = 0; + int batchSize = WIENER_EM_BATCH_SIZE > 0 ? WIENER_EM_BATCH_SIZE : nb_frames; + + while (pos < nb_frames) { + std::cout << "\tCovariance loop for source: " << source << ", pos: " << pos << std::endl; + int t_end = std::min(nb_frames, pos + batchSize); + + umxcpp::Tensor5D tempR = calculateCovariance(y[source], pos, t_end); + + // Sum the calculated covariance into R[j] + // Sum along the first (time/frame) dimension to get a 4D tensor + umxcpp::Tensor4D tempR4D = sumAlongFirstDimension(tempR); + + // Add to existing R[j]; (R[j], tempR4D have the same dimensions) + for (std::size_t bin = 0; bin < R[source].data.size(); ++bin) { + for (std::size_t ch1 = 0; ch1 < R[source].data[0].size(); ++ch1) { + for (std::size_t ch2 = 0; ch2 < R[source].data[0][0].size(); ++ch2) { + for (std::size_t reim = 0; reim < R[source].data[0][0][0].size(); ++reim) { + R[source].data[bin][ch1][ch2][reim] += tempR4D.data[bin][ch1][ch2][reim]; + } + } + } + } + + // Update the weight summed v values across the frames for this batch + for (int t = pos; t < t_end; ++t) { + for (int bin = 0; bin < nb_bins; ++bin) { + weight.data[bin] += v.data[t][bin][source]; + } + } + + pos = t_end; + } + + // Normalize R[j] by weight + for (int bin = 0; bin < nb_bins; ++bin) { + for (int ch1 = 0; ch1 < nb_channels; ++ch1) { + for (int ch2 = 0; ch2 < nb_channels; ++ch2) { + for (int k = 0; k < 2; ++k) { + R[source].data[bin][ch1][ch2][k] /= weight.data[bin]; + } + } + } + } + + // Reset the weight for the next iteration + weight.fill(0.0f); + } + + int pos = 0; + int batchSize = WIENER_EM_BATCH_SIZE > 0 ? WIENER_EM_BATCH_SIZE : nb_frames; + while (pos < nb_frames) { + int t_end = std::min(nb_frames, pos + batchSize); + + std::cout << "\tMix covariance loop for pos: " << pos << std::endl; + // Reset y values to zero for this batch + // Assuming you have a way to set all elements of y between frames pos and t_end to 0.0 + for (int source = 0; source < 4; ++source) { + for (int i = pos; i < t_end; ++i) { + for (int j = 0; j < nb_bins; ++j) { + y[source](0, i, j) = std::complex{0.0f, 0.0f}; + y[source](1, i, j) = std::complex{0.0f, 0.0f}; + } + } + } + int nb_frames_chunk = t_end-pos; + + // Compute mix covariance matrix Cxx + //Tensor3D Cxx = regularization; + Tensor5D Cxx(nb_frames_chunk, nb_bins, nb_channels, nb_channels, 2); + + // copy regularization into expanded form in middle of broadcast loop + for (int frame = 0; frame < nb_frames_chunk; ++frame) { + for (int bin = 0; bin < nb_bins; ++bin) { + for (int source = 0; source < nb_sources; ++source) { + float multiplier = v.data[frame+pos][bin][source]; + for (int ch1 = 0; ch1 < nb_channels; ++ch1) { + for (int ch2 = 0; ch2 < nb_channels; ++ch2) { + for (int re_im = 0; re_im < 2; ++re_im) { + Cxx.data[frame][bin][ch1][ch2][re_im] += regularization.data[ch1][ch2][re_im] + multiplier * R[source].data[bin][ch1][ch2][re_im]; + } + } + } + } + } + } + + // Invert Cxx + std::cout << "\tInvert Cxx and Wiener gain calculation" << std::endl; + invert5D(Cxx); // Assuming invertMatrix performs element-wise inversion + Tensor5D inv_Cxx = Cxx; // Assuming copy constructor or assignment operator performs deep copy + + // Separate the sources + for (int source = 0; source < nb_sources; ++source) { + // Initialize with zeros + // create gain with broadcast size of inv_Cxx + Tensor5D gain(nb_frames_chunk, nb_bins, nb_channels, nb_channels, 2); + gain.setZero(); + + for (int frame = 0; frame < nb_frames_chunk; ++frame) { + for (int bin = 0; bin < nb_bins; ++bin) { + for (int ch1 = 0; ch1 < nb_channels; ++ch1) { + for (int ch2 = 0; ch2 < nb_channels; ++ch2) { + for (int ch3 = 0; ch3 < nb_channels; ++ch3) { + auto a = R[source].data[bin][ch1][ch3]; + auto b = inv_Cxx.data[frame][bin][ch3][ch2]; + + gain.data[frame][bin][ch1][ch2][0] += a[0]*b[0] - a[1]*b[1]; + gain.data[frame][bin][ch1][ch2][1] += a[0]*b[1] + a[1]*b[0]; + } + } + } + } + } + + // Element-wise multiplication with v + for (int frame = 0; frame < nb_frames_chunk; ++frame) { + for (int bin = 0; bin < nb_bins; ++bin) { + for (int ch1 = 0; ch1 < nb_channels; ++ch1) { + for (int ch2 = 0; ch2 < nb_channels; ++ch2) { + for (int re_im = 0; re_im < 2; ++re_im) { // Assuming last dimension has size 2 (real/imaginary) + // undoing the inv_Cxx factor of 4.0f + gain.data[frame][bin][ch1][ch2][re_im] *= v.data[frame+pos][bin][source]/4.0f; + } + } + } + } + } + + std::cout << "\tApply gain to y, source: " << source << ", pos: " << pos << std::endl; + // apply it to the mixture + for (int frame = 0; frame < nb_frames_chunk; ++frame) { + for (int bin = 0; bin < nb_bins; ++bin) { + for (int ch1 = 0; ch1 < nb_channels; ++ch1) { + for (int ch2 = 0; ch2 < nb_channels; ++ch2) { + float sample_real = y[source](ch2, frame+pos, bin).real(); + float sample_imag = y[source](ch2, frame+pos, bin).imag(); + + float a_real = gain.data[frame][bin][ch2][ch1][0]; + float a_imag = gain.data[frame][bin][ch2][ch1][1]; + + float b_real = mix_stft(ch1, frame+pos, bin).real(); + float b_imag = mix_stft(ch1, frame+pos, bin).imag(); + + y[source](ch2, frame+pos, bin) = std::complex{ + sample_real + (a_real*b_real - a_imag*b_imag), + sample_imag + (a_real*b_imag + a_imag*b_real)}; + + } + } + } + } + } + + pos = t_end; + } + } + + // scale y by max_abs again + for (int source = 0; source < 4; ++source) { + for (int i = 0; i < mix_stft.dimension(1); ++i) { + for (int j = 0; j < mix_stft.dimension(2); ++j) { + y[source](0, i, j) = std::complex{ + y[source](0, i, j).real()*max_abs, + y[source](0, i, j).imag()*max_abs}; + y[source](1, i, j) = std::complex{ + y[source](1, i, j).real()*max_abs, + y[source](1, i, j).imag()*max_abs}; + } + } + } + + return y; +} + +// Compute the empirical covariance for a source. +/* + * y_j shape: 2, nb_frames_total, 2049 + * pos-t_end = nb_frames (i.e. a chunk of y_j) + * + * returns Cj: + * shape: nb_frames, nb_bins, nb_channels, nb_channels, realim + */ +static umxcpp::Tensor5D calculateCovariance( + const Eigen::Tensor3dXcf &y_j, + const int pos, + const int t_end +) { + //int nb_frames = y_j.dimension(1); + int nb_frames = t_end-pos; + int nb_bins = y_j.dimension(2); + int nb_channels = 2; + + // Initialize Cj tensor with zeros + umxcpp::Tensor5D Cj(nb_frames, nb_bins, nb_channels, nb_channels, 2); + Cj.setZero(); + + for (int frame = 0; frame < nb_frames; ++frame) { + for (int bin = 0; bin < nb_bins; ++bin) { + for (int ch1 = 0; ch1 < nb_channels; ++ch1) { + for (int ch2 = 0; ch2 < nb_channels; ++ch2) { + // assign real + std::complex a = y_j(ch1, frame+pos, bin); + std::complex b = std::conj(y_j(ch2, frame+pos, bin)); + + float a_real = a.real(); + float a_imag = a.imag(); + + float b_real = b.real(); + float b_imag = b.imag(); + + // Update the tensor + // _mul_add y_j, conj(y_j) -> y_j = a, conj = b + Cj.data[frame][bin][ch1][ch2][0] += (a_real*b_real - a_imag*b_imag); + Cj.data[frame][bin][ch1][ch2][1] += (a_real*b_imag + a_imag*b_real); + } + } + } + } + + return Cj; +} diff --git a/src/wiener.hpp b/src/wiener.hpp new file mode 100644 index 0000000..4ab999f --- /dev/null +++ b/src/wiener.hpp @@ -0,0 +1,168 @@ +#ifndef WIENER_HPP +#define WIENER_HPP + +#include "tensor.hpp" +#include "dsp.hpp" +#include +#include +#include +#include +#include + +namespace umxcpp { +const float WIENER_EPS = 1e-10f; +const float WIENER_SCALE_FACTOR = 10.0f; + +// try a smaller batch for memory issues +const int WIENER_EM_BATCH_SIZE = 200; +const int WIENER_ITERATIONS = 1; + +std::array +wiener_filter(Eigen::Tensor3dXcf &mix_spectrogram, + const std::array &targets_mag_spectrograms); + +struct Tensor5D { + std::vector>>>> data; + + Tensor5D(int dim1, int dim2, int dim3, int dim4, int dim5) { + data.resize(dim1); + for (int i = 0; i < dim1; ++i) { + data[i].resize(dim2); + for (int j = 0; j < dim2; ++j) { + data[i][j].resize(dim3); + for (int k = 0; k < dim3; ++k) { + data[i][j][k].resize(dim4); + for (int l = 0; l < dim4; ++l) { + data[i][j][k][l].resize(dim5, 0.0f); // Initializing with 0 + } + } + } + } + } + + // Method to fill diagonal with 1s for a specific 3D slice + void fill_diagonal(int dim1, int dim2, float param = 1.0f) { + for (std::size_t i = 0; i < data[0][0].size(); ++i) { + for (std::size_t j = 0; j < data[0][0][0].size(); ++j) { + if (i == j) { + data[dim1][dim2][i][j][0] = param; + } + } + } + } + + // Method to scale the tensor by a scalar + void scale_by(float scalar) { + for (std::size_t i = 0; i < data.size(); ++i) { + for (std::size_t j = 0; j < data[0].size(); ++j) { + for (std::size_t k = 0; k < data[0][0].size(); ++k) { + for (std::size_t l = 0; l < data[0][0][0].size(); ++l) { + for (std::size_t m = 0; m < data[0][0][0][0].size(); ++m) { + data[i][j][k][l][m] *= scalar; + } + } + } + } + } + } + + void setZero() { + for (std::size_t i = 0; i < data.size(); ++i) { + for (std::size_t j = 0; j < data[0].size(); ++j) { + for (std::size_t k = 0; k < data[0][0].size(); ++k) { + for (std::size_t l = 0; l < data[0][0][0].size(); ++l) { + for (std::size_t m = 0; m < data[0][0][0][0].size(); ++m) { + data[i][j][k][l][m] = 0.0f; + } + } + } + } + } + } +}; + +struct Tensor4D { + std::vector>>> data; + + Tensor4D(int dim1, int dim2, int dim3, int dim4) { + resize(dim1, dim2, dim3, dim4); + } + + void resize(int dim1, int dim2, int dim3, int dim4) { + data.resize(dim1); + for (int i = 0; i < dim1; ++i) { + data[i].resize(dim2); + for (int j = 0; j < dim2; ++j) { + data[i][j].resize(dim3); + for (int k = 0; k < dim3; ++k) { + data[i][j][k].resize(dim4, 0.0f); // Initializing with 0 + } + } + } + } + + void setZero() { + for (std::size_t i = 0; i < data.size(); ++i) { + for (std::size_t j = 0; j < data[0].size(); ++j) { + for (std::size_t k = 0; k < data[0][0].size(); ++k) { + for (std::size_t l = 0; l < data[0][0][0].size(); ++l) { + data[i][j][k][l] = 0.0f; + } + } + } + } + } +}; + +// Tensor3D +struct Tensor3D { + std::vector>> data; + + Tensor3D(int dim1, int dim2, int dim3) { + resize(dim1, dim2, dim3); + } + + void resize(int dim1, int dim2, int dim3) { + data.resize(dim1); + for (int i = 0; i < dim1; ++i) { + data[i].resize(dim2); + for (int j = 0; j < dim2; ++j) { + data[i][j].resize(dim3, 0.0f); // Initializing with 0 + } + } + } + + // Method to fill diagonal with param + void fill_diagonal(float param) { + for (std::size_t i = 0; i < data[0].size(); ++i) { + for (std::size_t j = 0; j < data[0][0].size(); ++j) { + if (i == j) { + data[i][j][0] = param; + } + } + } + } + +}; + +// Tensor1D +struct Tensor1D { + std::vector data; + + Tensor1D(int dim1) { + resize(dim1); + } + + void resize(int dim1) { + data.resize(dim1, 0.0f); // Initializing with 0 + } + + void fill(float value) { + for (std::size_t i = 0; i < data.size(); ++i) { + data[i] = value; + } + } +}; +} // namespace umxcpp + +#endif // WIENER_HPP diff --git a/umx.cpp b/umx.cpp index 8b26278..ef4d7f1 100644 --- a/umx.cpp +++ b/umx.cpp @@ -11,6 +11,7 @@ #include #include #include +#include "wiener.hpp" using namespace umxcpp; @@ -72,10 +73,6 @@ int main(int argc, const char **argv) // now let's get a stereo magnitude spectrogram Eigen::Tensor3dXf mix_mag = spectrogram.abs(); - std::cout << "Computing STFT phase" << std::endl; - Eigen::Tensor3dXf mix_phase = spectrogram.unaryExpr( - [](const std::complex &c) { return std::arg(c); }); - // apply umx inference to the magnitude spectrogram // first create a ggml_tensor for the input @@ -117,6 +114,8 @@ int main(int argc, const char **argv) std::array x_outputs = umx_inference(&model, x, hidden_size); + std::array mag_targets; + #pragma omp parallel for for (int target = 0; target < 4; ++target) { @@ -126,7 +125,7 @@ int main(int argc, const char **argv) << std::endl; // copy mix-mag - Eigen::Tensor3dXf mix_mag_target(mix_mag); + mag_targets[target] = mix_mag; // element-wise multiplication, taking into account the stacked outputs // of the neural network @@ -136,17 +135,18 @@ int main(int argc, const char **argv) #pragma omp parallel for for (std::size_t j = 0; j < mix_mag.dimension(2); j++) { - mix_mag_target(0, i, j) *= x_outputs[target](i, j); - mix_mag_target(1, i, j) *= + mag_targets[target](0, i, j) *= x_outputs[target](i, j); + mag_targets[target](1, i, j) *= x_outputs[target](i, j + mix_mag.dimension(2)); } } + } - // now let's get a stereo waveform back first with phase - // initial estimate - Eigen::MatrixXf audio_target = - istft(polar_to_complex(mix_mag_target, mix_phase)); + // now let's get a stereo waveform back with wiener filtering + std::array complex_targets = umxcpp::wiener_filter(spectrogram, mag_targets); +#pragma omp parallel for + for (int target = 0; target < 4; ++target) { // now write the 4 audio waveforms to files in the output dir // using libnyquist // join out_dir with "/target_0.wav" @@ -163,6 +163,7 @@ int main(int argc, const char **argv) std::cout << "Writing wav file " << p_target << " to " << out_dir << std::endl; + Eigen::MatrixXf audio_target = istft(complex_targets[target]); umxcpp::write_audio_file(audio_target, p_target); } }