Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix:improve_media_clf #46

Merged
merged 1 commit into from
Dec 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 42 additions & 29 deletions ocp_pipeline/opm.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,11 +484,15 @@ def _process_play_query(self, utterance: str, lang: str, match: dict = None,
if skill_id not in sess.blacklisted_skills and
any(s.lower() in utterance for s in samples)
]
valid_labels = []
if valid_skills:
LOG.info(f"OCP specific skill names matched: {valid_skills}")
for mtype, skills in self.media2skill.items():
if any([s in skills for s in valid_skills]):
valid_labels.append(mtype)

# classify the query media type
media_type, conf = self.classify_media(utterance, lang)
media_type, conf = self.classify_media(utterance, lang, valid_labels=valid_labels)

# extract the query string
query = self.remove_voc(utterance, "Play", lang).strip()
Expand Down Expand Up @@ -692,70 +696,77 @@ def handle_search_error_intent(self, message: Message):
self.ocp_api.stop(source_message=message)

# NLP
def voc_match_media(self, query: str, lang: str) -> Tuple[MediaType, float]:
def voc_match_media(self, query: str, lang: str, valid_labels: Optional[List[MediaType]] = None) -> Tuple[MediaType, float]:
lang = standardize_lang_tag(lang)
valid_labels = valid_labels or [m for m, s in self.media2skill.items() if s] or list(MediaType)
# simplistic approach via voc_match, works anywhere
# and it's easy to localize, but isn't very accurate
if self.voc_match(query, "MusicKeyword", lang=lang):
if MediaType.MUSIC in valid_labels and self.voc_match(query, "MusicKeyword", lang=lang):
# NOTE - before movie to handle "{movie_name} soundtrack"
return MediaType.MUSIC, 0.6
elif self.voc_match(query, "MovieKeyword", lang=lang):
if self.voc_match(query, "ShortKeyword", lang=lang):
elif any([s in valid_labels for s in [MediaType.MOVIE, MediaType.SHORT_FILM, MediaType.SILENT_MOVIE, MediaType.BLACK_WHITE_MOVIE]]) and \
self.voc_match(query, "MovieKeyword", lang=lang):
if MediaType.SHORT_FILM in valid_labels and self.voc_match(query, "ShortKeyword", lang=lang):
return MediaType.SHORT_FILM, 0.7
elif self.voc_match(query, "SilentKeyword", lang=lang):
elif MediaType.SILENT_MOVIE in valid_labels and self.voc_match(query, "SilentKeyword", lang=lang):
return MediaType.SILENT_MOVIE, 0.7
elif self.voc_match(query, "BWKeyword", lang=lang):
elif MediaType.BLACK_WHITE_MOVIE in valid_labels and self.voc_match(query, "BWKeyword", lang=lang):
return MediaType.BLACK_WHITE_MOVIE, 0.7
return MediaType.MOVIE, 0.6
elif self.voc_match(query, "DocumentaryKeyword", lang=lang):
elif MediaType.DOCUMENTARY in valid_labels and self.voc_match(query, "DocumentaryKeyword", lang=lang):
return MediaType.DOCUMENTARY, 0.6
elif self.voc_match(query, "AudioBookKeyword", lang=lang):
elif MediaType.AUDIOBOOK in valid_labels and self.voc_match(query, "AudioBookKeyword", lang=lang):
return MediaType.AUDIOBOOK, 0.6
elif self.voc_match(query, "NewsKeyword", lang=lang):
elif MediaType.NEWS in valid_labels and self.voc_match(query, "NewsKeyword", lang=lang):
return MediaType.NEWS, 0.6
elif self.voc_match(query, "AnimeKeyword", lang=lang):
elif MediaType.ANIME in valid_labels and self.voc_match(query, "AnimeKeyword", lang=lang):
return MediaType.ANIME, 0.6
elif self.voc_match(query, "CartoonKeyword", lang=lang):
elif MediaType.CARTOON in valid_labels and self.voc_match(query, "CartoonKeyword", lang=lang):
return MediaType.CARTOON, 0.6
elif self.voc_match(query, "PodcastKeyword", lang=lang):
elif MediaType.PODCAST in valid_labels and self.voc_match(query, "PodcastKeyword", lang=lang):
return MediaType.PODCAST, 0.6
elif self.voc_match(query, "TVKeyword", lang=lang):
elif MediaType.TV in valid_labels and self.voc_match(query, "TVKeyword", lang=lang):
return MediaType.TV, 0.6
elif self.voc_match(query, "SeriesKeyword", lang=lang):
elif MediaType.VIDEO_EPISODES in valid_labels and self.voc_match(query, "SeriesKeyword", lang=lang):
return MediaType.VIDEO_EPISODES, 0.6
elif self.voc_match(query, "AudioDramaKeyword", lang=lang):
elif MediaType.RADIO_THEATRE in valid_labels and self.voc_match(query, "AudioDramaKeyword", lang=lang):
# NOTE - before "radio" to allow "radio theatre"
return MediaType.RADIO_THEATRE, 0.6
elif self.voc_match(query, "RadioKeyword", lang=lang):
elif MediaType.RADIO in valid_labels and self.voc_match(query, "RadioKeyword", lang=lang):
return MediaType.RADIO, 0.6
elif self.voc_match(query, "ComicBookKeyword", lang=lang):
elif MediaType.VISUAL_STORY in valid_labels and self.voc_match(query, "ComicBookKeyword", lang=lang):
return MediaType.VISUAL_STORY, 0.4
elif self.voc_match(query, "GameKeyword", lang=lang):
elif MediaType.GAME in valid_labels and self.voc_match(query, "GameKeyword", lang=lang):
return MediaType.GAME, 0.4
elif self.voc_match(query, "ADKeyword", lang=lang):
elif MediaType.AUDIO_DESCRIPTION in valid_labels and self.voc_match(query, "ADKeyword", lang=lang):
return MediaType.AUDIO_DESCRIPTION, 0.4
elif self.voc_match(query, "ASMRKeyword", lang=lang):
elif MediaType.ASMR in valid_labels and self.voc_match(query, "ASMRKeyword", lang=lang):
return MediaType.ASMR, 0.4
elif self.voc_match(query, "AdultKeyword", lang=lang):
if self.voc_match(query, "CartoonKeyword", lang=lang) or \
elif any([s in valid_labels for s in [MediaType.ADULT, MediaType.HENTAI, MediaType.ADULT_AUDIO]]) and self.voc_match(query, "AdultKeyword", lang=lang):
if MediaType.HENTAI in valid_labels and self.voc_match(query, "CartoonKeyword", lang=lang) or \
self.voc_match(query, "AnimeKeyword", lang=lang) or \
self.voc_match(query, "HentaiKeyword", lang=lang):
return MediaType.HENTAI, 0.4
elif self.voc_match(query, "AudioKeyword", lang=lang) or \
elif MediaType.ADULT_AUDIO in valid_labels and self.voc_match(query, "AudioKeyword", lang=lang) or \
self.voc_match(query, "ASMRKeyword", lang=lang):
return MediaType.ADULT_AUDIO, 0.4
return MediaType.ADULT, 0.4
elif self.voc_match(query, "HentaiKeyword", lang=lang):
elif MediaType.HENTAI in valid_labels and self.voc_match(query, "HentaiKeyword", lang=lang):
return MediaType.HENTAI, 0.4
elif self.voc_match(query, "VideoKeyword", lang=lang):
elif MediaType.VIDEO in valid_labels and self.voc_match(query, "VideoKeyword", lang=lang):
return MediaType.VIDEO, 0.4
elif self.voc_match(query, "AudioKeyword", lang=lang):
elif MediaType.AUDIO in valid_labels and self.voc_match(query, "AudioKeyword", lang=lang):
return MediaType.AUDIO, 0.4
return MediaType.GENERIC, 0.0

def classify_media(self, query: str, lang: str) -> Tuple[MediaType, float]:
def classify_media(self, query: str, lang: str, valid_labels: Optional[List[MediaType]] = None) -> Tuple[MediaType, float]:
""" determine what media type is being requested """
lang = standardize_lang_tag(lang)
valid_labels = valid_labels or [m for m, s in self.media2skill.items() if s] or list(MediaType)
LOG.debug(f"valid media types: {valid_labels}")
if len(valid_labels) == 1:
return valid_labels[0], 1.0

# using a trained classifier (Experimental)
if self.config.get("experimental_media_classifier", False):
from ovos_classifiers.skovos.classifier import SklearnOVOSClassifier
Expand All @@ -768,6 +779,8 @@ def classify_media(self, query: str, lang: str) -> Tuple[MediaType, float]:
featurizer: OCPFeaturizer = self._media_clf[1]
X = featurizer.transform([query])
preds = clf.predict_labels(X)[0]
preds = {k: v for k, v in preds.items()
if OCPFeaturizer.label2media(k) in valid_labels}
label = max(preds, key=preds.get)
prob = float(round(preds[label], 3))
LOG.info(f"OVOSCommonPlay MediaType prediction: {label} confidence: {prob}")
Expand All @@ -779,7 +792,7 @@ def classify_media(self, query: str, lang: str) -> Tuple[MediaType, float]:
return OCPFeaturizer.label2media(label), prob
except:
LOG.exception(f"OCP classifier exception: {query}")
return self.voc_match_media(query, lang)
return self.voc_match_media(query, lang, valid_labels)

def is_ocp_query(self, query: str, lang: str) -> Tuple[bool, float]:
""" determine if a playback question is being asked"""
Expand Down
Loading