Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix/valid langs list #2

Merged
merged 1 commit into from
Oct 23, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions ovos_audio_transformer_plugin_speechbrain_langdetect/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import numpy as np
import torch

from ovos_bus_client.session import SessionManager
from ovos_config.config import Configuration
from ovos_config.locale import get_default_lang
from ovos_plugin_manager.templates.transformers import AudioTransformer
from ovos_utils.log import LOG
from ovos_utils.xdg_utils import xdg_data_home
Expand All @@ -19,6 +20,11 @@ def __init__(self, config=None):
else:
self.engine = EncoderClassifier.from_hparams(source=model, savedir=f"{xdg_data_home()}/speechbrain")

@property
def valid_langs(self) -> List[str]:
return list(set([get_default_lang()] +
Configuration().get("secondary_langs", [])))

@staticmethod
def audiochunk2array(audio_data):
# Convert buffer to float32 using NumPy
Expand All @@ -43,17 +49,15 @@ def signal2probs(self, signal):
def transform(self, audio_data):
signal = self.audiochunk2array(audio_data)

# list of lang codes for this request from bus message/config
s = SessionManager.get()
valid = [l.split("-")[0] for l in s.valid_languages]
valid = [l.split("-")[0] for l in self.valid_langs]
if len(valid) == 1:
# no classification needed
return audio_data, {}

probs = self.signal2probs(signal)
probs = [(k, v) for k, v in probs.items() if k in valid]
total = sum(p[1] for p in probs) or 1
probs = [(k, v/total) for k, v in probs]
probs = [(k, v / total) for k, v in probs]
lang, prob = max(probs, key=lambda k: k[1])
LOG.info(f"Detected speech language '{lang}' with probability {prob}")
return audio_data, {"stt_lang": lang.split(":")[0],
Expand Down
Loading