diff --git a/CHANGELOG.md b/CHANGELOG.md index 5a8f86a..c8aa2c7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,12 +1,12 @@ # Changelog -## [0.1.3a1](https://github.com/OpenVoiceOS/ovos-ocp-pipeline-plugin/tree/0.1.3a1) (2024-10-16) +## [1.0.0a1](https://github.com/OpenVoiceOS/ovos-ocp-pipeline-plugin/tree/1.0.0a1) (2024-10-16) -[Full Changelog](https://github.com/OpenVoiceOS/ovos-ocp-pipeline-plugin/compare/0.1.2...0.1.3a1) +[Full Changelog](https://github.com/OpenVoiceOS/ovos-ocp-pipeline-plugin/compare/0.1.3...1.0.0a1) -**Merged pull requests:** +**Breaking changes:** -- fix:standardize\_lang [\#12](https://github.com/OpenVoiceOS/ovos-ocp-pipeline-plugin/pull/12) ([JarbasAl](https://github.com/JarbasAl)) +- feat:pipeline plugin factory [\#14](https://github.com/OpenVoiceOS/ovos-ocp-pipeline-plugin/pull/14) ([JarbasAl](https://github.com/JarbasAl)) diff --git a/ocp_pipeline/opm.py b/ocp_pipeline/opm.py index da90673..ab246c9 100644 --- a/ocp_pipeline/opm.py +++ b/ocp_pipeline/opm.py @@ -4,22 +4,26 @@ from dataclasses import dataclass from os.path import join, dirname from threading import RLock -from typing import List, Tuple, Optional, Union +from typing import Tuple, Optional, Dict, List, Union +from langcodes import closest_match from ovos_bus_client.apis.ocp import ClassicAudioServiceInterface from ovos_bus_client.apis.ocp import OCPInterface, OCPQuery +from ovos_bus_client.client import MessageBusClient from ovos_bus_client.message import Message, dig_for_message from ovos_bus_client.session import SessionManager +from ovos_config import Configuration from ovos_plugin_manager.ocp import available_extractors -from ovos_plugin_manager.templates.pipeline import IntentMatch, PipelinePlugin +from ovos_plugin_manager.templates.pipeline import IntentHandlerMatch, ConfidenceMatcherPipeline, PipelineStageMatcher, \ + PipelineMatch from ovos_utils.lang import standardize_lang_tag, get_language_dir -from ovos_utils.log import LOG +from ovos_utils.log import LOG, deprecated, log_deprecation from ovos_utils.messagebus import FakeBus from ovos_utils.ocp import MediaType, PlaybackType, PlaybackMode, PlayerState, OCP_ID, \ MediaEntry, Playlist, MediaState, TrackState, dict2entry, PluginStream from ovos_workshop.app import OVOSAbstractApplication from padacioso import IntentContainer -from langcodes import closest_match + from ocp_pipeline.feats import OCPFeaturizer from ocp_pipeline.legacy import LegacyCommonPlay @@ -35,25 +39,24 @@ class OCPPlayerProxy: media_type: MediaType = MediaType.GENERIC -class OCPPipelineMatcher(PipelinePlugin, OVOSAbstractApplication): +class OCPPipelineMatcher(ConfidenceMatcherPipeline, OVOSAbstractApplication): intents = ["play.intent", "open.intent", "media_stop.intent", "next.intent", "prev.intent", "pause.intent", "play_favorites.intent", "resume.intent", "like_song.intent"] + intent_matchers = {} - def __init__(self, bus=None, config=None): + def __init__(self, bus: Optional[Union[MessageBusClient, FakeBus]] = None, + config: Optional[Dict] = None): OVOSAbstractApplication.__init__( self, bus=bus or FakeBus(), skill_id=OCP_ID, resources_dir=f"{dirname(__file__)}") - PipelinePlugin.__init__(self, config) + ConfidenceMatcherPipeline.__init__(self, bus, config) self.ocp_api = OCPInterface(self.bus) self.legacy_api = ClassicAudioServiceInterface(self.bus) - self.mycroft_cps = LegacyCommonPlay(self.bus) - self.config = config or {} self.search_lock = RLock() self.ocp_sessions = {} # session_id: PlaybackCapabilities - self.intent_matchers = {} self.skill_aliases = { # "skill_id": ["names"] } @@ -100,16 +103,19 @@ def load_classifiers(self): c = SklearnOVOSClassifier.from_file(f"{b}/media_ocp_cv2_kw_medium.clf") self._media_en_clf = (c, OCPFeaturizer("media_ocp_cv2_medium")) - def load_resource_files(self): + @classmethod + def load_resource_files(cls): intents = {} - for lang in self.native_langs: + langs = Configuration().get('secondary_langs', []) + [Configuration().get('lang', "en-US")] + langs = set([standardize_lang_tag(l) for l in langs]) + for lang in langs: lang = standardize_lang_tag(lang) intents[lang] = {} locale_folder = get_language_dir(join(dirname(__file__), "locale"), lang) if locale_folder is not None: for f in os.listdir(locale_folder): path = join(locale_folder, f) - if f in self.intents: + if f in cls.intents: with open(path) as intent: samples = intent.read().split("\n") for idx, s in enumerate(samples): @@ -137,19 +143,22 @@ def register_ocp_api_events(self): self.add_event("mycroft.audio.service.stop", self._handle_legacy_audio_stop) self.bus.emit(Message("ovos.common_play.status")) # sync player state on launch - def register_ocp_intents(self): - intent_files = self.load_resource_files() + @classmethod + def load_intent_files(cls): + intent_files = cls.load_resource_files() for lang, intent_data in intent_files.items(): lang = standardize_lang_tag(lang) - self.intent_matchers[lang] = IntentContainer() - for intent_name in self.intents: + cls.intent_matchers[lang] = IntentContainer() + for intent_name in cls.intents: samples = intent_data.get(intent_name) if samples: LOG.debug(f"registering OCP intent: {intent_name}") - self.intent_matchers[lang].add_intent( + cls.intent_matchers[lang].add_intent( intent_name.replace(".intent", ""), samples) + def register_ocp_intents(self): + self.load_intent_files() self.add_event("ocp:play", self.handle_play_intent, is_intent=True) self.add_event("ocp:play_favorites", self.handle_play_favorites_intent, is_intent=True) self.add_event("ocp:open", self.handle_open_intent, is_intent=True) @@ -160,7 +169,6 @@ def register_ocp_intents(self): self.add_event("ocp:media_stop", self.handle_stop_intent, is_intent=True) self.add_event("ocp:search_error", self.handle_search_error_intent, is_intent=True) self.add_event("ocp:like_song", self.handle_like_intent, is_intent=True) - self.add_event("ocp:legacy_cps", self.handle_legacy_cps, is_intent=True) def update_player_proxy(self, player: OCPPlayerProxy): """remember OCP session state""" @@ -287,7 +295,7 @@ def handle_player_state_update(self, message: Message): self.update_player_proxy(player) # pipeline - def match_high(self, utterances: List[str], lang: str, message: Message = None) -> Optional[IntentMatch]: + def match_high(self, utterances: List[str], lang: str, message: Message = None) -> Optional[IntentHandlerMatch]: """ exact matches only, handles playback control recommended after high confidence intents pipeline stage """ lang = self._get_closest_lang(lang) @@ -323,13 +331,12 @@ def match_high(self, utterances: List[str], lang: str, message: Message = None) else: return None - return IntentMatch(intent_service="OCP_intents", - intent_type=f'ocp:{match["name"]}', - intent_data=match, - skill_id=OCP_ID, - utterance=utterance) + return IntentHandlerMatch(match_type=f'ocp:{match["name"]}', + match_data=match, + skill_id=OCP_ID, + utterance=utterance) - def match_medium(self, utterances: List[str], lang: str, message: Message = None) -> Optional[IntentMatch]: + def match_medium(self, utterances: List[str], lang: str, message: Message = None) -> Optional[IntentHandlerMatch]: """ match a utterance via classifiers, recommended before common_qa pipeline stage""" lang = standardize_lang_tag(lang) @@ -353,17 +360,16 @@ def match_medium(self, utterances: List[str], lang: str, message: Message = None # extract the query string query = self.remove_voc(utterance, "Play", lang).strip() - return IntentMatch(intent_service="OCP_media", - intent_type="ocp:play", - intent_data={"media_type": media_type, - "entities": ents, - "query": query, - "is_ocp_conf": bconf, - "conf": confidence}, - skill_id=OCP_ID, - utterance=utterance) - - def match_fallback(self, utterances: List[str], lang: str, message: Message = None) -> Optional[IntentMatch]: + return IntentHandlerMatch(match_type="ocp:play", + match_data={"media_type": media_type, + "entities": ents, + "query": query, + "is_ocp_conf": bconf, + "conf": confidence}, + skill_id=OCP_ID, + utterance=utterance) + + def match_low(self, utterances: List[str], lang: str, message: Message = None) -> Optional[IntentHandlerMatch]: """ match an utterance via presence of known OCP keywords, recommended before fallback_low pipeline stage""" utterance = utterances[0].lower() @@ -386,39 +392,36 @@ def match_fallback(self, utterances: List[str], lang: str, message: Message = No # extract the query string query = self.remove_voc(utterance, "Play", lang).strip() - return IntentMatch(intent_service="OCP_fallback", - intent_type="ocp:play", - intent_data={"media_type": media_type, - "entities": ents, - "query": query, - "conf": float(confidence)}, - skill_id=OCP_ID, - utterance=utterance) + return IntentHandlerMatch(match_type="ocp:play", + match_data={"media_type": media_type, + "entities": ents, + "query": query, + "conf": float(confidence)}, + skill_id=OCP_ID, + utterance=utterance) def _process_play_query(self, utterance: str, lang: str, match: dict = None, - message: Optional[Message] = None) -> Optional[IntentMatch]: + message: Optional[Message] = None) -> Optional[IntentHandlerMatch]: lang = standardize_lang_tag(lang) match = match or {} player = self.get_player(message) # if media is currently paused, empty string means "resume playback" if player.player_state == PlayerState.PAUSED and \ self._should_resume(utterance, lang, message=message): - return IntentMatch(intent_service="OCP_intents", - intent_type="ocp:resume", - intent_data=match, - skill_id=OCP_ID, - utterance=utterance) + return IntentHandlerMatch(match_type="ocp:resume", + match_data=match, + skill_id=OCP_ID, + utterance=utterance) if not utterance: # user just said "play", we are missing the search query phrase = self.get_response("play.what", num_retries=2) if not phrase: # let the error intent handler take action - return IntentMatch(intent_service="OCP_intents", - intent_type="ocp:search_error", - intent_data=match, - skill_id=OCP_ID, - utterance=utterance) + return IntentHandlerMatch(match_type="ocp:search_error", + match_data=match, + skill_id=OCP_ID, + utterance=utterance) sess = SessionManager.get(message) # if a skill was explicitly requested, search it first @@ -441,18 +444,17 @@ def _process_play_query(self, utterance: str, lang: str, match: dict = None, else: ents = OCPFeaturizer.extract_entities(utterance) - return IntentMatch(intent_service="OCP_intents", - intent_type="ocp:play", - intent_data={"media_type": media_type, - "query": query, - "entities": ents, - "skills": valid_skills, - "conf": match["conf"], - "media_conf": float(conf), - # "results": results, - "lang": lang}, - skill_id=OCP_ID, - utterance=utterance) + return IntentHandlerMatch(match_type="ocp:play", + match_data={"media_type": media_type, + "query": query, + "entities": ents, + "skills": valid_skills, + "conf": match["conf"], + "media_conf": float(conf), + # "results": results, + "lang": lang}, + skill_id=OCP_ID, + utterance=utterance) # bus api def handle_search_query(self, message: Message): @@ -1032,10 +1034,53 @@ def _handle_legacy_audio_end(self, message: Message): player.media_state = MediaState.END_OF_MEDIA self.update_player_proxy(player) + @classmethod + def _get_closest_lang(cls, lang: str) -> Optional[str]: + if cls.intent_matchers: + lang = standardize_lang_tag(lang) + closest, score = closest_match(lang, list(cls.intent_matchers.keys())) + # https://langcodes-hickford.readthedocs.io/en/sphinx/index.html#distance-values + # 0 -> These codes represent the same language, possibly after filling in values and normalizing. + # 1- 3 -> These codes indicate a minor regional difference. + # 4 - 10 -> These codes indicate a significant but unproblematic regional difference. + if score < 10: + return closest + return None + + def shutdown(self): + self.default_shutdown() # remove events registered via self.add_event + + # deprecated + @property + def mycroft_cps(self) -> LegacyCommonPlay: + log_deprecation("self.mycroft_cps is deprecated, use MycroftCPSLegacyPipeline instead", "2.0.0") + return LegacyCommonPlay(self.bus) + + @deprecated("match_fallback has been renamed match_low", "2.0.0") + def match_fallback(self, utterances: List[str], lang: str, message: Message = None) -> Optional[IntentHandlerMatch]: + return self.match_low(utterances, lang, message) + + @deprecated("match_legacy is deprecated! use MycroftCPSLegacyPipeline class directly instead", "2.0.0") + def match_legacy(self, utterances: List[str], lang: str, message: Message = None) -> Optional[PipelineMatch]: + """ match legacy mycroft common play skills (must import from deprecated mycroft module) + not recommended, legacy support only + + legacy base class at mycroft/skills/common_play_skill.py marked for removal in ovos-core 0.1.0 + """ + return MycroftCPSLegacyPipeline(self.bus, self.config).match_high(utterances, lang, message) + + +class MycroftCPSLegacyPipeline(PipelineStageMatcher): + def __init__(self, bus: Optional[Union[MessageBusClient, FakeBus]] = None, + config: Optional[Dict] = None): + super().__init__(bus, config) + self.mycroft_cps = LegacyCommonPlay(self.bus) + OCPPipelineMatcher.load_intent_files() + self.bus.on("ocp:legacy_cps", self.handle_legacy_cps) + ############ # Legacy Mycroft CommonPlay skills - - def match_legacy(self, utterances: List[str], lang: str, message: Message = None) -> Optional[IntentMatch]: + def match(self, utterances: List[str], lang: str, message: Message = None) -> Optional[PipelineMatch]: """ match legacy mycroft common play skills (must import from deprecated mycroft module) not recommended, legacy support only @@ -1047,34 +1092,29 @@ def match_legacy(self, utterances: List[str], lang: str, message: Message = None utterance = utterances[0].lower() - lang = self._get_closest_lang(lang) + lang = OCPPipelineMatcher._get_closest_lang(lang) if lang is None: # no intents registered for this lang return None - match = self.intent_matchers[lang].calc_intent(utterance) + match = OCPPipelineMatcher.intent_matchers[lang].calc_intent(utterance) if match["name"] is None: return None if match["name"] == "play": LOG.info(f"Legacy Mycroft CommonPlay match: {match}") utterance = match["entities"].pop("query") - return IntentMatch(intent_service="OCP_media", - intent_type="ocp:legacy_cps", - intent_data={"query": utterance, - "conf": 0.7}, - skill_id=OCP_ID, - utterance=utterance) - - def _get_closest_lang(self, lang: str) -> Optional[str]: - if self.intent_matchers: - lang = standardize_lang_tag(lang) - closest, score = closest_match(lang, list(self.intent_matchers.keys())) - # https://langcodes-hickford.readthedocs.io/en/sphinx/index.html#distance-values - # 0 -> These codes represent the same language, possibly after filling in values and normalizing. - # 1- 3 -> These codes indicate a minor regional difference. - # 4 - 10 -> These codes indicate a significant but unproblematic regional difference. - if score < 10: - return closest + self.bus.emit(Message("ocp:legacy_cps", + {"query": utterance, "conf": 0.7})) + return PipelineMatch(handled=True, + match_data={"query": utterance, + "conf": 0.7}, + skill_id=OCP_ID, + utterance=utterance) + + def match_medium(self, utterances: List[str], lang: str, message: Message = None) -> Optional[IntentHandlerMatch]: + return None + + def match_low(self, utterances: List[str], lang: str, message: Message = None) -> Optional[IntentHandlerMatch]: return None def handle_legacy_cps(self, message: Message): @@ -1095,4 +1135,3 @@ def handle_legacy_cps(self, message: Message): def shutdown(self): self.mycroft_cps.shutdown() - self.default_shutdown() # remove events registered via self.add_event diff --git a/ocp_pipeline/version.py b/ocp_pipeline/version.py index 35edff0..cd3bcb6 100644 --- a/ocp_pipeline/version.py +++ b/ocp_pipeline/version.py @@ -1,6 +1,6 @@ # START_VERSION_BLOCK -VERSION_MAJOR = 0 -VERSION_MINOR = 1 -VERSION_BUILD = 3 -VERSION_ALPHA = 0 +VERSION_MAJOR = 1 +VERSION_MINOR = 0 +VERSION_BUILD = 0 +VERSION_ALPHA = 1 # END_VERSION_BLOCK diff --git a/requirements.txt b/requirements.txt index b37171d..8af7f93 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ ovos-workshop>=0.1.7,<2.0.0 ovos-classifiers ovos-utils>=0.3.5,<1.0.0 +ovos-plugin-manager>=0.5.0,<1.0.0 langcodes \ No newline at end of file diff --git a/setup.py b/setup.py index fe7df61..b8c99ba 100644 --- a/setup.py +++ b/setup.py @@ -41,7 +41,8 @@ def required(requirements_file): if pkg.strip() and not pkg.startswith("#")] -PLUGIN_ENTRY_POINT = 'ovos-ocp-pipeline-plugin=ocp_pipeline.opm:OCPPipelineMatcher' +PLUGIN_ENTRY_POINT = ('ovos-ocp-pipeline-plugin=ocp_pipeline.opm:OCPPipelineMatcher', + 'ovos-ocp-pipeline-plugin-legacy=ocp_pipeline.opm:MycroftCPSLegacyPipeline') setup( name="ovos-ocp-pipeline-plugin", diff --git a/tests/test_ocp.py b/tests/test_ocp.py index 48c5392..30157ed 100644 --- a/tests/test_ocp.py +++ b/tests/test_ocp.py @@ -52,59 +52,56 @@ def setUp(self): self.ocp = OCPPipelineMatcher(config=config) def test_match_high(self): - result = self.ocp.match_high(["play metallica"], "en-us") + result = self.ocp.match_high(["play metallica"], "en-US") self.assertIsNotNone(result) - self.assertEqual(result.intent_service, 'OCP_intents') - self.assertEqual(result.intent_type, 'ocp:play') + self.assertEqual(result.match_type, 'ocp:play') def test_match_high_with_invalid_input(self): - result = self.ocp.match_high(["put on some music"], "en-us") + result = self.ocp.match_high(["put on some music"], "en-US") self.assertIsNone(result) def test_match_medium(self): - result = self.ocp.match_medium(["put on some movie"], "en-us") + result = self.ocp.match_medium(["put on some movie"], "en-US") self.assertIsNotNone(result) - self.assertEqual(result.intent_service, 'OCP_media') - self.assertEqual(result.intent_type, 'ocp:play') + self.assertEqual(result.match_type, 'ocp:play') def test_match_medium_with_invalid_input(self): - result = self.ocp.match_medium(["i wanna hear metallica"], "en-us") + result = self.ocp.match_medium(["i wanna hear metallica"], "en-US") self.assertIsNone(result) def test_match_fallback(self): - result = self.ocp.match_fallback(["i want music"], "en-us") + result = self.ocp.match_low(["i want music"], "en-US") self.assertIsNotNone(result) - self.assertEqual(result.intent_service, 'OCP_fallback') - self.assertEqual(result.intent_type, 'ocp:play') + self.assertEqual(result.match_type, 'ocp:play') def test_match_fallback_with_invalid_input(self): - result = self.ocp.match_fallback(["do the thing"], "en-us") + result = self.ocp.match_low(["do the thing"], "en-US") self.assertIsNone(result) def test_predict(self): - self.assertTrue(self.ocp.is_ocp_query("play a song", "en-us")[0]) - self.assertTrue(self.ocp.is_ocp_query("play a movie", "en-us")[0]) - self.assertTrue(self.ocp.is_ocp_query("play a podcast", "en-us")[0]) - self.assertFalse(self.ocp.is_ocp_query("tell me a joke", "en-us")[0]) - self.assertFalse(self.ocp.is_ocp_query("who are you", "en-us")[0]) - self.assertFalse(self.ocp.is_ocp_query("you suck", "en-us")[0]) + self.assertTrue(self.ocp.is_ocp_query("play a song", "en-US")[0]) + self.assertTrue(self.ocp.is_ocp_query("play a movie", "en-US")[0]) + self.assertTrue(self.ocp.is_ocp_query("play a podcast", "en-US")[0]) + self.assertFalse(self.ocp.is_ocp_query("tell me a joke", "en-US")[0]) + self.assertFalse(self.ocp.is_ocp_query("who are you", "en-US")[0]) + self.assertFalse(self.ocp.is_ocp_query("you suck", "en-US")[0]) def test_predict_prob(self): noise = "hglisjerhksrtjhdgsf" - self.assertEqual(self.ocp.classify_media(f"play {noise} music", "en-us")[0], MediaType.MUSIC) - self.assertIsInstance(self.ocp.classify_media(f"play music {noise}", "en-us")[1], float) - self.assertEqual(self.ocp.classify_media(f"play {noise} movie soundtrack", "en-us")[0], MediaType.MUSIC) - self.assertEqual(self.ocp.classify_media(f"play movie {noise}", "en-us")[0], MediaType.MOVIE) - self.assertEqual(self.ocp.classify_media(f"play silent {noise} movie", "en-us")[0], MediaType.SILENT_MOVIE) - self.assertEqual(self.ocp.classify_media(f"play {noise} black and white movie", "en-us")[0], + self.assertEqual(self.ocp.classify_media(f"play {noise} music", "en-US")[0], MediaType.MUSIC) + self.assertIsInstance(self.ocp.classify_media(f"play music {noise}", "en-US")[1], float) + self.assertEqual(self.ocp.classify_media(f"play {noise} movie soundtrack", "en-US")[0], MediaType.MUSIC) + self.assertEqual(self.ocp.classify_media(f"play movie {noise}", "en-US")[0], MediaType.MOVIE) + self.assertEqual(self.ocp.classify_media(f"play silent {noise} movie", "en-US")[0], MediaType.SILENT_MOVIE) + self.assertEqual(self.ocp.classify_media(f"play {noise} black and white movie", "en-US")[0], MediaType.BLACK_WHITE_MOVIE) - self.assertEqual(self.ocp.classify_media(f"play short {noise} film", "en-us")[0], MediaType.SHORT_FILM) - self.assertEqual(self.ocp.classify_media(f"play cartoons {noise}", "en-us")[0], MediaType.CARTOON) - self.assertEqual(self.ocp.classify_media(f"play {noise} episode", "en-us")[0], MediaType.VIDEO_EPISODES) - self.assertEqual(self.ocp.classify_media(f"play {noise} podcast", "en-us")[0], MediaType.PODCAST) - self.assertEqual(self.ocp.classify_media(f"play {noise} book", "en-us")[0], MediaType.AUDIOBOOK) - self.assertEqual(self.ocp.classify_media(f"play radio {noise} FM", "en-us")[0], MediaType.RADIO) - self.assertEqual(self.ocp.classify_media(f"read {noise}", "en-us")[0], MediaType.AUDIOBOOK) + self.assertEqual(self.ocp.classify_media(f"play short {noise} film", "en-US")[0], MediaType.SHORT_FILM) + self.assertEqual(self.ocp.classify_media(f"play cartoons {noise}", "en-US")[0], MediaType.CARTOON) + self.assertEqual(self.ocp.classify_media(f"play {noise} episode", "en-US")[0], MediaType.VIDEO_EPISODES) + self.assertEqual(self.ocp.classify_media(f"play {noise} podcast", "en-US")[0], MediaType.PODCAST) + self.assertEqual(self.ocp.classify_media(f"play {noise} book", "en-US")[0], MediaType.AUDIOBOOK) + self.assertEqual(self.ocp.classify_media(f"play radio {noise} FM", "en-US")[0], MediaType.RADIO) + self.assertEqual(self.ocp.classify_media(f"read {noise}", "en-US")[0], MediaType.AUDIOBOOK) class TestOCPPipelineMatcher(unittest.TestCase): @@ -119,61 +116,57 @@ def setUp(self): self.ocp = OCPPipelineMatcher(config=config) def test_match_high(self): - result = self.ocp.match_high(["play metallica"], "en-us") + result = self.ocp.match_high(["play metallica"], "en-US") self.assertIsNotNone(result) - self.assertEqual(result.intent_service, 'OCP_intents') - self.assertEqual(result.intent_type, 'ocp:play') + self.assertEqual(result.match_type, 'ocp:play') def test_match_high_with_invalid_input(self): - result = self.ocp.match_high(["put on some metallica"], "en-us") + result = self.ocp.match_high(["put on some metallica"], "en-US") self.assertIsNone(result) def test_match_medium(self): - result = self.ocp.match_medium(["put on some metallica"], "en-us") + result = self.ocp.match_medium(["put on some metallica"], "en-US") self.assertIsNotNone(result) - self.assertEqual(result.intent_service, 'OCP_media') - self.assertEqual(result.intent_type, 'ocp:play') + self.assertEqual(result.match_type, 'ocp:play') def test_match_medium_with_invalid_input(self): - result = self.ocp.match_medium(["i wanna hear metallica"], "en-us") + result = self.ocp.match_medium(["i wanna hear metallica"], "en-US") self.assertIsNone(result) def test_match_fallback(self): - result = self.ocp.match_fallback(["i wanna hear metallica"], "en-us") - print(result) + result = self.ocp.match_low(["i wanna hear metallica"], "en-US") self.assertIsNotNone(result) - self.assertEqual(result.intent_service, 'OCP_fallback') - self.assertEqual(result.intent_type, 'ocp:play') + self.assertEqual(result.match_type, 'ocp:play') def test_match_fallback_with_invalid_input(self): - result = self.ocp.match_fallback(["do the thing"], "en-us") + result = self.ocp.match_low(["do the thing"], "en-US") self.assertIsNone(result) def test_predict(self): - self.assertTrue(self.ocp.is_ocp_query("play a song", "en-us")[0]) - self.assertTrue(self.ocp.is_ocp_query("play my morning jams", "en-us")[0]) - self.assertTrue(self.ocp.is_ocp_query("i want to watch the matrix", "en-us")[0]) - self.assertFalse(self.ocp.is_ocp_query("tell me a joke", "en-us")[0]) - self.assertFalse(self.ocp.is_ocp_query("who are you", "en-us")[0]) - self.assertFalse(self.ocp.is_ocp_query("you suck", "en-us")[0]) + self.assertTrue(self.ocp.is_ocp_query("play a song", "en-US")[0]) + self.assertTrue(self.ocp.is_ocp_query("play my morning jams", "en-US")[0]) + self.assertTrue(self.ocp.is_ocp_query("i want to watch the matrix", "en-US")[0]) + self.assertFalse(self.ocp.is_ocp_query("tell me a joke", "en-US")[0]) + self.assertFalse(self.ocp.is_ocp_query("who are you", "en-US")[0]) + self.assertFalse(self.ocp.is_ocp_query("you suck", "en-US")[0]) def test_predict_prob(self): # "metallica" in csv dataset self.ocp.config["classifier_threshold"] = 0.2 - self.assertEqual(self.ocp.classify_media("play metallica", "en-us")[0], MediaType.MUSIC) - self.assertIsInstance(self.ocp.classify_media("play metallica", "en-us")[1], float) + self.assertEqual(self.ocp.classify_media("play metallica", "en-US")[0], MediaType.MUSIC) + self.assertIsInstance(self.ocp.classify_media("play metallica", "en-US")[1], float) self.ocp.config["classifier_threshold"] = 0.5 - self.assertEqual(self.ocp.classify_media("play metallica", "en-us")[0], MediaType.GENERIC) - self.assertIsInstance(self.ocp.classify_media("play metallica", "en-us")[1], float) + self.assertEqual(self.ocp.classify_media("play metallica", "en-US")[0], MediaType.GENERIC) + self.assertIsInstance(self.ocp.classify_media("play metallica", "en-US")[1], float) @unittest.skip("TODO - classifiers needs retraining") def test_predict_prob_with_unknown_entity(self): # "klownevilus" not in the csv dataset self.ocp.config["classifier_threshold"] = 0.2 - self.assertEqual(self.ocp.classify_media("play klownevilus", "en-us")[0], MediaType.MUSIC) - self.assertIsInstance(self.ocp.classify_media("play klownevilus", "en-us")[1], float) + self.assertEqual(self.ocp.classify_media("play klownevilus", "en-US")[0], MediaType.MUSIC) + self.assertIsInstance(self.ocp.classify_media("play klownevilus", "en-US")[1], float) self.ocp.config["classifier_threshold"] = 0.5 - self.assertEqual(self.ocp.classify_media("play klownevilus", "en-us")[0], MediaType.GENERIC) + self.assertEqual(self.ocp.classify_media("play klownevilus", "en-US")[0], MediaType.GENERIC) self.ocp.config["classifier_threshold"] = 0.1 self.ocp.handle_skill_keyword_register(Message("", { @@ -183,7 +176,7 @@ def test_predict_prob_with_unknown_entity(self): "samples": ["klownevilus"] })) # should be MOVIE not MUSIC TODO fix me - self.assertEqual(self.ocp.classify_media("play klownevilus", "en-us")[0], MediaType.MOVIE) + self.assertEqual(self.ocp.classify_media("play klownevilus", "en-US")[0], MediaType.MOVIE) if __name__ == '__main__':