Skip to content

Commit

Permalink
feat/lang_detection_plugin (#5)
Browse files Browse the repository at this point in the history
* feat/lang_detection_plugin

* Update requirements.txt
  • Loading branch information
JarbasAl authored Apr 20, 2024
1 parent 61fceff commit d769875
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 22 deletions.
34 changes: 15 additions & 19 deletions ovos_audio_transformer_plugin_speechbrain_langdetect/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from typing import List

import numpy as np
import torch
from ovos_config.config import Configuration
from ovos_config.locale import get_default_lang, get_valid_languages
from ovos_plugin_manager.templates.transformers import AudioTransformer
from ovos_utils.log import LOG
from ovos_config.locale import get_valid_languages
from ovos_plugin_manager.templates.transformers import AudioLanguageDetector
from ovos_utils.xdg_utils import xdg_data_home
from speechbrain.pretrained import EncoderClassifier
from speech_recognition import AudioData
from speechbrain.inference import EncoderClassifier


class SpeechBrainLangClassifier(AudioTransformer):
class SpeechBrainLangClassifier(AudioLanguageDetector):
def __init__(self, config=None):
config = config or {}
super().__init__("ovos-audio-transformer-plugin-speechbrain-langdetect", 10, config)
Expand Down Expand Up @@ -90,27 +87,26 @@ def signal2probs(self, signal):
return {langmap[k].lower(): v for k, v in results.items()}

# plugin api
def transform(self, audio_data):
def detect(self, audio_data: bytes, valid_langs=None):
if isinstance(audio_data, AudioData):
audio_data = audio_data.get_wav_data()

signal = self.audiochunk2array(audio_data)

valid = get_valid_languages()
valid = valid_langs or get_valid_languages()
if len(valid) == 1:
# no classification needed
return audio_data, {}

probs = self.signal2probs(signal)

valid2 = [l.split("-")[0] for l in valid]
probs = [(k, v) for k, v in probs.items()
if k.split("-")[0] in valid2]

total = sum(p[1] for p in probs) or 1
probs = [(k, v / total) for k, v in probs]

lang, prob = max(probs, key=lambda k: k[1])
LOG.info(f"Detected speech language '{lang}' with probability {prob}")
return audio_data, {"stt_lang": lang.split(":")[0],
"lang_probability": prob,
"lang_predictions": probs}
return lang, prob


if __name__ == "__main__":
Expand All @@ -121,6 +117,6 @@ def transform(self, audio_data):
audio = Recognizer().record(source)

s = SpeechBrainLangClassifier()
_, ctxt = s.transform(audio.get_wav_data())
print(ctxt)
# {'stt_lang': 'en', 'lang_probability': 0.8076384663581848}
lang, prob = s.detect(audio.get_wav_data(), valid_langs=["en-us", "es-es"])
print(lang, prob)
# en-us 0.5979952496320518
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
torchaudio
speechbrain
ovos_plugin_manager
ovos-config>=0.0.12a3
speechbrain~=1.0.0
ovos-plugin-manager~=0.0, >=0.0.26a15
ovos-config>=0.0.12

0 comments on commit d769875

Please sign in to comment.