Skip to content

Commit

Permalink
fix/get_response (#138)
Browse files Browse the repository at this point in the history
  • Loading branch information
JarbasAl authored Oct 2, 2023
1 parent 7e5453f commit 171b545
Showing 1 changed file with 93 additions and 86 deletions.
179 changes: 93 additions & 86 deletions ovos_workshop/skills/ovos.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ def __init__(self, name: Optional[str] = None,
self.public_api: Dict[str, dict] = {}

self._original_converse = self.converse # for get_response

self.__responses = {}
self._threads = [] # for killable events decorator

# yay, following python best practices again!
Expand Down Expand Up @@ -935,6 +937,8 @@ def _register_system_event_handlers(self):
self.add_event('mycroft.skills.settings.changed',
self.handle_settings_change, speak_errors=False)

self.add_event("skill.converse.get_response", self.__handle_get_response, speak_errors=False)

def _send_public_api(self, message: Message):
"""
Respond with the skill's public api.
Expand Down Expand Up @@ -1492,7 +1496,7 @@ def play_audio(self, filename: str, instant: bool = False):
self.bus.emit(message.forward("mycroft.audio.queue",
{"uri": filename}))

def __get_response_v1(self):
def __get_response_v1(self, session=None):
"""Helper to get a response from the user
NOTE: There is a race condition here. There is a small amount of
Expand All @@ -1507,13 +1511,10 @@ def __get_response_v1(self):
Returns:
str: user's response or None on a timeout
"""
srcm = dig_for_message() or Message("", context={"source": "skills",
"skill_id": self.skill_id})
self.bus.emit(srcm.forward("skill.converse.get_response.enable",
{"skill_id": self.skill_id}))
session = session or SessionManager.get()

# TODO: Support `message` signature like default?
def converse(utterances, lang=None):
self.__responses[session.session_id] = utterances
converse.response = utterances[0] if utterances else None
converse.finished = True
return True
Expand All @@ -1525,23 +1526,39 @@ def converse(utterances, lang=None):
self.converse = converse

# 10 for listener, 5 for SST, then timeout
ans = []
# NOTE: a threading.Event is not used otherwise we can't raise the
# AbortEvent exception to kill the thread
# this is for compat with killable_intents decorators
start = time.time()
while time.time() - start <= 15 and not converse.finished:
# TODO: Refactor to event-based handling
while time.time() - start <= 15 and not ans:
ans = self.__responses[session.session_id]
time.sleep(0.1)
if self.__response is not False:
if self.__response is None:
# aborted externally (if None)
self.log.debug("get_response aborted")
if ans is None:
# aborted externally (if None)
self.log.debug("get_response aborted")
converse.finished = True
converse.response = self.__response # external override
break

self.converse = self._original_converse
self.bus.emit(srcm.forward("skill.converse.get_response.disable",
{"skill_id": self.skill_id}))
return converse.response
return ans

def __handle_get_response(self, message):

skill_id = message.data["skill_id"]
if skill_id != self.skill_id:
return # not for us!

# validate session_id to ensure this isnt another
# user querying the skill at same time
sess2 = SessionManager.get(message)
if sess2.session_id not in self.__responses:
LOG.debug(f"ignoring get_response answer for session: {sess2.session_id}")
return # not for us!

utterances = message.data["utterances"]
# received get_response
self.__responses[sess2.session_id] = utterances

@backwards_compat(classic_core=__get_response_v1, pre_008=__get_response_v1)
def __get_response(self, session: Session):
Expand All @@ -1563,51 +1580,21 @@ def __get_response(self, session: Session):
"skill_id": self.skill_id})
srcm.context["session"] = session.serialize()

self.bus.emit(srcm.forward("skill.converse.get_response.enable",
{"skill_id": self.skill_id}))
utterances = []

LOG.debug(f"get_response session: {session.session_id}")

def _handle_get_response(message):
nonlocal utterances

skill_id = message.data["skill_id"]
if skill_id != self.skill_id:
return # not for us!

# validate session_id to ensure this isnt another
# user querying the skill at same time
sess2 = SessionManager.get(message)
if session.session_id != sess2.session_id:
LOG.debug(f"ignoring get_response answer for session: {sess2.session_id}")
return # not for us!

utterances = message.data["utterances"]
# received get_response

self.bus.on("skill.converse.get_response", _handle_get_response)
ans = []

# NOTE: a threading.Event is not used otherwise we can't raise the
# AbortEvent exception to kill the thread
# this is for compat with killable_intents decorators
start = time.time()
while time.time() - start <= 15 and not len(utterances):
while time.time() - start <= 15 and not ans:
ans = self.__responses[session.session_id]
time.sleep(0.1)
if self.__response is not False:
if self.__response is None:
# aborted externally (if None)
self.log.debug("get_response aborted")
else:
utterances = [self.__response] # external override

self.bus.remove("skill.converse.get_response", _handle_get_response)
self.bus.emit(srcm.forward("skill.converse.get_response.disable",
{"skill_id": self.skill_id}))

if utterances:
return utterances[0]
return None
if ans is None:
# aborted externally (if None)
self.log.debug("get_response aborted")
break
return ans

def get_response(self, dialog: str = '', data: Optional[dict] = None,
validator: Optional[Callable[[str], bool]] = None,
Expand Down Expand Up @@ -1635,6 +1622,11 @@ def get_response(self, dialog: str = '', data: Optional[dict] = None,
Message('mycroft.mic.listen', context={"skill_id": self.skill_id})
data = data or {}

session = SessionManager.get(message)
self.__responses[session.session_id] = []
self.bus.emit(message.forward("skill.converse.get_response.enable",
{"skill_id": self.skill_id}))

def on_fail_default(utterance):
fail_data = data.copy()
fail_data['utterance'] = utterance
Expand Down Expand Up @@ -1663,8 +1655,11 @@ def validator_default(utterance):
else:
msg = message.reply('mycroft.mic.listen')
self.bus.emit(msg)
return self._wait_response(is_cancel, validator, on_fail_fn,
num_retries, message)
ans = self._wait_response(is_cancel, validator, on_fail_fn,
num_retries, message)
self.bus.emit(message.forward("skill.converse.get_response.disable",
{"skill_id": self.skill_id}))
return ans

def _wait_response(self, is_cancel: callable, validator: callable,
on_fail: callable, num_retries: int,
Expand All @@ -1678,25 +1673,52 @@ def _wait_response(self, is_cancel: callable, validator: callable,
@param num_retries: Number of times to retry getting a response
@returns: User response if validated, else None
"""
self.__response = False
session = SessionManager.get(message)

# self.__responses.get(session.session_id) <- set in a killable thread
self._real_wait_response(is_cancel, validator, on_fail, num_retries, message)
while self.__response is False:

# wait for answer from killable thread
ans = []
while not ans:
# TODO: Refactor to Event
time.sleep(0.1)
return self.__response or None
ans = self.__responses.get(session.session_id)
if ans or ans is None: # canceled response
break

if session.session_id in self.__responses:
self.__responses.pop(session.session_id)

if isinstance(ans, list):
ans = ans[0] # TODO handle multiple transcriptions

# catch user saying 'cancel'
if is_cancel(ans):
return None

# returns the validated value or the response
# (backwards compat)
validated = validator(ans)
ans = ans if validated is True else validated

return ans

def _handle_killed_wait_response(self):
"""
Handle "stop" request when getting a response.
"""
self.__response = None
self.__responses = {k: None for k in self.__responses}
self.converse = self._original_converse

@killable_event("mycroft.skills.abort_question", exc=AbortQuestion,
callback=_handle_killed_wait_response, react_to_stop=True)
def _real_wait_response(self, is_cancel, validator, on_fail, num_retries,
message: Message):
"""
runs in a thread, result retrieved via self.__responses[sess.session_id]
Loop until a valid response is received from the user or the retry
limit is reached.
Expand All @@ -1707,46 +1729,31 @@ def _real_wait_response(self, is_cancel, validator, on_fail, num_retries,
"""
sess = SessionManager.get(message)
msg = message.reply('mycroft.mic.listen')

num_fails = 0
while True:
if self.__response is not False:
# usually None when aborted externally
# also allows overriding returned result from other events
return self.__response

response = self.__get_response(sess)

if response is None:
break # killed externally

if not response: # empty list
# if nothing said, prompt one more time
num_none_fails = 1 if num_retries < 0 else num_retries
if num_fails >= num_none_fails:
self.__response = None
return
else:
# catch user saying 'cancel'
if is_cancel(response):
self.__response = None
return

validated = validator(response)
# returns the validated value or the response
# (backwards compat)
if validated is not False and validated is not None:
self.__response = response if validated is True else validated
return
if num_fails >= num_retries:
self.__responses[sess.session_id] = None # stop trying

num_fails += 1
if 0 < num_retries < num_fails or self.__response is not False:
self.__response = None
return
num_fails += 1

if self.__responses.get(sess.session_id) is None:
return # dont prompt

# re-prompt user
line = on_fail(response)
if line:
self.speak(line, expect_response=True)
else:
self.bus.emit(msg)
self.bus.emit(message.reply('mycroft.mic.listen'))

@staticmethod
def __acknowledge_classic():
Expand Down

0 comments on commit 171b545

Please sign in to comment.