Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: extend wait_for_response to accept list #52

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions ovos_bus_client/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def emit(self, message: Message):
if hasattr(message, 'serialize'):
msg = message.serialize()
else:
msg = json.dumps(message.__dict__)
msg = json.dumps(message.__dict__)
try:
self.client.send(msg)
except WebSocketConnectionClosedException:
Expand Down Expand Up @@ -269,24 +269,30 @@ def wait_for_message(self, message_type: str,
return MessageWaiter(self, message_type).wait(timeout)

def wait_for_response(self, message: Message,
reply_type: Optional[str] = None,
reply_type: Optional[Union[str, List[str]]] = None,
timeout: Union[float, int] = 3.0) -> \
Optional[Message]:
"""
Send a message and wait for a response.

Arguments:
message (Message): message to send
reply_type (str): the message type of the expected reply.
reply_type (str | List[str]): the message type(s) of the expected reply.
Defaults to "<message.msg_type>.response".
timeout: seconds to wait before timeout, defaults to 3

Returns:
The received message or None if the response timed out
"""
message_type = reply_type or message.msg_type + '.response'
message_type = None
if isinstance(reply_type, list):
message_type = reply_type
elif isinstance(reply_type, str):
mikejgray marked this conversation as resolved.
Show resolved Hide resolved
message_type = [reply_type]
elif reply_type is None:
message_type = [message.msg_type + '.response']
waiter = MessageWaiter(self, message_type) # Setup response handler
# Send message and wait for it's response
# Send message and wait for its response
self.emit(message)
return waiter.wait(timeout)

Expand Down
17 changes: 11 additions & 6 deletions ovos_bus_client/client/waiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#

from threading import Event
from typing import List, Union

try:
from mycroft_bus_client.client.waiter import MessageWaiter as _MessageWaiterBase
Expand All @@ -28,20 +29,23 @@ class MessageWaiter:
"""Wait for a single message.
Encapsulate the wait for a message logic separating the setup from
the actual waiting act so the waiting can be setuo, actions can be
the actual waiting act so the waiting can be setup, actions can be
performed and _then_ the message can be waited for.
Argunments:
Arguments:
bus: Bus to check for messages on
message_type: message type to wait for
message_type: message type(s) to wait for
"""
def __init__(self, bus, message_type):
def __init__(self, bus, message_type: Union[str, List[str]]):
self.bus = bus
if not isinstance(message_type, list):
message_type = [message_type]
self.msg_type = message_type
self.received_msg = None
# Setup response handler
self.response_event = Event()
self.bus.once(message_type, self._handler)
for msg in self.msg_type:
self.bus.once(msg, self._handler)

def _handler(self, message):
"""Receive response data."""
Expand All @@ -61,7 +65,8 @@ def wait(self, timeout=3.0):
if not self.response_event.is_set():
# Clean up the event handler
try:
self.bus.remove(self.msg_type, self._handler)
for msg in self.msg_type:
self.bus.remove(msg, self._handler)
except (ValueError, KeyError):
# ValueError occurs on pyee 5.0.1 removing handlers
# registered with once.
Expand Down
122 changes: 61 additions & 61 deletions test/unittests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,46 +11,42 @@
# limitations under the License.

import unittest
from unittest.mock import Mock
from unittest.mock import call, Mock, patch

from pyee import ExecutorEventEmitter

from ovos_bus_client.message import Message
from ovos_bus_client.client.client import MessageBusClient, GUIWebsocketClient
from ovos_bus_client.client import MessageWaiter, MessageCollector

WS_CONF = {
'websocket': {
"host": "testhost",
"port": 1337,
"route": "/core",
"ssl": False
}
}
WS_CONF = {"websocket": {"host": "testhost", "port": 1337, "route": "/core", "ssl": False}}


class TestClient(unittest.TestCase):
def test_echo(self):
from ovos_bus_client.client.client import echo

# TODO

def test_inheritance(self):
from mycroft_bus_client.client import MessageBusClient as _Client

self.assertTrue(issubclass(MessageBusClient, _Client))


class TestMessageBusClient(unittest.TestCase):
from ovos_bus_client.client.client import MessageBusClient

client = MessageBusClient()

def test_build_url(self):
url = MessageBusClient.build_url('localhost', 1337, '/core', False)
self.assertEqual(url, 'ws://localhost:1337/core')
ssl_url = MessageBusClient.build_url('sslhost', 443, '/core', True)
self.assertEqual(ssl_url, 'wss://sslhost:443/core')
url = MessageBusClient.build_url("localhost", 1337, "/core", False)
self.assertEqual(url, "ws://localhost:1337/core")
ssl_url = MessageBusClient.build_url("sslhost", 443, "/core", True)
self.assertEqual(ssl_url, "wss://sslhost:443/core")

def test_create_client(self):
self.assertEqual(self.client.client.url, 'ws://0.0.0.0:8181/core')
self.assertEqual(self.client.client.url, "ws://0.0.0.0:8181/core")
self.assertIsInstance(self.client.emitter, ExecutorEventEmitter)

mock_emitter = Mock()
Expand Down Expand Up @@ -85,9 +81,25 @@ def test_on_collect(self):
# TODO
pass

def test_wait_for_message(self):
# TODO
pass
@patch("ovos_bus_client.client.client.MessageWaiter")
def test_wait_for_message_str(self, mock_message_waiter):
# Arrange
test_message = Message("test.message")
self.client.emit = Mock()
# Act
self.client.wait_for_response(test_message)
# Assert
mock_message_waiter.assert_called_once_with(self.client, ["test.message.response"])

@patch("ovos_bus_client.client.client.MessageWaiter")
def test_wait_for_message_list(self, mock_message_waiter):
# Arrange
test_message = Message("test.message")
self.client.emit = Mock()
# Act
self.client.wait_for_response(test_message, ["test.message.response", "test.message.response2"])
# Assert
mock_message_waiter.assert_called_once_with(self.client, ["test.message.response", "test.message.response2"])

def test_wait_for_response(self):
# TODO
Expand Down Expand Up @@ -145,75 +157,63 @@ def test_on_message(self):
class TestMessageWaiter:
def test_message_wait_success(self):
bus = Mock()
waiter = MessageWaiter(bus, 'delayed.message')
bus.once.assert_called_with('delayed.message', waiter._handler)
waiter = MessageWaiter(bus, "delayed.message")
bus.once.assert_called_with("delayed.message", waiter._handler)

test_msg = Mock(name='test_msg')
test_msg = Mock(name="test_msg")
waiter._handler(test_msg) # Inject response

assert waiter.wait() == test_msg

def test_message_wait_timeout(self):
bus = Mock()
waiter = MessageWaiter(bus, 'delayed.message')
bus.once.assert_called_with('delayed.message', waiter._handler)
waiter = MessageWaiter(bus, "delayed.message")
bus.once.assert_called_with("delayed.message", waiter._handler)

assert waiter.wait(0.3) is None

def test_message_converts_to_list(self):
bus = Mock()
waiter = MessageWaiter(bus, "test.message")
assert isinstance(waiter.msg_type, list)
bus.once.assert_called_with("test.message", waiter._handler)

def test_multiple_messages(self):
bus = Mock()
waiter = MessageWaiter(bus, ["test.message", "test.message2"])
bus.once.assert_has_calls([call("test.message", waiter._handler), call("test.message2", waiter._handler)])


class TestMessageCollector:
def test_message_wait_success(self):
bus = Mock()
collector = MessageCollector(bus, Message('delayed.message'),
min_timeout=0.0, max_timeout=2.0)

test_register = Mock(name='test_register')
test_register.data = {
'query': collector.collect_id,
'timeout': 5,
'handler': 'test_handler1'
}
collector = MessageCollector(bus, Message("delayed.message"), min_timeout=0.0, max_timeout=2.0)

test_register = Mock(name="test_register")
test_register.data = {"query": collector.collect_id, "timeout": 5, "handler": "test_handler1"}
collector._register_handler(test_register) # Inject response

test_response = Mock(name='test_register')
test_response.data = {
'query': collector.collect_id,
'handler': 'test_handler1'
}
test_response = Mock(name="test_register")
test_response.data = {"query": collector.collect_id, "handler": "test_handler1"}
collector._receive_response(test_response)

assert collector.collect() == [test_response]

def test_message_drop_invalid(self):
bus = Mock()
collector = MessageCollector(bus, Message('delayed.message'),
min_timeout=0.0, max_timeout=2.0)

valid_register = Mock(name='valid_register')
valid_register.data = {
'query': collector.collect_id,
'timeout': 5,
'handler': 'test_handler1'
}
invalid_register = Mock(name='invalid_register')
invalid_register.data = {
'query': 'asdf',
'timeout': 5,
'handler': 'test_handler1'
}
collector = MessageCollector(bus, Message("delayed.message"), min_timeout=0.0, max_timeout=2.0)

valid_register = Mock(name="valid_register")
valid_register.data = {"query": collector.collect_id, "timeout": 5, "handler": "test_handler1"}
invalid_register = Mock(name="invalid_register")
invalid_register.data = {"query": "asdf", "timeout": 5, "handler": "test_handler1"}
collector._register_handler(valid_register) # Inject response
collector._register_handler(invalid_register) # Inject response

valid_response = Mock(name='valid_register')
valid_response.data = {
'query': collector.collect_id,
'handler': 'test_handler1'
}
invalid_response = Mock(name='invalid_register')
invalid_response.data = {
'query': 'asdf',
'handler': 'test_handler1'
}
valid_response = Mock(name="valid_register")
valid_response.data = {"query": collector.collect_id, "handler": "test_handler1"}
invalid_response = Mock(name="invalid_register")
invalid_response.data = {"query": "asdf", "handler": "test_handler1"}
collector._receive_response(valid_response)
collector._receive_response(invalid_response)
assert collector.collect() == [valid_response]
Loading