Skip to content

Commit

Permalink
fix(vc): support both cpu and cuda (#244)
Browse files Browse the repository at this point in the history
  • Loading branch information
eginhard authored Jan 8, 2025
1 parent d8c2224 commit fd93176
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 29 deletions.
42 changes: 18 additions & 24 deletions TTS/vc/layers/freevc/speaker_encoder/speaker_encoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from time import perf_counter as timer
from typing import List, Union
from typing import List

import numpy as np
import torch
Expand All @@ -22,26 +22,15 @@


class SpeakerEncoder(nn.Module):
def __init__(self, weights_fpath, device: Union[str, torch.device] = None):
"""
:param device: either a torch device or the name of a torch device (e.g. "cpu", "cuda").
If None, defaults to cuda if it is available on your machine, otherwise the model will
run on cpu. Outputs are always returned on the cpu, as numpy arrays.
"""
def __init__(self, weights_fpath):
"""FreeVC speaker encoder."""
super().__init__()

# Define the network
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
self.relu = nn.ReLU()

# Get the target device
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
elif isinstance(device, str):
device = torch.device(device)
self.device = device

# Load the pretrained model'speaker weights
# weights_fpath = Path(__file__).resolve().parent.joinpath("pretrained.pt")
# if not weights_fpath.exists():
Expand All @@ -52,8 +41,11 @@ def __init__(self, weights_fpath, device: Union[str, torch.device] = None):
checkpoint = load_fsspec(weights_fpath, map_location="cpu")

self.load_state_dict(checkpoint["model_state"], strict=False)
self.to(device)
logger.info("Loaded the voice encoder model on %s in %.2f seconds.", device.type, timer() - start)
logger.info("Loaded the voice encoder model in %.2f seconds.", timer() - start)

@property
def device(self):
return next(self.parameters()).device

def forward(self, mels: torch.FloatTensor):
"""
Expand Down Expand Up @@ -143,8 +135,8 @@ def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_
then the last partial utterance will be considered by zero-padding the audio. Otherwise,
it will be discarded. If there aren't enough frames for one partial utterance,
this parameter is ignored so that the function always returns at least one slice.
:return: the embedding as a numpy array of float32 of shape (model_embedding_size,). If
<return_partials> is True, the partial utterances as a numpy array of float32 of shape
:return: the embedding as a float tensor of shape (model_embedding_size,). If
<return_partials> is True, the partial utterances as a float tensor of shape
(n_partials, model_embedding_size) and the wav partials as a list of slices will also be
returned.
"""
Expand All @@ -160,11 +152,11 @@ def embed_utterance(self, wav: np.ndarray, return_partials=False, rate=1.3, min_
mels = np.array([mel[s] for s in mel_slices])
with torch.no_grad():
mels = torch.from_numpy(mels).to(self.device)
partial_embeds = self(mels).cpu().numpy()
partial_embeds = self(mels)

# Compute the utterance embedding from the partial embeddings
raw_embed = np.mean(partial_embeds, axis=0)
embed = raw_embed / np.linalg.norm(raw_embed, 2)
raw_embed = partial_embeds.mean(dim=0)
embed = raw_embed / torch.norm(raw_embed, p=2)

if return_partials:
return embed, partial_embeds, wav_slices
Expand All @@ -177,7 +169,9 @@ def embed_speaker(self, wavs: List[np.ndarray], **kwargs):
:param wavs: list of wavs a numpy arrays of float32.
:param kwargs: extra arguments to embed_utterance()
:return: the embedding as a numpy array of float32 of shape (model_embedding_size,).
:return: the embedding as a float tensor of shape (model_embedding_size,).
"""
raw_embed = np.mean([self.embed_utterance(wav, return_partials=False, **kwargs) for wav in wavs], axis=0)
return raw_embed / np.linalg.norm(raw_embed, 2)
raw_embed = torch.mean(
torch.stack([self.embed_utterance(wav, return_partials=False, **kwargs) for wav in wavs]), dim=0
)
return raw_embed / torch.norm(raw_embed, p=2)
5 changes: 2 additions & 3 deletions TTS/vc/models/freevc.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ def load_pretrained_speaker_encoder(self):
"""Load pretrained speaker encoder model as mentioned in the paper."""
logger.info("Loading pretrained speaker encoder model ...")
self.enc_spk_ex = SpeakerEncoderEx(
"https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/speaker_encoder.pt", device=self.device
"https://github.com/coqui-ai/TTS/releases/download/v0.13.0_models/speaker_encoder.pt"
)

def init_multispeaker(self, config: Coqpit):
Expand Down Expand Up @@ -454,8 +454,7 @@ def voice_conversion(self, src, tgt):
wav_tgt, _ = librosa.effects.trim(wav_tgt, top_db=20)

if self.config.model_args.use_spk:
g_tgt = self.enc_spk_ex.embed_utterance(wav_tgt)
g_tgt = torch.from_numpy(g_tgt)[None, :, None].to(self.device)
g_tgt = self.enc_spk_ex.embed_utterance(wav_tgt)[None, :, None]
else:
wav_tgt = torch.from_numpy(wav_tgt).unsqueeze(0).to(self.device)
mel_tgt = mel_spectrogram_torch(
Expand Down
3 changes: 1 addition & 2 deletions TTS/vc/models/openvoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,7 @@ def load_audio(self, wav: Union[str, npt.NDArray[np.float32], torch.Tensor, list
return out.to(self.device).float()

def extract_se(self, audio: Union[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
audio_ref = self.load_audio(audio)
y = torch.FloatTensor(audio_ref)
y = self.load_audio(audio)
y = y.to(self.device)
y = y.unsqueeze(0)
spec = wav_to_spec(
Expand Down

0 comments on commit fd93176

Please sign in to comment.