diff --git a/ovos_padatious/domain_container.py b/ovos_padatious/domain_container.py new file mode 100644 index 0000000..49ae1ab --- /dev/null +++ b/ovos_padatious/domain_container.py @@ -0,0 +1,171 @@ +from collections import defaultdict +from typing import Dict, List, Optional +from ovos_utils.log import LOG +from ovos_padatious.intent_container import IntentContainer +from ovos_padatious.match_data import MatchData + + +class DomainIntentContainer: + """ + A domain-aware intent recognition engine that organizes intents and entities + into specific domains, providing flexible and hierarchical intent matching. + """ + + def __init__(self, cache_dir: Optional[str] = None): + """ + Initialize the DomainIntentEngine. + + Attributes: + domain_engine (IntentContainer): A top-level intent container for cross-domain calculations. + domains (Dict[str, IntentContainer]): A mapping of domain names to their respective intent containers. + training_data (Dict[str, List[str]]): A mapping of domain names to their associated training samples. + """ + self.cache_dir = cache_dir + self.domain_engine = IntentContainer(cache_dir=cache_dir) + self.domains: Dict[str, IntentContainer] = {} + self.training_data: Dict[str, List[str]] = defaultdict(list) + self.must_train = True + + def remove_domain(self, domain_name: str): + """ + Remove a domain and its associated intents and training data. + + Args: + domain_name (str): The name of the domain to remove. + """ + if domain_name in self.training_data: + self.training_data.pop(domain_name) + if domain_name in self.domains: + self.domains.pop(domain_name) + if domain_name in self.domain_engine.intent_names: + self.domain_engine.remove_intent(domain_name) + + def add_domain_intent(self, domain_name: str, intent_name: str, intent_samples: List[str]): + """ + Register an intent within a specific domain. + + Args: + domain_name (str): The name of the domain. + intent_name (str): The name of the intent to register. + intent_samples (List[str]): A list of sample sentences for the intent. + """ + if domain_name not in self.domains: + self.domains[domain_name] = IntentContainer(cache_dir=self.cache_dir) + self.domains[domain_name].add_intent(intent_name, intent_samples) + self.training_data[domain_name] += intent_samples + self.must_train = True + + def remove_domain_intent(self, domain_name: str, intent_name: str): + """ + Remove a specific intent from a domain. + + Args: + domain_name (str): The name of the domain. + intent_name (str): The name of the intent to remove. + """ + if domain_name in self.domains: + self.domains[domain_name].remove_intent(intent_name) + + def add_domain_entity(self, domain_name: str, entity_name: str, entity_samples: List[str]): + """ + Register an entity within a specific domain. + + Args: + domain_name (str): The name of the domain. + entity_name (str): The name of the entity to register. + entity_samples (List[str]): A list of sample phrases for the entity. + """ + if domain_name not in self.domains: + self.domains[domain_name] = IntentContainer(cache_dir=self.cache_dir) + self.domains[domain_name].add_entity(entity_name, entity_samples) + + def remove_domain_entity(self, domain_name: str, entity_name: str): + """ + Remove a specific entity from a domain. + + Args: + domain_name (str): The name of the domain. + entity_name (str): The name of the entity to remove. + """ + if domain_name in self.domains: + self.domains[domain_name].remove_entity(entity_name) + + def calc_domains(self, query: str) -> List[MatchData]: + """ + Calculate the matching domains for a query. + + Args: + query (str): The input query. + + Returns: + List[MatchData]: A list of MatchData objects representing matching domains. + """ + if self.must_train: + self.train() + + return self.domain_engine.calc_intents(query) + + def calc_domain(self, query: str) -> MatchData: + """ + Calculate the best matching domain for a query. + + Args: + query (str): The input query. + + Returns: + MatchData: The best matching domain. + """ + if self.must_train: + self.train() + return self.domain_engine.calc_intent(query) + + def calc_intent(self, query: str, domain: Optional[str] = None) -> MatchData: + """ + Calculate the best matching intent for a query within a specific domain. + + Args: + query (str): The input query. + domain (Optional[str]): The domain to limit the search to. Defaults to None. + + Returns: + MatchData: The best matching intent. + """ + if self.must_train: + self.train() + domain: str = domain or self.domain_engine.calc_intent(query).name + if domain in self.domains: + return self.domains[domain].calc_intent(query) + return MatchData(name=None, sent=query, matches=None, conf=0.0) + + def calc_intents(self, query: str, domain: Optional[str] = None, top_k_domains: int = 2) -> List[MatchData]: + """ + Calculate matching intents for a query across domains or within a specific domain. + + Args: + query (str): The input query. + domain (Optional[str]): The specific domain to search in. If None, searches across top-k domains. + top_k_domains (int): The number of top domains to consider. Defaults to 2. + + Returns: + List[MatchData]: A list of MatchData objects representing matching intents, sorted by confidence. + """ + if self.must_train: + self.train() + if domain: + return self.domains[domain].calc_intents(query) + matches = [] + domains = self.calc_domains(query)[:top_k_domains] + for domain in domains: + if domain.name in self.domains: + matches += self.domains[domain.name].calc_intents(query) + return sorted(matches, reverse=True, key=lambda k: k.conf) + + def train(self): + for domain, samples in self.training_data.items(): + LOG.debug(f"Training domain: {domain}") + self.domain_engine.add_intent(domain, samples) + self.domain_engine.train() + for domain in self.domains: + LOG.debug(f"Training domain sub-intents: {domain}") + self.domains[domain].train() + self.must_train = False diff --git a/ovos_padatious/intent_container.py b/ovos_padatious/intent_container.py index ef5c288..e7dd29f 100644 --- a/ovos_padatious/intent_container.py +++ b/ovos_padatious/intent_container.py @@ -16,7 +16,9 @@ from functools import wraps from typing import List, Dict, Any, Optional +from ovos_config.meta import get_xdg_base from ovos_utils.log import LOG +from ovos_utils.xdg_utils import xdg_data_home from ovos_padatious import padaos from ovos_padatious.entity import Entity @@ -54,7 +56,8 @@ class IntentContainer: cache_dir (str): Directory for caching the neural network models and intent/entity files. """ - def __init__(self, cache_dir: str) -> None: + def __init__(self, cache_dir: str = None) -> None: + cache_dir = cache_dir or f"{xdg_data_home()}/{get_xdg_base()}/intent_cache" os.makedirs(cache_dir, exist_ok=True) self.cache_dir: str = cache_dir self.must_train: bool = False @@ -64,6 +67,10 @@ def __init__(self, cache_dir: str) -> None: self.train_thread: Optional[Any] = None # deprecated self.serialized_args: List[Dict[str, Any]] = [] # Serialized calls for training intents/entities + @property + def intent_names(self): + return self.intents.intent_names + def clear(self) -> None: """ Clears the current intent and entity managers and resets the container. diff --git a/ovos_padatious/intent_manager.py b/ovos_padatious/intent_manager.py index ddfcb3e..fc65e01 100644 --- a/ovos_padatious/intent_manager.py +++ b/ovos_padatious/intent_manager.py @@ -32,6 +32,10 @@ def __init__(self, cache: str, debug: bool = False): super().__init__(Intent, cache) self.debug = debug + @property + def intent_names(self): + return [i.name for i in self.objects + self.objects_to_train] + def calc_intents(self, query: str, entity_manager) -> List[MatchData]: """ Calculate matches for the given query against all registered intents. diff --git a/ovos_padatious/opm.py b/ovos_padatious/opm.py index 5d9d72b..faf1516 100644 --- a/ovos_padatious/opm.py +++ b/ovos_padatious/opm.py @@ -13,10 +13,11 @@ # limitations under the License. # """Intent service wrapping padatious.""" +from collections import defaultdict from functools import lru_cache from os.path import expanduser, isfile from threading import Event, RLock -from typing import Optional, Dict, List, Union +from typing import Optional, Dict, List, Union, Type from langcodes import closest_match from ovos_bus_client.client import MessageBusClient @@ -31,9 +32,16 @@ from ovos_utils.log import LOG, deprecated, log_deprecation from ovos_utils.xdg_utils import xdg_data_home -from ovos_padatious import IntentContainer as PadatiousIntentContainer +from ovos_padatious import IntentContainer +from ovos_padatious.domain_container import DomainIntentContainer from ovos_padatious.match_data import MatchData as PadatiousIntent +PadatiousIntentContainer = IntentContainer # backwards compat + +# for easy typing +PadatiousEngine = Union[Type[IntentContainer], + Type[DomainIntentContainer]] + class PadatiousMatcher: """Matcher class to avoid redundancy in padatious intent matching.""" @@ -87,7 +95,8 @@ class PadatiousPipeline(ConfidenceMatcherPipeline): """Service class for padatious intent matching.""" def __init__(self, bus: Optional[Union[MessageBusClient, FakeBus]] = None, - config: Optional[Dict] = None): + config: Optional[Dict] = None, + engine_class: Optional[PadatiousEngine] = IntentContainer): super().__init__(bus, config) self.lock = RLock() @@ -102,16 +111,20 @@ def __init__(self, bus: Optional[Union[MessageBusClient, FakeBus]] = None, self.conf_med = self.config.get("conf_med") or 0.8 self.conf_low = self.config.get("conf_low") or 0.5 + if engine_class is None and self.config.get("domain_engine"): + engine_class = DomainIntentContainer + + self.engine_class = engine_class or IntentContainer intent_cache = expanduser(self.config.get('intent_cache') or f"{xdg_data_home()}/{get_xdg_base()}/intent_cache") - self.containers = {lang: PadatiousIntentContainer(f"{intent_cache}/{lang}") - for lang in langs} + self.containers = {lang: self.engine_class(cache_dir=f"{intent_cache}/{lang}") for lang in langs} self.finished_training_event = Event() # DEPRECATED self.finished_initial_train = False self.registered_intents = [] self.registered_entities = [] + self._skill2intent = defaultdict(list) self.max_words = 50 # if an utterance contains more words than this, don't attempt to match self.bus.on('padatious:register_intent', self.register_intent) @@ -225,7 +238,12 @@ def __detach_intent(self, intent_name): if intent_name in self.registered_intents: self.registered_intents.remove(intent_name) for lang in self.containers: - self.containers[lang].remove_intent(intent_name) + for skill_id, intents in self._skill2intent.items(): + if intent_name in intents: + if isinstance(self.containers[lang], DomainIntentContainer): + self.containers[lang].remove_domain_intent(skill_id, intent_name) + else: + self.containers[lang].remove_intent(intent_name) def handle_detach_intent(self, message): """Messagebus handler for detaching padatious intent. @@ -242,8 +260,7 @@ def handle_detach_skill(self, message): message (Message): message triggering action """ skill_id = message.data['skill_id'] - remove_list = [i for i in self.registered_intents if skill_id in i] - for i in remove_list: + for i in self._skill2intent[skill_id]: self.__detach_intent(i) def _register_object(self, message, object_name, register_func): @@ -254,6 +271,7 @@ def _register_object(self, message, object_name, register_func): object_name (str): type of entry to register register_func (callable): function to call for registration """ + skill_id = message.data.get("skill_id") or message.context.get("skill_id") file_name = message.data.get('file_name') samples = message.data.get("samples") name = message.data['name'] @@ -268,7 +286,10 @@ def _register_object(self, message, object_name, register_func): with open(file_name) as f: samples = [line.strip() for line in f.readlines()] - register_func(name, samples) + if self.engine_class == DomainIntentContainer: + register_func(skill_id, name, samples) + else: + register_func(name, samples) self.finished_initial_train = False if self.config.get("instant_train", True): @@ -280,11 +301,17 @@ def register_intent(self, message): Args: message (Message): message triggering action """ + skill_id = message.data.get("skill_id") or message.context.get("skill_id") + self._skill2intent[skill_id].append(message.data['name']) + lang = message.data.get('lang', self.lang) lang = standardize_lang_tag(lang) if lang in self.containers: self.registered_intents.append(message.data['name']) - self._register_object(message, 'intent', self.containers[lang].add_intent) + if isinstance(self.containers[lang], DomainIntentContainer): + self._register_object(message, 'intent', self.containers[lang].add_domain_intent) + else: + self._register_object(message, 'intent', self.containers[lang].add_intent) def register_entity(self, message): """Messagebus handler for registering entities. @@ -296,8 +323,10 @@ def register_entity(self, message): lang = standardize_lang_tag(lang) if lang in self.containers: self.registered_entities.append(message.data) - self._register_object(message, 'entity', - self.containers[lang].add_entity) + if isinstance(self.containers[lang], DomainIntentContainer): + self._register_object(message, 'entity', self.containers[lang].add_domain_entity) + else: + self._register_object(message, 'entity', self.containers[lang].add_entity) def calc_intent(self, utterances: Union[str, List[str]], lang: Optional[str] = None, message: Optional[Message] = None) -> Optional[PadatiousIntent]: @@ -390,7 +419,7 @@ def handle_entity_manifest(self, message): @lru_cache(maxsize=3) # repeat calls under different conf levels wont re-run code def _calc_padatious_intent(utt: str, - intent_container: PadatiousIntentContainer, + intent_container: Union[IntentContainer, DomainIntentContainer], sess: Session) -> Optional[PadatiousIntent]: """ Try to match an utterance to an intent in an intent_container diff --git a/tests/test_domain.py b/tests/test_domain.py new file mode 100644 index 0000000..b898234 --- /dev/null +++ b/tests/test_domain.py @@ -0,0 +1,183 @@ +import unittest +from unittest.mock import MagicMock + +from ovos_padatious.domain_container import DomainIntentContainer # Replace 'your_module' with the actual module name + +from ovos_padatious.match_data import MatchData + + +class TestDomainIntentEngine(unittest.TestCase): + def setUp(self): + self.engine = DomainIntentContainer() + + def test_register_domain_intent(self): + self.engine.add_domain_intent("domain1", "intent1", ["sample1", "sample2"]) + self.assertIn("domain1", self.engine.training_data) + self.assertIn("intent1", self.engine.domains["domain1"].intent_names) + + def test_remove_domain(self): + self.engine.add_domain_intent("domain1", "intent1", ["sample1", "sample2"]) + self.engine.remove_domain("domain1") + self.assertNotIn("domain1", self.engine.training_data) + self.assertNotIn("domain1", self.engine.domains) + + def test_remove_domain_intent(self): + self.engine.add_domain_intent("domain1", "intent1", ["sample1", "sample2"]) + self.engine.remove_domain_intent("domain1", "intent1") + self.assertNotIn("intent1", self.engine.domains["domain1"].intent_names) + + def test_calc_domains(self): + self.engine.train = MagicMock() + self.engine.domain_engine.calc_intents = MagicMock( + return_value=[MatchData(name="domain1", sent="query", matches=None, conf=0.9)]) + result = self.engine.calc_domains("query") + self.engine.train.assert_called_once() + self.assertEqual(result[0].name, "domain1") + + def test_calc_domain(self): + self.engine.train = MagicMock() + self.engine.domain_engine.calc_intent = MagicMock( + return_value=MatchData(name="domain1", sent="query", matches=None, conf=0.9)) + result = self.engine.calc_domain("query") + self.engine.train.assert_called_once() + self.assertEqual(result.name, "domain1") + + def test_calc_intent(self): + self.engine.train = MagicMock() + mock_domain_container = MagicMock() + mock_domain_container.calc_intent.return_value = MatchData(name="intent1", sent="query", matches=None, conf=0.9) + self.engine.domains["domain1"] = mock_domain_container + + self.engine.domain_engine.calc_intent = MagicMock( + return_value=MatchData(name="domain1", sent="query", matches=None, conf=0.9)) + result = self.engine.calc_intent("query") + self.assertEqual(result.name, "intent1") + + def test_calc_intents(self): + self.engine.train = MagicMock() + mock_domain_container = MagicMock() + mock_domain_container.calc_intents.return_value = [ + MatchData(name="intent1", sent="query", matches=None, conf=0.9), + MatchData(name="intent2", sent="query", matches=None, conf=0.8), + ] + self.engine.domains["domain1"] = mock_domain_container + + self.engine.domain_engine.calc_intents = MagicMock( + return_value=[MatchData(name="domain1", sent="query", matches=None, conf=0.9)]) + result = self.engine.calc_intents("query") + self.assertEqual(len(result), 2) + self.assertEqual(result[0].name, "intent1") + + def test_train(self): + self.engine.training_data["domain1"] = ["sample1", "sample2"] + self.engine.domain_engine.add_intent = MagicMock() + self.engine.domain_engine.train = MagicMock() + + mock_domain_container = MagicMock() + self.engine.domains["domain1"] = mock_domain_container + + self.engine.train() + self.engine.domain_engine.add_intent.assert_called_with("domain1", ["sample1", "sample2"]) + self.engine.domain_engine.train.assert_called_once() + mock_domain_container.train.assert_called_once() + self.assertFalse(self.engine.must_train) + + +class TestDomainIntentEngineWithLiveData(unittest.TestCase): + def setUp(self): + self.engine = DomainIntentContainer() + # Sample training data + self.training_data = { + "IOT": { + "turn_on_device": ["Turn on the lights", "Switch on the fan", "Activate the air conditioner"], + "turn_off_device": ["Turn off the lights", "Switch off the heater", "Deactivate the air conditioner"], + }, + "greetings": { + "say_hello": ["Hello", "Hi there", "Good morning"], + "say_goodbye": ["Goodbye", "See you later", "Bye"], + }, + "General Knowledge": { + "ask_fact": ["Tell me a fact about space", "What is the capital of France?", + "Who invented the telephone?"], + }, + "Question": { + "ask_question": ["Why is the sky blue?", "What is quantum mechanics?", + "Can you explain photosynthesis?"], + }, + "Media Playback": { + "play_music": ["Play some music", "Start the playlist", "Play a song"], + "stop_music": ["Stop the music", "Pause playback", "Halt the song"], + }, + } + # Register domains and intents + for domain, intents in self.training_data.items(): + for intent, samples in intents.items(): + self.engine.add_domain_intent(domain, intent, samples) + self.engine.train() + + def test_live_data_intent_matching(self): + # Test IOT domain + query = "Switch on the fan" + result = self.engine.calc_intent(query, domain="IOT") + self.assertEqual(result.name, "turn_on_device") + self.assertGreater(result.conf, 0.8) + + # Test greetings domain + query = "Hi there" + result = self.engine.calc_intent(query, domain="greetings") + self.assertEqual(result.name, "say_hello") + self.assertGreater(result.conf, 0.8) + + # Test General Knowledge domain + query = "What is the capital of France?" + result = self.engine.calc_intent(query, domain="General Knowledge") + self.assertEqual(result.name, "ask_fact") + self.assertGreater(result.conf, 0.8) + + # Test Question domain + query = "Why is the sky blue?" + result = self.engine.calc_intent(query, domain="Question") + self.assertEqual(result.name, "ask_question") + self.assertGreater(result.conf, 0.8) + + # Test Media Playback domain + query = "Play a song" + result = self.engine.calc_intent(query, domain="Media Playback") + self.assertEqual(result.name, "play_music") + self.assertGreater(result.conf, 0.8) + + def test_live_data_cross_domain_matching(self): + # Test cross-domain intent matching + query = "Tell me a fact about space" + result = self.engine.calc_domain(query) + self.assertEqual(result.name, "General Knowledge") + self.assertGreater(result.conf, 0.8) + + # Validate intent from the matched domain + result = self.engine.calc_intent(query, domain=result.name) + self.assertEqual(result.name, "ask_fact") + self.assertGreater(result.conf, 0.8) + + def test_calc_intent_without_domain(self): + # Test intent calculation without specifying a domain + query = "Turn on the lights" + result = self.engine.calc_intent(query) + self.assertIsNotNone(result.name, "Intent name should not be None") + self.assertEqual(result.name, "turn_on_device") + self.assertGreater(result.conf, 0.8) + + query = "Goodbye" + result = self.engine.calc_intent(query) + self.assertIsNotNone(result.name, "Intent name should not be None") + self.assertEqual(result.name, "say_goodbye") + self.assertGreater(result.conf, 0.8) + + query = "What is quantum mechanics?" + result = self.engine.calc_intent(query) + self.assertIsNotNone(result.name, "Intent name should not be None") + self.assertEqual(result.name, "ask_question") + self.assertGreater(result.conf, 0.8) + + +if __name__ == "__main__": + unittest.main()