diff --git a/ovos_dinkum_listener/plugins.py b/ovos_dinkum_listener/plugins.py index 3ef83bd..23d0272 100644 --- a/ovos_dinkum_listener/plugins.py +++ b/ovos_dinkum_listener/plugins.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, List, Tuple +from typing import Any, Dict, Optional, List, Tuple, Union from ovos_config.config import Configuration from ovos_plugin_manager.stt import OVOSSTTFactory @@ -56,18 +56,25 @@ def create_streaming_thread(self): return FakeStreamThread(self.queue, self.lang, self.engine, sample_rate, sample_width) - def transcribe(self, audio: Optional = None, + def transcribe(self, audio: Optional[Union[bytes, AudioData]] = None, lang: Optional[str] = None) -> List[Tuple[str, float]]: """transcribe audio data to a list of possible transcriptions and respective confidences""" # plugins expect AudioData objects - audiod = AudioData(audio or self.stream.buffer.read(), - sample_rate=self.stream.sample_rate, - sample_width=self.stream.sample_width) - transcripts = self.engine.transcribe(audiod, lang) if audio is None: + audiod = AudioData(self.stream.buffer.read(), + sample_rate=self.stream.sample_rate, + sample_width=self.stream.sample_width) self.stream.buffer.clear() - return transcripts + elif isinstance(audio, bytes): + audiod = AudioData(audio, + sample_rate=self.stream.sample_rate, + sample_width=self.stream.sample_width) + elif isinstance(audio, AudioData): + audiod = audio + else: + raise ValueError(f"'audio' must be 'bytes' or 'AudioData', got '{type(audio)}'") + return self.engine.transcribe(audiod, lang) def load_stt_module(config: Dict[str, Any] = None) -> StreamingSTT: diff --git a/ovos_dinkum_listener/service.py b/ovos_dinkum_listener/service.py index a85c7ec..2d165df 100644 --- a/ovos_dinkum_listener/service.py +++ b/ovos_dinkum_listener/service.py @@ -12,18 +12,18 @@ import base64 import json import subprocess -import time import wave +from shutil import which from enum import Enum from hashlib import md5 from os.path import dirname from pathlib import Path from tempfile import NamedTemporaryFile from threading import Thread, RLock, Event -from typing import List, Tuple +from typing import List, Tuple, Optional import speech_recognition as sr -from distutils.spawn import find_executable +import time from ovos_bus_client import MessageBusClient from ovos_bus_client.message import Message from ovos_bus_client.session import SessionManager @@ -31,6 +31,8 @@ from ovos_config.locations import get_xdg_data_save_path from ovos_plugin_manager.microphone import OVOSMicrophoneFactory from ovos_plugin_manager.stt import get_stt_lang_configs, get_stt_supported_langs, get_stt_module_configs +from ovos_plugin_manager.templates.stt import STT +from ovos_plugin_manager.templates.vad import VADEngine from ovos_plugin_manager.utils.tts_cache import hash_sentence from ovos_plugin_manager.vad import OVOSVADFactory from ovos_plugin_manager.vad import get_vad_configs @@ -38,11 +40,12 @@ from ovos_utils.log import LOG, log_deprecation from ovos_utils.process_utils import ProcessStatus, StatusCallbackMap, ProcessState +from ovos_dinkum_listener._util import _TemplateFilenameFormatter from ovos_dinkum_listener.plugins import load_stt_module, load_fallback_stt from ovos_dinkum_listener.transformers import AudioTransformersService from ovos_dinkum_listener.voice_loop import DinkumVoiceLoop, ListeningMode, ListeningState from ovos_dinkum_listener.voice_loop.hotwords import HotwordContainer -from ovos_dinkum_listener._util import _TemplateFilenameFormatter + try: from ovos_backend_client.api import DatasetApi except ImportError: @@ -64,11 +67,11 @@ def bytes2audiodata(data): recognizer = sr.Recognizer() with NamedTemporaryFile() as fp: fp.write(data) - - if find_executable("ffmpeg"): + ffmpeg = which("ffmpeg") + if ffmpeg: p = fp.name + "converted.wav" # ensure file format - cmd = ["ffmpeg", "-i", fp.name, "-acodec", "pcm_s16le", "-ar", + cmd = [ffmpeg, "-i", fp.name, "-acodec", "pcm_s16le", "-ar", "16000", "-ac", "1", "-f", "wav", p, "-y"] subprocess.call(cmd) else: @@ -150,7 +153,12 @@ class OVOSDinkumVoiceService(Thread): def __init__(self, on_ready=on_ready, on_error=on_error, on_stopping=on_stopping, on_alive=on_alive, on_started=on_started, watchdog=lambda: None, mic=None, - bus=None, validate_source=True, *args, **kwargs): + bus=None, validate_source=True, + stt: Optional[STT] = None, + fallback_stt: Optional[STT] = None, + vad: Optional[VADEngine] = None, + disable_fallback: bool = False, + *args, **kwargs): """ watchdog: (callable) function to call periodically indicating operational status. @@ -186,9 +194,14 @@ def __init__(self, on_ready=on_ready, on_error=on_error, self.mic = mic or OVOSMicrophoneFactory.create(microphone_config) self.hotwords = HotwordContainer(self.bus) - self.vad = OVOSVADFactory.create() - self.stt = load_stt_module() - self.fallback_stt = load_fallback_stt() + self.vad = vad or OVOSVADFactory.create() + self.stt = stt or load_stt_module() + self.disable_fallback = disable_fallback + self.disable_reload = stt is not None + if disable_fallback: + self.fallback_stt = None + else: + self.fallback_stt = fallback_stt or load_fallback_stt() self.transformers = AudioTransformersService(self.bus, self.config) self._load_lock = RLock() @@ -374,6 +387,7 @@ def register_event_handlers(self): self.bus.on('recognizer_loop:sleep', self._handle_sleep) self.bus.on('recognizer_loop:wake_up', self._handle_wake_up) + self.bus.on('recognizer_loop:b64_transcribe', self._handle_b64_transcribe) self.bus.on('recognizer_loop:b64_audio', self._handle_b64_audio) self.bus.on('recognizer_loop:record_stop', self._handle_stop_recording) self.bus.on('recognizer_loop:state.set', self._handle_change_state) @@ -671,15 +685,15 @@ def __normtranscripts(self, transcripts: List[Tuple[str, float]]) -> List[str]: ] hallucinations = self.config.get("hallucination_list", default_hallucinations) \ if self.config.get("filter_hallucinations", True) else [] - utts = [u[0].lstrip(" \"'").strip(" \"'") for u in transcripts] + utts = [u[0].lstrip(" \"'").strip(" \"'") for u in transcripts if u[0]] filtered_hutts = [u for u in utts if u and u.lower() not in hallucinations] - hutts = [u for u in utts if u and u not in filtered_hutts] + hutts = [u for u in utts if u not in filtered_hutts] if hutts: LOG.debug(f"Filtered hallucinations: {hutts}") return filtered_hutts def _stt_text(self, transcripts: List[Tuple[str, float]], stt_context: dict): - utts = self.__normtranscripts(transcripts) + utts = self.__normtranscripts(transcripts) if transcripts else [] LOG.debug(f"STT: {utts}") if utts: lang = stt_context.get("lang") or Configuration().get("lang", "en-us") @@ -922,8 +936,25 @@ def _handle_sound_played(self, message: Message): if self.voice_loop.state == ListeningState.CONFIRMATION: self.voice_loop.state = ListeningState.BEFORE_COMMAND + def _handle_b64_transcribe(self, message: Message): + """ transcribe base64 encoded audio and return result via message""" + LOG.debug("Handling Base64 STT request") + b64audio = message.data["audio"] + lang = message.data.get("lang", self.voice_loop.stt.lang) + + wav_data = base64.b64decode(b64audio) + + self.voice_loop.stt.stream_start() + audio = bytes2audiodata(wav_data) + utterances = self.voice_loop.stt.transcribe(audio, lang) + self.voice_loop.stt.stream_stop() + + LOG.debug(f"transcripts: {utterances}") + self.bus.emit(message.response({"transcriptions": utterances, "lang": lang})) + def _handle_b64_audio(self, message: Message): - """ transcribe base64 encoded audio """ + """ transcribe base64 encoded audio and inject result into bus""" + LOG.debug("Handling Base64 Incoming Audio") b64audio = message.data["audio"] lang = message.data.get("lang", self.voice_loop.stt.lang) @@ -1055,7 +1086,7 @@ def reload_configuration(self): Configuration object reports a change """ if self._config_hash() == self._applied_config_hash: - LOG.info(f"No relevant configuration changed") + LOG.debug("No relevant configuration changed") return LOG.info("Reloading changed configuration") if not self._load_lock.acquire(timeout=30): @@ -1071,7 +1102,7 @@ def reload_configuration(self): # Configuration changed, update status and reload self.status.set_alive() - if new_hash['stt'] != self._applied_config_hash['stt']: + if not self.disable_reload and new_hash['stt'] != self._applied_config_hash['stt']: LOG.info(f"Reloading STT") if self.stt: LOG.debug(f"old={self.stt.__class__}: {self.stt.config}") @@ -1083,7 +1114,8 @@ def reload_configuration(self): if self.stt: LOG.debug(f"new={self.stt.__class__}: {self.stt.config}") - if new_hash['fallback'] != self._applied_config_hash['fallback']: + if not self.disable_reload and not self.disable_fallback and new_hash['fallback'] != \ + self._applied_config_hash['fallback']: LOG.info(f"Reloading Fallback STT") if self.fallback_stt: LOG.debug(f"old={self.fallback_stt.__class__}: "