Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

performance: support padatious #42

Merged
merged 3 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 40 additions & 3 deletions ocp_pipeline/opm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from ovos_utils.ocp import MediaType, PlaybackType, PlaybackMode, PlayerState, OCP_ID, \
MediaEntry, Playlist, MediaState, TrackState, dict2entry, PluginStream
from ovos_workshop.app import OVOSAbstractApplication
from padacioso import IntentContainer
from ovos_utils.xdg_utils import xdg_data_home
from ovos_config.meta import get_xdg_base

from ocp_pipeline.feats import OCPFeaturizer
from ocp_pipeline.legacy import LegacyCommonPlay
Expand All @@ -48,6 +49,7 @@ class OCPPipelineMatcher(ConfidenceMatcherPipeline, OVOSAbstractApplication):
"next.intent", "prev.intent", "pause.intent", "play_favorites.intent",
"resume.intent", "like_song.intent"]
intent_matchers = {}
intent_cache = f"{xdg_data_home()}/{get_xdg_base()}/intent_cache"

def __init__(self, bus: Optional[Union[MessageBusClient, FakeBus]] = None,
config: Optional[Dict] = None):
Expand Down Expand Up @@ -150,9 +152,22 @@ def register_ocp_api_events(self):
@classmethod
def load_intent_files(cls):
intent_files = cls.load_resource_files()

try:
from ovos_padatious import IntentContainer
is_padatious = True
except ImportError:
from padacioso import IntentContainer
is_padatious = False
LOG.warning("Padatious not available, using padacioso. intent matching will be orders of magnitude slower!")

for lang, intent_data in intent_files.items():
lang = standardize_lang_tag(lang)
cls.intent_matchers[lang] = IntentContainer()
if is_padatious:
cache = f"{cls.intent_cache}/{lang}"
cls.intent_matchers[lang] = IntentContainer(cache)
else:
cls.intent_matchers[lang] = IntentContainer()
for intent_name in cls.intents:
samples = intent_data.get(intent_name)
if samples:
Expand Down Expand Up @@ -301,6 +316,10 @@ def handle_player_state_update(self, message: Message):
def match_high(self, utterances: List[str], lang: str, message: Message = None) -> Optional[IntentHandlerMatch]:
""" exact matches only, handles playback control
recommended after high confidence intents pipeline stage """

if not len(self.skill_aliases): # skill_id registered when skills load
return None # dont waste compute cycles, no media skills -> no match

lang = self._get_closest_lang(lang)
if lang is None: # no intents registered for this lang
return None
Expand All @@ -310,9 +329,21 @@ def match_high(self, utterances: List[str], lang: str, message: Message = None)
utterance = utterances[0].lower()
match = self.intent_matchers[lang].calc_intent(utterance)

if hasattr(match, "name"): # padatious
match = {
"name": match.name,
"conf": match.conf,
"entities": match.matches
}

if match["name"] is None:
return None
LOG.info(f"OCP exact match: {match}")

if match.get("conf", 1.0) < 0.7:
LOG.debug(f"Ignoring low confidence OCP match: {match}")
return None

LOG.info(f"OCP match: {match}")

player = self.get_player(message)

Expand Down Expand Up @@ -1128,6 +1159,12 @@ def match(self, utterances: List[str], lang: str, message: Message = None) -> Op
return None

match = OCPPipelineMatcher.intent_matchers[lang].calc_intent(utterance)
if hasattr(match, "name"): # padatious
match = {
"name": match.name,
"conf": match.conf,
"entities": match.matches
}

if match["name"] is None:
return None
Expand Down
2 changes: 2 additions & 0 deletions tests/test_ocp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def setUp(self):
os.path.dirname(ocp_pipeline.opm.__file__) + "/models/ocp_entities_v0.csv"
]}
self.ocp = OCPPipelineMatcher(config=config)
self.ocp.skill_aliases["test"] = ["Test Skill"] # pretend a skill is loaded or matching is skipped

def test_match_high(self):
result = self.ocp.match_high(["play metallica"], "en-US")
Expand Down Expand Up @@ -114,6 +115,7 @@ def setUp(self):
os.path.dirname(ocp_pipeline.opm.__file__) + "/models/ocp_entities_v0.csv"
]}
self.ocp = OCPPipelineMatcher(config=config)
self.ocp.skill_aliases["test"] = ["Test Skill"] # pretend a skill is loaded or matching is skipped

def test_match_high(self):
result = self.ocp.match_high(["play metallica"], "en-US")
Expand Down
Loading