Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat/reranker+optional_ovos_classifiers #529

Merged
merged 8 commits into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ovos_core/intent_services/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(self, bus, config=None):
self.padacioso_service = PadaciosoService(bus, self.config["padatious"])
self.fallback = FallbackService(bus)
self.converse = ConverseService(bus)
self.common_qa = CommonQAService(bus)
self.common_qa = CommonQAService(bus, self.config.get("common_query"))
self.stop = StopService(bus)
self.ocp = OCPPipelineMatcher(self.bus, config=self.config.get("OCP", {}))
self.utterance_plugins = UtteranceTransformersService(bus)
Expand Down
30 changes: 22 additions & 8 deletions ovos_core/intent_services/commonqa_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,11 @@
import time
from ovos_bus_client.message import Message
from ovos_bus_client.session import SessionManager
from ovos_classifiers.opm.heuristics import BM25MultipleChoiceSolver
from ovos_config.config import Configuration
from ovos_utils import flatten_list
from ovos_utils.log import LOG
from ovos_workshop.app import OVOSAbstractApplication

from ovos_plugin_manager.solvers import find_multiple_choice_solver_plugins
from ovos_plugin_manager.templates.pipeline import IntentMatch


Expand All @@ -32,19 +31,30 @@ class Query:


class CommonQAService(OVOSAbstractApplication):
def __init__(self, bus):
def __init__(self, bus, config: Optional[Dict] = None):
super().__init__(bus=bus,
skill_id="common_query.openvoiceos",
resources_dir=f"{dirname(__file__)}")
self.active_queries: Dict[str, Query] = dict()

self.common_query_skills = None
config = Configuration().get('skills', {}).get("common_query") or dict()
config = config or Configuration().get('intents', {}).get("common_query") or dict()
self._extension_time = config.get('extension_time') or 3
CommonQAService._EXTENSION_TIME = self._extension_time
self._min_wait = config.get('min_response_wait') or 2
self._max_time = config.get('max_response_wait') or 6 # regardless of extensions
self.untier = BM25MultipleChoiceSolver() # TODO - allow plugin from config
reranker_module = config.get("reranker", "ovos-choice-solver-bm25") # default to BM25 from ovos-classifiers
self.reranker = None
try:
for name, plug in find_multiple_choice_solver_plugins().items():
if name == reranker_module:
self.reranker = plug()
LOG.info(f"CommonQuery ReRanker: {name}")
break
else:
LOG.info("No CommonQuery ReRanker loaded!")
except Exception as e:
LOG.error(f"Failed to load ReRanker plugin: {e}")
self.add_event('question:query.response', self.handle_query_response)
self.add_event('common_query.question', self.handle_question)
self.add_event('ovos.common_query.pong', self.handle_skill_pong)
Expand Down Expand Up @@ -253,9 +263,13 @@ def _query_timeout(self, message: Message):
tied_ids = [m["skill_id"] for m in ties]
LOG.info(f"Tied skills: {tied_ids}")
answers = {m["answer"]: m for m in ties}
best_ans = self.untier.select_answer(query.query,
list(answers.keys()),
{"lang": query.lang})
if self.reranker is None:
# random pick, no re-ranker available
best_ans = list(answers.keys())[0]
else:
best_ans = self.reranker.select_answer(query.query,
list(answers.keys()),
{"lang": query.lang})
best = answers[best_ans]

LOG.info('Handling with: ' + str(best['skill_id']))
Expand Down
59 changes: 44 additions & 15 deletions ovos_core/intent_services/ocp_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from ovos_bus_client.message import Message, dig_for_message
from ovos_bus_client.session import SessionManager
from ovos_bus_client.util import wait_for_reply
from ovos_classifiers.skovos.classifier import SklearnOVOSClassifier
from ovos_classifiers.skovos.features import ClassifierProbaVectorizer, KeywordFeaturesVectorizer
from ovos_utils import classproperty
from ovos_utils.gui import is_gui_connected, is_gui_running
from ovos_utils.log import LOG
Expand Down Expand Up @@ -146,11 +144,16 @@ def __init__(self, bus=None, config=None):

def load_classifiers(self):
# warm up the featurizer so intent matches faster (lazy loaded)
if self.entity_csvs:
OCPFeaturizer.load_csv(self.entity_csvs)
OCPFeaturizer.extract_entities("UNLEASH THE AUTOMATONS")

if self.config.get("experimental_binary_classifier", True): # ocp_medium
try:
OCPFeaturizer.init_keyword_matcher()
if self.entity_csvs:
OCPFeaturizer.load_csv(self.entity_csvs)
OCPFeaturizer.extract_entities("UNLEASH THE AUTOMATONS")
except ImportError: # ovos-classifiers is optional
pass

if self.config.get("experimental_binary_classifier", False): # ocp_medium
from ovos_classifiers.skovos.classifier import SklearnOVOSClassifier
LOG.info("Using experimental OCP binary classifier")
# TODO - train a single multilingual model instead of this
b = f"{dirname(__file__)}/models"
Expand All @@ -160,7 +163,8 @@ def load_classifiers(self):
c = SklearnOVOSClassifier.from_file(f"{b}/binary_ocp_cv2_kw_medium.clf")
self._binary_en_clf = (c, OCPFeaturizer("binary_ocp_cv2_small"))

if self.config.get("experimental_media_classifier", True):
if self.config.get("experimental_media_classifier", False):
from ovos_classifiers.skovos.classifier import SklearnOVOSClassifier
LOG.info("Using experimental OCP media type classifier")
# TODO - train a single multilingual model instead of this
b = f"{dirname(__file__)}/models"
Expand Down Expand Up @@ -253,6 +257,8 @@ def handle_skill_register(self, message: Message):
except:
LOG.error(f"{skill_id} reported an invalid media_type: {m}")

if OCPFeaturizer.ocp_keywords is None:
return
# TODO - review below and add missing
# set bias in classifier
# aliases -> {type}_streaming_service bias
Expand Down Expand Up @@ -281,6 +287,8 @@ def handle_skill_register(self, message: Message):

def handle_skill_keyword_register(self, message: Message):
""" register skill provided keywords """
if OCPFeaturizer.ocp_keywords is None:
return
skill_id = message.data["skill_id"]
kw_label = message.data["label"]
media = message.data["media_type"]
Expand Down Expand Up @@ -405,7 +413,10 @@ def match_medium(self, utterances: List[str], lang: str, message: Message = None
media_type, confidence = self.classify_media(utterance, lang)

# extract entities
ents = OCPFeaturizer.extract_entities(utterance)
if OCPFeaturizer.ocp_keywords is None:
ents = {}
else:
ents = OCPFeaturizer.extract_entities(utterance)

# extract the query string
query = self.remove_voc(utterance, "Play", lang).strip()
Expand All @@ -424,7 +435,11 @@ def match_fallback(self, utterances: List[str], lang: str, message: Message = No
""" match an utterance via presence of known OCP keywords,
recommended before fallback_low pipeline stage"""
utterance = utterances[0].lower()
ents = OCPFeaturizer.extract_entities(utterance)
if OCPFeaturizer.ocp_keywords is None:
ents = {}
else:
ents = OCPFeaturizer.extract_entities(utterance)

if not ents:
return None

Expand Down Expand Up @@ -487,7 +502,10 @@ def _process_play_query(self, utterance: str, lang: str, match: dict = None,
# extract the query string
query = self.remove_voc(utterance, "Play", lang).strip()

ents = OCPFeaturizer.extract_entities(utterance)
if OCPFeaturizer.ocp_keywords is None:
ents = {}
else:
ents = OCPFeaturizer.extract_entities(utterance)

return IntentMatch(intent_service="OCP_intents",
intent_type="ocp:play",
Expand Down Expand Up @@ -728,7 +746,8 @@ def voc_match_media(self, query: str, lang: str) -> Tuple[MediaType, float]:
def classify_media(self, query: str, lang: str) -> Tuple[MediaType, float]:
""" determine what media type is being requested """
# using a trained classifier (Experimental)
if self.config.get("experimental_media_classifier", True):
if self.config.get("experimental_media_classifier", False):
from ovos_classifiers.skovos.classifier import SklearnOVOSClassifier
try:
if lang.startswith("en"):
clf: SklearnOVOSClassifier = self._media_en_clf[0]
Expand All @@ -753,7 +772,8 @@ def classify_media(self, query: str, lang: str) -> Tuple[MediaType, float]:

def is_ocp_query(self, query: str, lang: str) -> Tuple[bool, float]:
""" determine if a playback question is being asked"""
if self.config.get("experimental_binary_classifier", True):
if self.config.get("experimental_binary_classifier", False):
from ovos_classifiers.skovos.classifier import SklearnOVOSClassifier
try:
# TODO - train a single multilingual classifier
if lang.startswith("en"):
Expand Down Expand Up @@ -1214,8 +1234,7 @@ def cps2media(res: dict, media_type=MediaType.GENERIC) -> Tuple[MediaEntry, dict
class OCPFeaturizer:
"""used by the experimental media type classifier,
API should be considered unstable"""
# ignore_list accounts for "noise" keywords in the csv file
ocp_keywords = KeywordFeaturesVectorizer(ignore_list=["play", "stop"])
ocp_keywords = None
# defined at training time
_clf_labels = ['ad_keyword', 'album_name', 'anime_genre', 'anime_name', 'anime_streaming_service',
'artist_name', 'asmr_keyword', 'asmr_trigger', 'audio_genre', 'audiobook_narrator',
Expand Down Expand Up @@ -1246,6 +1265,9 @@ class OCPFeaturizer:

def __init__(self, base_clf=None):
self.clf_feats = None
from ovos_classifiers.skovos.classifier import SklearnOVOSClassifier
from ovos_classifiers.skovos.features import ClassifierProbaVectorizer
self.init_keyword_matcher()
if base_clf:
if isinstance(base_clf, str):
clf_path = f"{dirname(__file__)}/models/{base_clf}.clf"
Expand All @@ -1255,6 +1277,13 @@ def __init__(self, base_clf=None):
for l in self._clf_labels: # no samples, just to ensure featurizer has right number of feats
self.ocp_keywords.register_entity(l, [])

@classmethod
def init_keyword_matcher(cls):
from ovos_classifiers.skovos.features import KeywordFeaturesVectorizer
if OCPFeaturizer.ocp_keywords is None:
# ignore_list accounts for "noise" keywords in the csv file
OCPFeaturizer.ocp_keywords = KeywordFeaturesVectorizer(ignore_list=["play", "stop"])

@classmethod
def load_csv(cls, entity_csvs: list):
for csv in entity_csvs or []:
Expand Down
2 changes: 0 additions & 2 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ ovos-config~=0.0,>=0.0.13a8
ovos-lingua-franca>=0.4.7
ovos-backend-client~=0.1.0
ovos-workshop>=0.0.16a45
# provides plugins and classic machine learning framework
ovos-classifiers<0.1.0, >=0.0.0a53

# ensure default plugin available for any solver plugins
ovos-translate-server-plugin
3 changes: 3 additions & 0 deletions requirements/tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ ovos-listener~=0.0, >=0.0.3a2
ovos-gui~=0.0, >=0.0.2
ovos-messagebus~=0.0

# provides plugins and classic machine learning framework
ovos-classifiers<0.1.0, >=0.0.0a53

# Support OCP tests
ovos_bus_client>=0.0.9a15
ovos-utils>=0.1.0a16
Expand Down
2 changes: 1 addition & 1 deletion test/unittests/skills/test_ocp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def setUp(self):
self.featurizer = OCPFeaturizer()

@patch('os.path.isfile', return_value=True)
@patch('ovos_core.intent_services.ocp_service.KeywordFeaturesVectorizer.load_entities')
@patch('ovos_classifiers.skovos.features.KeywordFeaturesVectorizer.load_entities')
@patch.object(LOG, 'info')
def test_load_csv_with_existing_file(self, mock_log_info, mock_load_entities, mock_isfile):
csv_path = "existing_file.csv"
Expand Down
Loading