Skip to content

Commit

Permalink
Send language to Wyoming STT (home-assistant#97344)
Browse files Browse the repository at this point in the history
  • Loading branch information
synesthesiam authored Aug 1, 2023
1 parent 5aa3e36 commit 8ad37d7
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 7 deletions.
7 changes: 6 additions & 1 deletion homeassistant/components/wyoming/stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import AsyncIterable
import logging

from wyoming.asr import Transcript
from wyoming.asr import Transcribe, Transcript
from wyoming.audio import AudioChunk, AudioStart, AudioStop
from wyoming.client import AsyncTcpClient

Expand Down Expand Up @@ -89,6 +89,10 @@ async def async_process_audio_stream(
"""Process an audio stream to STT service."""
try:
async with AsyncTcpClient(self.service.host, self.service.port) as client:
# Set transcription language
await client.write_event(Transcribe(language=metadata.language).event())

# Begin audio stream
await client.write_event(
AudioStart(
rate=SAMPLE_RATE,
Expand All @@ -106,6 +110,7 @@ async def async_process_audio_stream(
)
await client.write_event(chunk.event())

# End audio stream
await client.write_event(AudioStop().event())

while True:
Expand Down
14 changes: 14 additions & 0 deletions tests/components/wyoming/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest

from homeassistant.components import stt
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant

Expand Down Expand Up @@ -69,3 +70,16 @@ async def init_wyoming_tts(hass: HomeAssistant, tts_config_entry: ConfigEntry):
return_value=TTS_INFO,
):
await hass.config_entries.async_setup(tts_config_entry.entry_id)


@pytest.fixture
def metadata(hass: HomeAssistant) -> stt.SpeechMetadata:
"""Get default STT metadata."""
return stt.SpeechMetadata(
language=hass.config.language,
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
)
7 changes: 7 additions & 0 deletions tests/components/wyoming/snapshots/test_stt.ambr
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
# serializer version: 1
# name: test_streaming_audio
list([
dict({
'data': dict({
'language': 'en',
}),
'payload': None,
'type': 'transcibe',
}),
dict({
'data': dict({
'channels': 1,
Expand Down
16 changes: 10 additions & 6 deletions tests/components/wyoming/test_stt.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ async def test_support(hass: HomeAssistant, init_wyoming_stt) -> None:
assert entity.supported_channels == [stt.AudioChannels.CHANNEL_MONO]


async def test_streaming_audio(hass: HomeAssistant, init_wyoming_stt, snapshot) -> None:
async def test_streaming_audio(
hass: HomeAssistant, init_wyoming_stt, metadata, snapshot
) -> None:
"""Test streaming audio."""
entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
assert entity is not None
Expand All @@ -40,15 +42,15 @@ async def audio_stream():
"homeassistant.components.wyoming.stt.AsyncTcpClient",
MockAsyncTcpClient([Transcript(text="Hello world").event()]),
) as mock_client:
result = await entity.async_process_audio_stream(None, audio_stream())
result = await entity.async_process_audio_stream(metadata, audio_stream())

assert result.result == stt.SpeechResultState.SUCCESS
assert result.text == "Hello world"
assert mock_client.written == snapshot


async def test_streaming_audio_connection_lost(
hass: HomeAssistant, init_wyoming_stt
hass: HomeAssistant, init_wyoming_stt, metadata
) -> None:
"""Test streaming audio and losing connection."""
entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
Expand All @@ -61,13 +63,15 @@ async def audio_stream():
"homeassistant.components.wyoming.stt.AsyncTcpClient",
MockAsyncTcpClient([None]),
):
result = await entity.async_process_audio_stream(None, audio_stream())
result = await entity.async_process_audio_stream(metadata, audio_stream())

assert result.result == stt.SpeechResultState.ERROR
assert result.text is None


async def test_streaming_audio_oserror(hass: HomeAssistant, init_wyoming_stt) -> None:
async def test_streaming_audio_oserror(
hass: HomeAssistant, init_wyoming_stt, metadata
) -> None:
"""Test streaming audio and error raising."""
entity = stt.async_get_speech_to_text_entity(hass, "stt.test_asr")
assert entity is not None
Expand All @@ -81,7 +85,7 @@ async def audio_stream():
"homeassistant.components.wyoming.stt.AsyncTcpClient",
mock_client,
), patch.object(mock_client, "read_event", side_effect=OSError("Boom!")):
result = await entity.async_process_audio_stream(None, audio_stream())
result = await entity.async_process_audio_stream(metadata, audio_stream())

assert result.result == stt.SpeechResultState.ERROR
assert result.text is None

0 comments on commit 8ad37d7

Please sign in to comment.