Skip to content

Commit

Permalink
added Lhotse online augmentation tutorial for SE
Browse files Browse the repository at this point in the history
Signed-off-by: Rauf <[email protected]>
  • Loading branch information
nasretdinovr committed Oct 21, 2024
1 parent ff4e519 commit 1f1e094
Show file tree
Hide file tree
Showing 3 changed files with 1,110 additions and 0 deletions.
119 changes: 119 additions & 0 deletions examples/audio/conf/masking_with_online_augmentation.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
name: "masking_with_online_augmenatation"

model:
sample_rate: 16000
skip_nan_grad: false
num_outputs: 1

train_ds:
use_lhotse: true # enable Lhotse data loader
cuts_path: ??? # path to Lhotse cuts manifest with speech signals for augmentation (including custom "target_recording" field with the same signals)
truncate_duration: 4.0 # Number of STFT time frames = 1 + truncate_duration // encoder.hop_length = 256
truncate_offset_type: random # if the file is longer than truncate_duration, use random offset to select a subsegment
batch_size: 64 # batch size may be increased based on the available memory
shuffle: true
num_workers: 8
pin_memory: true
rir_enabled: true # enable room impulse response augmentation
rir_path: ??? # path to Lhotse recordings manifest with room impulse response signals
noise_path: ??? # path to Lhotse cuts manifest with noise signals

validation_ds:
use_lhotse: true # enable Lhotse data loader
cuts_path: ??? # path to Lhotse cuts manifest with noisy speech signals (including custom "target_recording" field with the clean signals)
batch_size: 64 # batch size may be increased based on the available memory
shuffle: false
num_workers: 4
pin_memory: true

test_ds:
use_lhotse: true # enable Lhotse data loader
cuts_path: ??? # path to Lhotse cuts manifest with noisy speech signals (including custom "target_recording" field with the clean signals)
batch_size: 1 # batch size may be increased based on the available memory
shuffle: false
num_workers: 4
pin_memory: true

encoder:
_target_: nemo.collections.audio.modules.transforms.AudioToSpectrogram
fft_length: 512 # Length of the window and FFT for calculating spectrogram
hop_length: 256 # Hop length for calculating spectrogram

decoder:
_target_: nemo.collections.audio.modules.transforms.SpectrogramToAudio
fft_length: 512 # Length of the window and FFT for calculating spectrogram
hop_length: 256 # Hop length for calculating spectrogram

mask_estimator:
_target_: nemo.collections.audio.modules.masking.MaskEstimatorRNN
num_outputs: ${model.num_outputs}
num_subbands: 257 # Number of subbands of the input spectrogram
num_features: 256 # Number of features at RNN input
num_layers: 5 # Number of RNN layers
bidirectional: true # Use bi-directional RNN

mask_processor:
_target_: nemo.collections.audio.modules.masking.MaskReferenceChannel # Apply mask on the reference channel
ref_channel: 0 # Reference channel for the output

loss:
_target_: nemo.collections.audio.losses.SDRLoss
scale_invariant: true # Use scale-invariant SDR

metrics:
val:
sdr: # output SDR
_target_: torchmetrics.audio.SignalDistortionRatio
test:
sdr_ch0: # SDR on output channel 0
_target_: torchmetrics.audio.SignalDistortionRatio
channel: 0

optim:
name: adamw
lr: 1e-4
# optimizer arguments
betas: [0.9, 0.98]
weight_decay: 1e-3

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: -1
max_steps: -1 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: null
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 25 # Interval of logging.
enable_progress_bar: true
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager

exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: "val_loss"
mode: "min"
save_top_k: 5
always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints

resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
# you need to set these two to true to continue the training
resume_if_exists: false
resume_ignore_no_checkpoint: false

# You may use this section to create a W&B logger
create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
7 changes: 7 additions & 0 deletions nemo/collections/audio/data/audio_to_audio_lhotse.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ def __getitem__(self, cuts: CutSet) -> dict[str, torch.Tensor]:
retained_cuts = [
cut._first_non_padding_cut if isinstance(cut, MixedCut) else cut for cut in retained_padded_cuts
]

# if online augmentation is applied, some retained cuts still may be MixedCuts (including the original speech, noise, and augmentation)
# get the first non-padding cut from there, which is supposed to be the clean speech signal
for n, cut in enumerate(retained_cuts):
if isinstance(cut, MixedCut):
retained_cuts[n] = cut._first_non_padding_cut
# create cutset
retained_cuts = CutSet.from_cuts(retained_cuts)

if _key_available(retained_cuts, self.TARGET_KEY):
Expand Down
Loading

0 comments on commit 1f1e094

Please sign in to comment.