Skip to content

Commit

Permalink
feat:b64 (#152)
Browse files Browse the repository at this point in the history
* feat:b64

previously it was alowed to "upload" audio as b64 and the transcription was directly injected for utterance handling

this commit adds support for returning the transcript instead of injecting it

* Apply suggestions from code review

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* drop distutils.spawn in favor of shutil.which

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
JarbasAl and coderabbitai[bot] authored Oct 23, 2024
1 parent beda994 commit 9175ae0
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 25 deletions.
21 changes: 14 additions & 7 deletions ovos_dinkum_listener/plugins.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, List, Tuple
from typing import Any, Dict, Optional, List, Tuple, Union

from ovos_config.config import Configuration
from ovos_plugin_manager.stt import OVOSSTTFactory
Expand Down Expand Up @@ -56,18 +56,25 @@ def create_streaming_thread(self):
return FakeStreamThread(self.queue, self.lang, self.engine, sample_rate,
sample_width)

def transcribe(self, audio: Optional = None,
def transcribe(self, audio: Optional[Union[bytes, AudioData]] = None,
lang: Optional[str] = None) -> List[Tuple[str, float]]:
"""transcribe audio data to a list of
possible transcriptions and respective confidences"""
# plugins expect AudioData objects
audiod = AudioData(audio or self.stream.buffer.read(),
sample_rate=self.stream.sample_rate,
sample_width=self.stream.sample_width)
transcripts = self.engine.transcribe(audiod, lang)
if audio is None:
audiod = AudioData(self.stream.buffer.read(),
sample_rate=self.stream.sample_rate,
sample_width=self.stream.sample_width)
self.stream.buffer.clear()
return transcripts
elif isinstance(audio, bytes):
audiod = AudioData(audio,
sample_rate=self.stream.sample_rate,
sample_width=self.stream.sample_width)
elif isinstance(audio, AudioData):
audiod = audio
else:
raise ValueError(f"'audio' must be 'bytes' or 'AudioData', got '{type(audio)}'")
return self.engine.transcribe(audiod, lang)


def load_stt_module(config: Dict[str, Any] = None) -> StreamingSTT:
Expand Down
68 changes: 50 additions & 18 deletions ovos_dinkum_listener/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,37 +12,40 @@
import base64
import json
import subprocess
import time
import wave
from shutil import which
from enum import Enum
from hashlib import md5
from os.path import dirname
from pathlib import Path
from tempfile import NamedTemporaryFile
from threading import Thread, RLock, Event
from typing import List, Tuple
from typing import List, Tuple, Optional

import speech_recognition as sr
from distutils.spawn import find_executable
import time
from ovos_bus_client import MessageBusClient
from ovos_bus_client.message import Message
from ovos_bus_client.session import SessionManager
from ovos_config import Configuration
from ovos_config.locations import get_xdg_data_save_path
from ovos_plugin_manager.microphone import OVOSMicrophoneFactory
from ovos_plugin_manager.stt import get_stt_lang_configs, get_stt_supported_langs, get_stt_module_configs
from ovos_plugin_manager.templates.stt import STT
from ovos_plugin_manager.templates.vad import VADEngine
from ovos_plugin_manager.utils.tts_cache import hash_sentence
from ovos_plugin_manager.vad import OVOSVADFactory
from ovos_plugin_manager.vad import get_vad_configs
from ovos_plugin_manager.wakewords import get_ww_lang_configs, get_ww_supported_langs, get_ww_module_configs
from ovos_utils.log import LOG, log_deprecation
from ovos_utils.process_utils import ProcessStatus, StatusCallbackMap, ProcessState

from ovos_dinkum_listener._util import _TemplateFilenameFormatter
from ovos_dinkum_listener.plugins import load_stt_module, load_fallback_stt
from ovos_dinkum_listener.transformers import AudioTransformersService
from ovos_dinkum_listener.voice_loop import DinkumVoiceLoop, ListeningMode, ListeningState
from ovos_dinkum_listener.voice_loop.hotwords import HotwordContainer
from ovos_dinkum_listener._util import _TemplateFilenameFormatter

try:
from ovos_backend_client.api import DatasetApi
except ImportError:
Expand All @@ -64,11 +67,11 @@ def bytes2audiodata(data):
recognizer = sr.Recognizer()
with NamedTemporaryFile() as fp:
fp.write(data)

if find_executable("ffmpeg"):
ffmpeg = which("ffmpeg")
if ffmpeg:
p = fp.name + "converted.wav"
# ensure file format
cmd = ["ffmpeg", "-i", fp.name, "-acodec", "pcm_s16le", "-ar",
cmd = [ffmpeg, "-i", fp.name, "-acodec", "pcm_s16le", "-ar",
"16000", "-ac", "1", "-f", "wav", p, "-y"]
subprocess.call(cmd)
else:
Expand Down Expand Up @@ -150,7 +153,12 @@ class OVOSDinkumVoiceService(Thread):
def __init__(self, on_ready=on_ready, on_error=on_error,
on_stopping=on_stopping, on_alive=on_alive,
on_started=on_started, watchdog=lambda: None, mic=None,
bus=None, validate_source=True, *args, **kwargs):
bus=None, validate_source=True,
stt: Optional[STT] = None,
fallback_stt: Optional[STT] = None,
vad: Optional[VADEngine] = None,
disable_fallback: bool = False,
*args, **kwargs):
"""
watchdog: (callable) function to call periodically indicating
operational status.
Expand Down Expand Up @@ -186,9 +194,14 @@ def __init__(self, on_ready=on_ready, on_error=on_error,
self.mic = mic or OVOSMicrophoneFactory.create(microphone_config)

self.hotwords = HotwordContainer(self.bus)
self.vad = OVOSVADFactory.create()
self.stt = load_stt_module()
self.fallback_stt = load_fallback_stt()
self.vad = vad or OVOSVADFactory.create()
self.stt = stt or load_stt_module()
self.disable_fallback = disable_fallback
self.disable_reload = stt is not None
if disable_fallback:
self.fallback_stt = None
else:
self.fallback_stt = fallback_stt or load_fallback_stt()
self.transformers = AudioTransformersService(self.bus, self.config)

self._load_lock = RLock()
Expand Down Expand Up @@ -374,6 +387,7 @@ def register_event_handlers(self):

self.bus.on('recognizer_loop:sleep', self._handle_sleep)
self.bus.on('recognizer_loop:wake_up', self._handle_wake_up)
self.bus.on('recognizer_loop:b64_transcribe', self._handle_b64_transcribe)
self.bus.on('recognizer_loop:b64_audio', self._handle_b64_audio)
self.bus.on('recognizer_loop:record_stop', self._handle_stop_recording)
self.bus.on('recognizer_loop:state.set', self._handle_change_state)
Expand Down Expand Up @@ -671,15 +685,15 @@ def __normtranscripts(self, transcripts: List[Tuple[str, float]]) -> List[str]:
]
hallucinations = self.config.get("hallucination_list", default_hallucinations) \
if self.config.get("filter_hallucinations", True) else []
utts = [u[0].lstrip(" \"'").strip(" \"'") for u in transcripts]
utts = [u[0].lstrip(" \"'").strip(" \"'") for u in transcripts if u[0]]
filtered_hutts = [u for u in utts if u and u.lower() not in hallucinations]
hutts = [u for u in utts if u and u not in filtered_hutts]
hutts = [u for u in utts if u not in filtered_hutts]
if hutts:
LOG.debug(f"Filtered hallucinations: {hutts}")
return filtered_hutts

def _stt_text(self, transcripts: List[Tuple[str, float]], stt_context: dict):
utts = self.__normtranscripts(transcripts)
utts = self.__normtranscripts(transcripts) if transcripts else []
LOG.debug(f"STT: {utts}")
if utts:
lang = stt_context.get("lang") or Configuration().get("lang", "en-us")
Expand Down Expand Up @@ -922,8 +936,25 @@ def _handle_sound_played(self, message: Message):
if self.voice_loop.state == ListeningState.CONFIRMATION:
self.voice_loop.state = ListeningState.BEFORE_COMMAND

def _handle_b64_transcribe(self, message: Message):
""" transcribe base64 encoded audio and return result via message"""
LOG.debug("Handling Base64 STT request")
b64audio = message.data["audio"]
lang = message.data.get("lang", self.voice_loop.stt.lang)

wav_data = base64.b64decode(b64audio)

self.voice_loop.stt.stream_start()
audio = bytes2audiodata(wav_data)
utterances = self.voice_loop.stt.transcribe(audio, lang)
self.voice_loop.stt.stream_stop()

LOG.debug(f"transcripts: {utterances}")
self.bus.emit(message.response({"transcriptions": utterances, "lang": lang}))

def _handle_b64_audio(self, message: Message):
""" transcribe base64 encoded audio """
""" transcribe base64 encoded audio and inject result into bus"""
LOG.debug("Handling Base64 Incoming Audio")
b64audio = message.data["audio"]
lang = message.data.get("lang", self.voice_loop.stt.lang)

Expand Down Expand Up @@ -1055,7 +1086,7 @@ def reload_configuration(self):
Configuration object reports a change
"""
if self._config_hash() == self._applied_config_hash:
LOG.info(f"No relevant configuration changed")
LOG.debug("No relevant configuration changed")
return
LOG.info("Reloading changed configuration")
if not self._load_lock.acquire(timeout=30):
Expand All @@ -1071,7 +1102,7 @@ def reload_configuration(self):
# Configuration changed, update status and reload
self.status.set_alive()

if new_hash['stt'] != self._applied_config_hash['stt']:
if not self.disable_reload and new_hash['stt'] != self._applied_config_hash['stt']:
LOG.info(f"Reloading STT")
if self.stt:
LOG.debug(f"old={self.stt.__class__}: {self.stt.config}")
Expand All @@ -1083,7 +1114,8 @@ def reload_configuration(self):
if self.stt:
LOG.debug(f"new={self.stt.__class__}: {self.stt.config}")

if new_hash['fallback'] != self._applied_config_hash['fallback']:
if not self.disable_reload and not self.disable_fallback and new_hash['fallback'] != \
self._applied_config_hash['fallback']:
LOG.info(f"Reloading Fallback STT")
if self.fallback_stt:
LOG.debug(f"old={self.fallback_stt.__class__}: "
Expand Down

0 comments on commit 9175ae0

Please sign in to comment.