diff --git a/wenet/tts/vits/models.py b/wenet/tts/vits/models.py index 2f52914ac7..f43867f99d 100644 --- a/wenet/tts/vits/models.py +++ b/wenet/tts/vits/models.py @@ -15,6 +15,7 @@ from wenet.tts.vits.commons import init_weights, get_padding from wenet.tts.vits.losses import generator_loss, discriminator_loss, feature_loss, kl_loss from wenet.tts.vits.mel_processing import mel_spectrogram_torch +from wenet.utils.mask import make_pad_mask class StochasticDurationPredictor(nn.Module): @@ -791,6 +792,8 @@ def __init__(self, n_vocab, spec_channels, **kwargs): def forward(self, batch: dict, device: torch.device): x = batch['target'].to(device) x_lengths = batch['target_lengths'].to(device) + x_mask = make_pad_mask(x_lengths) + x = x.masked_fill(x_mask, 0) # change pad value(IGNORE_ID = -1) to 0 spec = batch['feats'].to(device) spec_lengths = batch['feats_lengths'].to(device) spec = spec.transpose(1, 2)