Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
JarbasAl committed Dec 12, 2024
1 parent 80ba4e9 commit 02d2aa3
Showing 1 changed file with 22 additions and 18 deletions.
40 changes: 22 additions & 18 deletions tests/test_domain.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,43 @@
import unittest
from unittest.mock import MagicMock
from ovos_padatious.intent_container import IntentContainer

from ovos_padatious.domain_container import DomainIntentContainer # Replace 'your_module' with the actual module name

from ovos_padatious.match_data import MatchData
from ovos_padatious.domain_engine import DomainIntentEngine # Replace 'your_module' with the actual module name


class TestDomainIntentEngine(unittest.TestCase):
def setUp(self):
self.engine = DomainIntentEngine()
self.engine = DomainIntentContainer()

def test_register_domain_intent(self):
self.engine.register_domain_intent("domain1", "intent1", ["sample1", "sample2"])
self.engine.add_domain_intent("domain1", "intent1", ["sample1", "sample2"])
self.assertIn("domain1", self.engine.training_data)
self.assertIn("intent1", self.engine.domains["domain1"].intent_names)

def test_remove_domain(self):
self.engine.register_domain_intent("domain1", "intent1", ["sample1", "sample2"])
self.engine.add_domain_intent("domain1", "intent1", ["sample1", "sample2"])
self.engine.remove_domain("domain1")
self.assertNotIn("domain1", self.engine.training_data)
self.assertNotIn("domain1", self.engine.domains)

def test_remove_domain_intent(self):
self.engine.register_domain_intent("domain1", "intent1", ["sample1", "sample2"])
self.engine.add_domain_intent("domain1", "intent1", ["sample1", "sample2"])
self.engine.remove_domain_intent("domain1", "intent1")
self.assertNotIn("intent1", self.engine.domains["domain1"].intent_names)

def test_calc_domains(self):
self.engine.train = MagicMock()
self.engine.domain_engine.calc_intents = MagicMock(return_value=[MatchData(name="domain1", sent="query", matches=None, conf=0.9)])
self.engine.domain_engine.calc_intents = MagicMock(
return_value=[MatchData(name="domain1", sent="query", matches=None, conf=0.9)])
result = self.engine.calc_domains("query")
self.engine.train.assert_called_once()
self.assertEqual(result[0].name, "domain1")

def test_calc_domain(self):
self.engine.train = MagicMock()
self.engine.domain_engine.calc_intent = MagicMock(return_value=MatchData(name="domain1", sent="query", matches=None, conf=0.9))
self.engine.domain_engine.calc_intent = MagicMock(
return_value=MatchData(name="domain1", sent="query", matches=None, conf=0.9))
result = self.engine.calc_domain("query")
self.engine.train.assert_called_once()
self.assertEqual(result.name, "domain1")
Expand All @@ -45,7 +48,8 @@ def test_calc_intent(self):
mock_domain_container.calc_intent.return_value = MatchData(name="intent1", sent="query", matches=None, conf=0.9)
self.engine.domains["domain1"] = mock_domain_container

self.engine.domain_engine.calc_intent = MagicMock(return_value=MatchData(name="domain1", sent="query", matches=None, conf=0.9))
self.engine.domain_engine.calc_intent = MagicMock(
return_value=MatchData(name="domain1", sent="query", matches=None, conf=0.9))
result = self.engine.calc_intent("query")
self.assertEqual(result.name, "intent1")

Expand All @@ -58,7 +62,8 @@ def test_calc_intents(self):
]
self.engine.domains["domain1"] = mock_domain_container

self.engine.domain_engine.calc_intents = MagicMock(return_value=[MatchData(name="domain1", sent="query", matches=None, conf=0.9)])
self.engine.domain_engine.calc_intents = MagicMock(
return_value=[MatchData(name="domain1", sent="query", matches=None, conf=0.9)])
result = self.engine.calc_intents("query")
self.assertEqual(len(result), 2)
self.assertEqual(result[0].name, "intent1")
Expand All @@ -75,13 +80,12 @@ def test_train(self):
self.engine.domain_engine.add_intent.assert_called_with("domain1", ["sample1", "sample2"])
self.engine.domain_engine.train.assert_called_once()
mock_domain_container.train.assert_called_once()
self.assertFalse(self.engine._must_train)

self.assertFalse(self.engine.must_train)


class TestDomainIntentEngineWithLiveData(unittest.TestCase):
def setUp(self):
self.engine = DomainIntentEngine()
self.engine = DomainIntentContainer()
# Sample training data
self.training_data = {
"IOT": {
Expand All @@ -93,10 +97,12 @@ def setUp(self):
"say_goodbye": ["Goodbye", "See you later", "Bye"],
},
"General Knowledge": {
"ask_fact": ["Tell me a fact about space", "What is the capital of France?", "Who invented the telephone?"],
"ask_fact": ["Tell me a fact about space", "What is the capital of France?",
"Who invented the telephone?"],
},
"Question": {
"ask_question": ["Why is the sky blue?", "What is quantum mechanics?", "Can you explain photosynthesis?"],
"ask_question": ["Why is the sky blue?", "What is quantum mechanics?",
"Can you explain photosynthesis?"],
},
"Media Playback": {
"play_music": ["Play some music", "Start the playlist", "Play a song"],
Expand All @@ -106,7 +112,7 @@ def setUp(self):
# Register domains and intents
for domain, intents in self.training_data.items():
for intent, samples in intents.items():
self.engine.register_domain_intent(domain, intent, samples)
self.engine.add_domain_intent(domain, intent, samples)
self.engine.train()

def test_live_data_intent_matching(self):
Expand Down Expand Up @@ -172,8 +178,6 @@ def test_calc_intent_without_domain(self):
self.assertEqual(result.name, "ask_question")
self.assertGreater(result.conf, 0.8)

if __name__ == "__main__":
unittest.main()

if __name__ == "__main__":
unittest.main()

0 comments on commit 02d2aa3

Please sign in to comment.