Skip to content

Commit

Permalink
added 48 kHz azure and styletts voice change
Browse files Browse the repository at this point in the history
  • Loading branch information
KoljaB committed Dec 10, 2024
1 parent 905f1fb commit c1f6dc0
Show file tree
Hide file tree
Showing 8 changed files with 265 additions and 51 deletions.
10 changes: 2 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ Let me know if you need any adjustments or additional languages!

## Updates

Latest Version: v0.4.19
Latest Version: v0.4.20

Introducing StyleTTS2 engine:

Expand Down Expand Up @@ -665,14 +665,8 @@ While the source of this library is open-source, the usage of many of the engine
Kolja Beigel
Email: [email protected]
<p align="center">
<a href="https://github.com/KoljaB/RealtimeTTS" target="_blank">
<img src="https://img.shields.io/badge/GitHub-181717?style=for-the-badge&logo=github&logoColor=white" alt="GitHub">
</a>
&nbsp;&nbsp;&nbsp;
<a href="#realtimetts" target="_blank">
<img src="https://img.shields.io/badge/Back%20to%20Top-000000?style=for-the-badge" alt="Back to Top">
</a>
</p>
</p>
4 changes: 2 additions & 2 deletions RealtimeTTS/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,6 @@
EdgeEngine, EdgeVoice = None, None

try:
from .engines import StyleTTSEngine # noqa: F401
from .engines import StyleTTSEngine, StyleTTSVoice # noqa: F401
except ImportError:
StyleTTSEngine = None
StyleTTSEngine, StyleTTSVoice = None
4 changes: 2 additions & 2 deletions RealtimeTTS/engines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,6 @@
EdgeEngine, EdgeVoice = None, None

try:
from .style_engine import StyleTTSEngine # noqa: F401
from .style_engine import StyleTTSEngine, StyleTTSVoice # noqa: F401
except ImportError as e:
StyleTTSEngine = None
StyleTTSEngine, StyleTTSVoice = None
24 changes: 23 additions & 1 deletion RealtimeTTS/engines/azure_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import azure.cognitiveservices.speech as tts
from azure.cognitiveservices.speech import SpeechSynthesisOutputFormat
from .base_engine import BaseEngine
from typing import Union
import requests
Expand Down Expand Up @@ -49,13 +50,24 @@ def _extract_voice_language(locale):


class AzureEngine(BaseEngine):
SUPPORTED_AUDIO_FORMATS = {
"riff-16khz-16bit-mono-pcm": 16000,
"riff-24khz-16bit-mono-pcm": 24000,
"riff-48khz-16bit-mono-pcm": 48000,
}
AUDIO_FORMAT_MAP = {
"riff-16khz-16bit-mono-pcm": tts.SpeechSynthesisOutputFormat.Riff16Khz16BitMonoPcm,
"riff-24khz-16bit-mono-pcm": tts.SpeechSynthesisOutputFormat.Riff24Khz16BitMonoPcm,
"riff-48khz-16bit-mono-pcm": tts.SpeechSynthesisOutputFormat.Riff48Khz16BitMonoPcm,
}
def __init__(
self,
speech_key: str = "",
service_region: str = "",
voice: str = "en-US-AshleyNeural",
rate: float = 0.0,
pitch: float = 0.0,
audio_format: str = "riff-16khz-16bit-mono-pcm",
):
"""
Initializes an azure voice realtime text to speech engine object.
Expand All @@ -66,8 +78,17 @@ def __init__(
voice (str, optional): Voice name. Defaults to "en-US-AshleyNeural".
rate (float, optional): Speech speed as a percentage. Defaults to "0.0". Indicating the relative change.
pitch (float, optional): Speech pitch as a percentage. Defaults to "0.0". Indicating the relative change.
audio_format (str, optional): Audio format for output. Defaults to "riff-16khz-16bit-mono-pcm". Must be one of these supported formats: "riff-16khz-16bit-mono-pcm", "riff-24khz-16bit-mono-pcm", "riff-48khz-16bit-mono-pcm".
Raises:
ValueError: If the provided audio_format is not supported.
"""
if audio_format not in self.SUPPORTED_AUDIO_FORMATS:
raise ValueError(
f"Invalid audio_format '{audio_format}'. Supported formats are: {list(self.SUPPORTED_AUDIO_FORMATS.keys())}"
)

self.audio_format = audio_format
self.sample_rate = self.SUPPORTED_AUDIO_FORMATS[audio_format]
self.speech_key = speech_key
self.service_region = service_region
self.language = voice[:5]
Expand Down Expand Up @@ -138,7 +159,7 @@ def get_stream_info(self):
- Channels (int): The number of audio channels. 1 represents mono audio.
- Sample Rate (int): The sample rate of the audio in Hz. 16000 represents 16kHz sample rate.
"""
return pyaudio.paInt16, 1, 16000
return pyaudio.paInt16, 1, self.sample_rate

def synthesize(self, text: str) -> bool:
"""
Expand All @@ -152,6 +173,7 @@ def synthesize(self, text: str) -> bool:
speech_config = tts.SpeechConfig(
subscription=self.speech_key, region=self.service_region
)
speech_config.set_speech_synthesis_output_format(self.AUDIO_FORMAT_MAP[self.audio_format])
stream_callback = PushAudioOutputStreamSampleCallback(self.queue)
push_stream = tts.audio.PushAudioOutputStream(stream_callback)
stream_config = tts.audio.AudioOutputConfig(stream=push_stream)
Expand Down
164 changes: 147 additions & 17 deletions RealtimeTTS/engines/style_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,59 @@
import torch
import sys
import os
import gc
import time
from numba import cuda

class StyleTTSVoice:
def __init__(self,
model_config_path: str,
model_checkpoint_path: str,
ref_audio_path: str):
"""
Represents a StyleTTS voice configuration.
Args:
model_config_path (str): Path to the StyleTTS model configuration file.
model_checkpoint_path (str): Path to the StyleTTS model checkpoint file.
ref_audio_path (str): Path to the reference audio file for extracting style.
"""
self.model_config_path = model_config_path
self.model_checkpoint_path = model_checkpoint_path
self.ref_audio_path = ref_audio_path

def __str__(self):
"""
String representation of the StyleTTS voice configuration.
"""
return (
f"StyleTTSVoice("
f"Config: {self.model_config_path}, "
f"Checkpoint: {self.model_checkpoint_path}, "
f"Reference Audio: {self.ref_audio_path})"
)

def __repr__(self):
"""
Detailed representation of the StyleTTS voice configuration.
"""
return (
f"StyleTTSVoice:\n"
f" Model Config Path: {self.model_config_path}\n"
f" Model Checkpoint Path: {self.model_checkpoint_path}\n"
f" Reference Audio Path: {self.ref_audio_path}"
)

class StyleTTSEngine(BaseEngine):
def __init__(self,
style_root: str,
model_config_path: str,
model_checkpoint_path: str,
ref_audio_path: str, # path to reference audio for style
voice: StyleTTSVoice,
device: str = 'cuda',
alpha: float = 0.3,
beta: float = 0.7,
diffusion_steps: int = 5,
embedding_scale: float = 1.0):
embedding_scale: float = 1.0,
cuda_reset_delay: float = 0.0): # Delay after resetting CUDA device
"""
Initializes the StyleTTS engine with customizable parameters.
Expand Down Expand Up @@ -66,18 +107,24 @@ def __init__(self,
- A higher scale (e.g., 1.2 or 1.5) strengthens the alignment with the text and reference,
potentially enhancing style adherence and expressiveness.
- A very high scale might introduce artifacts or unnatural audio, so fine-tuning is recommended.
cuda_reset_delay (float): Time in seconds to wait after resetting the CUDA device.
"""
self.device = device if torch.cuda.is_available() else 'cpu'
self.style_root = style_root.replace("\\", "/")
self.model_config_path = model_config_path.replace("\\", "/")
self.model_checkpoint_path = model_checkpoint_path.replace("\\", "/")
self.ref_audio_path = ref_audio_path

# Use the properties from the StyleTTSVoice instance
self.voice = voice
self.model_config_path = self.voice.model_config_path.replace("\\", "/")
self.model_checkpoint_path = self.voice.model_checkpoint_path.replace("\\", "/")
self.ref_audio_path = self.voice.ref_audio_path

# Parameters for synthesis
self.alpha = alpha
self.beta = beta
self.diffusion_steps = diffusion_steps
self.embedding_scale = embedding_scale
self.cuda_reset_delay = cuda_reset_delay # Store the delay parameter

# Add the root directory to sys.path
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), self.style_root)))
Expand All @@ -90,20 +137,83 @@ def __init__(self,
def post_init(self):
self.engine_name = "styletts"

def get_stream_info(self):
def unload_model(self):
"""
Returns the PyAudio stream configuration:
- Format: pyaudio.paInt16 (16-bit)
- Channels: 1 (mono)
- Sample Rate: 24000 Hz
Unloads the current model and clears VRAM to prevent memory leaks.
Steps:
1. Move models to CPU to ensure PyTorch releases GPU memory.
2. Delete references to the model and other components to allow garbage collection.
3. Trigger garbage collection and clear the CUDA memory cache.
"""
# Move models to CPU first
if hasattr(self, 'model'):
for key in self.model:
self.model[key].to('cpu')
# Explanation: Moving models to the CPU ensures that all tensors allocated on the GPU
# are detached from the GPU's memory. If a model is directly deleted while still residing
# on the GPU, PyTorch may not fully release its VRAM due to lingering device-side context.

# Delete references
if hasattr(self, 'model'):
del self.model # Remove the main model
if hasattr(self, 'sampler'):
del self.sampler # Remove the diffusion sampler
if hasattr(self, 'text_aligner'):
del self.text_aligner # Remove the ASR-based text aligner
if hasattr(self, 'pitch_extractor'):
del self.pitch_extractor # Remove the pitch extraction model
if hasattr(self, 'plbert'):
del self.plbert # Remove the pre-trained BERT model used for prosody

# Force garbage collection and try to free cache
gc.collect()
torch.cuda.empty_cache()
# Explanation: After removing references, garbage collection ensures that
# Python clears any remaining objects that might still hold references to GPU memory.
# `torch.cuda.empty_cache()` clears PyTorch's internal GPU memory management cache,
# freeing up VRAM for the next model or process.

def set_model_config_path(self, new_path: str):
self.unload_model()
self.model_config_path = new_path.replace("\\", "/")
self.load_model()
print(f"Model config updated to: {new_path}")

def set_model_checkpoint_path(self, new_path: str):
self.unload_model()
self.model_checkpoint_path = new_path.replace("\\", "/")
self.load_model()
print(f"Model checkpoint updated to: {new_path}")

def set_ref_audio_path(self, new_path: str):
# Updating the reference audio doesn't require unloading the model.
# We're just recomputing style embeddings.
self.ref_audio_path = new_path
self.compute_reference_style(self.ref_audio_path)
print(f"Reference audio updated to: {new_path}")

def set_all_parameters(self, model_config_path: str, model_checkpoint_path: str, ref_audio_path: str):
"""
Updates model config, checkpoint, and reference audio simultaneously,
reloading the model only once.
"""
self.unload_model() # Unload the previous model
self.model_config_path = model_config_path.replace("\\", "/")
self.model_checkpoint_path = model_checkpoint_path.replace("\\", "/")
self.ref_audio_path = ref_audio_path
self.load_model() # Reload the new model with updated config and checkpoint
self.compute_reference_style(self.ref_audio_path) # Recompute style embeddings
print(f"Updated all parameters:\n - Model config: {model_config_path}\n - Model checkpoint: {model_checkpoint_path}\n - Reference audio: {ref_audio_path}")

def get_stream_info(self):
import pyaudio
return pyaudio.paInt16, 1, 24000

def synthesize(self, text: str) -> bool:
"""
Synthesizes text to audio stream using the loaded StyleTTS model.
Args:
text (str): Text to synthesize.
"""
Expand Down Expand Up @@ -186,7 +296,7 @@ def load_model(self):
state_dict = params[key]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
name = k[7:]
new_state_dict[name] = v
self.model[key].load_state_dict(new_state_dict, strict=False)
_ = [self.model[key].eval() for key in self.model]
Expand All @@ -198,8 +308,8 @@ def load_model(self):

# Initialize phonemizer
self.global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us',
preserve_punctuation=True,
with_stress=True)
preserve_punctuation=True,
with_stress=True)

# Initialize diffusion sampler
self.sampler = DiffusionSampler(
Expand Down Expand Up @@ -241,7 +351,7 @@ def inference(self, text: str,
embedding_scale: float = 1.0) -> np.ndarray:
"""
Run inference with given parameters and return audio waveform.
Args:
text (str): Text to synthesize.
alpha (float): Timbre blending factor.
Expand Down Expand Up @@ -322,3 +432,23 @@ def inference(self, text: str,
waveform = waveform[..., :-50]

return waveform

def get_voices(self):
"""
Retrieves the installed voices available for the StyleTTS engine.
We return an empty list since StyleTTS does not support voice retrieval.
"""
voice_objects = []
return voice_objects

def set_voice(self, voice: StyleTTSVoice):
"""
Sets the voice to be used for speech synthesis.
"""
if isinstance(voice, StyleTTSVoice):
self.voice = voice
self.set_all_parameters(
model_config_path=voice.model_config_path,
model_checkpoint_path=voice.model_checkpoint_path,
ref_audio_path=voice.ref_audio_path,
)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
current_version = "0.4.19"
current_version = "0.4.20"

import setuptools

Expand Down
22 changes: 22 additions & 0 deletions tests/azure_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
if __name__ == "__main__":
import os
from RealtimeTTS import TextToAudioStream, AzureEngine

def dummy_generator():
yield "Hey guys! These here are realtime spoken sentences based on local text synthesis. "
yield "With a local, neuronal, cloned voice. So every spoken sentence sounds unique."

# for normal use with minimal logging:
import os
engine = AzureEngine(
os.environ["AZURE_SPEECH_KEY"],
os.environ["AZURE_SPEECH_REGION"],
audio_format="riff-48khz-16bit-mono-pcm"
)

stream = TextToAudioStream(engine)

print("Starting to play stream")
stream.feed(dummy_generator()).play(log_synthesized_text=True)

engine.shutdown()
Loading

0 comments on commit c1f6dc0

Please sign in to comment.