Skip to content

Commit

Permalink
fix/valid langs list (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
JarbasAl authored Oct 23, 2023
1 parent d856755 commit 44cc847
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions ovos_audio_transformer_plugin_speechbrain_langdetect/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
import torch

from ovos_bus_client.session import SessionManager
from ovos_config.config import Configuration
from ovos_config.locale import get_default_lang
from ovos_plugin_manager.templates.transformers import AudioTransformer
from ovos_utils.log import LOG
from ovos_utils.xdg_utils import xdg_data_home
Expand All @@ -19,6 +20,11 @@ def __init__(self, config=None):
else:
self.engine = EncoderClassifier.from_hparams(source=model, savedir=f"{xdg_data_home()}/speechbrain")

@property
def valid_langs(self) -> List[str]:
return list(set([get_default_lang()] +
Configuration().get("secondary_langs", [])))

@staticmethod
def audiochunk2array(audio_data):
# Convert buffer to float32 using NumPy
Expand All @@ -43,17 +49,15 @@ def signal2probs(self, signal):
def transform(self, audio_data):
signal = self.audiochunk2array(audio_data)

# list of lang codes for this request from bus message/config
s = SessionManager.get()
valid = [l.split("-")[0] for l in s.valid_languages]
valid = [l.split("-")[0] for l in self.valid_langs]
if len(valid) == 1:
# no classification needed
return audio_data, {}

probs = self.signal2probs(signal)
probs = [(k, v) for k, v in probs.items() if k in valid]
total = sum(p[1] for p in probs) or 1
probs = [(k, v/total) for k, v in probs]
probs = [(k, v / total) for k, v in probs]
lang, prob = max(probs, key=lambda k: k[1])
LOG.info(f"Detected speech language '{lang}' with probability {prob}")
return audio_data, {"stt_lang": lang.split(":")[0],
Expand Down

0 comments on commit 44cc847

Please sign in to comment.