From 94c160e710e646c42e5160577dfe957bf32e7a6e Mon Sep 17 00:00:00 2001 From: Alberto Cetoli Date: Sun, 4 Aug 2024 10:51:35 +0100 Subject: [PATCH] using entailer to pre-filter the retrieved rules --- todo.txt | 12 ++++++++++-- wafl/answerer/answerer_implementation.py | 7 +++++++ wafl/answerer/dialogue_answerer.py | 10 ++++++---- wafl/answerer/entailer.py | 3 +++ wafl/interface/conversation.py | 9 +++++++++ 5 files changed, 35 insertions(+), 6 deletions(-) diff --git a/todo.txt b/todo.txt index e26b2a8a..17ec607e 100644 --- a/todo.txt +++ b/todo.txt @@ -1,5 +1,13 @@ -* Add tqdm to indexing. -* Make it index when wafl start first, not at the first use/login +* apply entailer to rule retrieval: + if more than one rule is retrieved, then the one + that is entailed by the query should be chosen + +* the answer from the indexed files should be directed from a rule. + - facts and rules should live at the highest level of the retrieval + + +/* Add tqdm to indexing. +/* Make it index when wafl start first, not at the first use/login /* The prior items with timestamps might not be necessary. / - Just implement a queue with a fixed size diff --git a/wafl/answerer/answerer_implementation.py b/wafl/answerer/answerer_implementation.py index 4f54230f..9a2c25df 100644 --- a/wafl/answerer/answerer_implementation.py +++ b/wafl/answerer/answerer_implementation.py @@ -3,6 +3,7 @@ from typing import List, Tuple +from wafl.answerer.entailer import Entailer from wafl.exceptions import CloseConversation from wafl.data_objects.facts import Fact, Sources from wafl.interface.conversation import Conversation, Utterance @@ -160,3 +161,9 @@ def add_dummy_utterances_to_continue_generation( def add_memories_to_facts(facts: str, memories: List[str]) -> str: return facts + "\n" + "\n".join(memories) + + +def select_best_rules_using_entailer(conversation: Conversation, rules_as_strings: List[str], entailer: Entailer, num_rules: int) -> str: + query_text = conversation.get_last_speaker_utterance("user") + rules_as_strings = sorted(rules_as_strings, key=lambda x: entailer.get_score(query_text, x), reverse=True) + return rules_as_strings[:num_rules] diff --git a/wafl/answerer/dialogue_answerer.py b/wafl/answerer/dialogue_answerer.py index 45449d3f..5377233a 100644 --- a/wafl/answerer/dialogue_answerer.py +++ b/wafl/answerer/dialogue_answerer.py @@ -2,6 +2,7 @@ from inspect import getmembers, isfunction from typing import List +from wafl.answerer.entailer import Entailer from wafl.answerer.answerer_implementation import ( substitute_memory_in_answer_and_get_memories_if_present, create_one_liner, @@ -9,13 +10,12 @@ add_dummy_utterances_to_continue_generation, add_memories_to_facts, execute_results_in_answer, - create_memory_from_fact_list, + create_memory_from_fact_list, select_best_rules_using_entailer, ) from wafl.answerer.base_answerer import BaseAnswerer from wafl.answerer.rule_maker import RuleMaker from wafl.connectors.clients.llm_chat_client import LLMChatClient from wafl.data_objects.dataclasses import Query, Answer -from wafl.data_objects.facts import Sources from wafl.interface.conversation import Conversation from wafl.simple_text_processing.questions import is_question @@ -24,6 +24,7 @@ class DialogueAnswerer(BaseAnswerer): def __init__(self, config, knowledge, interface, code_path, logger): self._threshold_for_facts = 0.85 self._client = LLMChatClient(config) + self._entailer = Entailer(config) self._knowledge = knowledge self._logger = logger self._interface = interface @@ -108,8 +109,9 @@ async def _get_relevant_facts(self, query: Query, has_prior_rules: bool) -> str: return memory async def _get_relevant_rules(self, conversation: Conversation) -> List[str]: - rules = await self._rule_creator.create_from_query(conversation) - for rule in rules: + rules_as_strings = await self._rule_creator.create_from_query(conversation) + rules_as_strings = select_best_rules_using_entailer(conversation, rules_as_strings, self._entailer, num_rules=1) + for rule in rules_as_strings: if rule not in self._prior_rules: self._prior_rules.insert(0, rule) self._prior_rules = self._prior_rules[: self._max_num_rules] diff --git a/wafl/answerer/entailer.py b/wafl/answerer/entailer.py index 6b6215b9..3f3c2ab9 100644 --- a/wafl/answerer/entailer.py +++ b/wafl/answerer/entailer.py @@ -9,3 +9,6 @@ def __init__(self, config): async def left_entails_right(self, lhs: str, rhs: str) -> bool: prediction = await self.entailer_client.get_entailment_score(lhs, rhs) return prediction > 0.5 + + async def get_score(self, lhs: str, rhs: str) -> float: + return await self.entailer_client.get_entailment_score(lhs, rhs) diff --git a/wafl/interface/conversation.py b/wafl/interface/conversation.py index 68687eb1..c65f1f5b 100644 --- a/wafl/interface/conversation.py +++ b/wafl/interface/conversation.py @@ -111,6 +111,15 @@ def get_last_speaker_utterances(self, speaker: str, n: int) -> List[str]: if utterance.speaker == speaker ][-n:] + def get_last_speaker_utterance(self, speaker: str) -> str: + if not self.utterances: + return "" + + for utterance in reversed(self.utterances): + if utterance.speaker == speaker: + return utterance.text + return "" + def get_first_timestamp(self) -> float: return self.utterances[0].timestamp if self.utterances else None