Skip to content

Commit

Permalink
feat/domain_engine
Browse files Browse the repository at this point in the history
match the skill first, and then the intent
  • Loading branch information
JarbasAl committed Dec 9, 2024
1 parent d5965cb commit 26af1fc
Show file tree
Hide file tree
Showing 4 changed files with 353 additions and 1 deletion.
162 changes: 162 additions & 0 deletions ovos_padatious/domain_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from collections import defaultdict
from typing import Dict, List, Optional

from ovos_padatious.intent_container import IntentContainer
from ovos_padatious.match_data import MatchData


class DomainIntentEngine:
"""
A domain-aware intent recognition engine that organizes intents and entities
into specific domains, providing flexible and hierarchical intent matching.
"""

def __init__(self):
"""
Initialize the DomainIntentEngine.
Attributes:
domain_engine (IntentContainer): A top-level intent container for cross-domain calculations.
domains (Dict[str, IntentContainer]): A mapping of domain names to their respective intent containers.
training_data (Dict[str, List[str]]): A mapping of domain names to their associated training samples.
"""
self.domain_engine = IntentContainer()
self.domains: Dict[str, IntentContainer] = defaultdict(IntentContainer)
self.training_data: Dict[str, List[str]] = defaultdict(list)
self._must_train = True

def remove_domain(self, domain_name: str):
"""
Remove a domain and its associated intents and training data.
Args:
domain_name (str): The name of the domain to remove.
"""
if domain_name in self.training_data:
self.training_data.pop(domain_name)
if domain_name in self.domains:
self.domains.pop(domain_name)
if domain_name in self.domain_engine.intent_names:
self.domain_engine.remove_intent(domain_name)

def register_domain_intent(self, domain_name: str, intent_name: str, intent_samples: List[str]):
"""
Register an intent within a specific domain.
Args:
domain_name (str): The name of the domain.
intent_name (str): The name of the intent to register.
intent_samples (List[str]): A list of sample sentences for the intent.
"""
self.domains[domain_name].add_intent(intent_name, intent_samples)
self.training_data[domain_name] += intent_samples
self._must_train = True

def remove_domain_intent(self, domain_name: str, intent_name: str):
"""
Remove a specific intent from a domain.
Args:
domain_name (str): The name of the domain.
intent_name (str): The name of the intent to remove.
"""
self.domains[domain_name].remove_intent(intent_name)

def register_domain_entity(self, domain_name: str, entity_name: str, entity_samples: List[str]):
"""
Register an entity within a specific domain.
Args:
domain_name (str): The name of the domain.
entity_name (str): The name of the entity to register.
entity_samples (List[str]): A list of sample phrases for the entity.
"""
self.domains[domain_name].add_entity(entity_name, entity_samples)

def remove_domain_entity(self, domain_name: str, entity_name: str):
"""
Remove a specific entity from a domain.
Args:
domain_name (str): The name of the domain.
entity_name (str): The name of the entity to remove.
"""
self.domains[domain_name].remove_entity(entity_name)

def calc_domains(self, query: str) -> List[MatchData]:
"""
Calculate the matching domains for a query.
Args:
query (str): The input query.
Returns:
List[MatchData]: A list of MatchData objects representing matching domains.
"""
if self._must_train:
self.train()

return self.domain_engine.calc_intents(query)

def calc_domain(self, query: str) -> MatchData:
"""
Calculate the best matching domain for a query.
Args:
query (str): The input query.
Returns:
MatchData: The best matching domain.
"""
if self._must_train:
self.train()
return self.domain_engine.calc_intent(query)

def calc_intent(self, query: str, domain: Optional[str] = None) -> MatchData:
"""
Calculate the best matching intent for a query within a specific domain.
Args:
query (str): The input query.
domain (Optional[str]): The domain to limit the search to. Defaults to None.
Returns:
MatchData: The best matching intent.
"""
if self._must_train:
self.train()
domain: str = domain or self.domain_engine.calc_intent(query).name
if domain in self.domains:
return self.domains[domain].calc_intent(query)
return MatchData(name=None, sent=query, matches=None, conf=0.0)

def calc_intents(self, query: str, domain: Optional[str] = None, top_k_domains: int = 2) -> List[MatchData]:
"""
Calculate matching intents for a query across domains or within a specific domain.
Args:
query (str): The input query.
domain (Optional[str]): The specific domain to search in. If None, searches across top-k domains.
top_k_domains (int): The number of top domains to consider. Defaults to 2.
Returns:
List[MatchData]: A list of MatchData objects representing matching intents, sorted by confidence.
"""
if self._must_train:
self.train()
if domain:
return self.domains[domain].calc_intents(query)
matches = []
domains = self.calc_domains(query)[:top_k_domains]
for domain in domains:
if domain.name in self.domains:
matches += self.domains[domain.name].calc_intents(query)
return sorted(matches, reverse=True, key=lambda k: k.conf)

def train(self):
for domain, samples in self.training_data.items():
self.domain_engine.add_intent(domain, samples)
self.domain_engine.train()
for domain in self.domains:
self.domains[domain].train()
self._must_train = False
9 changes: 8 additions & 1 deletion ovos_padatious/intent_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from functools import wraps
from typing import List, Dict, Any, Optional

from ovos_config.meta import get_xdg_base
from ovos_utils.log import LOG
from ovos_utils.xdg_utils import xdg_data_home

from ovos_padatious import padaos
from ovos_padatious.entity import Entity
Expand Down Expand Up @@ -54,7 +56,8 @@ class IntentContainer:
cache_dir (str): Directory for caching the neural network models and intent/entity files.
"""

def __init__(self, cache_dir: str) -> None:
def __init__(self, cache_dir: str = None) -> None:
cache_dir = cache_dir or f"{xdg_data_home()}/{get_xdg_base()}/intent_cache"
os.makedirs(cache_dir, exist_ok=True)
self.cache_dir: str = cache_dir
self.must_train: bool = False
Expand All @@ -64,6 +67,10 @@ def __init__(self, cache_dir: str) -> None:
self.train_thread: Optional[Any] = None # deprecated
self.serialized_args: List[Dict[str, Any]] = [] # Serialized calls for training intents/entities

@property
def intent_names(self):
return self.intents.intent_names

def clear(self) -> None:
"""
Clears the current intent and entity managers and resets the container.
Expand Down
4 changes: 4 additions & 0 deletions ovos_padatious/intent_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ class IntentManager(TrainingManager):
def __init__(self, cache):
super(IntentManager, self).__init__(Intent, cache)

@property
def intent_names(self):
return [i.name for i in self.objects + self.objects_to_train]

def calc_intents(self, query, entity_manager):
sent = tokenize(query)
matches = []
Expand Down
179 changes: 179 additions & 0 deletions tests/test_domain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import unittest
from unittest.mock import MagicMock
from ovos_padatious.intent_container import IntentContainer
from ovos_padatious.match_data import MatchData
from ovos_padatious.domain_engine import DomainIntentEngine # Replace 'your_module' with the actual module name


class TestDomainIntentEngine(unittest.TestCase):
def setUp(self):
self.engine = DomainIntentEngine()

def test_register_domain_intent(self):
self.engine.register_domain_intent("domain1", "intent1", ["sample1", "sample2"])
self.assertIn("domain1", self.engine.training_data)
self.assertIn("intent1", self.engine.domains["domain1"].intent_names)

def test_remove_domain(self):
self.engine.register_domain_intent("domain1", "intent1", ["sample1", "sample2"])
self.engine.remove_domain("domain1")
self.assertNotIn("domain1", self.engine.training_data)
self.assertNotIn("domain1", self.engine.domains)

def test_remove_domain_intent(self):
self.engine.register_domain_intent("domain1", "intent1", ["sample1", "sample2"])
self.engine.remove_domain_intent("domain1", "intent1")
self.assertNotIn("intent1", self.engine.domains["domain1"].intent_names)

def test_calc_domains(self):
self.engine.train = MagicMock()
self.engine.domain_engine.calc_intents = MagicMock(return_value=[MatchData(name="domain1", sent="query", matches=None, conf=0.9)])
result = self.engine.calc_domains("query")
self.engine.train.assert_called_once()
self.assertEqual(result[0].name, "domain1")

def test_calc_domain(self):
self.engine.train = MagicMock()
self.engine.domain_engine.calc_intent = MagicMock(return_value=MatchData(name="domain1", sent="query", matches=None, conf=0.9))
result = self.engine.calc_domain("query")
self.engine.train.assert_called_once()
self.assertEqual(result.name, "domain1")

def test_calc_intent(self):
self.engine.train = MagicMock()
mock_domain_container = MagicMock()
mock_domain_container.calc_intent.return_value = MatchData(name="intent1", sent="query", matches=None, conf=0.9)
self.engine.domains["domain1"] = mock_domain_container

self.engine.domain_engine.calc_intent = MagicMock(return_value=MatchData(name="domain1", sent="query", matches=None, conf=0.9))
result = self.engine.calc_intent("query")
self.assertEqual(result.name, "intent1")

def test_calc_intents(self):
self.engine.train = MagicMock()
mock_domain_container = MagicMock()
mock_domain_container.calc_intents.return_value = [
MatchData(name="intent1", sent="query", matches=None, conf=0.9),
MatchData(name="intent2", sent="query", matches=None, conf=0.8),
]
self.engine.domains["domain1"] = mock_domain_container

self.engine.domain_engine.calc_intents = MagicMock(return_value=[MatchData(name="domain1", sent="query", matches=None, conf=0.9)])
result = self.engine.calc_intents("query")
self.assertEqual(len(result), 2)
self.assertEqual(result[0].name, "intent1")

def test_train(self):
self.engine.training_data["domain1"] = ["sample1", "sample2"]
self.engine.domain_engine.add_intent = MagicMock()
self.engine.domain_engine.train = MagicMock()

mock_domain_container = MagicMock()
self.engine.domains["domain1"] = mock_domain_container

self.engine.train()
self.engine.domain_engine.add_intent.assert_called_with("domain1", ["sample1", "sample2"])
self.engine.domain_engine.train.assert_called_once()
mock_domain_container.train.assert_called_once()
self.assertFalse(self.engine._must_train)



class TestDomainIntentEngineWithLiveData(unittest.TestCase):
def setUp(self):
self.engine = DomainIntentEngine()
# Sample training data
self.training_data = {
"IOT": {
"turn_on_device": ["Turn on the lights", "Switch on the fan", "Activate the air conditioner"],
"turn_off_device": ["Turn off the lights", "Switch off the heater", "Deactivate the air conditioner"],
},
"greetings": {
"say_hello": ["Hello", "Hi there", "Good morning"],
"say_goodbye": ["Goodbye", "See you later", "Bye"],
},
"General Knowledge": {
"ask_fact": ["Tell me a fact about space", "What is the capital of France?", "Who invented the telephone?"],
},
"Question": {
"ask_question": ["Why is the sky blue?", "What is quantum mechanics?", "Can you explain photosynthesis?"],
},
"Media Playback": {
"play_music": ["Play some music", "Start the playlist", "Play a song"],
"stop_music": ["Stop the music", "Pause playback", "Halt the song"],
},
}
# Register domains and intents
for domain, intents in self.training_data.items():
for intent, samples in intents.items():
self.engine.register_domain_intent(domain, intent, samples)
self.engine.train()

def test_live_data_intent_matching(self):
# Test IOT domain
query = "Switch on the fan"
result = self.engine.calc_intent(query, domain="IOT")
self.assertEqual(result.name, "turn_on_device")
self.assertGreater(result.conf, 0.8)

# Test greetings domain
query = "Hi there"
result = self.engine.calc_intent(query, domain="greetings")
self.assertEqual(result.name, "say_hello")
self.assertGreater(result.conf, 0.8)

# Test General Knowledge domain
query = "What is the capital of France?"
result = self.engine.calc_intent(query, domain="General Knowledge")
self.assertEqual(result.name, "ask_fact")
self.assertGreater(result.conf, 0.8)

# Test Question domain
query = "Why is the sky blue?"
result = self.engine.calc_intent(query, domain="Question")
self.assertEqual(result.name, "ask_question")
self.assertGreater(result.conf, 0.8)

# Test Media Playback domain
query = "Play a song"
result = self.engine.calc_intent(query, domain="Media Playback")
self.assertEqual(result.name, "play_music")
self.assertGreater(result.conf, 0.8)

def test_live_data_cross_domain_matching(self):
# Test cross-domain intent matching
query = "Tell me a fact about space"
result = self.engine.calc_domain(query)
self.assertEqual(result.name, "General Knowledge")
self.assertGreater(result.conf, 0.8)

# Validate intent from the matched domain
result = self.engine.calc_intent(query, domain=result.name)
self.assertEqual(result.name, "ask_fact")
self.assertGreater(result.conf, 0.8)

def test_calc_intent_without_domain(self):
# Test intent calculation without specifying a domain
query = "Turn on the lights"
result = self.engine.calc_intent(query)
self.assertIsNotNone(result.name, "Intent name should not be None")
self.assertEqual(result.name, "turn_on_device")
self.assertGreater(result.conf, 0.8)

query = "Goodbye"
result = self.engine.calc_intent(query)
self.assertIsNotNone(result.name, "Intent name should not be None")
self.assertEqual(result.name, "say_goodbye")
self.assertGreater(result.conf, 0.8)

query = "What is quantum mechanics?"
result = self.engine.calc_intent(query)
self.assertIsNotNone(result.name, "Intent name should not be None")
self.assertEqual(result.name, "ask_question")
self.assertGreater(result.conf, 0.8)

if __name__ == "__main__":
unittest.main()

if __name__ == "__main__":
unittest.main()

0 comments on commit 26af1fc

Please sign in to comment.