Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added Lhotse online augmentation tutorial for SE #10944

Merged
merged 4 commits into from
Oct 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions examples/audio/conf/masking_with_online_augmentation.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
name: "masking_with_online_augmenatation"

model:
sample_rate: 16000
skip_nan_grad: false
num_outputs: 1

train_ds:
use_lhotse: true # enable Lhotse data loader
cuts_path: ??? # path to Lhotse cuts manifest with speech signals for augmentation (including custom "target_recording" field with the same signals)
truncate_duration: 4.0 # Number of STFT time frames = 1 + truncate_duration // encoder.hop_length = 256
truncate_offset_type: random # if the file is longer than truncate_duration, use random offset to select a subsegment
batch_size: 64 # batch size may be increased based on the available memory
shuffle: true
num_workers: 8
pin_memory: true
rir_enabled: true # enable room impulse response augmentation
rir_path: ??? # path to Lhotse recordings manifest with room impulse response signals
noise_path: ??? # path to Lhotse cuts manifest with noise signals

validation_ds:
use_lhotse: true # enable Lhotse data loader
cuts_path: ??? # path to Lhotse cuts manifest with noisy speech signals (including custom "target_recording" field with the clean signals)
batch_size: 64 # batch size may be increased based on the available memory
shuffle: false
num_workers: 4
pin_memory: true

test_ds:
use_lhotse: true # enable Lhotse data loader
cuts_path: ??? # path to Lhotse cuts manifest with noisy speech signals (including custom "target_recording" field with the clean signals)
batch_size: 1 # batch size may be increased based on the available memory
shuffle: false
num_workers: 4
pin_memory: true

encoder:
_target_: nemo.collections.audio.modules.transforms.AudioToSpectrogram
fft_length: 512 # Length of the window and FFT for calculating spectrogram
hop_length: 256 # Hop length for calculating spectrogram

decoder:
_target_: nemo.collections.audio.modules.transforms.SpectrogramToAudio
fft_length: 512 # Length of the window and FFT for calculating spectrogram
hop_length: 256 # Hop length for calculating spectrogram

mask_estimator:
_target_: nemo.collections.audio.modules.masking.MaskEstimatorRNN
num_outputs: ${model.num_outputs}
num_subbands: 257 # Number of subbands of the input spectrogram
num_features: 256 # Number of features at RNN input
num_layers: 5 # Number of RNN layers
bidirectional: true # Use bi-directional RNN

mask_processor:
_target_: nemo.collections.audio.modules.masking.MaskReferenceChannel # Apply mask on the reference channel
ref_channel: 0 # Reference channel for the output

loss:
_target_: nemo.collections.audio.losses.SDRLoss
scale_invariant: true # Use scale-invariant SDR

metrics:
val:
sdr: # output SDR
_target_: torchmetrics.audio.SignalDistortionRatio
test:
sdr_ch0: # SDR on output channel 0
_target_: torchmetrics.audio.SignalDistortionRatio
channel: 0

optim:
name: adamw
lr: 1e-4
# optimizer arguments
betas: [0.9, 0.98]
weight_decay: 1e-3

trainer:
devices: -1 # number of GPUs, -1 would use all available GPUs
num_nodes: 1
max_epochs: -1
max_steps: -1 # computed at runtime if not set
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations
accelerator: auto
strategy: ddp
accumulate_grad_batches: 1
gradient_clip_val: null
precision: 32 # Should be set to 16 for O1 and O2 to enable the AMP.
log_every_n_steps: 25 # Interval of logging.
enable_progress_bar: true
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs
sync_batchnorm: true
enable_checkpointing: False # Provided by exp_manager
logger: false # Provided by exp_manager

exp_manager:
exp_dir: null
name: ${name}
create_tensorboard_logger: true
create_checkpoint_callback: true
checkpoint_callback_params:
# in case of multiple validation sets, first one is used
monitor: "val_loss"
mode: "min"
save_top_k: 5
always_save_nemo: true # saves the checkpoints as nemo files instead of PTL checkpoints

resume_from_checkpoint: null # The path to a checkpoint file to continue the training, restores the whole state including the epoch, step, LR schedulers, apex, etc.
# you need to set these two to true to continue the training
resume_if_exists: false
resume_ignore_no_checkpoint: false

# You may use this section to create a W&B logger
create_wandb_logger: false
wandb_logger_kwargs:
name: null
project: null
7 changes: 7 additions & 0 deletions nemo/collections/audio/data/audio_to_audio_lhotse.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ def __getitem__(self, cuts: CutSet) -> dict[str, torch.Tensor]:
retained_cuts = [
cut._first_non_padding_cut if isinstance(cut, MixedCut) else cut for cut in retained_padded_cuts
]

# if online augmentation is applied, some retained cuts still may be MixedCuts (including the original speech, noise, and augmentation)
# get the first non-padding cut from there, which is supposed to be the clean speech signal
for n, cut in enumerate(retained_cuts):
if isinstance(cut, MixedCut):
retained_cuts[n] = cut._first_non_padding_cut
# create cutset
retained_cuts = CutSet.from_cuts(retained_cuts)

if _key_available(retained_cuts, self.TARGET_KEY):
Expand Down
Loading
Loading