Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add device support in TTS and Synthesizer #2855

Merged
merged 9 commits into from
Aug 14, 2023
Merged

Conversation

jaketae
Copy link
Contributor

@jaketae jaketae commented Aug 10, 2023

Context

In #2282, we proposed up the possibility of implementing atts.to(device) interface as a substitute for use_cuda or gpu flags. The current flags do not allow users to specify the specific GPU device (e.g., cuda:3). It also does not allow users to use other accelerated backends, such as Apple Silicon GPUs (MPS), which PyTorch now supports.

Solution

We make TTS and Synthesizer classes inherit from nn.Module. This gives us .to(device) for free for both of the classes.

We can now run TTS on Apple Silicon (tested on M2 Max). Not all kernels have been implemented in MPS in PyTorch yet, so we need to set the environment variable

export PYTORCH_ENABLE_MPS_FALLBACK=1

to enable CPU fallback. With this set, we can now run

>>> from TTS.api import TTS
>>> model_name = TTS.list_models()[0]
>>> tts = TTS(model_name)
>>> tts = tts.to("mps")
>>> tts.tts_to_file(text="Hello world!", speaker=tts.speakers[0], language=tts.languages[0], file_path="output.wav")

Also tested with make test.

@CLAassistant
Copy link

CLAassistant commented Aug 10, 2023

CLA assistant check
All committers have signed the CLA.

self.manager = ModelManager(models_file=self.get_models_file_path(), progress_bar=progress_bar, verbose=False)

self.synthesizer = None
self.voice_converter = None
self.csapi = None
self.model_name = None

if gpu:
warnings.warn("`gpu` will be deprecated. Please use `tts.to(device)` instead.")
Copy link
Contributor Author

@jaketae jaketae Aug 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added warning. We could add specific dates or versions to better inform users about future plans, but I left it this way because I didn't have enough context on the future releases roadmap.

@@ -5,19 +5,21 @@
from torch import nn


def numpy_to_torch(np_array, dtype, cuda=False):
def numpy_to_torch(np_array, dtype, cuda=False, device="cpu"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added new device argument to functions called in Synthesizer. To retain backwards compatibility, we keep the cuda argument for now; we should probably clean them up in the future and provide a single way of configuring device/enabling CUDA.

use_gl = self.vocoder_model is None
if not use_gl:
vocoder_device = next(self.vocoder_model.parameters()).device
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In some obscure use cases, the user could have placed the feature frontend and the vocoder on different devices.

>>> tts.synthesizer.tts_model = tts.synthesizer.tts_model.to("cuda:0")
>>> tts.synthesizer.vocoder_model = tts.synthesizer.vocoder_model.to("cuda:1")

We check the device of the vocoder, if it exists.

@jaketae jaketae marked this pull request as ready for review August 10, 2023 22:30
@jaketae
Copy link
Contributor Author

jaketae commented Aug 10, 2023

Hi @erogol, curious to hear your thoughts on this implementation! The guiding philosophy was to use the PyTorch .to(device) API while keeping all functionality intact and retaining backwards compatibility with use_cuda or gpu.


I've signed the CLA, but the first few commits didn't have my GitHub email (I just got a new laptop and forgot to set up my Git user information), which is why the CLA test is marked as pending.

@erogol
Copy link
Member

erogol commented Aug 13, 2023

@jaketae thanks for the PR. I'll review it tomorrow 👍

Copy link
Member

@erogol erogol left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All looks good!! Thanks for the PR. If you think its done I can merge.

@jaketae
Copy link
Contributor Author

jaketae commented Aug 14, 2023

@erogol, thanks for the quick review! I think we can go ahead with the merge unless you have second thoughts. I'll maybe open a follow-up PR to improve docs or the README where applicable. Thanks!

@erogol erogol merged commit 409db50 into coqui-ai:dev Aug 14, 2023
44 checks passed
@erogol
Copy link
Member

erogol commented Aug 14, 2023

@jaketae awesome thanks again

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants