Skip to content

Commit

Permalink
fix:b64 improvements (#107)
Browse files Browse the repository at this point in the history
* fix:b64 improvements

* Update ovos_audio/service.py

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* fix:b64

* fix:b64

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
JarbasAl and coderabbitai[bot] authored Oct 23, 2024
1 parent 8e51f38 commit 29ab229
Showing 1 changed file with 35 additions and 15 deletions.
50 changes: 35 additions & 15 deletions ovos_audio/service.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import base64
import binascii
import json
import os
import os.path
import time
import json
from hashlib import md5
from os.path import exists
from queue import Queue
from tempfile import gettempdir
from threading import Thread, Lock
from typing import Optional

import binascii
import time
from ovos_bus_client import Message, MessageBusClient
from ovos_bus_client.session import SessionManager
from ovos_config.config import Configuration
Expand Down Expand Up @@ -53,7 +54,9 @@ class PlaybackService(Thread):
def __init__(self, ready_hook=on_ready, error_hook=on_error,
stopping_hook=on_stopping, alive_hook=on_alive,
started_hook=on_started, watchdog=lambda: None,
bus=None, disable_ocp=None, validate_source=True):
bus=None, disable_ocp=None, validate_source=True,
tts: Optional[TTS] = None,
disable_fallback: bool = False):
super(PlaybackService, self).__init__()

LOG.info("Starting Audio Service")
Expand All @@ -68,10 +71,12 @@ def __init__(self, ready_hook=on_ready, error_hook=on_error,
self.config = Configuration()
self.native_sources = self.config["Audio"].get("native_sources",
["debug_cli", "audio"])
self.tts = None
self.tts: Optional[TTS] = tts
self._tts_hash = None
self.lock = Lock()
self.fallback_tts = None
self.disable_reload = tts is not None
self.disable_fallback = disable_fallback
self.fallback_tts: Optional[TTS] = None
self._fallback_tts_hash = None
self._last_stop_signal = 0
self.validate_source = validate_source
Expand All @@ -90,7 +95,8 @@ def __init__(self, ready_hook=on_ready, error_hook=on_error,

try:
self._maybe_reload_tts()
Configuration.set_config_watcher(self._maybe_reload_tts)
if not self.disable_reload:
Configuration.set_config_watcher(self._maybe_reload_tts)
except Exception as e:
LOG.exception(e)
self.status.set_error(e)
Expand Down Expand Up @@ -273,14 +279,19 @@ def handle_b64_audio(self, message):
stopwatch = Stopwatch()
stopwatch.start()
utterance = message.data['utterance']
listen = message.data.get("listen", False)

ctxt = self.tts._get_ctxt({"message": message})
wav, _ = self.tts.synth(utterance, ctxt)
with open(wav, "rb") as f:
# cast to str() to get a path, as it is a AudioFile object from tts cache
with open(str(wav), "rb") as f:
audio = f.read()

b64_audio = base64.b64encode(audio)
self.bus.emit(message.response({"audio": b64_audio}))
b64_audio = base64.b64encode(audio).decode("utf-8")
self.bus.emit(message.response({"audio": b64_audio,
"listen": listen,
'tts_id': self.tts.plugin_id,
"utterance": utterance}))

stopwatch.stop()
report_timing(sess.session_id, stopwatch,
Expand Down Expand Up @@ -326,15 +337,20 @@ def handle_speak(self, message):
self.execute_tts(utterance, sess.session_id, listen, message)

stopwatch.stop()
plugin_id = self.tts.plugin_id if self.tts else ""
report_timing(sess.session_id, stopwatch,
{'utterance': utterance,
'tts': self.tts.plugin_id})
'tts': plugin_id})

def _maybe_reload_tts(self):
"""
Load TTS modules if not yet loaded or if configuration has changed.
Optionally pre-loads fallback TTS if configured
"""
if self.disable_reload:
LOG.debug("skipping TTS reload")
return

config = Configuration().get("tts", {})
tts_m = config.get("module", "")
ftts_m = config.get("fallback_module", "")
Expand All @@ -354,6 +370,10 @@ def _maybe_reload_tts(self):
self.tts.init(self.bus, self.playback_thread)
self._tts_hash = _tts_hash

if self.disable_fallback:
LOG.debug("skipping fallback TTS reload")
return

# if fallback TTS is the same as main TTS dont load it
if config.get("module", "") == config.get("fallback_module", "") or not config.get("fallback_module", ""):
LOG.debug("Skipping fallback TTS init, fallback is empty or same as main TTS")
Expand Down Expand Up @@ -391,7 +411,7 @@ def execute_tts(self, utterance, ident, listen=False, message: Message = None):
if self._tts_hash != self._fallback_tts_hash:
self.execute_fallback_tts(utterance, ident, listen, message)

def _get_tts_fallback(self):
def _get_tts_fallback(self) -> Optional[TTS]:
"""Lazily initializes the fallback TTS if needed."""
if not self.fallback_tts:
config = Configuration()
Expand Down Expand Up @@ -428,7 +448,7 @@ def execute_fallback_tts(self, utterance, ident, listen, message: Message = None
LOG.exception(f"TTS FAILURE! utterance : {utterance}")

@property
def is_speaking(self):
def is_speaking(self) -> bool:
return self.tts.playback is not None and \
self.tts.playback._now_playing is not None

Expand All @@ -448,7 +468,7 @@ def handle_stop(self, message: Message):
self.bus.emit(message.forward("mycroft.stop.handled", {"by": "TTS"}))

@staticmethod
def _resolve_sound_uri(uri: str):
def _resolve_sound_uri(uri: str) -> Optional[str]:
""" helper to resolve sound files full path"""
if uri is None:
return None
Expand All @@ -462,7 +482,7 @@ def _resolve_sound_uri(uri: str):
return audio_file

@staticmethod
def _path_from_hexdata(hex_audio, audio_ext=None):
def _path_from_hexdata(hex_audio, audio_ext=None) -> str:
""" hex_audio contains hex string encoded bytes
audio_ext if not provided assumed to be wav
Expand Down

0 comments on commit 29ab229

Please sign in to comment.