diff --git a/ocp_pipeline/opm.py b/ocp_pipeline/opm.py index 273b8f3..da90673 100644 --- a/ocp_pipeline/opm.py +++ b/ocp_pipeline/opm.py @@ -12,13 +12,14 @@ from ovos_bus_client.session import SessionManager from ovos_plugin_manager.ocp import available_extractors from ovos_plugin_manager.templates.pipeline import IntentMatch, PipelinePlugin +from ovos_utils.lang import standardize_lang_tag, get_language_dir from ovos_utils.log import LOG from ovos_utils.messagebus import FakeBus from ovos_utils.ocp import MediaType, PlaybackType, PlaybackMode, PlayerState, OCP_ID, \ MediaEntry, Playlist, MediaState, TrackState, dict2entry, PluginStream from ovos_workshop.app import OVOSAbstractApplication from padacioso import IntentContainer - +from langcodes import closest_match from ocp_pipeline.feats import OCPFeaturizer from ocp_pipeline.legacy import LegacyCommonPlay @@ -102,16 +103,18 @@ def load_classifiers(self): def load_resource_files(self): intents = {} for lang in self.native_langs: + lang = standardize_lang_tag(lang) intents[lang] = {} - locale_folder = join(dirname(__file__), "locale", lang) - for f in os.listdir(locale_folder): - path = join(locale_folder, f) - if f in self.intents: - with open(path) as intent: - samples = intent.read().split("\n") - for idx, s in enumerate(samples): - samples[idx] = s.replace("{{", "{").replace("}}", "}") - intents[lang][f] = samples + locale_folder = get_language_dir(join(dirname(__file__), "locale"), lang) + if locale_folder is not None: + for f in os.listdir(locale_folder): + path = join(locale_folder, f) + if f in self.intents: + with open(path) as intent: + samples = intent.read().split("\n") + for idx, s in enumerate(samples): + samples[idx] = s.replace("{{", "{").replace("}}", "}") + intents[lang][f] = samples return intents def register_ocp_api_events(self): @@ -138,6 +141,7 @@ def register_ocp_intents(self): intent_files = self.load_resource_files() for lang, intent_data in intent_files.items(): + lang = standardize_lang_tag(lang) self.intent_matchers[lang] = IntentContainer() for intent_name in self.intents: samples = intent_data.get(intent_name) @@ -286,7 +290,8 @@ def handle_player_state_update(self, message: Message): def match_high(self, utterances: List[str], lang: str, message: Message = None) -> Optional[IntentMatch]: """ exact matches only, handles playback control recommended after high confidence intents pipeline stage """ - if lang not in self.intent_matchers: + lang = self._get_closest_lang(lang) + if lang is None: # no intents registered for this lang return None self.bus.emit(Message("ovos.common_play.status")) # sync @@ -327,6 +332,8 @@ def match_high(self, utterances: List[str], lang: str, message: Message = None) def match_medium(self, utterances: List[str], lang: str, message: Message = None) -> Optional[IntentMatch]: """ match a utterance via classifiers, recommended before common_qa pipeline stage""" + lang = standardize_lang_tag(lang) + utterance = utterances[0].lower() # is this a OCP query ? is_ocp, bconf = self.is_ocp_query(utterance, lang) @@ -368,6 +375,8 @@ def match_fallback(self, utterances: List[str], lang: str, message: Message = No if not ents: return None + lang = standardize_lang_tag(lang) + # classify the query media type media_type, confidence = self.classify_media(utterance, lang) @@ -388,7 +397,7 @@ def match_fallback(self, utterances: List[str], lang: str, message: Message = No def _process_play_query(self, utterance: str, lang: str, match: dict = None, message: Optional[Message] = None) -> Optional[IntentMatch]: - + lang = standardize_lang_tag(lang) match = match or {} player = self.get_player(message) # if media is currently paused, empty string means "resume playback" @@ -455,6 +464,7 @@ def handle_search_query(self, message: Message): if num: phrase += " " + num + lang = standardize_lang_tag(lang) # classify the query media type media_type, prob = self.classify_media(utterance, lang) # search common play skills @@ -503,6 +513,7 @@ def handle_play_intent(self, message: Message): skills = message.data.get("skills", []) # search common play skills + lang = standardize_lang_tag(lang) results = self._search(query, media_type, lang, skills=skills, message=message) @@ -613,6 +624,7 @@ def handle_search_error_intent(self, message: Message): # NLP def voc_match_media(self, query: str, lang: str) -> Tuple[MediaType, float]: + lang = standardize_lang_tag(lang) # simplistic approach via voc_match, works anywhere # and it's easy to localize, but isn't very accurate if self.voc_match(query, "MusicKeyword", lang=lang): @@ -674,6 +686,7 @@ def voc_match_media(self, query: str, lang: str) -> Tuple[MediaType, float]: def classify_media(self, query: str, lang: str) -> Tuple[MediaType, float]: """ determine what media type is being requested """ + lang = standardize_lang_tag(lang) # using a trained classifier (Experimental) if self.config.get("experimental_media_classifier", False): from ovos_classifiers.skovos.classifier import SklearnOVOSClassifier @@ -701,6 +714,7 @@ def classify_media(self, query: str, lang: str) -> Tuple[MediaType, float]: def is_ocp_query(self, query: str, lang: str) -> Tuple[bool, float]: """ determine if a playback question is being asked""" + lang = standardize_lang_tag(lang) if self.config.get("experimental_binary_classifier", False): from ovos_classifiers.skovos.classifier import SklearnOVOSClassifier try: @@ -731,6 +745,7 @@ def _should_resume(self, phrase: str, lang: str, message: Optional[Message] = No @param phrase: Extracted playback phrase @return: True if player should resume, False if this is a new request """ + lang = standardize_lang_tag(lang) player = self.get_player(message) if player.player_state == PlayerState.PAUSED: if not phrase.strip() or \ @@ -782,6 +797,7 @@ def normalize_results(self, results: list) -> List[Union[MediaEntry, Playlist, P def filter_results(self, results: list, phrase: str, lang: str, media_type: MediaType = MediaType.GENERIC, message: Optional[Message] = None) -> list: + lang = standardize_lang_tag(lang) # ignore very low score matches l1 = len(results) results = [r for r in results @@ -1031,6 +1047,10 @@ def match_legacy(self, utterances: List[str], lang: str, message: Message = None utterance = utterances[0].lower() + lang = self._get_closest_lang(lang) + if lang is None: # no intents registered for this lang + return None + match = self.intent_matchers[lang].calc_intent(utterance) if match["name"] is None: @@ -1045,6 +1065,18 @@ def match_legacy(self, utterances: List[str], lang: str, message: Message = None skill_id=OCP_ID, utterance=utterance) + def _get_closest_lang(self, lang: str) -> Optional[str]: + if self.intent_matchers: + lang = standardize_lang_tag(lang) + closest, score = closest_match(lang, list(self.intent_matchers.keys())) + # https://langcodes-hickford.readthedocs.io/en/sphinx/index.html#distance-values + # 0 -> These codes represent the same language, possibly after filling in values and normalizing. + # 1- 3 -> These codes indicate a minor regional difference. + # 4 - 10 -> These codes indicate a significant but unproblematic regional difference. + if score < 10: + return closest + return None + def handle_legacy_cps(self, message: Message): """intent handler for legacy CPS matches""" utt = message.data["query"] diff --git a/requirements.txt b/requirements.txt index ebc59f3..b37171d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,4 @@ ovos-workshop>=0.1.7,<2.0.0 -ovos-classifiers \ No newline at end of file +ovos-classifiers +ovos-utils>=0.3.5,<1.0.0 +langcodes \ No newline at end of file