Skip to content

Commit

Permalink
Flesh out test
Browse files Browse the repository at this point in the history
  • Loading branch information
synesthesiam committed Feb 19, 2024
1 parent 70652bb commit 36353ff
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 74 deletions.
60 changes: 59 additions & 1 deletion tests/shared.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,64 @@
"""Shared code for Wyoming satellite tests."""
from wyoming.audio import AudioChunk
import asyncio
import io
from collections.abc import Iterable
from typing import Optional

from wyoming.audio import AudioChunk, AudioStart, AudioStop
from wyoming.client import AsyncClient
from wyoming.event import Event

AUDIO_START = AudioStart(rate=16000, width=2, channels=1)
AUDIO_STOP = AudioStop()

AUDIO_CHUNK = AudioChunk(
rate=16000, width=2, channels=1, audio=bytes([255] * 960) # 30ms
)


class FakeStreamReaderWriter:
def __init__(self) -> None:
self._undrained_data = bytes()
self._value = bytes()
self._data_ready = asyncio.Event()

def write(self, data: bytes) -> None:
self._undrained_data += data

def writelines(self, data: Iterable[bytes]) -> None:
for line in data:
self.write(line)

async def drain(self) -> None:
self._value += self._undrained_data
self._undrained_data = bytes()
self._data_ready.set()
self._data_ready.clear()

async def readline(self) -> bytes:
while b"\n" not in self._value:
await self._data_ready.wait()

with io.BytesIO(self._value) as value_io:
data = value_io.readline()
self._value = self._value[len(data) :]
return data

async def readexactly(self, n: int) -> bytes:
while len(self._value) < n:
await self._data_ready.wait()

data = self._value[:n]
self._value = self._value[n:]
return data


class MicClient(AsyncClient):
async def read_event(self) -> Optional[Event]:
# Send 30ms of audio every 30ms
await asyncio.sleep(AUDIO_CHUNK.seconds)
return AUDIO_CHUNK.event()

async def write_event(self, event: Event) -> None:
# Output only
pass
146 changes: 85 additions & 61 deletions tests/test_satellite.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,40 @@
import asyncio
import io
import logging
from collections.abc import Iterable
from pathlib import Path
from typing import Final, Optional
from unittest.mock import patch

import pytest
from wyoming.asr import Transcript
from wyoming.audio import AudioChunk
from wyoming.audio import AudioChunk, AudioStart, AudioStop
from wyoming.client import AsyncClient
from wyoming.event import Event, async_read_event
from wyoming.pipeline import PipelineStage, RunPipeline
from wyoming.satellite import RunSatellite, StreamingStarted, StreamingStopped
from wyoming.tts import Synthesize
from wyoming.wake import Detection

from wyoming_satellite import (
EventSettings,
MicSettings,
SatelliteSettings,
SndSettings,
WakeSettings,
WakeStreamingSatellite,
)

from .shared import AUDIO_CHUNK
from .shared import (
AUDIO_CHUNK,
AUDIO_START,
AUDIO_STOP,
FakeStreamReaderWriter,
MicClient,
)

_LOGGER = logging.getLogger()

TIMEOUT: Final = 1


class MicClient(AsyncClient):
def __init__(self) -> None:
super().__init__()

async def read_event(self) -> Optional[Event]:
await asyncio.sleep(AUDIO_CHUNK.seconds)
return AUDIO_CHUNK.event()

async def write_event(self, event: Event) -> None:
# Output only
pass


class WakeClient(AsyncClient):
def __init__(self) -> None:
super().__init__()
Expand All @@ -63,12 +55,40 @@ async def write_event(self, event: Event) -> None:
self._event_ready.set()


class SndClient(AsyncClient):
def __init__(self) -> None:
super().__init__()
self.synthesize = asyncio.Event()
self.audio_start = asyncio.Event()
self.audio_chunk = asyncio.Event()
self.audio_stop = asyncio.Event()

async def read_event(self) -> Optional[Event]:
# Input only
pass

async def write_event(self, event: Event) -> None:
if AudioChunk.is_type(event.type):
self.audio_chunk.set()
elif Synthesize.is_type(event.type):
self.synthesize.set()
elif AudioStart.is_type(event.type):
self.audio_start.set()
elif AudioStop.is_type(event.type):
self.audio_stop.set()


class EventClient(AsyncClient):
def __init__(self) -> None:
super().__init__()
self.detection = asyncio.Event()
self.streaming_started = asyncio.Event()
self.streaming_stopped = asyncio.Event()
self.transcript = asyncio.Event()
self.synthesize = asyncio.Event()
self.audio_start = asyncio.Event()
self.audio_chunk = asyncio.Event()
self.audio_stop = asyncio.Event()

async def read_event(self) -> Optional[Event]:
# Input only
Expand All @@ -81,54 +101,34 @@ async def write_event(self, event: Event) -> None:
self.streaming_started.set()
elif StreamingStopped.is_type(event.type):
self.streaming_stopped.set()
elif Transcript.is_type(event.type):
self.transcript.set()
elif Synthesize.is_type(event.type):
self.synthesize.set()
elif AudioChunk.is_type(event.type):
self.audio_chunk.set()
elif AudioStart.is_type(event.type):
self.audio_start.set()
elif AudioStop.is_type(event.type):
self.audio_stop.set()


class FakeStreamReaderWriter:
def __init__(self) -> None:
self._undrained_data = bytes()
self._value = bytes()
self._data_ready = asyncio.Event()

def write(self, data: bytes) -> None:
self._undrained_data += data

def writelines(self, data: Iterable[bytes]) -> None:
for line in data:
self.write(line)

async def drain(self) -> None:
self._value += self._undrained_data
self._undrained_data = bytes()
self._data_ready.set()
self._data_ready.clear()

async def readline(self) -> bytes:
while b"\n" not in self._value:
await self._data_ready.wait()

with io.BytesIO(self._value) as value_io:
data = value_io.readline()
self._value = self._value[len(data) :]
return data

async def readexactly(self, n: int) -> bytes:
while len(self._value) < n:
await self._data_ready.wait()

data = self._value[:n]
self._value = self._value[n:]
return data
# -----------------------------------------------------------------------------


@pytest.mark.asyncio
async def test_satellite_and_server(tmp_path: Path) -> None:
async def test_wake_satellite() -> None:
mic_client = MicClient()
snd_client = SndClient()
wake_client = WakeClient()
event_client = EventClient()

with patch(
"wyoming_satellite.satellite.SatelliteBase._make_mic_client",
return_value=mic_client,
), patch(
"wyoming_satellite.satellite.SatelliteBase._make_snd_client",
return_value=snd_client,
), patch(
"wyoming_satellite.satellite.SatelliteBase._make_wake_client",
return_value=wake_client,
Expand All @@ -139,19 +139,22 @@ async def test_satellite_and_server(tmp_path: Path) -> None:
satellite = WakeStreamingSatellite(
SatelliteSettings(
mic=MicSettings(uri="test"),
snd=SndSettings(uri="test"),
wake=WakeSettings(uri="test"),
event=EventSettings(uri="test"),
)
)

# Fake server connection
server_io = FakeStreamReaderWriter()
await satellite.set_server("test", server_io) # type: ignore

async def event_from_satellite() -> Optional[Event]:
return await async_read_event(server_io)

satellite_task = asyncio.create_task(satellite.run(), name="satellite")

# Fake server connection
server_io = FakeStreamReaderWriter()
await satellite.set_server("test", server_io) # type: ignore

# Start satellite
await satellite.event_from_server(RunSatellite().event())

# Trigger detection
Expand All @@ -165,9 +168,7 @@ async def event_from_satellite() -> Optional[Event]:
assert RunPipeline.is_type(event.type), event
run_pipeline = RunPipeline.from_event(event)
assert run_pipeline.start_stage == PipelineStage.ASR

# No TTS
assert run_pipeline.end_stage == PipelineStage.HANDLE
assert run_pipeline.end_stage == PipelineStage.TTS

# Event service should have received detection
await asyncio.wait_for(event_client.detection.wait(), timeout=TIMEOUT)
Expand All @@ -185,6 +186,9 @@ async def event_from_satellite() -> Optional[Event]:
# Send transcript
await satellite.event_from_server(Transcript(text="test").event())

# Event service should have received transcript
await asyncio.wait_for(event_client.transcript.wait(), timeout=TIMEOUT)

# Wait for streaming to stop
while satellite.is_streaming:
event = await asyncio.wait_for(event_from_satellite(), timeout=TIMEOUT)
Expand All @@ -194,5 +198,25 @@ async def event_from_satellite() -> Optional[Event]:
# Event service should have received streaming stop
await asyncio.wait_for(event_client.streaming_stopped.wait(), timeout=TIMEOUT)

# Fake a TTS response
await satellite.event_from_server(Synthesize(text="test").event())

# Event service should have received synthesize
await asyncio.wait_for(event_client.synthesize.wait(), timeout=TIMEOUT)

# Audio start, chunk, stop
await satellite.event_from_server(AUDIO_START.event())
await asyncio.wait_for(snd_client.audio_start.wait(), timeout=TIMEOUT)
await asyncio.wait_for(event_client.audio_start.wait(), timeout=TIMEOUT)

# Event service does not get audio chunks, just start/stop
await satellite.event_from_server(AUDIO_CHUNK.event())
await asyncio.wait_for(snd_client.audio_chunk.wait(), timeout=TIMEOUT)

await satellite.event_from_server(AUDIO_STOP.event())
await asyncio.wait_for(snd_client.audio_stop.wait(), timeout=TIMEOUT)
await asyncio.wait_for(event_client.audio_stop.wait(), timeout=TIMEOUT)

# Stop satellite
await satellite.stop()
await satellite_task
13 changes: 1 addition & 12 deletions tests/test_wake_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,11 @@
WakeStreamingSatellite,
)

from .shared import AUDIO_CHUNK
from .shared import MicClient

_LOGGER = logging.getLogger()


class MicClient(AsyncClient):
async def read_event(self) -> Optional[Event]:
# Send 30ms of audio every 30ms
await asyncio.sleep(AUDIO_CHUNK.seconds)
return AUDIO_CHUNK.event()

async def write_event(self, event: Event) -> None:
# Output only
pass


class WakeClient(AsyncClient):
def __init__(self) -> None:
super().__init__()
Expand Down

0 comments on commit 36353ff

Please sign in to comment.