diff --git a/.gitignore b/.gitignore index 9595be5..c8ece8c 100644 --- a/.gitignore +++ b/.gitignore @@ -18,3 +18,4 @@ dist # Created by unit tests .pytest_cache/ +/.gtm/ diff --git a/ovos_audio/playback.py b/ovos_audio/playback.py index f88db20..099d9ee 100644 --- a/ovos_audio/playback.py +++ b/ovos_audio/playback.py @@ -1,6 +1,8 @@ import random import threading +from ovos_audio.transformers import TTSTransformersService from ovos_bus_client.message import Message +from ovos_plugin_manager.templates.tts import TTS from ovos_utils.log import LOG from ovos_utils.sound import play_audio from queue import Empty @@ -13,19 +15,20 @@ class PlaybackThread(Thread): viseme data to enclosure. """ - def __init__(self, queue): + def __init__(self, queue=TTS.queue, bus=None): super(PlaybackThread, self).__init__() - self.queue = queue + self.queue = queue or TTS.queue self._terminated = False self._processing_queue = False self._paused = False self.enclosure = None self.p = None self._tts = [] - self.bus = None + self.bus = bus or None self._now_playing = None self.active_tts = None self._started = threading.Event() + self.tts_transform = TTSTransformersService(self.bus) @property def is_running(self): @@ -55,6 +58,7 @@ def set_bus(self, bus): bus (MycroftBusClient): bus client """ self.bus = bus + self.tts_transform.set_bus(bus) @property def tts(self): @@ -151,6 +155,9 @@ def _play(self): data, visemes, listen, tts_id, message = self._now_playing self.activate_tts(tts_id) self.on_start(message) + + data = self.tts_transform.transform(data, message.context) + self.p = play_audio(data) if visemes: self.show_visemes(visemes) diff --git a/ovos_audio/service.py b/ovos_audio/service.py index 2ed3423..f543408 100644 --- a/ovos_audio/service.py +++ b/ovos_audio/service.py @@ -4,8 +4,8 @@ from os.path import exists from ovos_audio.audio import AudioService from ovos_audio.playback import PlaybackThread -from ovos_audio.tts import TTSFactory from ovos_audio.transformers import DialogTransformersService +from ovos_audio.tts import TTSFactory from ovos_audio.utils import report_timing, validate_message_context from ovos_bus_client import Message, MessageBusClient from ovos_bus_client.session import SessionManager @@ -75,6 +75,7 @@ def __init__(self, ready_hook=on_ready, error_hook=on_error, self.status.bind(self.bus) self.init_messagebus() self.dialog_transform = DialogTransformersService(self.bus) + self.playback_thread = PlaybackThread(TTS.queue, self.bus) try: self._maybe_reload_tts() @@ -273,7 +274,9 @@ def handle_speak(self, message): utterance = message.data['utterance'] # allow dialog transformers to rewrite speech - utt2 = self.dialog_transform.transform(utterance, sess) + utt2, message.context = self.dialog_transform.transform(dialog=utterance, + context=message.context, + sess=sess) if utterance != utt22: LOG.debug(f"original dialog: {utterance}") LOG.info(f"dialog transformed to: {utt2}") @@ -302,7 +305,7 @@ def _maybe_reload_tts(self): # Create new tts instance LOG.info("(re)loading TTS engine") self.tts = TTSFactory.create(config) - self.tts.init(self.bus, PlaybackThread) + self.tts.init(self.bus, self.playback_thread) self._tts_hash = config.get("module", "") # if fallback TTS is the same as main TTS dont load it @@ -350,7 +353,7 @@ def _get_tts_fallback(self): engine: config.get('tts', {}).get(engine, {})}} self.fallback_tts = TTSFactory.create(cfg) self.fallback_tts.validator.validate() - self.fallback_tts.init(self.bus, PlaybackThread) + self.fallback_tts.init(self.bus, self.playback_thread) return self.fallback_tts diff --git a/ovos_audio/transformers.py b/ovos_audio/transformers.py index c6fb9dc..47dbdb3 100644 --- a/ovos_audio/transformers.py +++ b/ovos_audio/transformers.py @@ -1,11 +1,11 @@ -from ovos_plugin_manager.dialog_transformers import find_dialog_transformer_plugins -from ovos_utils.json_helper import merge_dict -from ovos_utils.log import LOG from ovos_bus_client.session import Session, SessionManager +from ovos_plugin_manager.dialog_transformers import find_dialog_transformer_plugins, find_tts_transformer_plugins +from ovos_utils.log import LOG class DialogTransformersService: """ transform dialogs before being sent to TTS """ + def __init__(self, bus, config=None): self.loaded_plugins = {} self.has_loaded = False @@ -25,7 +25,7 @@ def load_plugins(self): self.loaded_plugins[plug_name].bind(self.bus) LOG.info(f"loaded audio transformer plugin: {plug_name}") except Exception as e: - LOG.exception(f"Failed to load audio transformer plugin: " + LOG.exception(f"Failed to load dialog transformer plugin: " f"{plug_name}") self.has_loaded = True @@ -52,24 +52,106 @@ def shutdown(self): except Exception as e: LOG.warning(e) - def transform(self, dialog: str, session: Session= None) -> str: + def transform(self, wav_file: str, context: dict = None, sess: Session = None) -> str: """ Get transformed audio and context for the preceding audio @param dialog: str to be spoken @return: transformed dialog to be sent to TTS """ - session = session or SessionManager.get() # TODO property not yet introduced in Session - # this will be set per Session/Persona - # active_transformers = session.dialog_transformers or self.plugins + sess = sess or SessionManager.get() + # if isinstance(sess, dict): + # sess = Session.deserialize(sess) + # active_transformers = sess.dialog_transformers or self.plugins + active_transformers = self.plugins for module in active_transformers: try: LOG.debug(f"checking dialog transformer: {module}") - dialog = module.transform(dialog) + dialog, context = module.transform(dialog, context=kwargs) LOG.debug(f"{module.name}: {dialog}") except Exception as e: LOG.exception(e) return dialog + + +class TTSTransformersService: + """ transform wav_files after TTS """ + + def __init__(self, bus=None, config=None): + self.loaded_plugins = {} + self.has_loaded = False + self.bus = bus + # to activate a plugin, just add an entry to mycroft.conf for it + self.config = config or Configuration().get("tts_transformers", {}) + self.load_plugins() + + def load_plugins(self): + for plug_name, plug in find_tts_transformer_plugins().items(): + if plug_name in self.config: + # if disabled skip it + if not self.config[plug_name].get("active", True): + continue + try: + self.loaded_plugins[plug_name] = plug() + if self.bus: + self.loaded_plugins[plug_name].bind(self.bus) + LOG.info(f"loaded audio transformer plugin: {plug_name}") + except Exception as e: + LOG.exception(f"Failed to load tts transformer plugin: " + f"{plug_name}") + self.has_loaded = True + + def set_bus(self, bus): + self.bus = bus + for p in self.loaded_plugins.values(): + p.bind(self.bus) + + @property + def plugins(self) -> list: + """ + Return loaded transformers in priority order, such that modules with a + higher `priority` rank are called first and changes from lower ranked + transformers are applied last. + + A plugin of `priority` 1 will override any existing context keys and + will be the last to modify `audio_data` + """ + return sorted(self.loaded_plugins.values(), + key=lambda k: k.priority, reverse=True) + + def shutdown(self): + """ + Shutdown all loaded plugins + """ + for module in self.plugins: + try: + module.shutdown() + except Exception as e: + LOG.warning(e) + + def transform(self, wav_file: str, context: dict = None, sess: Session = None) -> str: + """ + Get transformed audio and context for the preceding audio + @param wav_file: str path for the TTS wav file + @return: path to transformed wav file + """ + + # TODO property not yet introduced in Session + sess = sess or SessionManager.get() + # if isinstance(sess, dict): + # sess = Session.deserialize(sess) + # active_transformers = sess.tts_transformers or self.plugins + + active_transformers = self.plugins + + for module in active_transformers: + try: + LOG.debug(f"checking tts transformer: {module}") + wav_file, context = module.transform(wav_file, context=context or {}) + LOG.debug(f"{module.name}: {wav_file}") + except Exception as e: + LOG.exception(e) + return dialog