diff --git a/ovos_padatious/domain_engine.py b/ovos_padatious/domain_engine.py new file mode 100644 index 0000000..3467c79 --- /dev/null +++ b/ovos_padatious/domain_engine.py @@ -0,0 +1,162 @@ +from collections import defaultdict +from typing import Dict, List, Optional + +from ovos_padatious.intent_container import IntentContainer +from ovos_padatious.match_data import MatchData + + +class DomainIntentEngine: + """ + A domain-aware intent recognition engine that organizes intents and entities + into specific domains, providing flexible and hierarchical intent matching. + """ + + def __init__(self): + """ + Initialize the DomainIntentEngine. + + Attributes: + domain_engine (IntentContainer): A top-level intent container for cross-domain calculations. + domains (Dict[str, IntentContainer]): A mapping of domain names to their respective intent containers. + training_data (Dict[str, List[str]]): A mapping of domain names to their associated training samples. + """ + self.domain_engine = IntentContainer() + self.domains: Dict[str, IntentContainer] = defaultdict(IntentContainer) + self.training_data: Dict[str, List[str]] = defaultdict(list) + self._must_train = True + + def remove_domain(self, domain_name: str): + """ + Remove a domain and its associated intents and training data. + + Args: + domain_name (str): The name of the domain to remove. + """ + if domain_name in self.training_data: + self.training_data.pop(domain_name) + if domain_name in self.domains: + self.domains.pop(domain_name) + if domain_name in self.domain_engine.intent_names: + self.domain_engine.remove_intent(domain_name) + + def register_domain_intent(self, domain_name: str, intent_name: str, intent_samples: List[str]): + """ + Register an intent within a specific domain. + + Args: + domain_name (str): The name of the domain. + intent_name (str): The name of the intent to register. + intent_samples (List[str]): A list of sample sentences for the intent. + """ + self.domains[domain_name].add_intent(intent_name, intent_samples) + self.training_data[domain_name] += intent_samples + self._must_train = True + + def remove_domain_intent(self, domain_name: str, intent_name: str): + """ + Remove a specific intent from a domain. + + Args: + domain_name (str): The name of the domain. + intent_name (str): The name of the intent to remove. + """ + self.domains[domain_name].remove_intent(intent_name) + + def register_domain_entity(self, domain_name: str, entity_name: str, entity_samples: List[str]): + """ + Register an entity within a specific domain. + + Args: + domain_name (str): The name of the domain. + entity_name (str): The name of the entity to register. + entity_samples (List[str]): A list of sample phrases for the entity. + """ + self.domains[domain_name].add_entity(entity_name, entity_samples) + + def remove_domain_entity(self, domain_name: str, entity_name: str): + """ + Remove a specific entity from a domain. + + Args: + domain_name (str): The name of the domain. + entity_name (str): The name of the entity to remove. + """ + self.domains[domain_name].remove_entity(entity_name) + + def calc_domains(self, query: str) -> List[MatchData]: + """ + Calculate the matching domains for a query. + + Args: + query (str): The input query. + + Returns: + List[MatchData]: A list of MatchData objects representing matching domains. + """ + if self._must_train: + self.train() + + return self.domain_engine.calc_intents(query) + + def calc_domain(self, query: str) -> MatchData: + """ + Calculate the best matching domain for a query. + + Args: + query (str): The input query. + + Returns: + MatchData: The best matching domain. + """ + if self._must_train: + self.train() + return self.domain_engine.calc_intent(query) + + def calc_intent(self, query: str, domain: Optional[str] = None) -> MatchData: + """ + Calculate the best matching intent for a query within a specific domain. + + Args: + query (str): The input query. + domain (Optional[str]): The domain to limit the search to. Defaults to None. + + Returns: + MatchData: The best matching intent. + """ + if self._must_train: + self.train() + domain: str = domain or self.domain_engine.calc_intent(query).name + if domain in self.domains: + return self.domains[domain].calc_intent(query) + return MatchData(name=None, sent=query, matches=None, conf=0.0) + + def calc_intents(self, query: str, domain: Optional[str] = None, top_k_domains: int = 2) -> List[MatchData]: + """ + Calculate matching intents for a query across domains or within a specific domain. + + Args: + query (str): The input query. + domain (Optional[str]): The specific domain to search in. If None, searches across top-k domains. + top_k_domains (int): The number of top domains to consider. Defaults to 2. + + Returns: + List[MatchData]: A list of MatchData objects representing matching intents, sorted by confidence. + """ + if self._must_train: + self.train() + if domain: + return self.domains[domain].calc_intents(query) + matches = [] + domains = self.calc_domains(query)[:top_k_domains] + for domain in domains: + if domain.name in self.domains: + matches += self.domains[domain.name].calc_intents(query) + return sorted(matches, reverse=True, key=lambda k: k.conf) + + def train(self): + for domain, samples in self.training_data.items(): + self.domain_engine.add_intent(domain, samples) + self.domain_engine.train() + for domain in self.domains: + self.domains[domain].train() + self._must_train = False diff --git a/ovos_padatious/intent_container.py b/ovos_padatious/intent_container.py index ef5c288..e7dd29f 100644 --- a/ovos_padatious/intent_container.py +++ b/ovos_padatious/intent_container.py @@ -16,7 +16,9 @@ from functools import wraps from typing import List, Dict, Any, Optional +from ovos_config.meta import get_xdg_base from ovos_utils.log import LOG +from ovos_utils.xdg_utils import xdg_data_home from ovos_padatious import padaos from ovos_padatious.entity import Entity @@ -54,7 +56,8 @@ class IntentContainer: cache_dir (str): Directory for caching the neural network models and intent/entity files. """ - def __init__(self, cache_dir: str) -> None: + def __init__(self, cache_dir: str = None) -> None: + cache_dir = cache_dir or f"{xdg_data_home()}/{get_xdg_base()}/intent_cache" os.makedirs(cache_dir, exist_ok=True) self.cache_dir: str = cache_dir self.must_train: bool = False @@ -64,6 +67,10 @@ def __init__(self, cache_dir: str) -> None: self.train_thread: Optional[Any] = None # deprecated self.serialized_args: List[Dict[str, Any]] = [] # Serialized calls for training intents/entities + @property + def intent_names(self): + return self.intents.intent_names + def clear(self) -> None: """ Clears the current intent and entity managers and resets the container. diff --git a/ovos_padatious/intent_manager.py b/ovos_padatious/intent_manager.py index fa15588..29791ff 100644 --- a/ovos_padatious/intent_manager.py +++ b/ovos_padatious/intent_manager.py @@ -22,6 +22,10 @@ class IntentManager(TrainingManager): def __init__(self, cache): super(IntentManager, self).__init__(Intent, cache) + @property + def intent_names(self): + return [i.name for i in self.objects + self.objects_to_train] + def calc_intents(self, query, entity_manager): sent = tokenize(query) matches = [] diff --git a/tests/test_domain.py b/tests/test_domain.py new file mode 100644 index 0000000..928e370 --- /dev/null +++ b/tests/test_domain.py @@ -0,0 +1,179 @@ +import unittest +from unittest.mock import MagicMock +from ovos_padatious.intent_container import IntentContainer +from ovos_padatious.match_data import MatchData +from ovos_padatious.domain_engine import DomainIntentEngine # Replace 'your_module' with the actual module name + + +class TestDomainIntentEngine(unittest.TestCase): + def setUp(self): + self.engine = DomainIntentEngine() + + def test_register_domain_intent(self): + self.engine.register_domain_intent("domain1", "intent1", ["sample1", "sample2"]) + self.assertIn("domain1", self.engine.training_data) + self.assertIn("intent1", self.engine.domains["domain1"].intent_names) + + def test_remove_domain(self): + self.engine.register_domain_intent("domain1", "intent1", ["sample1", "sample2"]) + self.engine.remove_domain("domain1") + self.assertNotIn("domain1", self.engine.training_data) + self.assertNotIn("domain1", self.engine.domains) + + def test_remove_domain_intent(self): + self.engine.register_domain_intent("domain1", "intent1", ["sample1", "sample2"]) + self.engine.remove_domain_intent("domain1", "intent1") + self.assertNotIn("intent1", self.engine.domains["domain1"].intent_names) + + def test_calc_domains(self): + self.engine.train = MagicMock() + self.engine.domain_engine.calc_intents = MagicMock(return_value=[MatchData(name="domain1", sent="query", matches=None, conf=0.9)]) + result = self.engine.calc_domains("query") + self.engine.train.assert_called_once() + self.assertEqual(result[0].name, "domain1") + + def test_calc_domain(self): + self.engine.train = MagicMock() + self.engine.domain_engine.calc_intent = MagicMock(return_value=MatchData(name="domain1", sent="query", matches=None, conf=0.9)) + result = self.engine.calc_domain("query") + self.engine.train.assert_called_once() + self.assertEqual(result.name, "domain1") + + def test_calc_intent(self): + self.engine.train = MagicMock() + mock_domain_container = MagicMock() + mock_domain_container.calc_intent.return_value = MatchData(name="intent1", sent="query", matches=None, conf=0.9) + self.engine.domains["domain1"] = mock_domain_container + + self.engine.domain_engine.calc_intent = MagicMock(return_value=MatchData(name="domain1", sent="query", matches=None, conf=0.9)) + result = self.engine.calc_intent("query") + self.assertEqual(result.name, "intent1") + + def test_calc_intents(self): + self.engine.train = MagicMock() + mock_domain_container = MagicMock() + mock_domain_container.calc_intents.return_value = [ + MatchData(name="intent1", sent="query", matches=None, conf=0.9), + MatchData(name="intent2", sent="query", matches=None, conf=0.8), + ] + self.engine.domains["domain1"] = mock_domain_container + + self.engine.domain_engine.calc_intents = MagicMock(return_value=[MatchData(name="domain1", sent="query", matches=None, conf=0.9)]) + result = self.engine.calc_intents("query") + self.assertEqual(len(result), 2) + self.assertEqual(result[0].name, "intent1") + + def test_train(self): + self.engine.training_data["domain1"] = ["sample1", "sample2"] + self.engine.domain_engine.add_intent = MagicMock() + self.engine.domain_engine.train = MagicMock() + + mock_domain_container = MagicMock() + self.engine.domains["domain1"] = mock_domain_container + + self.engine.train() + self.engine.domain_engine.add_intent.assert_called_with("domain1", ["sample1", "sample2"]) + self.engine.domain_engine.train.assert_called_once() + mock_domain_container.train.assert_called_once() + self.assertFalse(self.engine._must_train) + + + +class TestDomainIntentEngineWithLiveData(unittest.TestCase): + def setUp(self): + self.engine = DomainIntentEngine() + # Sample training data + self.training_data = { + "IOT": { + "turn_on_device": ["Turn on the lights", "Switch on the fan", "Activate the air conditioner"], + "turn_off_device": ["Turn off the lights", "Switch off the heater", "Deactivate the air conditioner"], + }, + "greetings": { + "say_hello": ["Hello", "Hi there", "Good morning"], + "say_goodbye": ["Goodbye", "See you later", "Bye"], + }, + "General Knowledge": { + "ask_fact": ["Tell me a fact about space", "What is the capital of France?", "Who invented the telephone?"], + }, + "Question": { + "ask_question": ["Why is the sky blue?", "What is quantum mechanics?", "Can you explain photosynthesis?"], + }, + "Media Playback": { + "play_music": ["Play some music", "Start the playlist", "Play a song"], + "stop_music": ["Stop the music", "Pause playback", "Halt the song"], + }, + } + # Register domains and intents + for domain, intents in self.training_data.items(): + for intent, samples in intents.items(): + self.engine.register_domain_intent(domain, intent, samples) + self.engine.train() + + def test_live_data_intent_matching(self): + # Test IOT domain + query = "Switch on the fan" + result = self.engine.calc_intent(query, domain="IOT") + self.assertEqual(result.name, "turn_on_device") + self.assertGreater(result.conf, 0.8) + + # Test greetings domain + query = "Hi there" + result = self.engine.calc_intent(query, domain="greetings") + self.assertEqual(result.name, "say_hello") + self.assertGreater(result.conf, 0.8) + + # Test General Knowledge domain + query = "What is the capital of France?" + result = self.engine.calc_intent(query, domain="General Knowledge") + self.assertEqual(result.name, "ask_fact") + self.assertGreater(result.conf, 0.8) + + # Test Question domain + query = "Why is the sky blue?" + result = self.engine.calc_intent(query, domain="Question") + self.assertEqual(result.name, "ask_question") + self.assertGreater(result.conf, 0.8) + + # Test Media Playback domain + query = "Play a song" + result = self.engine.calc_intent(query, domain="Media Playback") + self.assertEqual(result.name, "play_music") + self.assertGreater(result.conf, 0.8) + + def test_live_data_cross_domain_matching(self): + # Test cross-domain intent matching + query = "Tell me a fact about space" + result = self.engine.calc_domain(query) + self.assertEqual(result.name, "General Knowledge") + self.assertGreater(result.conf, 0.8) + + # Validate intent from the matched domain + result = self.engine.calc_intent(query, domain=result.name) + self.assertEqual(result.name, "ask_fact") + self.assertGreater(result.conf, 0.8) + + def test_calc_intent_without_domain(self): + # Test intent calculation without specifying a domain + query = "Turn on the lights" + result = self.engine.calc_intent(query) + self.assertIsNotNone(result.name, "Intent name should not be None") + self.assertEqual(result.name, "turn_on_device") + self.assertGreater(result.conf, 0.8) + + query = "Goodbye" + result = self.engine.calc_intent(query) + self.assertIsNotNone(result.name, "Intent name should not be None") + self.assertEqual(result.name, "say_goodbye") + self.assertGreater(result.conf, 0.8) + + query = "What is quantum mechanics?" + result = self.engine.calc_intent(query) + self.assertIsNotNone(result.name, "Intent name should not be None") + self.assertEqual(result.name, "ask_question") + self.assertGreater(result.conf, 0.8) + +if __name__ == "__main__": + unittest.main() + +if __name__ == "__main__": + unittest.main()