From c53caf9cfe15afff4df41e3e782c7b26c06533ac Mon Sep 17 00:00:00 2001 From: Swen Gross <25036977+emphasize@users.noreply.github.com> Date: Tue, 12 Dec 2023 21:14:27 +0100 Subject: [PATCH] option to remove fallback (handler) (#150) * option to remove fallback (handler) * clean up review * unittests --- ovos_workshop/skills/fallback.py | 31 ++++- test/unittests/skills/test_fallback_skill.py | 117 +++++++++++++++++-- 2 files changed, 135 insertions(+), 13 deletions(-) diff --git a/ovos_workshop/skills/fallback.py b/ovos_workshop/skills/fallback.py index 58372059..56b49056 100644 --- a/ovos_workshop/skills/fallback.py +++ b/ovos_workshop/skills/fallback.py @@ -341,9 +341,6 @@ def _register_system_event_handlers(self): speak_errors=False) self.add_event(f"ovos.skills.fallback.{self.skill_id}.request", self._handle_fallback_request, speak_errors=False) - self.bus.emit(Message("ovos.skills.fallback.register", - {"skill_id": self.skill_id, - "priority": self.priority})) def _handle_fallback_ack(self, message: Message): """ @@ -405,6 +402,10 @@ def wrapper(*args, **kwargs): self._fallback_handlers.append((priority, wrapper)) self.bus.on(f"ovos.skills.fallback.{self.skill_id}", wrapper) + # register with fallback service + self.bus.emit(Message("ovos.skills.fallback.register", + {"skill_id": self.skill_id, + "priority": self.priority})) @backwards_compat(classic_core=_old_register_fallback, pre_008=_old_register_fallback) def register_fallback(self, handler: callable, priority: int): @@ -419,6 +420,30 @@ def register_fallback(self, handler: callable, priority: int): f"ovos.skills.fallback.{self.skill_id}") self._fallback_handlers.append((priority, handler)) self.bus.on(f"ovos.skills.fallback.{self.skill_id}", handler) + # register with fallback service + self.bus.emit(Message("ovos.skills.fallback.register", + {"skill_id": self.skill_id, + "priority": self.priority})) + + def remove_fallback(self, handler_to_del: Optional[callable] = None) -> bool: + """ + Remove fallback registration / fallback handler. + @param handler_to_del: registered callback handler (or wrapped handler) + @return: True if at least one handler was removed, otherwise False + """ + found_handler = False + for i in reversed(range(len(self._fallback_handlers))): + _, handler = self._fallback_handlers[i] + if handler_to_del is None or handler == handler_to_del: + found_handler = True + del self._fallback_handlers[i] + + if not found_handler: + LOG.warning('No fallback matching {}'.format(handler_to_del)) + if len(self._fallback_handlers) == 0: + self.bus.emit(Message("ovos.skills.fallback.deregister", + {"skill_id": self.skill_id})) + return found_handler def default_shutdown(self): """ diff --git a/test/unittests/skills/test_fallback_skill.py b/test/unittests/skills/test_fallback_skill.py index 0188a82b..9b07c35b 100644 --- a/test/unittests/skills/test_fallback_skill.py +++ b/test/unittests/skills/test_fallback_skill.py @@ -1,7 +1,8 @@ from unittest import TestCase -from unittest.mock import patch +from unittest.mock import patch, Mock -from ovos_utils.messagebus import FakeBus +from threading import Event +from ovos_utils.messagebus import FakeBus, Message from ovos_workshop.decorators import fallback_handler from ovos_workshop.skills.base import BaseSkill from ovos_workshop.skills.fallback import FallbackSkillV1, FallbackSkillV2, \ @@ -205,25 +206,121 @@ def test_priority(self): fallback_skill.skill_id] = 80 self.assertEqual(fallback_skill.priority, 80) + FallbackSkillV2.fallback_config = {} + def test_can_answer(self): self.assertFalse(self.fallback_skill.can_answer([""], "en-us")) # TODO def test_register_system_event_handlers(self): - # TODO - pass + self.assertTrue(any(["ovos.skills.fallback.ping" in tup + for tup in self.fallback_skill.events])) + self.assertTrue(any([f"ovos.skills.fallback.{self.fallback_skill.skill_id}.request" + in tup for tup in self.fallback_skill.events])) def test_handle_fallback_ack(self): - # TODO - pass + def mock_pong(message: Message): + self.assertEqual(message.data["skill_id"], + self.fallback_skill.skill_id) + self.assertEqual(message.context["skill_id"], + self.fallback_skill.skill_id) + self.assertEqual(message.data["can_handle"], "test") + + orig_can_answer = self.fallback_skill.can_answer + self.fallback_skill.can_answer = Mock(return_value="test") + self.fallback_skill.bus.once("ovos.skills.fallback.pong", mock_pong) + + self.fallback_skill._handle_fallback_ack(Message("test")) + self.fallback_skill.can_answer = orig_can_answer + def test_handle_fallback_request(self): - # TODO - pass + start_event = Event() + handler_event = Event() + + def mock_start(message: Message): + start_event.set() + + def mock_handler(message: Message): + handler_event.set() + return True + + def mock_resonse(message: Message): + self.assertTrue(message.data["result"]) + self.assertEqual(message.data["fallback_handler"], + "mock_handler") + + self.fallback_skill.bus.once( + f"ovos.skills.fallback.{self.fallback_skill.skill_id}.start", + mock_start + ) + self.fallback_skill.bus.once( + f"ovos.skills.fallback.{self.fallback_skill.skill_id}.response", + mock_resonse + ) + self.fallback_skill._fallback_handlers = [(100, mock_handler)] + + self.fallback_skill._handle_fallback_request(Message("test")) + self.assertTrue(start_event.is_set()) + self.assertTrue(handler_event.is_set()) + + self.fallback_skill._fallback_handlers = [] def test_register_fallback(self): - # TODO - pass + priority = 75 + + def fallback_service_register(message: Message): + self.assertEqual(message.data["skill_id"], + self.fallback_skill.skill_id) + self.assertEqual(message.data["priority"], priority) + + # test with f"ovos.skills.fallback.{self.skill_id}" + def mock_handler(_: Message): + return True + + self.fallback_skill.bus.once( + f"ovos.skills.fallback.register", fallback_service_register + ) + self.fallback_skill.register_fallback(mock_handler, priority) + self.assertEqual(len(self.fallback_skill._fallback_handlers), 1) + self.assertEqual(self.fallback_skill._fallback_handlers[0][0], + priority) + self.assertEqual(self.fallback_skill._fallback_handlers[0][1], + mock_handler) + + self.fallback_skill._fallback_handlers = [] + + def test_remove_fallback(self): + + def mock_handler(_: Message): + return True + + def fallback_service_deregister(message: Message): + deregister_event.set() + self.assertEqual(message.data["skill_id"], + self.fallback_skill.skill_id) + + deregister_event = Event() + self.fallback_skill.bus.once( + f"ovos.skills.fallback.deregister", fallback_service_deregister + ) + self.fallback_skill._fallback_handlers = [(50, mock_handler)] + self.assertEqual(len(self.fallback_skill._fallback_handlers), 1) + self.fallback_skill.remove_fallback(mock_handler) + self.assertEqual(len(self.fallback_skill._fallback_handlers), 0) + self.assertTrue(deregister_event.is_set()) + deregister_event.clear() + self.assertFalse(deregister_event.is_set()) + + self.fallback_skill.bus.once( + f"ovos.skills.fallback.deregister", fallback_service_deregister + ) + self.fallback_skill._fallback_handlers = [(100, mock_handler), (50, mock_handler)] + self.fallback_skill.remove_fallback() + self.assertEqual(len(self.fallback_skill._fallback_handlers), 0) + self.assertTrue(deregister_event.is_set()) + + self.fallback_skill._fallback_handlers = [] def test_default_shutdown(self): # TODO