Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat:delayed_padatious_training #29

Merged
merged 6 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 23 additions & 15 deletions ovos_padatious/opm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Intent service wrapping padatious."""
from functools import lru_cache
from os.path import expanduser, isfile
from threading import Event
from threading import Event, RLock
from typing import Optional, Dict, List, Union

from langcodes import closest_match
Expand Down Expand Up @@ -90,7 +90,7 @@ def __init__(self, bus: Optional[Union[MessageBusClient, FakeBus]] = None,
config: Optional[Dict] = None):

super().__init__(bus, config)

self.lock = RLock()
core_config = Configuration()
self.lang = standardize_lang_tag(core_config.get("lang", "en-US"))
langs = core_config.get('secondary_langs') or []
Expand Down Expand Up @@ -121,6 +121,7 @@ def __init__(self, bus: Optional[Union[MessageBusClient, FakeBus]] = None,
self.bus.on('intent.service.padatious.get', self.handle_get_padatious)
self.bus.on('intent.service.padatious.manifest.get', self.handle_padatious_manifest)
self.bus.on('intent.service.padatious.entities.manifest.get', self.handle_entity_manifest)
self.bus.on('padatious:train', self.train)

LOG.debug('Loaded Padatious intent pipeline')

Expand Down Expand Up @@ -189,21 +190,26 @@ def train(self, message=None):
Args:
message (Message): optional triggering message
"""
name = message.data["name"] if message else ""
if not any(engine.must_train
for engine in self.containers.values()):
LOG.debug(f"Nothing new to train for '{name}'")
LOG.debug("Padatious training start")
if not any(engine.must_train for engine in self.containers.values()):
LOG.debug(f"Nothing new to train for padatious")
# inform the rest of the system to not wait for training finish
self.bus.emit(Message('mycroft.skills.trained'))
return

for lang in self.containers:
if self.containers[lang].must_train:
LOG.debug(f"Training '{name}' for lang '{lang}'")
self.containers[lang].train()
with self.lock:
for lang in self.containers:
if self.containers[lang].must_train:
LOG.debug(f"Training padatious for lang '{lang}'")
self.containers[lang].train()

LOG.debug(f"Training complete for padatious!")
if not self.finished_initial_train:
self.finished_initial_train = True

LOG.debug(f"Training complete for '{name}'!")
if not self.finished_initial_train:
self.bus.emit(Message('mycroft.skills.trained'))
self.finished_initial_train = True
# inform the rest of the system to stop waiting for training finish
self.bus.emit(Message('mycroft.skills.trained'))
LOG.debug("Padatious training end")

@deprecated("'wait_and_train' has been deprecated, use 'train' directly", "2.0.0")
def wait_and_train(self):
Expand Down Expand Up @@ -264,7 +270,9 @@ def _register_object(self, message, object_name, register_func):

register_func(name, samples)

self.train(message)
self.finished_initial_train = False
if self.config.get("instant_train", True):
Copy link
Member Author

@JarbasAl JarbasAl Dec 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for backwards compat, but we might want to consider changing this default value in next breaking change release, this impacts thread safety and should likely be False

@coderabbitai open an issue

self.train(message)
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved

def register_intent(self, message):
"""Messagebus handler for registering intents.
Expand Down
10 changes: 7 additions & 3 deletions ovos_padatious/simple_intent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from fann2 import libfann as fann

from ovos_utils.log import LOG
from ovos_padatious.id_manager import IdManager
from ovos_padatious.util import resolve_conflicts, StrEnum

Expand All @@ -26,7 +26,7 @@ class Ids(StrEnum):
w_4 = ':4'


class SimpleIntent(object):
class SimpleIntent:
"""General intent used to match sentences or phrases"""
LENIENCE = 0.6

Expand Down Expand Up @@ -69,6 +69,9 @@ def train(self, train_data):
inputs = []
outputs = []

n_pos = len(list(train_data.my_sents(self.name)))
n_neg = len(list(train_data.other_sents(self.name)))

def add(vec, out):
inputs.append(self.vectorize(vec))
outputs.append([out])
Expand Down Expand Up @@ -115,13 +118,14 @@ def calc_weight(w): return pow(len(w), 3.0)

train_data = fann.training_data()
train_data.set_train_data(inputs, outputs)

LOG.debug(f"Training {self.name} with samples: {n_pos} positive + {n_neg} negative")
for _ in range(10):
self.configure_net()
self.net.train_on_data(train_data, 1000, 0, 0)
self.net.test_data(train_data)
if self.net.get_bit_fail() == 0:
break
LOG.debug(f"Training {self.name} finished!")

def save(self, prefix):
prefix += '.intent'
Expand Down
3 changes: 3 additions & 0 deletions ovos_padatious/training_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def add(self, name: str, lines: List[str], reload_cache: bool = False, must_trai
must_train (bool): Whether training is required for the new intent/entity.
"""
if not must_train:
LOG.debug(f"Loading {name} from intent cache")
self.objects.append(self.cls.from_file(name=name, folder=self.cache))
# general case: load resource (entity or intent) to training queue
# or if no change occurred to memory data structures
Expand All @@ -87,11 +88,13 @@ def add(self, name: str, lines: List[str], reload_cache: bool = False, must_trai
retrain = reload_cache or old_hsh != new_hsh
if not retrain:
try:
LOG.debug(f"Loading {name} from intent cache")
self.objects.append(self.cls.from_file(name=name, folder=self.cache))
except:
LOG.error(f"Failed to load intent from cache: {name}")
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
retrain = True
if retrain:
LOG.debug(f"Queuing {name} for training")
self.objects_to_train.append(self.cls(name=name, hsh=new_hsh))
self.train_data.add_lines(name, lines)

Expand Down
Loading