From 80ba4e9304ff3cc17911dfd391120bf802442fa0 Mon Sep 17 00:00:00 2001 From: miro Date: Thu, 12 Dec 2024 11:55:56 +0000 Subject: [PATCH] consistency in naming --- ovos_padatious/domain_container.py | 4 ++-- ovos_padatious/opm.py | 31 +++++++++++++++++++----------- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/ovos_padatious/domain_container.py b/ovos_padatious/domain_container.py index 137e9fd..49ae1ab 100644 --- a/ovos_padatious/domain_container.py +++ b/ovos_padatious/domain_container.py @@ -40,7 +40,7 @@ def remove_domain(self, domain_name: str): if domain_name in self.domain_engine.intent_names: self.domain_engine.remove_intent(domain_name) - def register_domain_intent(self, domain_name: str, intent_name: str, intent_samples: List[str]): + def add_domain_intent(self, domain_name: str, intent_name: str, intent_samples: List[str]): """ Register an intent within a specific domain. @@ -66,7 +66,7 @@ def remove_domain_intent(self, domain_name: str, intent_name: str): if domain_name in self.domains: self.domains[domain_name].remove_intent(intent_name) - def register_domain_entity(self, domain_name: str, entity_name: str, entity_samples: List[str]): + def add_domain_entity(self, domain_name: str, entity_name: str, entity_samples: List[str]): """ Register an entity within a specific domain. diff --git a/ovos_padatious/opm.py b/ovos_padatious/opm.py index c203e63..cc6002b 100644 --- a/ovos_padatious/opm.py +++ b/ovos_padatious/opm.py @@ -111,15 +111,13 @@ def __init__(self, bus: Optional[Union[MessageBusClient, FakeBus]] = None, self.conf_med = self.config.get("conf_med") or 0.8 self.conf_low = self.config.get("conf_low") or 0.5 - if engine_class is None: - if self.config.get("domain_engine"): - engine_class = DomainIntentContainer - else: - engine_class = IntentContainer + if engine_class is None and self.config.get("domain_engine"): + engine_class = DomainIntentContainer + self.engine_class = engine_class or IntentContainer intent_cache = expanduser(self.config.get('intent_cache') or f"{xdg_data_home()}/{get_xdg_base()}/intent_cache") - self.containers = {lang: engine_class(cache_dir=f"{intent_cache}/{lang}") for lang in langs} + self.containers = {lang: self.engine_class(cache_dir=f"{intent_cache}/{lang}") for lang in langs} self.finished_training_event = Event() # DEPRECATED self.finished_initial_train = False @@ -242,7 +240,10 @@ def __detach_intent(self, intent_name): for lang in self.containers: for skill_id, intents in self._skill2intent.items(): if intent_name in intents: - self.containers[lang].remove_domain_intent(skill_id, intent_name) + if isinstance(self.containers[lang], DomainIntentContainer): + self.containers[lang].remove_domain_intent(skill_id, intent_name) + else: + self.containers[lang].remove_intent(intent_name) def handle_detach_intent(self, message): """Messagebus handler for detaching padatious intent. @@ -285,7 +286,10 @@ def _register_object(self, message, object_name, register_func): with open(file_name) as f: samples = [line.strip() for line in f.readlines()] - register_func(skill_id, name, samples) + if self.engine_class == DomainIntentContainer: + register_func(skill_id, name, samples) + else: + register_func(name, samples) self.finished_initial_train = False if self.config.get("instant_train", True): @@ -304,7 +308,10 @@ def register_intent(self, message): lang = standardize_lang_tag(lang) if lang in self.containers: self.registered_intents.append(message.data['name']) - self._register_object(message, 'intent', self.containers[lang].register_domain_intent) + if isinstance(self.containers[lang], DomainIntentContainer): + self._register_object(message, 'intent', self.containers[lang].add_domain_intent) + else: + self._register_object(message, 'intent', self.containers[lang].add_intent) def register_entity(self, message): """Messagebus handler for registering entities. @@ -316,8 +323,10 @@ def register_entity(self, message): lang = standardize_lang_tag(lang) if lang in self.containers: self.registered_entities.append(message.data) - self._register_object(message, 'entity', - self.containers[lang].register_domain_entity) + if isinstance(self.containers[lang], DomainIntentContainer): + self._register_object(message, 'entity', self.containers[lang].add_domain_entity) + else: + self._register_object(message, 'entity', self.containers[lang].add_entity) def calc_intent(self, utterances: Union[str, List[str]], lang: Optional[str] = None, message: Optional[Message] = None) -> Optional[PadatiousIntent]: