From 44cc847f22c36fcb60f6e5a464b4e03f64dd9895 Mon Sep 17 00:00:00 2001 From: JarbasAI <33701864+JarbasAl@users.noreply.github.com> Date: Mon, 23 Oct 2023 21:35:44 +0100 Subject: [PATCH] fix/valid langs list (#2) --- .../__init__.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/ovos_audio_transformer_plugin_speechbrain_langdetect/__init__.py b/ovos_audio_transformer_plugin_speechbrain_langdetect/__init__.py index a6bf2d0..298dcbb 100644 --- a/ovos_audio_transformer_plugin_speechbrain_langdetect/__init__.py +++ b/ovos_audio_transformer_plugin_speechbrain_langdetect/__init__.py @@ -1,7 +1,8 @@ import numpy as np import torch - from ovos_bus_client.session import SessionManager +from ovos_config.config import Configuration +from ovos_config.locale import get_default_lang from ovos_plugin_manager.templates.transformers import AudioTransformer from ovos_utils.log import LOG from ovos_utils.xdg_utils import xdg_data_home @@ -19,6 +20,11 @@ def __init__(self, config=None): else: self.engine = EncoderClassifier.from_hparams(source=model, savedir=f"{xdg_data_home()}/speechbrain") + @property + def valid_langs(self) -> List[str]: + return list(set([get_default_lang()] + + Configuration().get("secondary_langs", []))) + @staticmethod def audiochunk2array(audio_data): # Convert buffer to float32 using NumPy @@ -43,9 +49,7 @@ def signal2probs(self, signal): def transform(self, audio_data): signal = self.audiochunk2array(audio_data) - # list of lang codes for this request from bus message/config - s = SessionManager.get() - valid = [l.split("-")[0] for l in s.valid_languages] + valid = [l.split("-")[0] for l in self.valid_langs] if len(valid) == 1: # no classification needed return audio_data, {} @@ -53,7 +57,7 @@ def transform(self, audio_data): probs = self.signal2probs(signal) probs = [(k, v) for k, v in probs.items() if k in valid] total = sum(p[1] for p in probs) or 1 - probs = [(k, v/total) for k, v in probs] + probs = [(k, v / total) for k, v in probs] lang, prob = max(probs, key=lambda k: k[1]) LOG.info(f"Detected speech language '{lang}' with probability {prob}") return audio_data, {"stt_lang": lang.split(":")[0],