forked from MycroftAI/padatious
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
match the skill first, and then the intent
- Loading branch information
Showing
4 changed files
with
353 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
from collections import defaultdict | ||
from typing import Dict, List, Optional | ||
|
||
from ovos_padatious.intent_container import IntentContainer | ||
from ovos_padatious.match_data import MatchData | ||
|
||
|
||
class DomainIntentEngine: | ||
""" | ||
A domain-aware intent recognition engine that organizes intents and entities | ||
into specific domains, providing flexible and hierarchical intent matching. | ||
""" | ||
|
||
def __init__(self): | ||
""" | ||
Initialize the DomainIntentEngine. | ||
Attributes: | ||
domain_engine (IntentContainer): A top-level intent container for cross-domain calculations. | ||
domains (Dict[str, IntentContainer]): A mapping of domain names to their respective intent containers. | ||
training_data (Dict[str, List[str]]): A mapping of domain names to their associated training samples. | ||
""" | ||
self.domain_engine = IntentContainer() | ||
self.domains: Dict[str, IntentContainer] = defaultdict(IntentContainer) | ||
self.training_data: Dict[str, List[str]] = defaultdict(list) | ||
self._must_train = True | ||
|
||
def remove_domain(self, domain_name: str): | ||
""" | ||
Remove a domain and its associated intents and training data. | ||
Args: | ||
domain_name (str): The name of the domain to remove. | ||
""" | ||
if domain_name in self.training_data: | ||
self.training_data.pop(domain_name) | ||
if domain_name in self.domains: | ||
self.domains.pop(domain_name) | ||
if domain_name in self.domain_engine.intent_names: | ||
self.domain_engine.remove_intent(domain_name) | ||
|
||
def register_domain_intent(self, domain_name: str, intent_name: str, intent_samples: List[str]): | ||
""" | ||
Register an intent within a specific domain. | ||
Args: | ||
domain_name (str): The name of the domain. | ||
intent_name (str): The name of the intent to register. | ||
intent_samples (List[str]): A list of sample sentences for the intent. | ||
""" | ||
self.domains[domain_name].add_intent(intent_name, intent_samples) | ||
self.training_data[domain_name] += intent_samples | ||
self._must_train = True | ||
|
||
def remove_domain_intent(self, domain_name: str, intent_name: str): | ||
""" | ||
Remove a specific intent from a domain. | ||
Args: | ||
domain_name (str): The name of the domain. | ||
intent_name (str): The name of the intent to remove. | ||
""" | ||
self.domains[domain_name].remove_intent(intent_name) | ||
|
||
def register_domain_entity(self, domain_name: str, entity_name: str, entity_samples: List[str]): | ||
""" | ||
Register an entity within a specific domain. | ||
Args: | ||
domain_name (str): The name of the domain. | ||
entity_name (str): The name of the entity to register. | ||
entity_samples (List[str]): A list of sample phrases for the entity. | ||
""" | ||
self.domains[domain_name].add_entity(entity_name, entity_samples) | ||
|
||
def remove_domain_entity(self, domain_name: str, entity_name: str): | ||
""" | ||
Remove a specific entity from a domain. | ||
Args: | ||
domain_name (str): The name of the domain. | ||
entity_name (str): The name of the entity to remove. | ||
""" | ||
self.domains[domain_name].remove_entity(entity_name) | ||
|
||
def calc_domains(self, query: str) -> List[MatchData]: | ||
""" | ||
Calculate the matching domains for a query. | ||
Args: | ||
query (str): The input query. | ||
Returns: | ||
List[MatchData]: A list of MatchData objects representing matching domains. | ||
""" | ||
if self._must_train: | ||
self.train() | ||
|
||
return self.domain_engine.calc_intents(query) | ||
|
||
def calc_domain(self, query: str) -> MatchData: | ||
""" | ||
Calculate the best matching domain for a query. | ||
Args: | ||
query (str): The input query. | ||
Returns: | ||
MatchData: The best matching domain. | ||
""" | ||
if self._must_train: | ||
self.train() | ||
return self.domain_engine.calc_intent(query) | ||
|
||
def calc_intent(self, query: str, domain: Optional[str] = None) -> MatchData: | ||
""" | ||
Calculate the best matching intent for a query within a specific domain. | ||
Args: | ||
query (str): The input query. | ||
domain (Optional[str]): The domain to limit the search to. Defaults to None. | ||
Returns: | ||
MatchData: The best matching intent. | ||
""" | ||
if self._must_train: | ||
self.train() | ||
domain: str = domain or self.domain_engine.calc_intent(query).name | ||
if domain in self.domains: | ||
return self.domains[domain].calc_intent(query) | ||
return MatchData(name=None, sent=query, matches=None, conf=0.0) | ||
|
||
def calc_intents(self, query: str, domain: Optional[str] = None, top_k_domains: int = 2) -> List[MatchData]: | ||
""" | ||
Calculate matching intents for a query across domains or within a specific domain. | ||
Args: | ||
query (str): The input query. | ||
domain (Optional[str]): The specific domain to search in. If None, searches across top-k domains. | ||
top_k_domains (int): The number of top domains to consider. Defaults to 2. | ||
Returns: | ||
List[MatchData]: A list of MatchData objects representing matching intents, sorted by confidence. | ||
""" | ||
if self._must_train: | ||
self.train() | ||
if domain: | ||
return self.domains[domain].calc_intents(query) | ||
matches = [] | ||
domains = self.calc_domains(query)[:top_k_domains] | ||
for domain in domains: | ||
if domain.name in self.domains: | ||
matches += self.domains[domain.name].calc_intents(query) | ||
return sorted(matches, reverse=True, key=lambda k: k.conf) | ||
|
||
def train(self): | ||
for domain, samples in self.training_data.items(): | ||
self.domain_engine.add_intent(domain, samples) | ||
self.domain_engine.train() | ||
for domain in self.domains: | ||
self.domains[domain].train() | ||
self._must_train = False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
import unittest | ||
from unittest.mock import MagicMock | ||
from ovos_padatious.intent_container import IntentContainer | ||
from ovos_padatious.match_data import MatchData | ||
from ovos_padatious.domain_engine import DomainIntentEngine # Replace 'your_module' with the actual module name | ||
|
||
|
||
class TestDomainIntentEngine(unittest.TestCase): | ||
def setUp(self): | ||
self.engine = DomainIntentEngine() | ||
|
||
def test_register_domain_intent(self): | ||
self.engine.register_domain_intent("domain1", "intent1", ["sample1", "sample2"]) | ||
self.assertIn("domain1", self.engine.training_data) | ||
self.assertIn("intent1", self.engine.domains["domain1"].intent_names) | ||
|
||
def test_remove_domain(self): | ||
self.engine.register_domain_intent("domain1", "intent1", ["sample1", "sample2"]) | ||
self.engine.remove_domain("domain1") | ||
self.assertNotIn("domain1", self.engine.training_data) | ||
self.assertNotIn("domain1", self.engine.domains) | ||
|
||
def test_remove_domain_intent(self): | ||
self.engine.register_domain_intent("domain1", "intent1", ["sample1", "sample2"]) | ||
self.engine.remove_domain_intent("domain1", "intent1") | ||
self.assertNotIn("intent1", self.engine.domains["domain1"].intent_names) | ||
|
||
def test_calc_domains(self): | ||
self.engine.train = MagicMock() | ||
self.engine.domain_engine.calc_intents = MagicMock(return_value=[MatchData(name="domain1", sent="query", matches=None, conf=0.9)]) | ||
result = self.engine.calc_domains("query") | ||
self.engine.train.assert_called_once() | ||
self.assertEqual(result[0].name, "domain1") | ||
|
||
def test_calc_domain(self): | ||
self.engine.train = MagicMock() | ||
self.engine.domain_engine.calc_intent = MagicMock(return_value=MatchData(name="domain1", sent="query", matches=None, conf=0.9)) | ||
result = self.engine.calc_domain("query") | ||
self.engine.train.assert_called_once() | ||
self.assertEqual(result.name, "domain1") | ||
|
||
def test_calc_intent(self): | ||
self.engine.train = MagicMock() | ||
mock_domain_container = MagicMock() | ||
mock_domain_container.calc_intent.return_value = MatchData(name="intent1", sent="query", matches=None, conf=0.9) | ||
self.engine.domains["domain1"] = mock_domain_container | ||
|
||
self.engine.domain_engine.calc_intent = MagicMock(return_value=MatchData(name="domain1", sent="query", matches=None, conf=0.9)) | ||
result = self.engine.calc_intent("query") | ||
self.assertEqual(result.name, "intent1") | ||
|
||
def test_calc_intents(self): | ||
self.engine.train = MagicMock() | ||
mock_domain_container = MagicMock() | ||
mock_domain_container.calc_intents.return_value = [ | ||
MatchData(name="intent1", sent="query", matches=None, conf=0.9), | ||
MatchData(name="intent2", sent="query", matches=None, conf=0.8), | ||
] | ||
self.engine.domains["domain1"] = mock_domain_container | ||
|
||
self.engine.domain_engine.calc_intents = MagicMock(return_value=[MatchData(name="domain1", sent="query", matches=None, conf=0.9)]) | ||
result = self.engine.calc_intents("query") | ||
self.assertEqual(len(result), 2) | ||
self.assertEqual(result[0].name, "intent1") | ||
|
||
def test_train(self): | ||
self.engine.training_data["domain1"] = ["sample1", "sample2"] | ||
self.engine.domain_engine.add_intent = MagicMock() | ||
self.engine.domain_engine.train = MagicMock() | ||
|
||
mock_domain_container = MagicMock() | ||
self.engine.domains["domain1"] = mock_domain_container | ||
|
||
self.engine.train() | ||
self.engine.domain_engine.add_intent.assert_called_with("domain1", ["sample1", "sample2"]) | ||
self.engine.domain_engine.train.assert_called_once() | ||
mock_domain_container.train.assert_called_once() | ||
self.assertFalse(self.engine._must_train) | ||
|
||
|
||
|
||
class TestDomainIntentEngineWithLiveData(unittest.TestCase): | ||
def setUp(self): | ||
self.engine = DomainIntentEngine() | ||
# Sample training data | ||
self.training_data = { | ||
"IOT": { | ||
"turn_on_device": ["Turn on the lights", "Switch on the fan", "Activate the air conditioner"], | ||
"turn_off_device": ["Turn off the lights", "Switch off the heater", "Deactivate the air conditioner"], | ||
}, | ||
"greetings": { | ||
"say_hello": ["Hello", "Hi there", "Good morning"], | ||
"say_goodbye": ["Goodbye", "See you later", "Bye"], | ||
}, | ||
"General Knowledge": { | ||
"ask_fact": ["Tell me a fact about space", "What is the capital of France?", "Who invented the telephone?"], | ||
}, | ||
"Question": { | ||
"ask_question": ["Why is the sky blue?", "What is quantum mechanics?", "Can you explain photosynthesis?"], | ||
}, | ||
"Media Playback": { | ||
"play_music": ["Play some music", "Start the playlist", "Play a song"], | ||
"stop_music": ["Stop the music", "Pause playback", "Halt the song"], | ||
}, | ||
} | ||
# Register domains and intents | ||
for domain, intents in self.training_data.items(): | ||
for intent, samples in intents.items(): | ||
self.engine.register_domain_intent(domain, intent, samples) | ||
self.engine.train() | ||
|
||
def test_live_data_intent_matching(self): | ||
# Test IOT domain | ||
query = "Switch on the fan" | ||
result = self.engine.calc_intent(query, domain="IOT") | ||
self.assertEqual(result.name, "turn_on_device") | ||
self.assertGreater(result.conf, 0.8) | ||
|
||
# Test greetings domain | ||
query = "Hi there" | ||
result = self.engine.calc_intent(query, domain="greetings") | ||
self.assertEqual(result.name, "say_hello") | ||
self.assertGreater(result.conf, 0.8) | ||
|
||
# Test General Knowledge domain | ||
query = "What is the capital of France?" | ||
result = self.engine.calc_intent(query, domain="General Knowledge") | ||
self.assertEqual(result.name, "ask_fact") | ||
self.assertGreater(result.conf, 0.8) | ||
|
||
# Test Question domain | ||
query = "Why is the sky blue?" | ||
result = self.engine.calc_intent(query, domain="Question") | ||
self.assertEqual(result.name, "ask_question") | ||
self.assertGreater(result.conf, 0.8) | ||
|
||
# Test Media Playback domain | ||
query = "Play a song" | ||
result = self.engine.calc_intent(query, domain="Media Playback") | ||
self.assertEqual(result.name, "play_music") | ||
self.assertGreater(result.conf, 0.8) | ||
|
||
def test_live_data_cross_domain_matching(self): | ||
# Test cross-domain intent matching | ||
query = "Tell me a fact about space" | ||
result = self.engine.calc_domain(query) | ||
self.assertEqual(result.name, "General Knowledge") | ||
self.assertGreater(result.conf, 0.8) | ||
|
||
# Validate intent from the matched domain | ||
result = self.engine.calc_intent(query, domain=result.name) | ||
self.assertEqual(result.name, "ask_fact") | ||
self.assertGreater(result.conf, 0.8) | ||
|
||
def test_calc_intent_without_domain(self): | ||
# Test intent calculation without specifying a domain | ||
query = "Turn on the lights" | ||
result = self.engine.calc_intent(query) | ||
self.assertIsNotNone(result.name, "Intent name should not be None") | ||
self.assertEqual(result.name, "turn_on_device") | ||
self.assertGreater(result.conf, 0.8) | ||
|
||
query = "Goodbye" | ||
result = self.engine.calc_intent(query) | ||
self.assertIsNotNone(result.name, "Intent name should not be None") | ||
self.assertEqual(result.name, "say_goodbye") | ||
self.assertGreater(result.conf, 0.8) | ||
|
||
query = "What is quantum mechanics?" | ||
result = self.engine.calc_intent(query) | ||
self.assertIsNotNone(result.name, "Intent name should not be None") | ||
self.assertEqual(result.name, "ask_question") | ||
self.assertGreater(result.conf, 0.8) | ||
|
||
if __name__ == "__main__": | ||
unittest.main() | ||
|
||
if __name__ == "__main__": | ||
unittest.main() |