Skip to content

Commit

Permalink
Feat/wiener em (#3)
Browse files Browse the repository at this point in the history
* Add wiener-em
  • Loading branch information
sevagh authored Sep 10, 2023
1 parent 690ec94 commit 49940d8
Show file tree
Hide file tree
Showing 6 changed files with 642 additions and 30 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# cmake file to compile src/
# link against included submodules libnyquist

cmake_minimum_required(VERSION 3.0)
cmake_minimum_required(VERSION 3.5)

if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE Release)
Expand Down
24 changes: 13 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# umx.cpp

**:boom: :dizzy: 2023-09-10 update: Wiener-EM is now implemented for maximum performance!**

C++17 implementation of [Open-Unmix](https://github.com/sigsep/open-unmix-pytorch) (UMX), a PyTorch neural network for music demixing.

It uses [libnyquist](https://github.com/ddiakopoulos/libnyquist) to load audio files, the [ggml](https://github.com/ggerganov/ggml) file format to serialize the PyTorch weights of `umxhq` and `umxl` to a binary file format, and [Eigen](https://eigen.tuxfamily.org/index.php?title=Main_Page) (+ OpenMP) to implement the inference of Open-Unmix.
Expand All @@ -8,9 +10,9 @@ The float32 weights of UMX are quantized to uint16 during the conversion to the

## Performance

The demixed output wav files (and their SDR score) of the main program [`umx.cpp`](./umx.cpp) are mostly identical to the PyTorch models (with the post-processing Wiener-EM step disabled):
The demixed output wav files (and their SDR score) of the main program [`umx.cpp`](./umx.cpp) are mostly identical to the PyTorch models:
```
# first, standard pytorch inference (no wiener-em)
# first, standard pytorch inference
$ python ./scripts/umx_pytorch_inference.py \
--model=umxl \
--dest-dir=./umx-py-xl-out \
Expand All @@ -28,23 +30,23 @@ $ python ./scripts/evaluate-demixed-output.py \
./umx-py-xl-out \
'Punkdisco - Oral Hygiene'
vocals ==> SDR: 7.377 SIR: 16.028 ISR: 15.628 SAR: 8.376
drums ==> SDR: 8.086 SIR: 12.205 ISR: 17.904 SAR: 9.055
bass ==> SDR: 5.459 SIR: 8.830 ISR: 13.361 SAR: 10.543
other ==> SDR: 1.442 SIR: 1.144 ISR: 5.199 SAR: 2.842
vocals ==> SDR: 7.695 SIR: 17.312 ISR: 16.426 SAR: 8.322
drums ==> SDR: 8.899 SIR: 14.054 ISR: 14.941 SAR: 9.428
bass ==> SDR: 8.338 SIR: 14.352 ISR: 14.171 SAR: 10.971
other ==> SDR: 2.017 SIR: 6.266 ISR: 6.821 SAR: 2.410
$ python ./scripts/evaluate-demixed-output.py \
--musdb-root="/MUSDB18-HQ" \
./umx-cpp-xl-out \
'Punkdisco - Oral Hygiene'
vocals ==> SDR: 7.377 SIR: 16.028 ISR: 15.628 SAR: 8.376
drums ==> SDR: 8.086 SIR: 12.205 ISR: 17.904 SAR: 9.055
bass ==> SDR: 5.459 SIR: 8.830 ISR: 13.361 SAR: 10.543
other ==> SDR: 1.442 SIR: 1.144 ISR: 5.199 SAR: 2.842
vocals ==> SDR: 7.750 SIR: 17.510 ISR: 16.195 SAR: 8.321
drums ==> SDR: 9.010 SIR: 14.149 ISR: 14.900 SAR: 9.416
bass ==> SDR: 8.349 SIR: 14.348 ISR: 14.160 SAR: 10.990
other ==> SDR: 1.987 SIR: 6.282 ISR: 6.674 SAR: 2.461
```

In runtime, this is actually slower than the PyTorch inference (and probably much slower than a possible Torch C++ inference implementation). For a 4:23 song, PyTorch takes 13s and umx.cpp takes 22s.
In runtime, this is actually slower than the PyTorch inference (and probably much slower than a possible Torch C++ inference implementation).

## Motivation

Expand Down
35 changes: 28 additions & 7 deletions scripts/umx_pytorch_inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python
import openunmix
from openunmix.filtering import wiener
import torch
import torchaudio.backend.sox_io_backend
import torchaudio
Expand Down Expand Up @@ -46,24 +47,44 @@
mag_spec = torch.abs(torch.view_as_complex(spec))
phase_spec = torch.angle(torch.view_as_complex(spec))

out_mag_specs = []

# UMX forward inference
for target_name, target_model in model.items():
print(f"Inference for target {target_name}")
out_mag_spec = target_model(mag_spec)
print(type(out_mag_spec))
print(out_mag_spec.shape)
out_mag_specs.append(torch.unsqueeze(out_mag_spec, dim=-1))

out_mag_spec_concat = torch.cat(out_mag_specs, dim=-1)
print(f"shape, dtype: {out_mag_spec_concat.shape}, {out_mag_spec_concat.dtype}")

# Convert back to complex tensor
#out_spec = out_mag_spec * torch.exp(1j * phase_spec)
# do wiener filtering

wiener_mag_inp = out_mag_spec_concat[0, ...].permute(2, 1, 0, 3)
wiener_spec_inp = spec[0, ...].permute(2, 1, 0, 3)

out_specs = wiener(wiener_mag_inp, wiener_spec_inp)

# Convert back to complex tensor
out_spec = out_mag_spec * torch.exp(1j * phase_spec)
# out_specs: torch.Size([44, 2049, 2, 2, 4])
# nb_frames, nb_bins, nb_channels, 2, targets
# 0 1 2 3 4
# permute:
# 4 2 1 0 3

# get istft
out_audio = istft(torch.view_as_real(out_spec))
print(out_audio.shape)
out_audio = torch.squeeze(out_audio, dim=0)
# samples, targets, channels, nb_bins, nb_frames, 2
out_specs = torch.unsqueeze(out_specs.permute(4, 2, 1, 0, 3), dim=0)
out_audios = istft(out_specs)[0]
print(out_audios.shape)

# get istft
for i, target_name in enumerate(model.keys()):
# write to file in directory
if args.dest_dir is not None:
os.makedirs(args.dest_dir, exist_ok=True)
torchaudio.save(os.path.join(args.dest_dir, f'target_{target_digit_map[target_name]}.wav'), out_audio, sample_rate=44100)
torchaudio.save(os.path.join(args.dest_dir, f'target_{target_digit_map[target_name]}.wav'), out_audios[i], sample_rate=44100)

print("Goodbye!")
Loading

0 comments on commit 49940d8

Please sign in to comment.