diff --git a/ovos_stt_plugin_citrinet/__init__.py b/ovos_stt_plugin_citrinet/__init__.py index fec6591..06d815d 100644 --- a/ovos_stt_plugin_citrinet/__init__.py +++ b/ovos_stt_plugin_citrinet/__init__.py @@ -17,7 +17,7 @@ def __init__(self, config: dict = None): lang = self.lang.split("-")[0] if lang not in self.available_languages: raise ValueError(f"unsupported language, must be one of {self.available_languages}") - LOG.info(f"preloading model: {Model.langs[lang]}") + LOG.info(f"preloading model: {Model.default_models[lang]}") self.load_model(lang) def load_model(self, lang: str): @@ -27,7 +27,7 @@ def load_model(self, lang: str): @property def available_languages(self) -> set: - return set(Model.langs) + return set(Model.default_models.keys()) def execute(self, audio: AudioData, language: Optional[str] = None): ''' diff --git a/ovos_stt_plugin_citrinet/engine.py b/ovos_stt_plugin_citrinet/engine.py index db03818..fe7d026 100644 --- a/ovos_stt_plugin_citrinet/engine.py +++ b/ovos_stt_plugin_citrinet/engine.py @@ -42,34 +42,16 @@ class Model: - langs = { - "en": { - "model": "neongeckocom/stt_en_citrinet_512_gamma_0_25", - }, - "es": { - "model": "neongeckocom/stt_es_citrinet_512_gamma_0_25", - }, - "fr": { - "model": "neongeckocom/stt_fr_citrinet_512_gamma_0_25", - }, - "de": { - "model": "neongeckocom/stt_de_citrinet_512_gamma_0_25", - }, - "it": { - "model": "neongeckocom/stt_it_citrinet_512_gamma_0_25", - }, - "uk": { - "model": "neongeckocom/stt_uk_citrinet_512_gamma_0_25", - }, - "nl": { - "model": "neongeckocom/stt_nl_citrinet_512_gamma_0_25", - }, - "pt": { - "model": "neongeckocom/stt_pt_citrinet_512_gamma_0_25", - }, - "ca": { - "model": "projecte-aina/stt-ca-citrinet-512" - }, + default_models = { + "en": "neongeckocom/stt_en_citrinet_512_gamma_0_25", + "es": "neongeckocom/stt_es_citrinet_512_gamma_0_25", + "fr": "neongeckocom/stt_fr_citrinet_512_gamma_0_25", + "de": "neongeckocom/stt_de_citrinet_512_gamma_0_25", + "it": "neongeckocom/stt_it_citrinet_512_gamma_0_25", + "uk": "neongeckocom/stt_uk_citrinet_512_gamma_0_25", + "nl": "neongeckocom/stt_nl_citrinet_512_gamma_0_25", + "pt": "neongeckocom/stt_pt_citrinet_512_gamma_0_25", + "ca": "projecte-aina/stt-ca-citrinet-512", } sample_rate = 16000 subfolder_name = "onnx" @@ -81,9 +63,9 @@ def __init__(self, lang: str, model_folder: Optional[str] = None): self._init_model(lang) def _init_model(self, lang: str): - if lang not in self.langs: - raise ValueError(f"Unsupported language '{lang}'. Available languages: {list(self.langs.keys())}") - model_name = self.langs[lang]["model"] + if lang not in self.default_models: + raise ValueError(f"Unsupported language '{lang}'. Available languages: {list(self.default_models.keys())}") + model_name = self.default_models[lang] self._init_preprocessor(model_name) self._init_encoder(model_name) self._init_tokenizer(model_name)