diff --git a/speech_recognition/__init__.py b/speech_recognition/__init__.py index 253ab0fe..5ddd6055 100644 --- a/speech_recognition/__init__.py +++ b/speech_recognition/__init__.py @@ -1411,6 +1411,17 @@ def recognize_tensorflow(self, audio_data, tensor_graph='tensorflow-data/conv_ac human_string = self.tflabels[node_id] return human_string + def pre_load_whisper_model(self, model="base", load_options=None): + """ + Pre-load whisper model + model can be any of tiny, base, small, medium, large, tiny.en, base.en, small.en, medium.en. See https://github.com/openai/whisper for more details. + """ + import whisper + + if load_options or not hasattr(self, "whisper_model") or self.whisper_model.get(model) is None: + self.whisper_model = getattr(self, "whisper_model", {}) + self.whisper_model[model] = whisper.load_model(model, **load_options or {}) + def recognize_whisper(self, audio_data, model="base", show_dict=False, load_options=None, language=None, translate=False, **transcribe_options): """ Performs speech recognition on ``audio_data`` (an ``AudioData`` instance), using Whisper. @@ -1430,11 +1441,8 @@ def recognize_whisper(self, audio_data, model="base", show_dict=False, load_opti import numpy as np import soundfile as sf import torch - import whisper - if load_options or not hasattr(self, "whisper_model") or self.whisper_model.get(model) is None: - self.whisper_model = getattr(self, "whisper_model", {}) - self.whisper_model[model] = whisper.load_model(model, **load_options or {}) + self.pre_load_whisper_model(model=model, load_options=load_options) # 16 kHz https://github.com/openai/whisper/blob/28769fcfe50755a817ab922a7bc83483159600a9/whisper/audio.py#L98-L99 wav_bytes = audio_data.get_wav_data(convert_rate=16000)