Skip to content

Commit

Permalink
Add support for multiple transcripts with updated unit tests (#25)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel McKnight <[email protected]>
  • Loading branch information
NeonDaniel and NeonDaniel authored Sep 19, 2024
1 parent c9d54dc commit 2dee34b
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
34 changes: 22 additions & 12 deletions neon_stt_plugin_nemo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
# LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
# NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
from typing import Optional, List, Tuple

import numpy as np

Expand Down Expand Up @@ -68,16 +69,26 @@ def _init_model(self, language=None) -> Model:
def available_languages(self) -> set:
return set(available_languages)

def execute(self, audio: AudioData, language = None):
'''
Executes speach recognition
def execute(self, audio: AudioData, language: Optional[str] = None) -> str:
"""
Executes speech recognition and returns the most likely transcription
@param audio: input audio data
@param language: language of input audio
@return: recognized text
"""

Parameters:
audio : input audio file path
Returns:
text (str): recognized text
'''
model = self._init_model(language)
return self.transcribe(audio, language)[0][0]

def transcribe(self, audio,
lang: Optional[str] = None) -> List[Tuple[str, float]]:
"""
Executes speech recognition and returns a list of possible
transcriptions with associated confidence levels.
@param audio: input audio data
@param lang: language of input audio
@return: List of (transcript, confidence) elements
"""
model = self._init_model(lang)

audio_buffer = np.frombuffer(audio.get_raw_data(), dtype=np.int16)
self.transcriptions = model.stt(audio_buffer, audio.sample_rate)
Expand All @@ -87,6 +98,5 @@ def execute(self, audio: AudioData, language = None):
self.transcriptions = []
else:
LOG.debug("Audio had data")
# TODO: Return a string since we currently only get one result and the
# ovos-stt-server only handles strings here
return self.transcriptions[0]
# Models do not return confidence, so just assume max of 1.0
return [(t, 1.0) for t in self.transcriptions]
5 changes: 5 additions & 0 deletions tests/test_stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,11 @@ def test_get_stt(self):
audio = r.record(source) # read the entire audio file
result = self.stt.execute(audio)
self.assertIn(transcription, result)
results = self.stt.transcribe(audio)
self.assertEqual(results[0][0], result)
for result in results:
self.assertIsInstance(result[0], str)
self.assertIsInstance(result[1], float)

def test_available_languages(self):
langs = self.stt.available_languages
Expand Down

0 comments on commit 2dee34b

Please sign in to comment.