diff --git a/neon_stt_plugin_nemo/__init__.py b/neon_stt_plugin_nemo/__init__.py index 47534e8..9eaf05f 100644 --- a/neon_stt_plugin_nemo/__init__.py +++ b/neon_stt_plugin_nemo/__init__.py @@ -25,6 +25,7 @@ # LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING # NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS # SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +from typing import Optional, List, Tuple import numpy as np @@ -68,16 +69,26 @@ def _init_model(self, language=None) -> Model: def available_languages(self) -> set: return set(available_languages) - def execute(self, audio: AudioData, language = None): - ''' - Executes speach recognition + def execute(self, audio: AudioData, language: Optional[str] = None) -> str: + """ + Executes speech recognition and returns the most likely transcription + @param audio: input audio data + @param language: language of input audio + @return: recognized text + """ - Parameters: - audio : input audio file path - Returns: - text (str): recognized text - ''' - model = self._init_model(language) + return self.transcribe(audio, language)[0][0] + + def transcribe(self, audio, + lang: Optional[str] = None) -> List[Tuple[str, float]]: + """ + Executes speech recognition and returns a list of possible + transcriptions with associated confidence levels. + @param audio: input audio data + @param lang: language of input audio + @return: List of (transcript, confidence) elements + """ + model = self._init_model(lang) audio_buffer = np.frombuffer(audio.get_raw_data(), dtype=np.int16) self.transcriptions = model.stt(audio_buffer, audio.sample_rate) @@ -87,6 +98,5 @@ def execute(self, audio: AudioData, language = None): self.transcriptions = [] else: LOG.debug("Audio had data") - # TODO: Return a string since we currently only get one result and the - # ovos-stt-server only handles strings here - return self.transcriptions[0] + # Models do not return confidence, so just assume max of 1.0 + return [(t, 1.0) for t in self.transcriptions] diff --git a/tests/test_stt.py b/tests/test_stt.py index 14aebb8..84c9f11 100644 --- a/tests/test_stt.py +++ b/tests/test_stt.py @@ -51,6 +51,11 @@ def test_get_stt(self): audio = r.record(source) # read the entire audio file result = self.stt.execute(audio) self.assertIn(transcription, result) + results = self.stt.transcribe(audio) + self.assertEqual(results[0][0], result) + for result in results: + self.assertIsInstance(result[0], str) + self.assertIsInstance(result[1], float) def test_available_languages(self): langs = self.stt.available_languages