Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat/reranker+optional_ovos_classifiers #529

Merged
merged 8 commits into from
Aug 4, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions ovos_core/intent_services/commonqa_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import time
from ovos_bus_client.message import Message
from ovos_bus_client.session import SessionManager
from ovos_classifiers.opm.heuristics import BM25MultipleChoiceSolver
from ovos_config.config import Configuration
from ovos_utils import flatten_list
from ovos_utils.log import LOG
Expand Down Expand Up @@ -44,7 +43,16 @@ def __init__(self, bus):
CommonQAService._EXTENSION_TIME = self._extension_time
self._min_wait = config.get('min_response_wait') or 2
self._max_time = config.get('max_response_wait') or 6 # regardless of extensions
self.untier = BM25MultipleChoiceSolver() # TODO - allow plugin from config
try:
from ovos_classifiers.opm.heuristics import BM25MultipleChoiceSolver
reranker = BM25MultipleChoiceSolver() # TODO - allow plugin from config
except Exception as e:
LOG.error(f"Failed to load CommonQuery ReRanker: {e}")
reranker = None
if reranker:
self.reranker = reranker
else:
self.reranker = None
self.add_event('question:query.response', self.handle_query_response)
self.add_event('common_query.question', self.handle_question)
self.add_event('ovos.common_query.pong', self.handle_skill_pong)
Expand Down Expand Up @@ -253,9 +261,13 @@ def _query_timeout(self, message: Message):
tied_ids = [m["skill_id"] for m in ties]
LOG.info(f"Tied skills: {tied_ids}")
answers = {m["answer"]: m for m in ties}
best_ans = self.untier.select_answer(query.query,
list(answers.keys()),
{"lang": query.lang})
if self.reranker is None:
# random pick, no re-ranker available
best_ans = list(answers.keys())[0]
else:
best_ans = self.reranker.select_answer(query.query,
list(answers.keys()),
{"lang": query.lang})
best = answers[best_ans]

LOG.info('Handling with: ' + str(best['skill_id']))
Expand Down
20 changes: 16 additions & 4 deletions ovos_core/intent_services/ocp_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from ovos_bus_client.message import Message, dig_for_message
from ovos_bus_client.session import SessionManager
from ovos_bus_client.util import wait_for_reply
from ovos_classifiers.skovos.classifier import SklearnOVOSClassifier
from ovos_classifiers.skovos.features import ClassifierProbaVectorizer, KeywordFeaturesVectorizer
from ovos_utils import classproperty
from ovos_utils.gui import is_gui_connected, is_gui_running
from ovos_utils.log import LOG
Expand Down Expand Up @@ -151,6 +149,7 @@ def load_classifiers(self):
OCPFeaturizer.extract_entities("UNLEASH THE AUTOMATONS")

if self.config.get("experimental_binary_classifier", True): # ocp_medium
from ovos_classifiers.skovos.classifier import SklearnOVOSClassifier
LOG.info("Using experimental OCP binary classifier")
# TODO - train a single multilingual model instead of this
b = f"{dirname(__file__)}/models"
Expand All @@ -161,6 +160,7 @@ def load_classifiers(self):
self._binary_en_clf = (c, OCPFeaturizer("binary_ocp_cv2_small"))

if self.config.get("experimental_media_classifier", True):
from ovos_classifiers.skovos.classifier import SklearnOVOSClassifier
LOG.info("Using experimental OCP media type classifier")
# TODO - train a single multilingual model instead of this
b = f"{dirname(__file__)}/models"
Expand Down Expand Up @@ -253,6 +253,8 @@ def handle_skill_register(self, message: Message):
except:
LOG.error(f"{skill_id} reported an invalid media_type: {m}")

if OCPFeaturizer.ocp_keywords is None:
return
# TODO - review below and add missing
# set bias in classifier
# aliases -> {type}_streaming_service bias
Expand Down Expand Up @@ -281,6 +283,8 @@ def handle_skill_register(self, message: Message):

def handle_skill_keyword_register(self, message: Message):
""" register skill provided keywords """
if OCPFeaturizer.ocp_keywords is None:
return
skill_id = message.data["skill_id"]
kw_label = message.data["label"]
media = message.data["media_type"]
Expand Down Expand Up @@ -729,6 +733,7 @@ def classify_media(self, query: str, lang: str) -> Tuple[MediaType, float]:
""" determine what media type is being requested """
# using a trained classifier (Experimental)
if self.config.get("experimental_media_classifier", True):
from ovos_classifiers.skovos.classifier import SklearnOVOSClassifier
try:
if lang.startswith("en"):
clf: SklearnOVOSClassifier = self._media_en_clf[0]
Expand All @@ -754,6 +759,7 @@ def classify_media(self, query: str, lang: str) -> Tuple[MediaType, float]:
def is_ocp_query(self, query: str, lang: str) -> Tuple[bool, float]:
""" determine if a playback question is being asked"""
if self.config.get("experimental_binary_classifier", True):
from ovos_classifiers.skovos.classifier import SklearnOVOSClassifier
try:
# TODO - train a single multilingual classifier
if lang.startswith("en"):
Expand Down Expand Up @@ -1214,8 +1220,7 @@ def cps2media(res: dict, media_type=MediaType.GENERIC) -> Tuple[MediaEntry, dict
class OCPFeaturizer:
"""used by the experimental media type classifier,
API should be considered unstable"""
# ignore_list accounts for "noise" keywords in the csv file
ocp_keywords = KeywordFeaturesVectorizer(ignore_list=["play", "stop"])
ocp_keywords = None
# defined at training time
_clf_labels = ['ad_keyword', 'album_name', 'anime_genre', 'anime_name', 'anime_streaming_service',
'artist_name', 'asmr_keyword', 'asmr_trigger', 'audio_genre', 'audiobook_narrator',
Expand Down Expand Up @@ -1246,6 +1251,11 @@ class OCPFeaturizer:

def __init__(self, base_clf=None):
self.clf_feats = None
from ovos_classifiers.skovos.classifier import SklearnOVOSClassifier
from ovos_classifiers.skovos.features import ClassifierProbaVectorizer, KeywordFeaturesVectorizer
if OCPFeaturizer.ocp_keywords is None:
# ignore_list accounts for "noise" keywords in the csv file
OCPFeaturizer.ocp_keywords = KeywordFeaturesVectorizer(ignore_list=["play", "stop"])
if base_clf:
if isinstance(base_clf, str):
clf_path = f"{dirname(__file__)}/models/{base_clf}.clf"
Expand All @@ -1257,6 +1267,8 @@ def __init__(self, base_clf=None):

@classmethod
def load_csv(cls, entity_csvs: list):
if OCPFeaturizer.ocp_keywords is None:
return
for csv in entity_csvs or []:
if not os.path.isfile(csv):
# check for bundled files
Expand Down
2 changes: 0 additions & 2 deletions requirements/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@ ovos-config~=0.0,>=0.0.13a8
ovos-lingua-franca>=0.4.7
ovos-backend-client~=0.1.0
ovos-workshop>=0.0.16a45
# provides plugins and classic machine learning framework
ovos-classifiers<0.1.0, >=0.0.0a53

# ensure default plugin available for any solver plugins
ovos-translate-server-plugin
3 changes: 3 additions & 0 deletions requirements/tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ ovos-listener~=0.0, >=0.0.3a2
ovos-gui~=0.0, >=0.0.2
ovos-messagebus~=0.0

# provides plugins and classic machine learning framework
ovos-classifiers<0.1.0, >=0.0.0a53

# Support OCP tests
ovos_bus_client>=0.0.9a15
ovos-utils>=0.1.0a16
Expand Down
2 changes: 1 addition & 1 deletion test/unittests/skills/test_ocp.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def setUp(self):
self.featurizer = OCPFeaturizer()

@patch('os.path.isfile', return_value=True)
@patch('ovos_core.intent_services.ocp_service.KeywordFeaturesVectorizer.load_entities')
@patch('ovos_classifiers.skovos.features.KeywordFeaturesVectorizer.load_entities')
@patch.object(LOG, 'info')
def test_load_csv_with_existing_file(self, mock_log_info, mock_load_entities, mock_isfile):
csv_path = "existing_file.csv"
Expand Down
Loading