Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat/domain_engine #31

Open
wants to merge 6 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
171 changes: 171 additions & 0 deletions ovos_padatious/domain_container.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
from collections import defaultdict
from typing import Dict, List, Optional
from ovos_utils.log import LOG
from ovos_padatious.intent_container import IntentContainer
from ovos_padatious.match_data import MatchData


class DomainIntentContainer:
"""
A domain-aware intent recognition engine that organizes intents and entities
into specific domains, providing flexible and hierarchical intent matching.
"""

def __init__(self, cache_dir: Optional[str] = None):
"""
Initialize the DomainIntentEngine.

Attributes:
domain_engine (IntentContainer): A top-level intent container for cross-domain calculations.
domains (Dict[str, IntentContainer]): A mapping of domain names to their respective intent containers.
training_data (Dict[str, List[str]]): A mapping of domain names to their associated training samples.
"""
self.cache_dir = cache_dir
self.domain_engine = IntentContainer(cache_dir=cache_dir)
self.domains: Dict[str, IntentContainer] = {}
self.training_data: Dict[str, List[str]] = defaultdict(list)
self.must_train = True

def remove_domain(self, domain_name: str):
"""
Remove a domain and its associated intents and training data.

Args:
domain_name (str): The name of the domain to remove.
"""
if domain_name in self.training_data:
self.training_data.pop(domain_name)
if domain_name in self.domains:
self.domains.pop(domain_name)
if domain_name in self.domain_engine.intent_names:
self.domain_engine.remove_intent(domain_name)

def add_domain_intent(self, domain_name: str, intent_name: str, intent_samples: List[str]):
"""
Register an intent within a specific domain.

Args:
domain_name (str): The name of the domain.
intent_name (str): The name of the intent to register.
intent_samples (List[str]): A list of sample sentences for the intent.
"""
if domain_name not in self.domains:
self.domains[domain_name] = IntentContainer(cache_dir=self.cache_dir)
self.domains[domain_name].add_intent(intent_name, intent_samples)
self.training_data[domain_name] += intent_samples
self.must_train = True

def remove_domain_intent(self, domain_name: str, intent_name: str):
"""
Remove a specific intent from a domain.

Args:
domain_name (str): The name of the domain.
intent_name (str): The name of the intent to remove.
"""
if domain_name in self.domains:
self.domains[domain_name].remove_intent(intent_name)

def add_domain_entity(self, domain_name: str, entity_name: str, entity_samples: List[str]):
"""
Register an entity within a specific domain.

Args:
domain_name (str): The name of the domain.
entity_name (str): The name of the entity to register.
entity_samples (List[str]): A list of sample phrases for the entity.
"""
if domain_name not in self.domains:
self.domains[domain_name] = IntentContainer(cache_dir=self.cache_dir)
self.domains[domain_name].add_entity(entity_name, entity_samples)

def remove_domain_entity(self, domain_name: str, entity_name: str):
"""
Remove a specific entity from a domain.

Args:
domain_name (str): The name of the domain.
entity_name (str): The name of the entity to remove.
"""
if domain_name in self.domains:
self.domains[domain_name].remove_entity(entity_name)

def calc_domains(self, query: str) -> List[MatchData]:
"""
Calculate the matching domains for a query.

Args:
query (str): The input query.

Returns:
List[MatchData]: A list of MatchData objects representing matching domains.
"""
if self.must_train:
self.train()

return self.domain_engine.calc_intents(query)

def calc_domain(self, query: str) -> MatchData:
"""
Calculate the best matching domain for a query.

Args:
query (str): The input query.

Returns:
MatchData: The best matching domain.
"""
if self.must_train:
self.train()
return self.domain_engine.calc_intent(query)

def calc_intent(self, query: str, domain: Optional[str] = None) -> MatchData:
"""
Calculate the best matching intent for a query within a specific domain.

Args:
query (str): The input query.
domain (Optional[str]): The domain to limit the search to. Defaults to None.

Returns:
MatchData: The best matching intent.
"""
if self.must_train:
self.train()
domain: str = domain or self.domain_engine.calc_intent(query).name
if domain in self.domains:
return self.domains[domain].calc_intent(query)
return MatchData(name=None, sent=query, matches=None, conf=0.0)

def calc_intents(self, query: str, domain: Optional[str] = None, top_k_domains: int = 2) -> List[MatchData]:
"""
Calculate matching intents for a query across domains or within a specific domain.

Args:
query (str): The input query.
domain (Optional[str]): The specific domain to search in. If None, searches across top-k domains.
top_k_domains (int): The number of top domains to consider. Defaults to 2.

Returns:
List[MatchData]: A list of MatchData objects representing matching intents, sorted by confidence.
"""
if self.must_train:
self.train()
if domain:
return self.domains[domain].calc_intents(query)
matches = []
domains = self.calc_domains(query)[:top_k_domains]
for domain in domains:
if domain.name in self.domains:
matches += self.domains[domain.name].calc_intents(query)
return sorted(matches, reverse=True, key=lambda k: k.conf)

def train(self):
for domain, samples in self.training_data.items():
LOG.debug(f"Training domain: {domain}")
self.domain_engine.add_intent(domain, samples)
self.domain_engine.train()
for domain in self.domains:
LOG.debug(f"Training domain sub-intents: {domain}")
self.domains[domain].train()
self.must_train = False
9 changes: 8 additions & 1 deletion ovos_padatious/intent_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from functools import wraps
from typing import List, Dict, Any, Optional

from ovos_config.meta import get_xdg_base
from ovos_utils.log import LOG
from ovos_utils.xdg_utils import xdg_data_home

from ovos_padatious import padaos
from ovos_padatious.entity import Entity
Expand Down Expand Up @@ -54,7 +56,8 @@ class IntentContainer:
cache_dir (str): Directory for caching the neural network models and intent/entity files.
"""

def __init__(self, cache_dir: str) -> None:
def __init__(self, cache_dir: str = None) -> None:
cache_dir = cache_dir or f"{xdg_data_home()}/{get_xdg_base()}/intent_cache"
JarbasAl marked this conversation as resolved.
Show resolved Hide resolved
os.makedirs(cache_dir, exist_ok=True)
self.cache_dir: str = cache_dir
self.must_train: bool = False
Expand All @@ -64,6 +67,10 @@ def __init__(self, cache_dir: str) -> None:
self.train_thread: Optional[Any] = None # deprecated
self.serialized_args: List[Dict[str, Any]] = [] # Serialized calls for training intents/entities

@property
def intent_names(self):
return self.intents.intent_names

def clear(self) -> None:
"""
Clears the current intent and entity managers and resets the container.
Expand Down
4 changes: 4 additions & 0 deletions ovos_padatious/intent_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ def __init__(self, cache: str, debug: bool = False):
super().__init__(Intent, cache)
self.debug = debug

@property
def intent_names(self):
return [i.name for i in self.objects + self.objects_to_train]

def calc_intents(self, query: str, entity_manager) -> List[MatchData]:
"""
Calculate matches for the given query against all registered intents.
Expand Down
55 changes: 42 additions & 13 deletions ovos_padatious/opm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
# limitations under the License.
#
"""Intent service wrapping padatious."""
from collections import defaultdict
from functools import lru_cache
from os.path import expanduser, isfile
from threading import Event, RLock
from typing import Optional, Dict, List, Union
from typing import Optional, Dict, List, Union, Type

from langcodes import closest_match
from ovos_bus_client.client import MessageBusClient
Expand All @@ -31,9 +32,16 @@
from ovos_utils.log import LOG, deprecated, log_deprecation
from ovos_utils.xdg_utils import xdg_data_home

from ovos_padatious import IntentContainer as PadatiousIntentContainer
from ovos_padatious import IntentContainer
from ovos_padatious.domain_container import DomainIntentContainer
from ovos_padatious.match_data import MatchData as PadatiousIntent

PadatiousIntentContainer = IntentContainer # backwards compat

# for easy typing
PadatiousEngine = Union[Type[IntentContainer],
Type[DomainIntentContainer]]


class PadatiousMatcher:
"""Matcher class to avoid redundancy in padatious intent matching."""
Expand Down Expand Up @@ -87,7 +95,8 @@ class PadatiousPipeline(ConfidenceMatcherPipeline):
"""Service class for padatious intent matching."""

def __init__(self, bus: Optional[Union[MessageBusClient, FakeBus]] = None,
config: Optional[Dict] = None):
config: Optional[Dict] = None,
engine_class: Optional[PadatiousEngine] = IntentContainer):

super().__init__(bus, config)
self.lock = RLock()
Expand All @@ -102,16 +111,20 @@ def __init__(self, bus: Optional[Union[MessageBusClient, FakeBus]] = None,
self.conf_med = self.config.get("conf_med") or 0.8
self.conf_low = self.config.get("conf_low") or 0.5

if engine_class is None and self.config.get("domain_engine"):
engine_class = DomainIntentContainer

self.engine_class = engine_class or IntentContainer
intent_cache = expanduser(self.config.get('intent_cache') or
f"{xdg_data_home()}/{get_xdg_base()}/intent_cache")
self.containers = {lang: PadatiousIntentContainer(f"{intent_cache}/{lang}")
for lang in langs}
self.containers = {lang: self.engine_class(cache_dir=f"{intent_cache}/{lang}") for lang in langs}

self.finished_training_event = Event() # DEPRECATED
self.finished_initial_train = False

self.registered_intents = []
self.registered_entities = []
self._skill2intent = defaultdict(list)
self.max_words = 50 # if an utterance contains more words than this, don't attempt to match

self.bus.on('padatious:register_intent', self.register_intent)
Expand Down Expand Up @@ -225,7 +238,12 @@ def __detach_intent(self, intent_name):
if intent_name in self.registered_intents:
self.registered_intents.remove(intent_name)
for lang in self.containers:
self.containers[lang].remove_intent(intent_name)
for skill_id, intents in self._skill2intent.items():
if intent_name in intents:
if isinstance(self.containers[lang], DomainIntentContainer):
self.containers[lang].remove_domain_intent(skill_id, intent_name)
else:
self.containers[lang].remove_intent(intent_name)

def handle_detach_intent(self, message):
"""Messagebus handler for detaching padatious intent.
Expand All @@ -242,8 +260,7 @@ def handle_detach_skill(self, message):
message (Message): message triggering action
"""
skill_id = message.data['skill_id']
remove_list = [i for i in self.registered_intents if skill_id in i]
for i in remove_list:
for i in self._skill2intent[skill_id]:
self.__detach_intent(i)

def _register_object(self, message, object_name, register_func):
Expand All @@ -254,6 +271,7 @@ def _register_object(self, message, object_name, register_func):
object_name (str): type of entry to register
register_func (callable): function to call for registration
"""
skill_id = message.data.get("skill_id") or message.context.get("skill_id")
file_name = message.data.get('file_name')
samples = message.data.get("samples")
name = message.data['name']
Expand All @@ -268,7 +286,10 @@ def _register_object(self, message, object_name, register_func):
with open(file_name) as f:
samples = [line.strip() for line in f.readlines()]

register_func(name, samples)
if self.engine_class == DomainIntentContainer:
register_func(skill_id, name, samples)
else:
register_func(name, samples)

self.finished_initial_train = False
if self.config.get("instant_train", True):
Expand All @@ -280,11 +301,17 @@ def register_intent(self, message):
Args:
message (Message): message triggering action
"""
skill_id = message.data.get("skill_id") or message.context.get("skill_id")
self._skill2intent[skill_id].append(message.data['name'])

Comment on lines +304 to +306
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ensure skill_id is properly obtained and handled

In the register_intent method, skill_id is obtained using:

skill_id = message.data.get("skill_id") or message.context.get("skill_id")

If skill_id is None, the intent might be registered under a None key in _skill2intent, which could lead to issues when detaching intents or skills. It's important to ensure that skill_id is not None. Consider adding a check to handle this situation:

if not skill_id:
    LOG.error("Skill ID is missing. Cannot register intent without a valid skill ID.")
    return

This will prevent unintended behaviors and ensure that each intent is associated with a valid skill.

lang = message.data.get('lang', self.lang)
lang = standardize_lang_tag(lang)
if lang in self.containers:
self.registered_intents.append(message.data['name'])
self._register_object(message, 'intent', self.containers[lang].add_intent)
if isinstance(self.containers[lang], DomainIntentContainer):
self._register_object(message, 'intent', self.containers[lang].add_domain_intent)
else:
self._register_object(message, 'intent', self.containers[lang].add_intent)

def register_entity(self, message):
"""Messagebus handler for registering entities.
Expand All @@ -296,8 +323,10 @@ def register_entity(self, message):
lang = standardize_lang_tag(lang)
if lang in self.containers:
self.registered_entities.append(message.data)
self._register_object(message, 'entity',
self.containers[lang].add_entity)
if isinstance(self.containers[lang], DomainIntentContainer):
self._register_object(message, 'entity', self.containers[lang].add_domain_entity)
else:
self._register_object(message, 'entity', self.containers[lang].add_entity)

def calc_intent(self, utterances: Union[str, List[str]], lang: Optional[str] = None,
message: Optional[Message] = None) -> Optional[PadatiousIntent]:
Expand Down Expand Up @@ -390,7 +419,7 @@ def handle_entity_manifest(self, message):

@lru_cache(maxsize=3) # repeat calls under different conf levels wont re-run code
def _calc_padatious_intent(utt: str,
intent_container: PadatiousIntentContainer,
intent_container: Union[IntentContainer, DomainIntentContainer],
sess: Session) -> Optional[PadatiousIntent]:
"""
Try to match an utterance to an intent in an intent_container
Expand Down
Loading
Loading