From 8ad37d7640067613fd80020724fa65d89966a7d1 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Tue, 1 Aug 2023 03:05:01 -0500 Subject: [PATCH] Send language to Wyoming STT (#97344) --- homeassistant/components/wyoming/stt.py | 7 ++++++- tests/components/wyoming/conftest.py | 14 ++++++++++++++ tests/components/wyoming/snapshots/test_stt.ambr | 7 +++++++ tests/components/wyoming/test_stt.py | 16 ++++++++++------ 4 files changed, 37 insertions(+), 7 deletions(-) diff --git a/homeassistant/components/wyoming/stt.py b/homeassistant/components/wyoming/stt.py index 3f5487881a32b2..e64a2f14667020 100644 --- a/homeassistant/components/wyoming/stt.py +++ b/homeassistant/components/wyoming/stt.py @@ -2,7 +2,7 @@ from collections.abc import AsyncIterable import logging -from wyoming.asr import Transcript +from wyoming.asr import Transcribe, Transcript from wyoming.audio import AudioChunk, AudioStart, AudioStop from wyoming.client import AsyncTcpClient @@ -89,6 +89,10 @@ async def async_process_audio_stream( """Process an audio stream to STT service.""" try: async with AsyncTcpClient(self.service.host, self.service.port) as client: + # Set transcription language + await client.write_event(Transcribe(language=metadata.language).event()) + + # Begin audio stream await client.write_event( AudioStart( rate=SAMPLE_RATE, @@ -106,6 +110,7 @@ async def async_process_audio_stream( ) await client.write_event(chunk.event()) + # End audio stream await client.write_event(AudioStop().event()) while True: diff --git a/tests/components/wyoming/conftest.py b/tests/components/wyoming/conftest.py index 0dd9041a0d5efc..6b4e705914f107 100644 --- a/tests/components/wyoming/conftest.py +++ b/tests/components/wyoming/conftest.py @@ -4,6 +4,7 @@ import pytest +from homeassistant.components import stt from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant @@ -69,3 +70,16 @@ async def init_wyoming_tts(hass: HomeAssistant, tts_config_entry: ConfigEntry): return_value=TTS_INFO, ): await hass.config_entries.async_setup(tts_config_entry.entry_id) + + +@pytest.fixture +def metadata(hass: HomeAssistant) -> stt.SpeechMetadata: + """Get default STT metadata.""" + return stt.SpeechMetadata( + language=hass.config.language, + format=stt.AudioFormats.WAV, + codec=stt.AudioCodecs.PCM, + bit_rate=stt.AudioBitRates.BITRATE_16, + sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, + channel=stt.AudioChannels.CHANNEL_MONO, + ) diff --git a/tests/components/wyoming/snapshots/test_stt.ambr b/tests/components/wyoming/snapshots/test_stt.ambr index 08fe6a1ef8e722..784f89b2ab8117 100644 --- a/tests/components/wyoming/snapshots/test_stt.ambr +++ b/tests/components/wyoming/snapshots/test_stt.ambr @@ -1,6 +1,13 @@ # serializer version: 1 # name: test_streaming_audio list([ + dict({ + 'data': dict({ + 'language': 'en', + }), + 'payload': None, + 'type': 'transcibe', + }), dict({ 'data': dict({ 'channels': 1, diff --git a/tests/components/wyoming/test_stt.py b/tests/components/wyoming/test_stt.py index 021419f3a5e430..1938d44d310620 100644 --- a/tests/components/wyoming/test_stt.py +++ b/tests/components/wyoming/test_stt.py @@ -27,7 +27,9 @@ async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None: assert entity.supported_channels == [stt.AudioChannels.CHANNEL_MONO] -async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot) -> None: +async def test_streaming_audio( + hass: HomeAssistant, init_wyoming_stt, metadata, snapshot +) -> None: """Test streaming audio.""" entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr") assert entity is not None @@ -40,7 +42,7 @@ async def audio_stream(): "homeassistant.components.wyoming.stt.AsyncTcpClient", MockAsyncTcpClient([Transcript(text="Hello world").event()]), ) as mock_client: - result = await entity.async_process_audio_stream(None, audio_stream()) + result = await entity.async_process_audio_stream(metadata, audio_stream()) assert result.result == stt.SpeechResultState.SUCCESS assert result.text == "Hello world" @@ -48,7 +50,7 @@ async def audio_stream(): async def test_streaming_audio_connection_lost( - hass: HomeAssistant, init_wyoming_stt + hass: HomeAssistant, init_wyoming_stt, metadata ) -> None: """Test streaming audio and losing connection.""" entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr") @@ -61,13 +63,15 @@ async def audio_stream(): "homeassistant.components.wyoming.stt.AsyncTcpClient", MockAsyncTcpClient([None]), ): - result = await entity.async_process_audio_stream(None, audio_stream()) + result = await entity.async_process_audio_stream(metadata, audio_stream()) assert result.result == stt.SpeechResultState.ERROR assert result.text is None -async def test_streaming_audio_oserror(hass: HomeAssistant, init_wyoming_stt) -> None: +async def test_streaming_audio_oserror( + hass: HomeAssistant, init_wyoming_stt, metadata +) -> None: """Test streaming audio and error raising.""" entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr") assert entity is not None @@ -81,7 +85,7 @@ async def audio_stream(): "homeassistant.components.wyoming.stt.AsyncTcpClient", mock_client, ), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")): - result = await entity.async_process_audio_stream(None, audio_stream()) + result = await entity.async_process_audio_stream(metadata, audio_stream()) assert result.result == stt.SpeechResultState.ERROR assert result.text is None