diff --git a/ovos_padatious/opm.py b/ovos_padatious/opm.py index 4d51ca3..5d9d72b 100644 --- a/ovos_padatious/opm.py +++ b/ovos_padatious/opm.py @@ -15,7 +15,7 @@ """Intent service wrapping padatious.""" from functools import lru_cache from os.path import expanduser, isfile -from threading import Event +from threading import Event, RLock from typing import Optional, Dict, List, Union from langcodes import closest_match @@ -90,7 +90,7 @@ def __init__(self, bus: Optional[Union[MessageBusClient, FakeBus]] = None, config: Optional[Dict] = None): super().__init__(bus, config) - + self.lock = RLock() core_config = Configuration() self.lang = standardize_lang_tag(core_config.get("lang", "en-US")) langs = core_config.get('secondary_langs') or [] @@ -121,6 +121,7 @@ def __init__(self, bus: Optional[Union[MessageBusClient, FakeBus]] = None, self.bus.on('intent.service.padatious.get', self.handle_get_padatious) self.bus.on('intent.service.padatious.manifest.get', self.handle_padatious_manifest) self.bus.on('intent.service.padatious.entities.manifest.get', self.handle_entity_manifest) + self.bus.on('mycroft.skills.train', self.train) LOG.debug('Loaded Padatious intent pipeline') @@ -189,21 +190,26 @@ def train(self, message=None): Args: message (Message): optional triggering message """ - name = message.data["name"] if message else "" - if not any(engine.must_train - for engine in self.containers.values()): - LOG.debug(f"Nothing new to train for '{name}'") + LOG.debug("Padatious training start") + if not any(engine.must_train for engine in self.containers.values()): + LOG.debug(f"Nothing new to train for padatious") + # inform the rest of the system to not wait for training finish + self.bus.emit(Message('mycroft.skills.trained')) return - for lang in self.containers: - if self.containers[lang].must_train: - LOG.debug(f"Training '{name}' for lang '{lang}'") - self.containers[lang].train() + with self.lock: + for lang in self.containers: + if self.containers[lang].must_train: + LOG.debug(f"Training padatious for lang '{lang}'") + self.containers[lang].train() + + LOG.debug(f"Training complete for padatious!") + if not self.finished_initial_train: + self.finished_initial_train = True - LOG.debug(f"Training complete for '{name}'!") - if not self.finished_initial_train: - self.bus.emit(Message('mycroft.skills.trained')) - self.finished_initial_train = True + # inform the rest of the system to stop waiting for training finish + self.bus.emit(Message('mycroft.skills.trained')) + LOG.debug("Padatious training end") @deprecated("'wait_and_train' has been deprecated, use 'train' directly", "2.0.0") def wait_and_train(self): @@ -264,7 +270,9 @@ def _register_object(self, message, object_name, register_func): register_func(name, samples) - self.train(message) + self.finished_initial_train = False + if self.config.get("instant_train", True): + self.train(message) def register_intent(self, message): """Messagebus handler for registering intents. diff --git a/ovos_padatious/simple_intent.py b/ovos_padatious/simple_intent.py index 572db9b..1cd97bc 100644 --- a/ovos_padatious/simple_intent.py +++ b/ovos_padatious/simple_intent.py @@ -13,7 +13,7 @@ # limitations under the License. from fann2 import libfann as fann - +from ovos_utils.log import LOG from ovos_padatious.id_manager import IdManager from ovos_padatious.util import resolve_conflicts, StrEnum @@ -26,7 +26,7 @@ class Ids(StrEnum): w_4 = ':4' -class SimpleIntent(object): +class SimpleIntent: """General intent used to match sentences or phrases""" LENIENCE = 0.6 @@ -69,6 +69,9 @@ def train(self, train_data): inputs = [] outputs = [] + n_pos = len(list(train_data.my_sents(self.name))) + n_neg = len(list(train_data.other_sents(self.name))) + def add(vec, out): inputs.append(self.vectorize(vec)) outputs.append([out]) @@ -115,13 +118,14 @@ def calc_weight(w): return pow(len(w), 3.0) train_data = fann.training_data() train_data.set_train_data(inputs, outputs) - + LOG.debug(f"Training {self.name} with samples: {n_pos} positive + {n_neg} negative") for _ in range(10): self.configure_net() self.net.train_on_data(train_data, 1000, 0, 0) self.net.test_data(train_data) if self.net.get_bit_fail() == 0: break + LOG.debug(f"Training {self.name} finished!") def save(self, prefix): prefix += '.intent' diff --git a/ovos_padatious/training_manager.py b/ovos_padatious/training_manager.py index 3cf9ac6..95981a9 100644 --- a/ovos_padatious/training_manager.py +++ b/ovos_padatious/training_manager.py @@ -73,6 +73,7 @@ def add(self, name: str, lines: List[str], reload_cache: bool = False, must_trai must_train (bool): Whether training is required for the new intent/entity. """ if not must_train: + LOG.debug(f"Loading {name} from intent cache") self.objects.append(self.cls.from_file(name=name, folder=self.cache)) # general case: load resource (entity or intent) to training queue # or if no change occurred to memory data structures @@ -87,11 +88,13 @@ def add(self, name: str, lines: List[str], reload_cache: bool = False, must_trai retrain = reload_cache or old_hsh != new_hsh if not retrain: try: + LOG.debug(f"Loading {name} from intent cache") self.objects.append(self.cls.from_file(name=name, folder=self.cache)) - except: - LOG.error(f"Failed to load intent from cache: {name}") + except Exception as e: + LOG.error(f"Failed to load intent from cache: {name} - {str(e)}") retrain = True if retrain: + LOG.debug(f"Queuing {name} for training") self.objects_to_train.append(self.cls(name=name, hsh=new_hsh)) self.train_data.add_lines(name, lines)