From 659629bd7581eb7b56a8655bdd6cbe1d4879d543 Mon Sep 17 00:00:00 2001 From: Kikkia Date: Wed, 17 Jan 2024 22:31:49 -0800 Subject: [PATCH 01/16] add json support to the TTS server api. This allows you to send a json request body with your text and will allow for longer texts that otherwise exceed the maximum http url length --- TTS/server/server.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/TTS/server/server.py b/TTS/server/server.py index 6b2141a9aa..58b0ddca69 100644 --- a/TTS/server/server.py +++ b/TTS/server/server.py @@ -191,10 +191,19 @@ def details(): @app.route("/api/tts", methods=["GET", "POST"]) def tts(): with lock: - text = request.headers.get("text") or request.values.get("text", "") - speaker_idx = request.headers.get("speaker-id") or request.values.get("speaker_id", "") - language_idx = request.headers.get("language-id") or request.values.get("language_id", "") - style_wav = request.headers.get("style-wav") or request.values.get("style_wav", "") + try: + data = request.get_json() + text = data.get("text", "") + speaker_idx = data.get("speaker-id", "") + language_idx = data.get("language-id", "") + style_wav = data.get("style-wav", "") + except: + # Fallback to headers and form data if JSON data is not present + text = request.headers.get("text") or request.values.get("text", "") + speaker_idx = request.headers.get("speaker-id") or request.values.get("speaker_id", "") + language_idx = request.headers.get("language-id") or request.values.get("language_id", "") + style_wav = request.headers.get("style-wav") or request.values.get("style_wav", "") + style_wav = style_wav_uri_to_dict(style_wav) print(f" > Model input: {text}") From 8c4d0142b71e2b40d9cdaa433263e18960bd9c58 Mon Sep 17 00:00:00 2001 From: Subuday Date: Sun, 11 Feb 2024 21:02:20 +0000 Subject: [PATCH 02/16] Add MatchaTTS backbone --- TTS/tts/configs/matcha_tts.py | 9 ++++++++ TTS/tts/models/matcha_tts.py | 30 ++++++++++++++++++++++++++ tests/tts_tests2/test_matcha_tts.py | 33 +++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+) create mode 100644 TTS/tts/configs/matcha_tts.py create mode 100644 TTS/tts/models/matcha_tts.py create mode 100644 tests/tts_tests2/test_matcha_tts.py diff --git a/TTS/tts/configs/matcha_tts.py b/TTS/tts/configs/matcha_tts.py new file mode 100644 index 0000000000..15bb91b829 --- /dev/null +++ b/TTS/tts/configs/matcha_tts.py @@ -0,0 +1,9 @@ +from dataclasses import dataclass, field + +from TTS.tts.configs.shared_configs import BaseTTSConfig + + +@dataclass +class MatchaTTSConfig(BaseTTSConfig): + model: str = "matcha_tts" + num_chars: int = None diff --git a/TTS/tts/models/matcha_tts.py b/TTS/tts/models/matcha_tts.py new file mode 100644 index 0000000000..08c0022b19 --- /dev/null +++ b/TTS/tts/models/matcha_tts.py @@ -0,0 +1,30 @@ +import torch + +from TTS.tts.configs.matcha_tts import MatchaTTSConfig +from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.text.tokenizer import TTSTokenizer + + +class MatchaTTS(BaseTTS): + + def __init__( + self, + config: MatchaTTSConfig, + ap: "AudioProcessor" = None, + tokenizer: "TTSTokenizer" = None, + ): + super().__init__(config, ap, tokenizer) + + def forward(self): + pass + + @torch.no_grad() + def inference(self): + pass + + @staticmethod + def init_from_config(config: "MatchaTTSConfig"): + pass + + def load_checkpoint(self, checkpoint_path): + pass diff --git a/tests/tts_tests2/test_matcha_tts.py b/tests/tts_tests2/test_matcha_tts.py new file mode 100644 index 0000000000..1939efbd60 --- /dev/null +++ b/tests/tts_tests2/test_matcha_tts.py @@ -0,0 +1,33 @@ +import unittest + +import torch + +from TTS.tts.configs.matcha_tts import MatchaTTSConfig +from TTS.tts.models.matcha_tts import MatchaTTS + +torch.manual_seed(1) +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +c = MatchaTTSConfig() + + +class TestMatchTTS(unittest.TestCase): + @staticmethod + def _create_inputs(batch_size=8): + input_dummy = torch.randint(0, 24, (batch_size, 128)).long().to(device) + input_lengths = torch.randint(100, 129, (batch_size,)).long().to(device) + input_lengths[-1] = 128 + mel_spec = torch.rand(batch_size, 30, c.audio["num_mels"]).to(device) + mel_lengths = torch.randint(20, 30, (batch_size,)).long().to(device) + speaker_ids = torch.randint(0, 5, (batch_size,)).long().to(device) + return input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids + + def _test_forward(self, batch_size): + input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs(batch_size) + config = MatchaTTSConfig(num_chars=32) + model = MatchaTTS(config).to(device) + + model.train() + + def test_forward(self): + self._test_forward(1) From 7314b1cbec969647ba777af59ee96d0d475cb264 Mon Sep 17 00:00:00 2001 From: Subuday Date: Mon, 12 Feb 2024 19:39:22 +0000 Subject: [PATCH 03/16] Implement model forward --- TTS/tts/layers/matcha_tts/decoder.py | 24 +++++++++++ TTS/tts/models/matcha_tts.py | 59 +++++++++++++++++++++++++++- tests/tts_tests2/test_matcha_tts.py | 2 + 3 files changed, 83 insertions(+), 2 deletions(-) create mode 100644 TTS/tts/layers/matcha_tts/decoder.py diff --git a/TTS/tts/layers/matcha_tts/decoder.py b/TTS/tts/layers/matcha_tts/decoder.py new file mode 100644 index 0000000000..de7f52dc24 --- /dev/null +++ b/TTS/tts/layers/matcha_tts/decoder.py @@ -0,0 +1,24 @@ +import torch +from torch import nn +import torch.nn.functional as F + + +class Decoder(nn.Module): + def __init__(self): + super().__init__() + self.sigma_min = 1e-5 + + def forward(self, x_1, mean, mask): + """ + Shapes: + - x_1: :math:`[B, C, T]` + - mean: :math:`[B, C ,T]` + - mask: :math:`[B, 1, T]` + """ + t = torch.rand([x_1.size(0), 1, 1], device=x_1.device, dtype=x_1.dtype) + x_0 = torch.randn_like(x_1) + x_t = (1 - (1 - self.sigma_min) * t) * x_0 + t * x_1 + u_t = x_1 - (1 - self.sigma_min) * x_0 + v_t = torch.randn_like(u_t) + loss = F.mse_loss(v_t, u_t, reduction="sum") / (torch.sum(mask) * u_t.shape[1]) + return loss diff --git a/TTS/tts/models/matcha_tts.py b/TTS/tts/models/matcha_tts.py index 08c0022b19..9bc3e0ffc4 100644 --- a/TTS/tts/models/matcha_tts.py +++ b/TTS/tts/models/matcha_tts.py @@ -1,7 +1,12 @@ +from dataclasses import field +import math import torch from TTS.tts.configs.matcha_tts import MatchaTTSConfig +from TTS.tts.layers.glow_tts.encoder import Encoder +from TTS.tts.layers.matcha_tts.decoder import Decoder from TTS.tts.models.base_tts import BaseTTS +from TTS.tts.utils.helpers import maximum_path, sequence_mask from TTS.tts.utils.text.tokenizer import TTSTokenizer @@ -14,9 +19,59 @@ def __init__( tokenizer: "TTSTokenizer" = None, ): super().__init__(config, ap, tokenizer) + self.encoder = Encoder( + self.config.num_chars, + out_channels=80, + hidden_channels=192, + hidden_channels_dp=256, + encoder_type='rel_pos_transformer', + encoder_params={ + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 6, + "num_heads": 2, + "hidden_channels_ffn": 768, + } + ) - def forward(self): - pass + self.decoder = Decoder() + + def forward(self, x, x_lengths, y, y_lengths): + """ + Args: + x (torch.Tensor): + Input text sequence ids. :math:`[B, T_en]` + + x_lengths (torch.Tensor): + Lengths of input text sequences. :math:`[B]` + + y (torch.Tensor): + Target mel-spectrogram frames. :math:`[B, T_de, C_mel]` + + y_lengths (torch.Tensor): + Lengths of target mel-spectrogram frames. :math:`[B]` + """ + y = y.transpose(1, 2) + y_max_length = y.size(2) + + o_mean, o_log_scale, o_log_dur, o_mask = self.encoder(x, x_lengths, g=None) + + y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(o_mask.dtype) + attn_mask = torch.unsqueeze(o_mask, -1) * torch.unsqueeze(y_mask, 2) + + with torch.no_grad(): + o_scale = torch.exp(-2 * o_log_scale) + logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) + logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (y**2)) + logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), y) + logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) + logp = logp1 + logp2 + logp3 + logp4 + attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() + + # Align encoded text with mel-spectrogram and get mu_y segment + c_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(1, 2) + + _ = self.decoder(x_1=y, mean=c_mean, mask=y_mask) @torch.no_grad() def inference(self): diff --git a/tests/tts_tests2/test_matcha_tts.py b/tests/tts_tests2/test_matcha_tts.py index 1939efbd60..bc94c6b4f1 100644 --- a/tests/tts_tests2/test_matcha_tts.py +++ b/tests/tts_tests2/test_matcha_tts.py @@ -29,5 +29,7 @@ def _test_forward(self, batch_size): model.train() + model.forward(input_dummy, input_lengths, mel_spec, mel_lengths) + def test_forward(self): self._test_forward(1) From b5467b8051c5ba3e639abc2efd36456c3ef1802d Mon Sep 17 00:00:00 2001 From: Subuday Date: Mon, 12 Feb 2024 21:44:29 +0000 Subject: [PATCH 04/16] Add UNet backbone --- TTS/tts/layers/matcha_tts/UNet.py | 64 ++++++++++++++++++++++++++++ TTS/tts/layers/matcha_tts/decoder.py | 9 +++- tests/tts_tests2/test_matcha_tts.py | 1 + 3 files changed, 73 insertions(+), 1 deletion(-) create mode 100644 TTS/tts/layers/matcha_tts/UNet.py diff --git a/TTS/tts/layers/matcha_tts/UNet.py b/TTS/tts/layers/matcha_tts/UNet.py new file mode 100644 index 0000000000..07616290cd --- /dev/null +++ b/TTS/tts/layers/matcha_tts/UNet.py @@ -0,0 +1,64 @@ +import math +from einops import pack +import torch +from torch import nn + + +class PositionalEncoding(torch.nn.Module): + def __init__(self, channels): + super().__init__() + self.channels = channels + + def forward(self, x, scale=1000): + if x.ndim < 1: + x = x.unsqueeze(0) + emb = math.log(10000) / (self.channels // 2 - 1) + emb = torch.exp(torch.arange(self.channels // 2, device=x.device).float() * -emb) + emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class UNet(nn.Module): + def __init__( + self, + in_channels: int, + model_channels: int, + out_channels: int, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.time_encoder = PositionalEncoding(in_channels) + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + nn.Linear(in_channels, time_embed_dim), + nn.SiLU(), + nn.Linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList([]) + self.middle_blocks = nn.ModuleList([]) + self.output_blocks = nn.ModuleList([]) + + self.conv = nn.Conv1d(model_channels, self.out_channels, 1) + + def forward(self, x_t, mean, mask, t): + t = self.time_encoder(t) + t = self.time_embed(t) + + x_t = pack([x_t, mean], "b * t")[0] + + for _ in self.input_blocks: + pass + + for _ in self.middle_blocks: + pass + + for _ in self.output_blocks: + pass + + output = self.conv(x_t) + + return output * mask \ No newline at end of file diff --git a/TTS/tts/layers/matcha_tts/decoder.py b/TTS/tts/layers/matcha_tts/decoder.py index de7f52dc24..e78d34cf98 100644 --- a/TTS/tts/layers/matcha_tts/decoder.py +++ b/TTS/tts/layers/matcha_tts/decoder.py @@ -2,11 +2,18 @@ from torch import nn import torch.nn.functional as F +from TTS.tts.layers.matcha_tts.UNet import UNet + class Decoder(nn.Module): def __init__(self): super().__init__() self.sigma_min = 1e-5 + self.predictor = UNet( + in_channels=80, + model_channels=160, + out_channels=80, + ) def forward(self, x_1, mean, mask): """ @@ -19,6 +26,6 @@ def forward(self, x_1, mean, mask): x_0 = torch.randn_like(x_1) x_t = (1 - (1 - self.sigma_min) * t) * x_0 + t * x_1 u_t = x_1 - (1 - self.sigma_min) * x_0 - v_t = torch.randn_like(u_t) + v_t = self.predictor(x_t, mean, mask, t.squeeze()) loss = F.mse_loss(v_t, u_t, reduction="sum") / (torch.sum(mask) * u_t.shape[1]) return loss diff --git a/tests/tts_tests2/test_matcha_tts.py b/tests/tts_tests2/test_matcha_tts.py index bc94c6b4f1..5fbe95377f 100644 --- a/tests/tts_tests2/test_matcha_tts.py +++ b/tests/tts_tests2/test_matcha_tts.py @@ -33,3 +33,4 @@ def _test_forward(self, batch_size): def test_forward(self): self._test_forward(1) + self._test_forward(3) From 0f7a7edb9bdb7d75291144ee8b276331a4f16df5 Mon Sep 17 00:00:00 2001 From: Subuday Date: Wed, 14 Feb 2024 21:21:07 +0000 Subject: [PATCH 05/16] Add conv block to UNet --- TTS/tts/layers/matcha_tts/UNet.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/TTS/tts/layers/matcha_tts/UNet.py b/TTS/tts/layers/matcha_tts/UNet.py index 07616290cd..642a854545 100644 --- a/TTS/tts/layers/matcha_tts/UNet.py +++ b/TTS/tts/layers/matcha_tts/UNet.py @@ -18,6 +18,23 @@ def forward(self, x, scale=1000): emb = torch.cat((emb.sin(), emb.cos()), dim=-1) return emb +class ConvBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, num_groups=8): + super().__init__() + self.block = nn.Sequential( + nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1), + nn.GroupNorm(num_groups, out_channels), + nn.Mish() + ) + + def forward(self, x, mask=None): + if mask is not None: + x = x * mask + output = self.block(x) + if mask is not None: + output = output * mask + return output + class UNet(nn.Module): def __init__( @@ -42,6 +59,7 @@ def __init__( self.middle_blocks = nn.ModuleList([]) self.output_blocks = nn.ModuleList([]) + self.conv_block = ConvBlock1D(model_channels, model_channels) self.conv = nn.Conv1d(model_channels, self.out_channels, 1) def forward(self, x_t, mean, mask, t): @@ -59,6 +77,7 @@ def forward(self, x_t, mean, mask, t): for _ in self.output_blocks: pass + output = self.conv_block(x_t) output = self.conv(x_t) return output * mask \ No newline at end of file From fd6c0afbbf53f363d3a5f48bfcd646647d2c654b Mon Sep 17 00:00:00 2001 From: Subuday Date: Thu, 15 Feb 2024 08:40:04 +0000 Subject: [PATCH 06/16] Add ResNetBlock1D to UNet --- TTS/tts/layers/matcha_tts/UNet.py | 43 ++++++++++++++++++++++++---- TTS/tts/layers/matcha_tts/decoder.py | 1 + 2 files changed, 39 insertions(+), 5 deletions(-) diff --git a/TTS/tts/layers/matcha_tts/UNet.py b/TTS/tts/layers/matcha_tts/UNet.py index 642a854545..8547bb9b64 100644 --- a/TTS/tts/layers/matcha_tts/UNet.py +++ b/TTS/tts/layers/matcha_tts/UNet.py @@ -36,26 +36,58 @@ def forward(self, x, mask=None): return output +class ResNetBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, time_embed_channels, num_groups=8): + super().__init__() + self.block_1 = ConvBlock1D(in_channels, out_channels, num_groups=num_groups) + self.mlp = nn.Sequential( + nn.Mish(), + nn.Linear(time_embed_channels, out_channels) + ) + self.block_2 = ConvBlock1D(in_channels, out_channels, num_groups=num_groups) + self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1) + + def forward(self, x, mask, t): + h = self.block_1(x, mask) + h += self.mlp(t).unsqueeze(-1) + h = self.block_2(h, mask) + output = h + self.conv(x * mask) + return output + + class UNet(nn.Module): def __init__( self, in_channels: int, model_channels: int, out_channels: int, + num_blocks: int, ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.time_encoder = PositionalEncoding(in_channels) - time_embed_dim = model_channels * 4 + time_embed_channels = model_channels * 4 self.time_embed = nn.Sequential( - nn.Linear(in_channels, time_embed_dim), + nn.Linear(in_channels, time_embed_channels), nn.SiLU(), - nn.Linear(time_embed_dim, time_embed_dim), + nn.Linear(time_embed_channels, time_embed_channels), ) self.input_blocks = nn.ModuleList([]) + block_in_channels = in_channels + for _ in range(num_blocks): + block = nn.ModuleList([]) + + block.append( + ResNetBlock1D( + in_channels=block_in_channels, + out_channels=model_channels, + time_embed_channels=time_embed_channels + ) + ) + self.middle_blocks = nn.ModuleList([]) self.output_blocks = nn.ModuleList([]) @@ -68,8 +100,9 @@ def forward(self, x_t, mean, mask, t): x_t = pack([x_t, mean], "b * t")[0] - for _ in self.input_blocks: - pass + for block in self.input_blocks: + res_net_block = block[0] + x_t = res_net_block(x_t, mask, t) for _ in self.middle_blocks: pass diff --git a/TTS/tts/layers/matcha_tts/decoder.py b/TTS/tts/layers/matcha_tts/decoder.py index e78d34cf98..b80c190e30 100644 --- a/TTS/tts/layers/matcha_tts/decoder.py +++ b/TTS/tts/layers/matcha_tts/decoder.py @@ -13,6 +13,7 @@ def __init__(self): in_channels=80, model_channels=160, out_channels=80, + num_blocks=2 ) def forward(self, x_1, mean, mask): From 8676ab30d9b6962621a1f1ba9b5ca886723f5bca Mon Sep 17 00:00:00 2001 From: Subuday Date: Thu, 15 Feb 2024 08:55:52 +0000 Subject: [PATCH 07/16] Fix appending a new block to input_blocks --- TTS/tts/layers/matcha_tts/UNet.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/TTS/tts/layers/matcha_tts/UNet.py b/TTS/tts/layers/matcha_tts/UNet.py index 8547bb9b64..0183c787d3 100644 --- a/TTS/tts/layers/matcha_tts/UNet.py +++ b/TTS/tts/layers/matcha_tts/UNet.py @@ -76,7 +76,7 @@ def __init__( ) self.input_blocks = nn.ModuleList([]) - block_in_channels = in_channels + block_in_channels = in_channels * 2 for _ in range(num_blocks): block = nn.ModuleList([]) @@ -88,6 +88,8 @@ def __init__( ) ) + self.input_blocks.append(block) + self.middle_blocks = nn.ModuleList([]) self.output_blocks = nn.ModuleList([]) From 5fd7ea93ea4fdfb384ae7658172b11ceba9c1801 Mon Sep 17 00:00:00 2001 From: Subuday Date: Thu, 15 Feb 2024 13:24:30 +0000 Subject: [PATCH 08/16] Add upsampling and downsampling to UNet --- TTS/tts/layers/matcha_tts/UNet.py | 76 +++++++++++++++++++++++++--- TTS/tts/layers/matcha_tts/decoder.py | 2 +- 2 files changed, 71 insertions(+), 7 deletions(-) diff --git a/TTS/tts/layers/matcha_tts/UNet.py b/TTS/tts/layers/matcha_tts/UNet.py index 0183c787d3..142ef98f30 100644 --- a/TTS/tts/layers/matcha_tts/UNet.py +++ b/TTS/tts/layers/matcha_tts/UNet.py @@ -44,7 +44,7 @@ def __init__(self, in_channels, out_channels, time_embed_channels, num_groups=8) nn.Mish(), nn.Linear(time_embed_channels, out_channels) ) - self.block_2 = ConvBlock1D(in_channels, out_channels, num_groups=num_groups) + self.block_2 = ConvBlock1D(in_channels=out_channels, out_channels=out_channels, num_groups=num_groups) self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1) def forward(self, x, mask, t): @@ -55,6 +55,24 @@ def forward(self, x, mask, t): return output +class Downsample1D(nn.Module): + def __init__(self, channels): + super().__init__() + self.conv = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=3, stride=2, padding=1) + + def forward(self, x): + return self.conv(x) + + +class Upsample1D(nn.Module): + def __init__(self, channels): + super().__init__() + self.conv = nn.ConvTranspose1d(in_channels=channels, out_channels=channels, kernel_size=4, stride=2, padding=1) + + def forward(self, x): + return self.conv(x) + + class UNet(nn.Module): def __init__( self, @@ -77,21 +95,49 @@ def __init__( self.input_blocks = nn.ModuleList([]) block_in_channels = in_channels * 2 - for _ in range(num_blocks): + block_out_channels = model_channels + for level in range(num_blocks): block = nn.ModuleList([]) block.append( ResNetBlock1D( in_channels=block_in_channels, - out_channels=model_channels, + out_channels=block_out_channels, time_embed_channels=time_embed_channels ) ) + if level != num_blocks - 1: + block.append(Downsample1D(block_out_channels)) + else: + block.append(None) + + block_in_channels = block_out_channels self.input_blocks.append(block) self.middle_blocks = nn.ModuleList([]) + self.output_blocks = nn.ModuleList([]) + block_in_channels = block_out_channels * 2 + block_out_channels = model_channels + for level in range(num_blocks): + block = nn.ModuleList([]) + + block.append( + ResNetBlock1D( + in_channels=block_in_channels, + out_channels=block_out_channels, + time_embed_channels=time_embed_channels + ) + ) + + if level != num_blocks - 1: + block.append(Upsample1D(block_out_channels)) + else: + block.append(None) + + block_in_channels = block_out_channels * 2 + self.output_blocks.append(block) self.conv_block = ConvBlock1D(model_channels, model_channels) self.conv = nn.Conv1d(model_channels, self.out_channels, 1) @@ -102,15 +148,33 @@ def forward(self, x_t, mean, mask, t): x_t = pack([x_t, mean], "b * t")[0] + hidden_states = [] + mask_states = [mask] + for block in self.input_blocks: - res_net_block = block[0] + res_net_block, downsample = block + x_t = res_net_block(x_t, mask, t) + hidden_states.append(x_t) + + if downsample is not None: + x_t = downsample(x_t * mask) + mask = mask[:, :, ::2] + mask_states.append(mask) for _ in self.middle_blocks: pass - for _ in self.output_blocks: - pass + for block in self.output_blocks: + res_net_block, upsample = block + + x_t = pack([x_t, hidden_states.pop()], "b * t")[0] + mask = mask_states.pop() + x_t = res_net_block(x_t, mask, t) + + if upsample is not None: + x_t = upsample(x_t * mask) + output = self.conv_block(x_t) output = self.conv(x_t) diff --git a/TTS/tts/layers/matcha_tts/decoder.py b/TTS/tts/layers/matcha_tts/decoder.py index b80c190e30..c87da9d559 100644 --- a/TTS/tts/layers/matcha_tts/decoder.py +++ b/TTS/tts/layers/matcha_tts/decoder.py @@ -11,7 +11,7 @@ def __init__(self): self.sigma_min = 1e-5 self.predictor = UNet( in_channels=80, - model_channels=160, + model_channels=256, out_channels=80, num_blocks=2 ) From f15230bb6731e9dc38a4d4f00f7a532629969c74 Mon Sep 17 00:00:00 2001 From: Subuday Date: Thu, 15 Feb 2024 18:52:42 +0000 Subject: [PATCH 09/16] Add transformer block to UNet --- TTS/tts/layers/matcha_tts/UNet.py | 131 ++++++++++++++++++++++++++++-- 1 file changed, 124 insertions(+), 7 deletions(-) diff --git a/TTS/tts/layers/matcha_tts/UNet.py b/TTS/tts/layers/matcha_tts/UNet.py index 142ef98f30..b47db51cf5 100644 --- a/TTS/tts/layers/matcha_tts/UNet.py +++ b/TTS/tts/layers/matcha_tts/UNet.py @@ -1,7 +1,8 @@ import math -from einops import pack +from einops import pack, rearrange import torch from torch import nn +import conformer class PositionalEncoding(torch.nn.Module): @@ -71,6 +72,40 @@ def __init__(self, channels): def forward(self, x): return self.conv(x) + + +class ConformerBlock(conformer.ConformerBlock): + def __init__( + self, + dim: int, + dim_head: int = 64, + heads: int = 8, + ff_mult: int = 4, + conv_expansion_factor: int = 2, + conv_kernel_size: int = 31, + attn_dropout: float = 0., + ff_dropout: float = 0., + conv_dropout: float = 0., + conv_causal: bool = False, + ): + super().__init__( + dim=dim, + dim_head=dim_head, + heads=heads, + ff_mult=ff_mult, + conv_expansion_factor=conv_expansion_factor, + conv_kernel_size=conv_kernel_size, + attn_dropout=attn_dropout, + ff_dropout=ff_dropout, + conv_dropout=conv_dropout, + conv_causal=conv_causal, + ) + + def forward(self, x, mask,): + x = rearrange(x, "b c t -> b t c") + mask = rearrange(mask, "b 1 t -> b t") + output = super().forward(x=x, mask=mask.bool()) + return rearrange(output, "b t c -> b c t") class UNet(nn.Module): @@ -80,6 +115,12 @@ def __init__( model_channels: int, out_channels: int, num_blocks: int, + transformer_num_heads: int = 4, + transformer_dim_head: int = 64, + transformer_ff_mult: int = 1, + transformer_conv_expansion_factor: int = 2, + transformer_conv_kernel_size: int = 31, + transformer_dropout: float = 0.05, ): super().__init__() self.in_channels = in_channels @@ -107,6 +148,18 @@ def __init__( ) ) + block.append( + self._create_transformer_block( + block_out_channels, + dim_head=transformer_dim_head, + num_heads=transformer_num_heads, + ff_mult=transformer_ff_mult, + conv_expansion_factor=transformer_conv_expansion_factor, + conv_kernel_size=transformer_conv_kernel_size, + dropout=transformer_dropout, + ) + ) + if level != num_blocks - 1: block.append(Downsample1D(block_out_channels)) else: @@ -116,6 +169,30 @@ def __init__( self.input_blocks.append(block) self.middle_blocks = nn.ModuleList([]) + for i in range(2): + block = nn.ModuleList([]) + + block.append( + ResNetBlock1D( + in_channels=block_out_channels, + out_channels=block_out_channels, + time_embed_channels=time_embed_channels + ) + ) + + block.append( + self._create_transformer_block( + block_out_channels, + dim_head=transformer_dim_head, + num_heads=transformer_num_heads, + ff_mult=transformer_ff_mult, + conv_expansion_factor=transformer_conv_expansion_factor, + conv_kernel_size=transformer_conv_kernel_size, + dropout=transformer_dropout, + ) + ) + + self.middle_blocks.append(block) self.output_blocks = nn.ModuleList([]) block_in_channels = block_out_channels * 2 @@ -131,6 +208,18 @@ def __init__( ) ) + block.append( + self._create_transformer_block( + block_out_channels, + dim_head=transformer_dim_head, + num_heads=transformer_num_heads, + ff_mult=transformer_ff_mult, + conv_expansion_factor=transformer_conv_expansion_factor, + conv_kernel_size=transformer_conv_kernel_size, + dropout=transformer_dropout, + ) + ) + if level != num_blocks - 1: block.append(Upsample1D(block_out_channels)) else: @@ -142,6 +231,29 @@ def __init__( self.conv_block = ConvBlock1D(model_channels, model_channels) self.conv = nn.Conv1d(model_channels, self.out_channels, 1) + def _create_transformer_block( + self, + dim, + dim_head: int = 64, + num_heads: int = 4, + ff_mult: int = 1, + conv_expansion_factor: int = 2, + conv_kernel_size: int = 31, + dropout: float = 0.05, + ): + return ConformerBlock( + dim=dim, + dim_head=dim_head, + heads=num_heads, + ff_mult=ff_mult, + conv_expansion_factor=conv_expansion_factor, + conv_kernel_size=conv_kernel_size, + attn_dropout=dropout, + ff_dropout=dropout, + conv_dropout=dropout, + conv_causal=False, + ) + def forward(self, x_t, mean, mask, t): t = self.time_encoder(t) t = self.time_embed(t) @@ -152,9 +264,11 @@ def forward(self, x_t, mean, mask, t): mask_states = [mask] for block in self.input_blocks: - res_net_block, downsample = block + res_net_block, transformer, downsample = block x_t = res_net_block(x_t, mask, t) + x_t = transformer(x_t, mask) + hidden_states.append(x_t) if downsample is not None: @@ -162,20 +276,23 @@ def forward(self, x_t, mean, mask, t): mask = mask[:, :, ::2] mask_states.append(mask) - for _ in self.middle_blocks: - pass + for block in self.middle_blocks: + res_net_block, transformer = block + mask = mask_states[-1] + x_t = res_net_block(x_t, mask, t) + x_t = transformer(x_t, mask) for block in self.output_blocks: - res_net_block, upsample = block - + res_net_block, transformer, upsample = block + x_t = pack([x_t, hidden_states.pop()], "b * t")[0] mask = mask_states.pop() x_t = res_net_block(x_t, mask, t) + x_t = transformer(x_t, mask) if upsample is not None: x_t = upsample(x_t * mask) - output = self.conv_block(x_t) output = self.conv(x_t) From 8aeced16fce77e20918906cc5623e7673895e41e Mon Sep 17 00:00:00 2001 From: David Martin Rius Date: Wed, 28 Feb 2024 19:58:25 +0100 Subject: [PATCH 10/16] import the spacy language class dynamically with a English fallback when import error --- TTS/tts/layers/xtts/tokenizer.py | 24 ++++++++---------------- 1 file changed, 8 insertions(+), 16 deletions(-) diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py index 1a3cc47aaf..fb941d7033 100644 --- a/TTS/tts/layers/xtts/tokenizer.py +++ b/TTS/tts/layers/xtts/tokenizer.py @@ -8,29 +8,21 @@ from hangul_romanize import Transliter from hangul_romanize.rule import academic from num2words import num2words -from spacy.lang.ar import Arabic from spacy.lang.en import English -from spacy.lang.es import Spanish -from spacy.lang.ja import Japanese -from spacy.lang.zh import Chinese +from spacy.util import get_lang_class + from tokenizers import Tokenizer from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words def get_spacy_lang(lang): - if lang == "zh": - return Chinese() - elif lang == "ja": - return Japanese() - elif lang == "ar": - return Arabic() - elif lang == "es": - return Spanish() - else: - # For most languages, Enlish does the job - return English() - + try: + lang_model = get_lang_class(lang)() + except ImportError: + # Fallback to English if the language model is not available + lang_model = English() + return lang_model def split_sentence(text, lang, text_split_length=250): """Preprocess the input text""" From 30a2d8d787d19277a5fc493397293497ec5f325c Mon Sep 17 00:00:00 2001 From: David Martin Rius Date: Wed, 28 Feb 2024 20:17:46 +0100 Subject: [PATCH 11/16] add requirements for spacy thailandese and vietnamese --- requirements.txt | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index 2944e6face..ba2ad1f514 100644 --- a/requirements.txt +++ b/requirements.txt @@ -54,4 +54,10 @@ encodec>=0.1.1 # deps for XTTS unidecode>=1.3.2 num2words -spacy[ja]>=3 \ No newline at end of file +spacy[ja]>=3 + +# spacy thai +pythainlp>=5.0.1 + +#spacy vietnamese +pyvi>=0.1.1 \ No newline at end of file From 3db0dec08aa1f55e6762684d7f6acf890e2e8d94 Mon Sep 17 00:00:00 2001 From: David Martin Rius Date: Wed, 28 Feb 2024 20:23:53 +0100 Subject: [PATCH 12/16] Add 2 functions to verify any spacy language can be instantiated. By now, the only one that needs special packages is Korean. So, all languages works well but Korean --- TTS/tts/layers/xtts/tokenizer.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/TTS/tts/layers/xtts/tokenizer.py b/TTS/tts/layers/xtts/tokenizer.py index fb941d7033..5d580a55ef 100644 --- a/TTS/tts/layers/xtts/tokenizer.py +++ b/TTS/tts/layers/xtts/tokenizer.py @@ -15,11 +15,27 @@ from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words +import spacy + +# These 2 functions are to verify that any language can be instantiated +def get_spacy_available_langs(): + from pathlib import Path + spacy_path = Path(spacy.__file__.replace('__init__.py','')) + spacy_langs = spacy_path / 'lang' + SPACY_LANGS = [str(x).split('/')[-1] for x in spacy_langs.iterdir() if x.is_dir() and str(x).split('/')[-1] != '__pycache__'] + print("Available languages in Spacy:", SPACY_LANGS) + return SPACY_LANGS +def get_all_spacy_langs(): + SPACY_LANGS = get_spacy_available_langs() + spacy_lang_instances = [] + for lang in SPACY_LANGS: + spacy_lang_instances.append(get_spacy_lang(lang)) def get_spacy_lang(lang): try: lang_model = get_lang_class(lang)() - except ImportError: + except ImportError as e: + print("Error", e) # Fallback to English if the language model is not available lang_model = English() return lang_model From ea3ae40888048fb5485e26aa22e2699f967a315b Mon Sep 17 00:00:00 2001 From: David Martin Rius <0991592@gmail.com> Date: Tue, 5 Mar 2024 18:28:22 +0100 Subject: [PATCH 13/16] Update .models.json Fix bark model --- TTS/.models.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/TTS/.models.json b/TTS/.models.json index b349e7397b..a77ebea1cf 100644 --- a/TTS/.models.json +++ b/TTS/.models.json @@ -46,7 +46,7 @@ "hf_url": [ "https://coqui.gateway.scarf.sh/hf/bark/coarse_2.pt", "https://coqui.gateway.scarf.sh/hf/bark/fine_2.pt", - "https://coqui.gateway.scarf.sh/hf/text_2.pt", + "https://coqui.gateway.scarf.sh/hf/bark/text_2.pt", "https://coqui.gateway.scarf.sh/hf/bark/config.json", "https://coqui.gateway.scarf.sh/hf/bark/hubert.pt", "https://coqui.gateway.scarf.sh/hf/bark/tokenizer.pth" From 64fdd0ed8b792af8bbd21b95024c8f31da314ac4 Mon Sep 17 00:00:00 2001 From: David Martin Rius <0991592@gmail.com> Date: Tue, 5 Mar 2024 18:31:08 +0100 Subject: [PATCH 14/16] Update manage.py fix: fairseq model --- TTS/utils/manage.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index 3a527f4609..5284005388 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -260,10 +260,13 @@ def set_model_url(model_item: Dict): def _set_model_item(self, model_name): # fetch model info from the dict if "fairseq" in model_name: - model_type = "tts_models" - lang = model_name.split("/")[1] + split = model_name.split("/") + model_type = split[0] + lang = split[1] + dataset = split[2] + model = split[3] model_item = { - "model_type": "tts_models", + "model_type": model_type, "license": "CC BY-NC 4.0", "default_vocoder": None, "author": "fairseq", From 275229a876e23c1fe3a7bf96d4712461e46f0af9 Mon Sep 17 00:00:00 2001 From: David Martin Rius <0991592@gmail.com> Date: Tue, 5 Mar 2024 18:36:47 +0100 Subject: [PATCH 15/16] Update synthesizer.py Configurable verbose output --- TTS/utils/synthesizer.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index b98647c30c..f72469a978 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -264,6 +264,7 @@ def tts( style_text=None, reference_wav=None, reference_speaker_name=None, + verbose: bool = True, split_sentences: bool = True, **kwargs, ) -> List[int]: @@ -278,6 +279,7 @@ def tts( style_text ([type], optional): transcription of style_wav for Capacitron. Defaults to None. reference_wav ([type], optional): reference waveform for voice conversion. Defaults to None. reference_speaker_name ([type], optional): speaker id of reference waveform. Defaults to None. + verbose (bool, optional): print verbose output. Defaults to True. split_sentences (bool, optional): split the input text into sentences. Defaults to True. **kwargs: additional arguments to pass to the TTS model. Returns: @@ -294,9 +296,11 @@ def tts( if text: sens = [text] if split_sentences: - print(" > Text splitted to sentences.") + if verbose: + print(" > Text splitted to sentences.") sens = self.split_into_sentences(text) - print(sens) + if verbose: + print(sens) # handle multi-speaker if "voice_dir" in kwargs: @@ -420,7 +424,8 @@ def tts( self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate, ] if scale_factor[1] != 1: - print(" > interpolating tts model output.") + if verbose: + print(" > interpolating tts model output.") vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input) else: vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable @@ -484,7 +489,8 @@ def tts( self.vocoder_config["audio"]["sample_rate"] / self.tts_model.ap.sample_rate, ] if scale_factor[1] != 1: - print(" > interpolating tts model output.") + if verbose: + print(" > interpolating tts model output.") vocoder_input = interpolate_vocoder_input(scale_factor, vocoder_input) else: vocoder_input = torch.tensor(vocoder_input).unsqueeze(0) # pylint: disable=not-callable @@ -497,9 +503,10 @@ def tts( waveform = waveform.numpy() wavs = waveform.squeeze() - # compute stats - process_time = time.time() - start_time - audio_time = len(wavs) / self.tts_config.audio["sample_rate"] - print(f" > Processing time: {process_time}") - print(f" > Real-time factor: {process_time / audio_time}") + if verbose: + # compute stats + process_time = time.time() - start_time + audio_time = len(wavs) / self.tts_config.audio["sample_rate"] + print(f" > Processing time: {process_time}") + print(f" > Real-time factor: {process_time / audio_time}") return wavs From 29a0e189b28ac9419772e80aaa259d86e7f4629a Mon Sep 17 00:00:00 2001 From: David Martin Rius <0991592@gmail.com> Date: Tue, 5 Mar 2024 22:38:59 +0100 Subject: [PATCH 16/16] Revert "Merge remote-tracking branch 'subuday/matcha_tts' into dev" This reverts commit f6a23c1d8a153b40d5c06b81dbfda9b66d0fc4d6, reversing changes made to 275229a876e23c1fe3a7bf96d4712461e46f0af9. --- TTS/tts/configs/matcha_tts.py | 9 - TTS/tts/layers/matcha_tts/UNet.py | 299 --------------------------- TTS/tts/layers/matcha_tts/decoder.py | 32 --- TTS/tts/models/matcha_tts.py | 85 -------- tests/tts_tests2/test_matcha_tts.py | 36 ---- 5 files changed, 461 deletions(-) delete mode 100644 TTS/tts/configs/matcha_tts.py delete mode 100644 TTS/tts/layers/matcha_tts/UNet.py delete mode 100644 TTS/tts/layers/matcha_tts/decoder.py delete mode 100644 TTS/tts/models/matcha_tts.py delete mode 100644 tests/tts_tests2/test_matcha_tts.py diff --git a/TTS/tts/configs/matcha_tts.py b/TTS/tts/configs/matcha_tts.py deleted file mode 100644 index 15bb91b829..0000000000 --- a/TTS/tts/configs/matcha_tts.py +++ /dev/null @@ -1,9 +0,0 @@ -from dataclasses import dataclass, field - -from TTS.tts.configs.shared_configs import BaseTTSConfig - - -@dataclass -class MatchaTTSConfig(BaseTTSConfig): - model: str = "matcha_tts" - num_chars: int = None diff --git a/TTS/tts/layers/matcha_tts/UNet.py b/TTS/tts/layers/matcha_tts/UNet.py deleted file mode 100644 index b47db51cf5..0000000000 --- a/TTS/tts/layers/matcha_tts/UNet.py +++ /dev/null @@ -1,299 +0,0 @@ -import math -from einops import pack, rearrange -import torch -from torch import nn -import conformer - - -class PositionalEncoding(torch.nn.Module): - def __init__(self, channels): - super().__init__() - self.channels = channels - - def forward(self, x, scale=1000): - if x.ndim < 1: - x = x.unsqueeze(0) - emb = math.log(10000) / (self.channels // 2 - 1) - emb = torch.exp(torch.arange(self.channels // 2, device=x.device).float() * -emb) - emb = scale * x.unsqueeze(1) * emb.unsqueeze(0) - emb = torch.cat((emb.sin(), emb.cos()), dim=-1) - return emb - -class ConvBlock1D(nn.Module): - def __init__(self, in_channels, out_channels, num_groups=8): - super().__init__() - self.block = nn.Sequential( - nn.Conv1d(in_channels, out_channels, kernel_size=3, padding=1), - nn.GroupNorm(num_groups, out_channels), - nn.Mish() - ) - - def forward(self, x, mask=None): - if mask is not None: - x = x * mask - output = self.block(x) - if mask is not None: - output = output * mask - return output - - -class ResNetBlock1D(nn.Module): - def __init__(self, in_channels, out_channels, time_embed_channels, num_groups=8): - super().__init__() - self.block_1 = ConvBlock1D(in_channels, out_channels, num_groups=num_groups) - self.mlp = nn.Sequential( - nn.Mish(), - nn.Linear(time_embed_channels, out_channels) - ) - self.block_2 = ConvBlock1D(in_channels=out_channels, out_channels=out_channels, num_groups=num_groups) - self.conv = nn.Conv1d(in_channels, out_channels, kernel_size=1) - - def forward(self, x, mask, t): - h = self.block_1(x, mask) - h += self.mlp(t).unsqueeze(-1) - h = self.block_2(h, mask) - output = h + self.conv(x * mask) - return output - - -class Downsample1D(nn.Module): - def __init__(self, channels): - super().__init__() - self.conv = nn.Conv1d(in_channels=channels, out_channels=channels, kernel_size=3, stride=2, padding=1) - - def forward(self, x): - return self.conv(x) - - -class Upsample1D(nn.Module): - def __init__(self, channels): - super().__init__() - self.conv = nn.ConvTranspose1d(in_channels=channels, out_channels=channels, kernel_size=4, stride=2, padding=1) - - def forward(self, x): - return self.conv(x) - - -class ConformerBlock(conformer.ConformerBlock): - def __init__( - self, - dim: int, - dim_head: int = 64, - heads: int = 8, - ff_mult: int = 4, - conv_expansion_factor: int = 2, - conv_kernel_size: int = 31, - attn_dropout: float = 0., - ff_dropout: float = 0., - conv_dropout: float = 0., - conv_causal: bool = False, - ): - super().__init__( - dim=dim, - dim_head=dim_head, - heads=heads, - ff_mult=ff_mult, - conv_expansion_factor=conv_expansion_factor, - conv_kernel_size=conv_kernel_size, - attn_dropout=attn_dropout, - ff_dropout=ff_dropout, - conv_dropout=conv_dropout, - conv_causal=conv_causal, - ) - - def forward(self, x, mask,): - x = rearrange(x, "b c t -> b t c") - mask = rearrange(mask, "b 1 t -> b t") - output = super().forward(x=x, mask=mask.bool()) - return rearrange(output, "b t c -> b c t") - - -class UNet(nn.Module): - def __init__( - self, - in_channels: int, - model_channels: int, - out_channels: int, - num_blocks: int, - transformer_num_heads: int = 4, - transformer_dim_head: int = 64, - transformer_ff_mult: int = 1, - transformer_conv_expansion_factor: int = 2, - transformer_conv_kernel_size: int = 31, - transformer_dropout: float = 0.05, - ): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - - self.time_encoder = PositionalEncoding(in_channels) - time_embed_channels = model_channels * 4 - self.time_embed = nn.Sequential( - nn.Linear(in_channels, time_embed_channels), - nn.SiLU(), - nn.Linear(time_embed_channels, time_embed_channels), - ) - - self.input_blocks = nn.ModuleList([]) - block_in_channels = in_channels * 2 - block_out_channels = model_channels - for level in range(num_blocks): - block = nn.ModuleList([]) - - block.append( - ResNetBlock1D( - in_channels=block_in_channels, - out_channels=block_out_channels, - time_embed_channels=time_embed_channels - ) - ) - - block.append( - self._create_transformer_block( - block_out_channels, - dim_head=transformer_dim_head, - num_heads=transformer_num_heads, - ff_mult=transformer_ff_mult, - conv_expansion_factor=transformer_conv_expansion_factor, - conv_kernel_size=transformer_conv_kernel_size, - dropout=transformer_dropout, - ) - ) - - if level != num_blocks - 1: - block.append(Downsample1D(block_out_channels)) - else: - block.append(None) - - block_in_channels = block_out_channels - self.input_blocks.append(block) - - self.middle_blocks = nn.ModuleList([]) - for i in range(2): - block = nn.ModuleList([]) - - block.append( - ResNetBlock1D( - in_channels=block_out_channels, - out_channels=block_out_channels, - time_embed_channels=time_embed_channels - ) - ) - - block.append( - self._create_transformer_block( - block_out_channels, - dim_head=transformer_dim_head, - num_heads=transformer_num_heads, - ff_mult=transformer_ff_mult, - conv_expansion_factor=transformer_conv_expansion_factor, - conv_kernel_size=transformer_conv_kernel_size, - dropout=transformer_dropout, - ) - ) - - self.middle_blocks.append(block) - - self.output_blocks = nn.ModuleList([]) - block_in_channels = block_out_channels * 2 - block_out_channels = model_channels - for level in range(num_blocks): - block = nn.ModuleList([]) - - block.append( - ResNetBlock1D( - in_channels=block_in_channels, - out_channels=block_out_channels, - time_embed_channels=time_embed_channels - ) - ) - - block.append( - self._create_transformer_block( - block_out_channels, - dim_head=transformer_dim_head, - num_heads=transformer_num_heads, - ff_mult=transformer_ff_mult, - conv_expansion_factor=transformer_conv_expansion_factor, - conv_kernel_size=transformer_conv_kernel_size, - dropout=transformer_dropout, - ) - ) - - if level != num_blocks - 1: - block.append(Upsample1D(block_out_channels)) - else: - block.append(None) - - block_in_channels = block_out_channels * 2 - self.output_blocks.append(block) - - self.conv_block = ConvBlock1D(model_channels, model_channels) - self.conv = nn.Conv1d(model_channels, self.out_channels, 1) - - def _create_transformer_block( - self, - dim, - dim_head: int = 64, - num_heads: int = 4, - ff_mult: int = 1, - conv_expansion_factor: int = 2, - conv_kernel_size: int = 31, - dropout: float = 0.05, - ): - return ConformerBlock( - dim=dim, - dim_head=dim_head, - heads=num_heads, - ff_mult=ff_mult, - conv_expansion_factor=conv_expansion_factor, - conv_kernel_size=conv_kernel_size, - attn_dropout=dropout, - ff_dropout=dropout, - conv_dropout=dropout, - conv_causal=False, - ) - - def forward(self, x_t, mean, mask, t): - t = self.time_encoder(t) - t = self.time_embed(t) - - x_t = pack([x_t, mean], "b * t")[0] - - hidden_states = [] - mask_states = [mask] - - for block in self.input_blocks: - res_net_block, transformer, downsample = block - - x_t = res_net_block(x_t, mask, t) - x_t = transformer(x_t, mask) - - hidden_states.append(x_t) - - if downsample is not None: - x_t = downsample(x_t * mask) - mask = mask[:, :, ::2] - mask_states.append(mask) - - for block in self.middle_blocks: - res_net_block, transformer = block - mask = mask_states[-1] - x_t = res_net_block(x_t, mask, t) - x_t = transformer(x_t, mask) - - for block in self.output_blocks: - res_net_block, transformer, upsample = block - - x_t = pack([x_t, hidden_states.pop()], "b * t")[0] - mask = mask_states.pop() - x_t = res_net_block(x_t, mask, t) - x_t = transformer(x_t, mask) - - if upsample is not None: - x_t = upsample(x_t * mask) - - output = self.conv_block(x_t) - output = self.conv(x_t) - - return output * mask \ No newline at end of file diff --git a/TTS/tts/layers/matcha_tts/decoder.py b/TTS/tts/layers/matcha_tts/decoder.py deleted file mode 100644 index c87da9d559..0000000000 --- a/TTS/tts/layers/matcha_tts/decoder.py +++ /dev/null @@ -1,32 +0,0 @@ -import torch -from torch import nn -import torch.nn.functional as F - -from TTS.tts.layers.matcha_tts.UNet import UNet - - -class Decoder(nn.Module): - def __init__(self): - super().__init__() - self.sigma_min = 1e-5 - self.predictor = UNet( - in_channels=80, - model_channels=256, - out_channels=80, - num_blocks=2 - ) - - def forward(self, x_1, mean, mask): - """ - Shapes: - - x_1: :math:`[B, C, T]` - - mean: :math:`[B, C ,T]` - - mask: :math:`[B, 1, T]` - """ - t = torch.rand([x_1.size(0), 1, 1], device=x_1.device, dtype=x_1.dtype) - x_0 = torch.randn_like(x_1) - x_t = (1 - (1 - self.sigma_min) * t) * x_0 + t * x_1 - u_t = x_1 - (1 - self.sigma_min) * x_0 - v_t = self.predictor(x_t, mean, mask, t.squeeze()) - loss = F.mse_loss(v_t, u_t, reduction="sum") / (torch.sum(mask) * u_t.shape[1]) - return loss diff --git a/TTS/tts/models/matcha_tts.py b/TTS/tts/models/matcha_tts.py deleted file mode 100644 index 9bc3e0ffc4..0000000000 --- a/TTS/tts/models/matcha_tts.py +++ /dev/null @@ -1,85 +0,0 @@ -from dataclasses import field -import math -import torch - -from TTS.tts.configs.matcha_tts import MatchaTTSConfig -from TTS.tts.layers.glow_tts.encoder import Encoder -from TTS.tts.layers.matcha_tts.decoder import Decoder -from TTS.tts.models.base_tts import BaseTTS -from TTS.tts.utils.helpers import maximum_path, sequence_mask -from TTS.tts.utils.text.tokenizer import TTSTokenizer - - -class MatchaTTS(BaseTTS): - - def __init__( - self, - config: MatchaTTSConfig, - ap: "AudioProcessor" = None, - tokenizer: "TTSTokenizer" = None, - ): - super().__init__(config, ap, tokenizer) - self.encoder = Encoder( - self.config.num_chars, - out_channels=80, - hidden_channels=192, - hidden_channels_dp=256, - encoder_type='rel_pos_transformer', - encoder_params={ - "kernel_size": 3, - "dropout_p": 0.1, - "num_layers": 6, - "num_heads": 2, - "hidden_channels_ffn": 768, - } - ) - - self.decoder = Decoder() - - def forward(self, x, x_lengths, y, y_lengths): - """ - Args: - x (torch.Tensor): - Input text sequence ids. :math:`[B, T_en]` - - x_lengths (torch.Tensor): - Lengths of input text sequences. :math:`[B]` - - y (torch.Tensor): - Target mel-spectrogram frames. :math:`[B, T_de, C_mel]` - - y_lengths (torch.Tensor): - Lengths of target mel-spectrogram frames. :math:`[B]` - """ - y = y.transpose(1, 2) - y_max_length = y.size(2) - - o_mean, o_log_scale, o_log_dur, o_mask = self.encoder(x, x_lengths, g=None) - - y_mask = torch.unsqueeze(sequence_mask(y_lengths, y_max_length), 1).to(o_mask.dtype) - attn_mask = torch.unsqueeze(o_mask, -1) * torch.unsqueeze(y_mask, 2) - - with torch.no_grad(): - o_scale = torch.exp(-2 * o_log_scale) - logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - o_log_scale, [1]).unsqueeze(-1) - logp2 = torch.matmul(o_scale.transpose(1, 2), -0.5 * (y**2)) - logp3 = torch.matmul((o_mean * o_scale).transpose(1, 2), y) - logp4 = torch.sum(-0.5 * (o_mean**2) * o_scale, [1]).unsqueeze(-1) - logp = logp1 + logp2 + logp3 + logp4 - attn = maximum_path(logp, attn_mask.squeeze(1)).unsqueeze(1).detach() - - # Align encoded text with mel-spectrogram and get mu_y segment - c_mean = torch.matmul(attn.squeeze(1).transpose(1, 2), o_mean.transpose(1, 2)).transpose(1, 2) - - _ = self.decoder(x_1=y, mean=c_mean, mask=y_mask) - - @torch.no_grad() - def inference(self): - pass - - @staticmethod - def init_from_config(config: "MatchaTTSConfig"): - pass - - def load_checkpoint(self, checkpoint_path): - pass diff --git a/tests/tts_tests2/test_matcha_tts.py b/tests/tts_tests2/test_matcha_tts.py deleted file mode 100644 index 5fbe95377f..0000000000 --- a/tests/tts_tests2/test_matcha_tts.py +++ /dev/null @@ -1,36 +0,0 @@ -import unittest - -import torch - -from TTS.tts.configs.matcha_tts import MatchaTTSConfig -from TTS.tts.models.matcha_tts import MatchaTTS - -torch.manual_seed(1) -device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - -c = MatchaTTSConfig() - - -class TestMatchTTS(unittest.TestCase): - @staticmethod - def _create_inputs(batch_size=8): - input_dummy = torch.randint(0, 24, (batch_size, 128)).long().to(device) - input_lengths = torch.randint(100, 129, (batch_size,)).long().to(device) - input_lengths[-1] = 128 - mel_spec = torch.rand(batch_size, 30, c.audio["num_mels"]).to(device) - mel_lengths = torch.randint(20, 30, (batch_size,)).long().to(device) - speaker_ids = torch.randint(0, 5, (batch_size,)).long().to(device) - return input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids - - def _test_forward(self, batch_size): - input_dummy, input_lengths, mel_spec, mel_lengths, _ = self._create_inputs(batch_size) - config = MatchaTTSConfig(num_chars=32) - model = MatchaTTS(config).to(device) - - model.train() - - model.forward(input_dummy, input_lengths, mel_spec, mel_lengths) - - def test_forward(self): - self._test_forward(1) - self._test_forward(3)