diff --git a/documentation/source/configuration.rst b/documentation/source/configuration.rst index f67dfbe7..b098092a 100644 --- a/documentation/source/configuration.rst +++ b/documentation/source/configuration.rst @@ -7,37 +7,33 @@ A typical configuration file looks like this: .. code-block:: text - { - "waking_up_word": "computer", - "waking_up_sound": true, - "deactivate_sound": true, - "rules": "rules.yaml", - "index": "indices.yaml", - "cache_filename": "knowledge_cache", - "prompt_filename": "main.prompt", - "functions": "functions.py", - "max_recursion": 2, - "llm_model": { - "model_host": "localhost", - "model_port": 8080, - "temperature": 0.4 - }, - "listener_model": { - "model_host": "localhost", - "model_port": 8080, - "listener_hotword_logp": -8, - "listener_volume_threshold": 0.6, - "listener_silence_timeout": 0.7 - }, - "speaker_model": { - "model_host": "localhost", - "model_port": 8080 - }, - "text_embedding_model": { - "model_host": "localhost", - "model_port": 8080 - } - } +{ + "waking_up_word": "computer", + "waking_up_sound": true, + "deactivate_sound": true, + "rules": "rules.yaml", + "index": "indices.yaml", + "cache_filename": "knowledge_cache", + "prompt_filename": "main.prompt", + "functions": "functions.py", + "max_recursion": 2, + "frontend_port": 8090, + "backend": { + "host": "localhost", + "port": 8080, + "token": "secret" + }, + "generation_config": { + "temperature": 0.4 + }, + "listener_model": { + "listener_hotword_logp": -8, + "listener_volume_threshold": 0.6, + "listener_silence_timeout": 0.7, + "interruptible": true + } +} + @@ -59,20 +55,10 @@ These settings regulate the following: * "frontend_port" is the port where the web frontend is running. The default is 8090. - * "llm_model" is the configuration to connect to wafl-llm in the backend. The default url is "localhost:8080". The "temperature" parameter is used to set the temperature for the LLM model. The default is 0.4. - - * "listener_model" is the configuration to connect to the listener model in the backend. The default is "localhost:8080". - - - The listener model is used to detect the wake-up word. - The similarity threshold for the detection can be set with the "listener_hotword_logp" parameter. - - - The "listener_volume_threshold" parameter is used to set the volume threshold for any conversation. - Any word uttered with a volume below this threshold is ignored. - - - The "listener_silence_timeout" parameter is used to set the silence timeout for any conversation. - If no word is uttered for a time longer than this timeout, the conversation is considered finished. - - * "speaker_model" is the configuration to connect to the speaker model in the backend. The default is "localhost:8080". + * "backend" is the configuration related to the backend. The default is "localhost:8080". - * "text_embedding_model" is the configuration to connect to the text embedding model in the backend. The default is "localhost:8080". + * "generation_config" is the configuration related to the generation of the response. The default is "temperature: 0.4". + * "listener_model" is the configuration related to the listener model. + These items determine the thresholds for hotword detection, volume threshold, silence timeout, and whether the listener is interruptible. + The default is "listener_hotword_logp: -8", "listener_volume_threshold: 0.6", "listener_silence_timeout: 0.7", "interruptible: true". diff --git a/requirements.txt b/requirements.txt index 19f7062d..7bba6a9b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ flask[async]==3.0.3 flask-cors==4.0.1 nltk==3.8.1 -gensim==4.3.3 sklearn==0.0 python-Levenshtein==0.25.1 fuzzywuzzy==0.18.0 diff --git a/setup.py b/setup.py index 3b239ac4..c926112f 100644 --- a/setup.py +++ b/setup.py @@ -20,9 +20,8 @@ "wafl.connectors.clients", "wafl.connectors.factories", "wafl.connectors.remote", - "wafl.dataclasses", + "wafl.data_objects", "wafl.events", - "wafl.extractors", "wafl.handlers", "wafl.inference", "wafl.interface", @@ -49,7 +48,6 @@ "flask[async]==3.0.3", "flask-cors==4.0.1", "nltk==3.8.1", - "gensim==4.3.3", "sklearn==0.0", "python-Levenshtein==0.25.1", "fuzzywuzzy==0.18.0", diff --git a/tests/config.json b/tests/config.json index 121d77f0..af40d7a7 100644 --- a/tests/config.json +++ b/tests/config.json @@ -8,24 +8,19 @@ "prompt_filename": "main.prompt", "functions": "functions.py", "max_recursion": 2, - "llm_model": { - "model_host": "localhost", - "model_port": 8080, + "frontend_port": 8090, + "backend": { + "host": "aragorn", + "port": 8080, + "token": "secret" + }, + "generation_config": { "temperature": 0.4 }, "listener_model": { - "model_host": "localhost", - "model_port": 8080, "listener_hotword_logp": -8, "listener_volume_threshold": 0.6, - "listener_silence_timeout": 0.7 - }, - "speaker_model": { - "model_host": "localhost", - "model_port": 8080 - }, - "text_embedding_model": { - "model_host": "localhost", - "model_port": 8080 + "listener_silence_timeout": 0.7, + "interruptible": true } } diff --git a/tests/main.prompt b/tests/main.prompt index 07b45290..8cbb4d64 100644 --- a/tests/main.prompt +++ b/tests/main.prompt @@ -8,4 +8,6 @@ The rules that *must* be followed are: Create a plausible dialogue based on the aforementioned summary and rules. Do not repeat yourself. Be friendly but not too servile. -Follow the rules if present and they apply to the dialogue. Do not improvise if rules are present. \ No newline at end of file +Follow the rules if present and they apply to the dialogue. Do not improvise if rules are present. +The user query might be incomplete or ambiguous or ungrammatical. The bot *must* ask for clarification if needed. +The bot only answers if the query is clear and unambiguous. \ No newline at end of file diff --git a/tests/test_connection.py b/tests/test_connection.py index 317a91fd..112a34c3 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -12,7 +12,7 @@ class TestConnection(TestCase): def test__connection_to_generative_model_can_generate_text(self): config = Configuration.load_local_config() - connector = RemoteLLMConnector(config.get_value("llm_model")) + connector = RemoteLLMConnector(config) prediction = asyncio.run( connector.predict( PromptCreator.create_from_one_instruction( @@ -25,7 +25,7 @@ def test__connection_to_generative_model_can_generate_text(self): def test__connection_to_generative_model_can_generate_text_within_tags(self): config = Configuration.load_local_config() - connector = RemoteLLMConnector(config.get_value("llm_model")) + connector = RemoteLLMConnector(config) connector._num_prediction_tokens = 200 text = 'Generate a full paragraph based on this chapter title " The First Contact". The theme of the paragraph is space opera. Include the characters "Alberto" and "Maria". Write at least three sentences.' prompt = f""" @@ -43,7 +43,7 @@ def test__connection_to_generative_model_can_generate_text_within_tags(self): def test__connection_to_generative_model_can_generate_a_python_list(self): config = Configuration.load_local_config() - connector = RemoteLLMConnector(config.get_value("llm_model")) + connector = RemoteLLMConnector(config) connector._num_prediction_tokens = 200 prompt = "Generate a Python list of 4 chapters names for a space opera book. The output needs to be a python list of strings: " prediction = asyncio.run( diff --git a/tests/test_entailer.py b/tests/test_entailer.py new file mode 100644 index 00000000..6ccc7088 --- /dev/null +++ b/tests/test_entailer.py @@ -0,0 +1,34 @@ +import asyncio +import os + +from unittest import TestCase +from wafl.config import Configuration +from wafl.connectors.remote.remote_entailer_connector import RemoteEntailerConnector +from wafl.connectors.clients.entailer_client import EntailerClient + +_path = os.path.dirname(__file__) + + +class TestConnection(TestCase): + def test__entailer_connector(self): + config = Configuration.load_local_config() + connector = RemoteEntailerConnector(config) + prediction = asyncio.run( + connector.predict( + "The first contact is a romance novel set in the middle ages.", + "The first contact is a science fiction novel about the first contact between humans and aliens.", + ) + ) + assert prediction["score"] < 0.5 + + def test__entailment_client(self): + + config = Configuration.load_local_config() + client = EntailerClient(config) + prediction = asyncio.run( + client.get_entailment_score( + "The first contact is a romance novel set in the middle ages.", + "The first contact is a science fiction novel about the first contact between humans and aliens.", + ) + ) + assert prediction < 0.5 diff --git a/tests/test_indexing.py b/tests/test_indexing.py index 1e64b4a4..0ea51557 100644 --- a/tests/test_indexing.py +++ b/tests/test_indexing.py @@ -5,7 +5,7 @@ from unittest import TestCase from wafl.config import Configuration -from wafl.dataclasses.dataclasses import Query +from wafl.data_objects.dataclasses import Query from wafl.knowledge.indexing_implementation import add_to_index, load_knowledge _path = os.path.dirname(__file__) diff --git a/tests/test_speaker.py b/tests/test_speaker.py index 35fe57fc..79db6a9a 100644 --- a/tests/test_speaker.py +++ b/tests/test_speaker.py @@ -4,7 +4,7 @@ from unittest import TestCase from wafl.config import Configuration -from wafl.speaker.fairseq_speaker import FairSeqSpeaker +from wafl.speaker.tts_speaker import TTSSpeaker from wafl.speaker.soundfile_speaker import SoundFileSpeaker _wafl_greetings = """ @@ -17,24 +17,30 @@ class TestSpeaker(TestCase): def test_voice(self): config = Configuration.load_local_config() - speaker = FairSeqSpeaker(config) + speaker = TTSSpeaker(config) text = "Hello world" asyncio.run(speaker.speak(text)) def test_long_text(self): config = Configuration.load_local_config() - speaker = FairSeqSpeaker(config) + speaker = TTSSpeaker(config) text = ( "Shall I compare you to a summer's day? Thou art more lovely and temperate." ) asyncio.run(speaker.speak(text)) - def test_number_pronunciation(self): + def test_number_pronunciation1(self): config = Configuration.load_local_config() - speaker = FairSeqSpeaker(config) + speaker = TTSSpeaker(config) text = "The time is 54 past 8" asyncio.run(speaker.speak(text)) + def test_number_pronunciation2(self): + config = Configuration.load_local_config() + speaker = TTSSpeaker(config) + text = "The time is 8 54" + asyncio.run(speaker.speak(text)) + def test_on_sound(self): speaker = SoundFileSpeaker() speaker.speak(os.path.join(_path, "../wafl/sounds/activation.wav")) diff --git a/tests/test_voice.py b/tests/test_voice.py index 6ecd785b..7b90ef83 100644 --- a/tests/test_voice.py +++ b/tests/test_voice.py @@ -15,7 +15,7 @@ rules: - the user's name is Jane: - - write "I hear you" + - reply with "I hear you" and nothing else """.strip() _path = os.path.dirname(__file__) @@ -23,7 +23,7 @@ class TestVoice(TestCase): def test__activation(self): - interface = DummyInterface(to_utter=["computer", "my name is Jane"]) + interface = DummyInterface(to_utter=["computer my name is Jane"]) config = Configuration.load_local_config() config.set_value("rules", _wafl_example) conversation_events = ConversationEvents(config=config, interface=interface) diff --git a/todo.txt b/todo.txt index fa7be712..db4f8d22 100644 --- a/todo.txt +++ b/todo.txt @@ -1,4 +1,42 @@ -* why do I need to re-initialise the retrievers after unpickling the knowledge? +* user prior text and response from failed substitutions between [] instead of just iteration number (line 89, dialogue_answerer.py) +* remove dead code +* re-fine-tune phi to get better performance + +* delete rules and memory from discourse_answerer + +* This is wrong - from wafl_ll + <|end|><|assistant|><|user|> Hi!<|end|><|assistant|> + The user is sandwiched between the assistant. It should be: + <|end|><|assistant|> Hi!<|end|><|user|> + +/* make interruptible speech optional + + +* use entailment score to flag a rule for execution before the answer. +* get all model list from wafl_llm backend. Only specify the connection port and host in wafl + +* the answer from the indexed files should be directed from a rule. + - facts and rules should live at the highest level of the retrieval + + +/* apply entailer to rule retrieval: +/ if more than one rule is retrieved, then the one +/ that is entailed by the query should be chosen + + + +/* Add tqdm to indexing. +/* Make it index when wafl start first, not at the first use/login + +/* The prior items with timestamps might not be necessary. +/ - Just implement a queue with a fixed size + +* add entailer to wafl_llm + + +/* why do I need to re-initialise the retrievers after unpickling the knowledge? + - maybe you should save the retrievers in the knowledge object separately? + - It was gensim that was not serializable. Took it out /* knowledge cache does not cache the rules or facts diff --git a/wafl/answerer/answerer_implementation.py b/wafl/answerer/answerer_implementation.py index 3c8b36e8..6d01549b 100644 --- a/wafl/answerer/answerer_implementation.py +++ b/wafl/answerer/answerer_implementation.py @@ -1,10 +1,9 @@ import re -import traceback from typing import List, Tuple - +from wafl.answerer.entailer import Entailer from wafl.exceptions import CloseConversation -from wafl.dataclasses.facts import Fact +from wafl.data_objects.facts import Fact, Sources from wafl.interface.conversation import Conversation, Utterance @@ -53,8 +52,11 @@ async def substitute_memory_in_answer_and_get_memories_if_present( async def execute_results_in_answer(answer_text: str, module, functions) -> str: + if "" in answer_text and "" not in answer_text: + answer_text += "" + matches = re.finditer( - r"(.*?)|(.*?\))$", + r"(.*?)", answer_text, re.DOTALL | re.MULTILINE, ) @@ -65,14 +67,6 @@ async def execute_results_in_answer(answer_text: str, module, functions) -> str: result = await _run_code(to_execute, module, functions) answer_text = answer_text.replace(match.group(0), result) - matches = re.finditer(r"(.*?\))$", answer_text, re.DOTALL | re.MULTILINE) - for match in matches: - to_execute = match.group(1) - if not to_execute: - continue - result = await _run_code(to_execute, module, functions) - answer_text = answer_text.replace(match.group(0), result) - return answer_text @@ -104,8 +98,7 @@ async def _run_code(to_execute: str, module, functions) -> str: result = ( f"Error while executing\n\n```python\n{to_execute}\n```\n\n{str(e)}" ) - traceback.print_exc() - break + raise RuntimeError(result) if not result: result = f"\n```python\n{to_execute}\n```" @@ -113,22 +106,32 @@ async def _run_code(to_execute: str, module, functions) -> str: return result -def get_text_from_facts_and_thresholds( +def create_memory_from_fact_list(facts: List[Fact], max_num_facts: int) -> str: + text_fact_list = [ + "\n\n- " + " " + fact.text + " " + for fact in facts + if fact.source == Sources.FROM_TEXT + ][:max_num_facts] + rule_fact_list = [ + "\n\n- " + " " + fact.text + " " + for fact in facts + if fact.source in [None, Sources.FROM_RULES] + ] + return "".join(text_fact_list + rule_fact_list) + + +def get_facts_with_metadata_from_facts_and_thresholds( facts_and_thresholds: List[Tuple[Fact, float]], memory: str ) -> List[str]: - text_list = [] + fact_list = [] for item in facts_and_thresholds: if item[0].text not in memory: - text = item[0].text + new_fact = item[0].copy() if item[0].metadata: - text = ( - f"Metadata for the following text: {str(item[0].metadata)}" - + "\n" - + text - ) - text_list.append(text) + new_fact.text = new_fact.text + fact_list.append(new_fact) - return text_list + return fact_list def add_dummy_utterances_to_continue_generation( @@ -150,3 +153,21 @@ def add_dummy_utterances_to_continue_generation( def add_memories_to_facts(facts: str, memories: List[str]) -> str: return facts + "\n" + "\n".join(memories) + + +async def select_best_rules_using_entailer( + conversation: Conversation, + rules_as_strings: List[str], + entailer: Entailer, + num_rules: int, +) -> List[str]: + query_text = conversation.get_last_speaker_utterance("user") + ### Sort rules by score + scores = [] + for rule in rules_as_strings: + score = await entailer.get_score(query_text, rule) + scores.append(score) + rules_as_strings = sorted( + rules_as_strings, key=lambda x: scores[rules_as_strings.index(x)], reverse=True + ) + return rules_as_strings[:num_rules] diff --git a/wafl/answerer/base_answerer.py b/wafl/answerer/base_answerer.py deleted file mode 100644 index a999dd0f..00000000 --- a/wafl/answerer/base_answerer.py +++ /dev/null @@ -1,3 +0,0 @@ -class BaseAnswerer: - async def answer(self, query_text: str) -> "Answer": - raise NotImplementedError diff --git a/wafl/answerer/dialogue_answerer.py b/wafl/answerer/dialogue_answerer.py index f12be579..9447070f 100644 --- a/wafl/answerer/dialogue_answerer.py +++ b/wafl/answerer/dialogue_answerer.py @@ -1,33 +1,37 @@ from importlib import import_module from inspect import getmembers, isfunction -from typing import List, Tuple +from typing import List + +from wafl.answerer.entailer import Entailer from wafl.answerer.answerer_implementation import ( substitute_memory_in_answer_and_get_memories_if_present, create_one_liner, - get_text_from_facts_and_thresholds, + get_facts_with_metadata_from_facts_and_thresholds, add_dummy_utterances_to_continue_generation, add_memories_to_facts, execute_results_in_answer, + create_memory_from_fact_list, + select_best_rules_using_entailer, ) -from wafl.answerer.base_answerer import BaseAnswerer from wafl.answerer.rule_maker import RuleMaker from wafl.connectors.clients.llm_chat_client import LLMChatClient -from wafl.dataclasses.dataclasses import Query, Answer -from wafl.interface.conversation import Conversation +from wafl.data_objects.dataclasses import Query, Answer +from wafl.interface.conversation import Conversation, Utterance from wafl.simple_text_processing.questions import is_question -class DialogueAnswerer(BaseAnswerer): +class DialogueAnswerer: def __init__(self, config, knowledge, interface, code_path, logger): self._threshold_for_facts = 0.85 self._client = LLMChatClient(config) + self._entailer = Entailer(config) self._knowledge = knowledge self._logger = logger self._interface = interface self._max_num_past_utterances = 5 - self._max_num_past_utterances_for_facts = 5 - self._max_num_past_utterances_for_rules = 2 - self._prior_facts_with_timestamp = [] + self._max_num_facts = 5 + self._max_num_rules = 2 + self._prior_facts = [] self._init_python_module(code_path.replace(".py", "")) self._prior_rules = [] self._max_predictions = 3 @@ -38,6 +42,10 @@ def __init__(self, config, knowledge, interface, code_path, logger): max_num_rules=1, ) + def reset(self): + self._prior_facts = [] + self._prior_rules = [] + async def answer(self, query_text: str) -> Answer: if self._logger: self._logger.write(f"Dialogue Answerer: the query is {query_text}") @@ -45,35 +53,50 @@ async def answer(self, query_text: str) -> Answer: conversation = self._interface.get_utterances_list_with_timestamp().get_last_n( self._max_num_past_utterances ) + conversation.add_utterance(Utterance(speaker="user", text=query_text)) rules_text = await self._get_relevant_rules(conversation) if not conversation: conversation = create_one_liner(query_text) - conversational_timestamp = len(conversation) - facts = await self._get_relevant_facts( + memory = await self._get_relevant_facts( query, has_prior_rules=bool(rules_text), - conversational_timestamp=conversational_timestamp, ) final_answer_text = "" - for _ in range(self._max_predictions): - original_answer_text = await self._client.get_answer( - text=facts, - rules_text=rules_text, - dialogue=conversation, - ) - await self._interface.add_fact(f"The bot predicts: {original_answer_text}") - answer_text, memories = await self._apply_substitutions( - original_answer_text - ) - - final_answer_text += answer_text + is_finished = False + for num_attempts in range(self._max_predictions): + try: + original_answer_text = await self._client.get_answer( + text=memory, + rules_text=rules_text, + dialogue=conversation, + ) + await self._interface.add_fact( + f"The bot predicts: {original_answer_text}" + ) + answer_text, memories = await self._apply_substitutions( + original_answer_text + ) - if not memories: - break + final_answer_text += answer_text + if not memories: + is_finished = True + break + facts = add_memories_to_facts(facts, memories) + add_dummy_utterances_to_continue_generation(conversation, answer_text) + + except RuntimeError as e: + if self._logger: + self._logger.write(f"Error in generating answer: {e}") + conversation.add_utterance( + Utterance( + speaker="bot", + text=f"[when using the answer {original_answer_text} the system says {e}]\n", + ) + ) - facts = add_memories_to_facts(facts, memories) - add_dummy_utterances_to_continue_generation(conversation, answer_text) + if not is_finished: + final_answer_text += "I was unable to generate a full answer. Please see the logs for more information." if self._logger: self._logger.write( @@ -82,22 +105,19 @@ async def answer(self, query_text: str) -> Answer: return Answer.create_from_text(final_answer_text) - async def _get_relevant_facts( - self, query: Query, has_prior_rules: bool, conversational_timestamp: int - ) -> str: - memory = "\n".join([item[0] for item in self._prior_facts_with_timestamp]) - self._prior_facts_with_timestamp = self._get_prior_facts_with_timestamp( - conversational_timestamp - ) + async def _get_relevant_facts(self, query: Query, has_prior_rules: bool) -> str: + memory = create_memory_from_fact_list(self._prior_facts, self._max_num_facts) facts_and_thresholds = await self._knowledge.ask_for_facts_with_threshold( query, is_from_user=True, threshold=self._threshold_for_facts ) if facts_and_thresholds: - facts = get_text_from_facts_and_thresholds(facts_and_thresholds, memory) - self._prior_facts_with_timestamp.extend( - (item, conversational_timestamp) for item in facts + facts = get_facts_with_metadata_from_facts_and_thresholds( + facts_and_thresholds, memory + ) + self._prior_facts.extend(facts) + memory = create_memory_from_fact_list( + self._prior_facts, self._max_num_facts ) - memory = "\n".join([item[0] for item in self._prior_facts_with_timestamp]) await self._interface.add_fact(f"The bot remembers the facts:\n{memory}") else: @@ -110,11 +130,14 @@ async def _get_relevant_facts( return memory async def _get_relevant_rules(self, conversation: Conversation) -> List[str]: - rules = await self._rule_creator.create_from_query(conversation) - for rule in rules: + rules_as_strings = await self._rule_creator.create_from_query(conversation) + rules_as_strings = await select_best_rules_using_entailer( + conversation, rules_as_strings, self._entailer, num_rules=1 + ) + for rule in rules_as_strings: if rule not in self._prior_rules: self._prior_rules.insert(0, rule) - self._prior_rules = self._prior_rules[: self._max_num_past_utterances_for_rules] + self._prior_rules = self._prior_rules[: self._max_num_rules] return self._prior_rules def _init_python_module(self, module_name): @@ -129,13 +152,3 @@ async def _apply_substitutions(self, original_answer_text): self._functions, ) ) - - def _get_prior_facts_with_timestamp( - self, conversational_timestamp: int - ) -> List[Tuple[str, int]]: - return [ - item - for item in self._prior_facts_with_timestamp - if item[1] - > conversational_timestamp - self._max_num_past_utterances_for_facts - ] diff --git a/wafl/answerer/entailer.py b/wafl/answerer/entailer.py index 54e4e3e2..3f3c2ab9 100644 --- a/wafl/answerer/entailer.py +++ b/wafl/answerer/entailer.py @@ -1,41 +1,14 @@ -import os -import textwrap - -from wafl.connectors.factories.llm_connector_factory import LLMConnectorFactory -from wafl.connectors.prompt_template import PromptTemplate -from wafl.interface.conversation import Utterance, Conversation - -_path = os.path.dirname(__file__) +from wafl.connectors.clients.entailer_client import EntailerClient class Entailer: def __init__(self, config): - self._connector = LLMConnectorFactory.get_connector(config) + self.entailer_client = EntailerClient(config) self._config = config - async def left_entails_right(self, lhs: str, rhs: str, dialogue) -> str: - prompt = await self._get_answer_prompt(lhs, rhs, dialogue) - result = await self._connector.generate(prompt) - result = self._clean_result(result) - return result == "yes" - - async def _get_answer_prompt(self, lhs, rhs, dialogue): - return PromptTemplate( - system_prompt="", - conversation=self._get_dialogue_prompt(lhs, rhs, dialogue), - ) - - def _clean_result(self, result): - result = result.replace("", "") - result = result.split("\n")[0] - result = result.strip() - return result.lower() + async def left_entails_right(self, lhs: str, rhs: str) -> bool: + prediction = await self.entailer_client.get_entailment_score(lhs, rhs) + return prediction > 0.5 - def _get_dialogue_prompt(self, dialogue, lhs, rhs): - text = f""" -Your task is to determine whether two sentences are similar. -1) {lhs.lower()} -2) {rhs.lower()} -Please answer "yes" if the two sentences are similar or "no" if not: - """.strip() - return Conversation([Utterance(speaker="user", text=text)]) + async def get_score(self, lhs: str, rhs: str) -> float: + return await self.entailer_client.get_entailment_score(lhs, rhs) diff --git a/wafl/answerer/rule_maker.py b/wafl/answerer/rule_maker.py index 115dfcfc..7454fe73 100644 --- a/wafl/answerer/rule_maker.py +++ b/wafl/answerer/rule_maker.py @@ -1,7 +1,7 @@ from typing import List -from wafl.dataclasses.dataclasses import Query -from wafl.dataclasses.rules import Rule +from wafl.data_objects.dataclasses import Query +from wafl.data_objects.rules import Rule class RuleMaker: diff --git a/wafl/changelog.txt b/wafl/changelog.txt new file mode 100644 index 00000000..bfd13914 --- /dev/null +++ b/wafl/changelog.txt @@ -0,0 +1,7 @@ +- version 0.1.3 +* added multi-threaded support for multiple files indexing +* TODO: ADD support for multiple knowledge bases. + It needs to index the rules and the files separately! +* the interface should show where the facts come from in the web interface +* add support for wafl studio where you can concatenate actions (and create corresponding yaml files) +* use <> tags for contactenation \ No newline at end of file diff --git a/wafl/command_line.py b/wafl/command_line.py index 4ebb6643..114f86be 100644 --- a/wafl/command_line.py +++ b/wafl/command_line.py @@ -9,6 +9,7 @@ run_testcases, print_incipit, download_models, + load_indices, ) from wafl.runners.run_from_actions import run_action @@ -52,26 +53,31 @@ def process_cli(): elif command == "run": from wafl.runners.run_web_and_audio_interface import run_app + load_indices() run_app() remove_preprocessed("/") elif command == "run-cli": + load_indices() run_from_command_line() remove_preprocessed("/") elif command == "run-audio": from wafl.runners.run_from_audio import run_from_audio + load_indices() run_from_audio() remove_preprocessed("/") elif command == "run-server": from wafl.runners.run_web_interface import run_server_only_app + load_indices() run_server_only_app() remove_preprocessed("/") elif command == "run-tests": + load_indices() run_testcases() remove_preprocessed("/") diff --git a/wafl/config.py b/wafl/config.py index 45c00c8e..cb5afcb7 100644 --- a/wafl/config.py +++ b/wafl/config.py @@ -2,6 +2,10 @@ import os import shutil +from wafl.connectors.remote.remote_configuration_connector import ( + RemoteConfigurationConnector, +) + _path = os.path.dirname(__file__) @@ -21,6 +25,10 @@ def __init__(self, filename): with open(filename) as file: self._data = json.load(file) + self._remote_config = RemoteConfigurationConnector( + self._data["backend"]["host"], self._data["backend"]["port"] + ) + def get_value(self, key): if key in self._data: return self._data[key] diff --git a/wafl/connectors/clients/clients_implementation.py b/wafl/connectors/clients/clients_implementation.py deleted file mode 100644 index ec463c56..00000000 --- a/wafl/connectors/clients/clients_implementation.py +++ /dev/null @@ -1,19 +0,0 @@ -import csv -import os -import joblib - -from wafl.knowledge.single_file_knowledge import SingleFileKnowledge - -_path = os.path.dirname(__file__) - - -async def load_knowledge_from_file(filename, config): - items_list = [] - with open(os.path.join(_path, "../../data/", filename + ".csv")) as file: - csvreader = csv.reader(file) - for row in csvreader: - items_list.append(row[0].strip()) - - knowledge = await SingleFileKnowledge.create_from_list(items_list, config) - joblib.dump(knowledge, os.path.join(_path, f"../../data/{filename}.knowledge")) - return knowledge diff --git a/wafl/connectors/clients/entailer_client.py b/wafl/connectors/clients/entailer_client.py new file mode 100644 index 00000000..bdde7dc1 --- /dev/null +++ b/wafl/connectors/clients/entailer_client.py @@ -0,0 +1,19 @@ +import os + +from wafl.connectors.factories.entailer_connector_factory import ( + EntailerConnectorFactory, +) + +_path = os.path.dirname(__file__) + + +class EntailerClient: + def __init__(self, config): + self._connector = EntailerConnectorFactory.get_connector(config) + self._config = config + + async def get_entailment_score(self, lhs: str, rhs: str) -> float: + prediction = await self._connector.predict(lhs, rhs) + if "score" not in prediction: + raise ValueError("The Entailment prediction does not contain a score.") + return prediction["score"] diff --git a/wafl/connectors/clients/information_client.py b/wafl/connectors/clients/information_client.py index 772afb00..533fc902 100644 --- a/wafl/connectors/clients/information_client.py +++ b/wafl/connectors/clients/information_client.py @@ -1,9 +1,6 @@ import os -import textwrap -from typing import List from wafl.connectors.factories.llm_connector_factory import LLMConnectorFactory -from wafl.connectors.prompt_template import PromptTemplate _path = os.path.dirname(__file__) diff --git a/wafl/connectors/factories/entailer_connector_factory.py b/wafl/connectors/factories/entailer_connector_factory.py new file mode 100644 index 00000000..3de3cdff --- /dev/null +++ b/wafl/connectors/factories/entailer_connector_factory.py @@ -0,0 +1,8 @@ +from wafl.config import Configuration +from wafl.connectors.remote.remote_entailer_connector import RemoteEntailerConnector + + +class EntailerConnectorFactory: + @staticmethod + def get_connector(config: Configuration): + return RemoteEntailerConnector(config) diff --git a/wafl/connectors/factories/llm_connector_factory.py b/wafl/connectors/factories/llm_connector_factory.py index 2d1c8714..ea82f837 100644 --- a/wafl/connectors/factories/llm_connector_factory.py +++ b/wafl/connectors/factories/llm_connector_factory.py @@ -1,7 +1,8 @@ +from wafl.config import Configuration from wafl.connectors.remote.remote_llm_connector import RemoteLLMConnector class LLMConnectorFactory: @staticmethod - def get_connector(config): - return RemoteLLMConnector(config.get_value("llm_model")) + def get_connector(config: Configuration): + return RemoteLLMConnector(config) diff --git a/wafl/connectors/factories/sentence_embedder_connector_factory.py b/wafl/connectors/factories/sentence_embedder_connector_factory.py index 7ab1a708..2db98989 100644 --- a/wafl/connectors/factories/sentence_embedder_connector_factory.py +++ b/wafl/connectors/factories/sentence_embedder_connector_factory.py @@ -1,3 +1,4 @@ +from wafl.config import Configuration from wafl.connectors.remote.remote_sentence_embedder_connector import ( RemoteSentenceEmbedderConnector, ) @@ -5,5 +6,5 @@ class SentenceEmbedderConnectorFactory: @staticmethod - def get_connector(model_name, config): - return RemoteSentenceEmbedderConnector(config.get_value(model_name)) + def get_connector(config: Configuration): + return RemoteSentenceEmbedderConnector(config) diff --git a/wafl/connectors/factories/speaker_connector_factory.py b/wafl/connectors/factories/speaker_connector_factory.py index 578497eb..da11134b 100644 --- a/wafl/connectors/factories/speaker_connector_factory.py +++ b/wafl/connectors/factories/speaker_connector_factory.py @@ -1,7 +1,8 @@ +from wafl.config import Configuration from wafl.connectors.remote.remote_speaker_connector import RemoteSpeakerConnector class SpeakerConnectorFactory: @staticmethod - def get_connector(config): - return RemoteSpeakerConnector(config.get_value("speaker_model")) + def get_connector(config: Configuration): + return RemoteSpeakerConnector(config) diff --git a/wafl/connectors/factories/whisper_connector_factory.py b/wafl/connectors/factories/whisper_connector_factory.py index 5b4e1c2f..d8a7068c 100644 --- a/wafl/connectors/factories/whisper_connector_factory.py +++ b/wafl/connectors/factories/whisper_connector_factory.py @@ -1,7 +1,8 @@ +from wafl.config import Configuration from wafl.connectors.remote.remote_whisper_connector import RemoteWhisperConnector class WhisperConnectorFactory: @staticmethod - def get_connector(config): - return RemoteWhisperConnector(config.get_value("listener_model")) + def get_connector(config: Configuration): + return RemoteWhisperConnector(config) diff --git a/wafl/connectors/remote/remote_configuration_connector.py b/wafl/connectors/remote/remote_configuration_connector.py new file mode 100644 index 00000000..3590b5d5 --- /dev/null +++ b/wafl/connectors/remote/remote_configuration_connector.py @@ -0,0 +1,61 @@ +import aiohttp +import asyncio +import json + +from typing import Dict +from wafl.variables import get_variables + + +class RemoteConfigurationConnector: + _max_tries = 3 + + def __init__(self, host: str, port: int): + self._server_url = f"https://{host}:{port}/predictions/configuration" + try: + loop = asyncio.get_running_loop() + + except RuntimeError: + loop = None + + if (not loop or (loop and not loop.is_running())) and not asyncio.run( + self.check_connection() + ): + raise RuntimeError( + "Cannot connect a running Configuration handler. Is WAFL-LLM running?" + ) + + async def predict(self) -> Dict[str, str]: + payload = {"version": get_variables()["version"]} + for _ in range(self._max_tries): + async with aiohttp.ClientSession( + connector=aiohttp.TCPConnector(ssl=False) + ) as session: + async with session.post(self._server_url, json=payload) as response: + data = await response.text() + prediction = json.loads(data) + listener_model = prediction["listener_model"] + speaker_model = prediction["speaker_model"] + text_embedding_model = prediction["text_embedding_model"] + entailer_model = prediction["entailer_model"] + llm_model = prediction["llm_model"] + return { + "listener_model": listener_model, + "speaker_model": speaker_model, + "text_embedding_model": text_embedding_model, + "entailer_model": entailer_model, + "llm_model": llm_model, + } + + return {} + + async def check_connection(self) -> bool: + try: + async with aiohttp.ClientSession( + connector=aiohttp.TCPConnector(ssl=False) + ) as session: + payload = {"version": get_variables()["version"]} + async with session.post(self._server_url, json=payload) as response: + return response.status == 200 + + except Exception: + return False diff --git a/wafl/connectors/remote/remote_entailer_connector.py b/wafl/connectors/remote/remote_entailer_connector.py new file mode 100644 index 00000000..50430cf9 --- /dev/null +++ b/wafl/connectors/remote/remote_entailer_connector.py @@ -0,0 +1,60 @@ +import aiohttp +import asyncio +import json +from typing import Dict + +from wafl.config import Configuration + + +class RemoteEntailerConnector: + _max_tries = 3 + + def __init__(self, config: Configuration): + host = config.get_value("backend")["host"] + port = config.get_value("backend")["port"] + + self._server_url = f"https://{host}:" f"{port}/predictions/entailer" + try: + loop = asyncio.get_running_loop() + + except RuntimeError: + loop = None + + if (not loop or (loop and not loop.is_running())) and not asyncio.run( + self.check_connection() + ): + raise RuntimeError("Cannot connect a running Entailment Model.") + + async def predict(self, lhs: str, rhs: str) -> Dict[str, float]: + payload = {"lhs": lhs, "rhs": rhs} + for _ in range(self._max_tries): + async with aiohttp.ClientSession( + connector=aiohttp.TCPConnector(ssl=False) + ) as session: + async with session.post(self._server_url, json=payload) as response: + data = await response.text() + prediction = json.loads(data) + if "score" in prediction: + score = prediction["score"] + return {"score": float(score)} + return {"score": -1.0} + + return {"score": -1.0} + + async def check_connection(self): + payload = {"lhs": "test", "rhs": "test"} + try: + async with aiohttp.ClientSession( + conn_timeout=3, connector=aiohttp.TCPConnector(ssl=False) + ) as session: + async with session.post(self._server_url, json=payload) as response: + await response.text() + return True + + except aiohttp.client.InvalidURL: + print() + print("Is the entailer server running?") + print("Please run 'bash start-llm.sh' (see docs for explanation).") + print() + + return False diff --git a/wafl/connectors/remote/remote_llm_connector.py b/wafl/connectors/remote/remote_llm_connector.py index a2b3d7f7..01b8bcd0 100644 --- a/wafl/connectors/remote/remote_llm_connector.py +++ b/wafl/connectors/remote/remote_llm_connector.py @@ -3,6 +3,7 @@ import aiohttp import asyncio +from wafl.config import Configuration from wafl.connectors.base_llm_connector import BaseLLMConnector from wafl.connectors.prompt_template import PromptTemplate from wafl.variables import is_supported @@ -11,14 +12,14 @@ class RemoteLLMConnector(BaseLLMConnector): _max_tries = 3 _max_reply_length = 1024 - _num_prediction_tokens = 200 + _num_prediction_tokens = 1024 _cache = {} - def __init__(self, config, last_strings=None, num_replicas=1): + def __init__(self, config: Configuration, last_strings=None, num_replicas=1): super().__init__(last_strings) - host = config["model_host"] - port = config["model_port"] - self._default_temperature = config["temperature"] + host = config.get_value("backend")["host"] + port = config.get_value("backend")["port"] + self._default_temperature = config.get_value("generation_config")["temperature"] self._server_url = f"https://{host}:{port}/predictions/bot" self._num_replicas = num_replicas diff --git a/wafl/connectors/remote/remote_sentence_embedder_connector.py b/wafl/connectors/remote/remote_sentence_embedder_connector.py index ae71e6fe..9b75b2f0 100644 --- a/wafl/connectors/remote/remote_sentence_embedder_connector.py +++ b/wafl/connectors/remote/remote_sentence_embedder_connector.py @@ -10,8 +10,8 @@ class RemoteSentenceEmbedderConnector: _max_tries = 3 def __init__(self, config): - host = config["model_host"] - port = config["model_port"] + host = config.get_value("backend")["host"] + port = config.get_value("backend")["port"] self._server_url = f"https://{host}:" f"{port}/predictions/sentence_embedder" try: diff --git a/wafl/connectors/remote/remote_speaker_connector.py b/wafl/connectors/remote/remote_speaker_connector.py index 3d1226b4..a0402fa8 100644 --- a/wafl/connectors/remote/remote_speaker_connector.py +++ b/wafl/connectors/remote/remote_speaker_connector.py @@ -4,15 +4,16 @@ import json from typing import Dict +from wafl.config import Configuration class RemoteSpeakerConnector: _max_tries = 3 - def __init__(self, config): + def __init__(self, config: Configuration): self._server_url = ( - f"https://{config['model_host']}:" - f"{config['model_port']}/predictions/speaker" + f"https://{config.get_value('backend')['host']}:" + f"{config.get_value('backend')['port']}/predictions/speaker" ) try: loop = asyncio.get_running_loop() diff --git a/wafl/connectors/remote/remote_whisper_connector.py b/wafl/connectors/remote/remote_whisper_connector.py index d9498a83..8632a62f 100644 --- a/wafl/connectors/remote/remote_whisper_connector.py +++ b/wafl/connectors/remote/remote_whisper_connector.py @@ -4,14 +4,16 @@ from typing import Dict +from wafl.config import Configuration + class RemoteWhisperConnector: _max_tries = 3 - def __init__(self, config): + def __init__(self, config: Configuration): self._server_url = ( - f"https://{config['model_host']}:" - f"{config['model_port']}/predictions/whisper" + f"https://{config.get_value('backend')['host']}:" + f"{config.get_value('backend')['port']}/predictions/whisper" ) try: loop = asyncio.get_running_loop() diff --git a/wafl/dataclasses/__init__.py b/wafl/data_objects/__init__.py similarity index 100% rename from wafl/dataclasses/__init__.py rename to wafl/data_objects/__init__.py diff --git a/wafl/dataclasses/dataclasses.py b/wafl/data_objects/dataclasses.py similarity index 100% rename from wafl/dataclasses/dataclasses.py rename to wafl/data_objects/dataclasses.py diff --git a/wafl/data_objects/facts.py b/wafl/data_objects/facts.py new file mode 100644 index 00000000..88926ab0 --- /dev/null +++ b/wafl/data_objects/facts.py @@ -0,0 +1,33 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Union + + +class Sources(Enum): + FROM_TEXT = 1 + FROM_RULES = 2 + + +@dataclass +class Fact: + text: Union[str, dict] + is_question: bool = False + variable: str = None + is_interruption: bool = False + destination: str = None + metadata: Union[str, dict] = None + source: Sources = Sources.FROM_RULES + + def toJSON(self): + return str(self) + + def copy(self): + return Fact( + self.text, + self.is_question, + self.variable, + self.is_interruption, + self.destination, + self.metadata, + self.source, + ) diff --git a/wafl/dataclasses/rules.py b/wafl/data_objects/rules.py similarity index 100% rename from wafl/dataclasses/rules.py rename to wafl/data_objects/rules.py diff --git a/wafl/dataclasses/facts.py b/wafl/dataclasses/facts.py deleted file mode 100644 index 0445adff..00000000 --- a/wafl/dataclasses/facts.py +++ /dev/null @@ -1,16 +0,0 @@ -from dataclasses import dataclass -from typing import Union - - -@dataclass -class Fact: - text: Union[str, dict] - is_question: bool = False - variable: str = None - is_interruption: bool = False - source: str = None - destination: str = None - metadata: Union[str, dict] = None - - def toJSON(self): - return str(self) diff --git a/wafl/events/conversation_events.py b/wafl/events/conversation_events.py index d83d52ce..4c0da959 100644 --- a/wafl/events/conversation_events.py +++ b/wafl/events/conversation_events.py @@ -19,6 +19,7 @@ def __init__( config: "Configuration", interface: "BaseInterface", logger=None, + knowledge=None, ): self._config = config try: @@ -29,6 +30,8 @@ def __init__( if not loop or not loop.is_running(): self._knowledge = asyncio.run(load_knowledge(config, logger)) + else: + self._knowledge = knowledge self._answerer = create_answerer(config, self._knowledge, interface, logger) self._answerer._client._connector._cache = {} @@ -38,6 +41,9 @@ def __init__( if logger: self._logger.set_depth(0) + def reset(self): + self._answerer.reset() + async def output(self, text: str): await self._interface.output(text) @@ -74,6 +80,9 @@ async def _process_query(self, text: str): ): await self._interface.output("I don't know what to reply") + if not text_is_question and not self._interface.get_utterances_list(): + await self._interface.output("I don't know what to reply") + if ( not text_is_question and answer.is_true() diff --git a/wafl/extractors/__init__.py b/wafl/extractors/__init__.py deleted file mode 100644 index ea96f559..00000000 --- a/wafl/extractors/__init__.py +++ /dev/null @@ -1 +0,0 @@ -print diff --git a/wafl/extractors/utils.py b/wafl/extractors/utils.py deleted file mode 100644 index 4401c22f..00000000 --- a/wafl/extractors/utils.py +++ /dev/null @@ -1,23 +0,0 @@ -import re - - -def get_answer_from_text(text): - _claim_yn = "The claim makes sense:\n" - pos = text.find(_claim_yn) + len(_claim_yn) - return text[pos] - - -def get_text_up_to_question(text): - _claim_yn = "The claim makes sense:\n" - return text[: text.find(_claim_yn) + len(_claim_yn)] - - -def get_function_description(text): - if "<" not in text: - return "" - - return re.sub(r".*<(.*)>$", r"\1", text, re.MULTILINE).strip() - - -def get_code(text): - return re.sub(r"(.*)<.*>$", r"\1", text, re.MULTILINE).strip() diff --git a/wafl/handlers/conversation_handler.py b/wafl/handlers/conversation_handler.py index ccb4d1a9..18f1e7af 100644 --- a/wafl/handlers/conversation_handler.py +++ b/wafl/handlers/conversation_handler.py @@ -66,6 +66,7 @@ async def _main_loop(self): and interactions == 1 ): self._interface.deactivate() + self._conversation_events.reset() num_misses = 0 if ( diff --git a/wafl/inference/utils.py b/wafl/inference/utils.py index 2f25bda1..d270814b 100644 --- a/wafl/inference/utils.py +++ b/wafl/inference/utils.py @@ -2,7 +2,7 @@ from typing import List, Dict, Tuple, Any from fuzzywuzzy import process -from wafl.dataclasses.dataclasses import Answer +from wafl.data_objects.dataclasses import Answer from wafl.simple_text_processing.normalize import normalized from wafl.simple_text_processing.questions import is_question diff --git a/wafl/interface/conversation.py b/wafl/interface/conversation.py index 68687eb1..c65f1f5b 100644 --- a/wafl/interface/conversation.py +++ b/wafl/interface/conversation.py @@ -111,6 +111,15 @@ def get_last_speaker_utterances(self, speaker: str, n: int) -> List[str]: if utterance.speaker == speaker ][-n:] + def get_last_speaker_utterance(self, speaker: str) -> str: + if not self.utterances: + return "" + + for utterance in reversed(self.utterances): + if utterance.speaker == speaker: + return utterance.text + return "" + def get_first_timestamp(self) -> float: return self.utterances[0].timestamp if self.utterances else None diff --git a/wafl/interface/voice_interface.py b/wafl/interface/voice_interface.py index 17ac6d5f..0b8a7821 100644 --- a/wafl/interface/voice_interface.py +++ b/wafl/interface/voice_interface.py @@ -1,4 +1,3 @@ -import asyncio import os import random import re @@ -7,7 +6,7 @@ from wafl.interface.base_interface import BaseInterface from wafl.interface.utils import not_good_enough from wafl.listener.whisper_listener import WhisperListener -from wafl.speaker.fairseq_speaker import FairSeqSpeaker +from wafl.speaker.tts_speaker import TTSSpeaker from wafl.speaker.soundfile_speaker import SoundFileSpeaker _path = os.path.dirname(__file__) @@ -27,7 +26,7 @@ def __init__(self, config): self._deactivation_sound_filename = self.__get_deactivation_sound_from_config( config ) - self._speaker = FairSeqSpeaker(config) + self._speaker = TTSSpeaker(config) self._listener = WhisperListener(config) self._listener.set_timeout( config.get_value("listener_model")["listener_silence_timeout"] @@ -77,7 +76,7 @@ async def input(self) -> str: text = text.lower().capitalize() print(COLOR_START + "user> " + text + COLOR_END) utterance = remove_text_between_brackets(text) - if utterance.strip(): + if utterance.strip() and self._is_listening: self._insert_utterance(speaker="user", text=text) return remove_unclear(text) diff --git a/wafl/knowledge/indexing_implementation.py b/wafl/knowledge/indexing_implementation.py index a780c7de..d4c3a8f8 100644 --- a/wafl/knowledge/indexing_implementation.py +++ b/wafl/knowledge/indexing_implementation.py @@ -1,24 +1,50 @@ +import asyncio import os - import joblib import yaml +import threading +from tqdm import tqdm from wafl.config import Configuration from wafl.knowledge.single_file_knowledge import SingleFileKnowledge from wafl.readers.reader_factory import ReaderFactory +async def add_file_to_knowledge(knowledge, filename): + reader = ReaderFactory.get_reader(filename) + for chunk in reader.get_chunks(filename): + await knowledge.add_fact(chunk) + + async def _add_indices_to_knowledge(knowledge, text): indices = yaml.safe_load(text) if "paths" not in indices or not indices["paths"]: return knowledge for path in indices["paths"]: - for root, _, files in os.walk(path): - for file in files: - reader = ReaderFactory.get_reader(file) - for chunk in reader.get_chunks(os.path.join(root, file)): - await knowledge.add_fact(chunk) + print(f"Indexing path: {path}") + file_count = sum(len(files) for _, _, files in os.walk(path)) + with tqdm(total=file_count) as pbar: + for root, _, files in os.walk(path): + threads = [] + for file in files: + threads.append( + threading.Thread( + target=asyncio.run, + args=( + add_file_to_knowledge( + knowledge, os.path.join(root, file) + ), + ), + ) + ) + num_threads = min(10, len(threads)) + for i in range(0, len(threads), num_threads): + for thread in threads[i : i + num_threads]: + thread.start() + for thread in threads[i : i + num_threads]: + thread.join() + pbar.update(num_threads) return knowledge @@ -27,10 +53,12 @@ async def load_knowledge(config, logger=None): if ".yaml" in config.get_value("rules") and not any( item in config.get_value("rules") for item in [" ", "\n"] ): + rules_filename = config.get_value("rules") with open(config.get_value("rules")) as file: rules_txt = file.read() else: + rules_filename = None rules_txt = config.get_value("rules") index_filename = config.get_value("index") @@ -41,17 +69,18 @@ async def load_knowledge(config, logger=None): cache_filename = config.get_value("cache_filename") if os.path.exists(cache_filename): - knowledge = joblib.load(cache_filename) - if knowledge.hash == hash(rules_txt) and os.path.getmtime( - cache_filename - ) > os.path.getmtime(index_filename): - await knowledge.initialize_retrievers() + if ( + rules_filename + and os.path.getmtime(cache_filename) > os.path.getmtime(rules_filename) + and os.path.getmtime(cache_filename) > os.path.getmtime(index_filename) + ): + knowledge = joblib.load(cache_filename) return knowledge knowledge = SingleFileKnowledge(config, rules_txt, logger=logger) knowledge = await _add_indices_to_knowledge(knowledge, index_txt) - joblib.dump(knowledge, config.get_value("cache_filename")) await knowledge.initialize_retrievers() + joblib.dump(knowledge, config.get_value("cache_filename")) return knowledge diff --git a/wafl/knowledge/single_file_knowledge.py b/wafl/knowledge/single_file_knowledge.py index 8c9c5a25..2fc15c39 100644 --- a/wafl/knowledge/single_file_knowledge.py +++ b/wafl/knowledge/single_file_knowledge.py @@ -4,9 +4,10 @@ from typing import List import nltk +from tqdm import tqdm from wafl.config import Configuration -from wafl.dataclasses.facts import Fact +from wafl.data_objects.facts import Fact from wafl.knowledge.base_knowledge import BaseKnowledge from wafl.knowledge.utils import ( text_is_exact_string, @@ -169,7 +170,8 @@ def get_facts_and_rule_as_text(self): return text async def initialize_retrievers(self): - for index, fact in self._facts_dict.items(): + print("Initializing fact retrievers") + for index, fact in tqdm(self._facts_dict.items()): if text_is_exact_string(fact.text): continue @@ -181,7 +183,8 @@ async def initialize_retrievers(self): clean_text_for_retrieval(fact.text), index ) - for index, rule in self._rules_dict.items(): + print("Initializing rule retrievers") + for index, rule in tqdm(self._rules_dict.items()): if text_is_exact_string(rule.effect.text): continue @@ -189,10 +192,6 @@ async def initialize_retrievers(self): clean_text_for_retrieval(rule.effect.text), index ) - for index, rule in self._rules_dict.items(): - if not text_is_exact_string(rule.effect.text): - continue - await self._rules_string_retriever.add_text_and_index( rule.effect.text, index ) diff --git a/wafl/parsing/line_rules_parser.py b/wafl/parsing/line_rules_parser.py index 73371f3e..b6bfa3ee 100644 --- a/wafl/parsing/line_rules_parser.py +++ b/wafl/parsing/line_rules_parser.py @@ -1,6 +1,6 @@ from wafl.simple_text_processing.questions import is_question -from wafl.dataclasses.facts import Fact -from wafl.dataclasses.rules import Rule +from wafl.data_objects.facts import Fact +from wafl.data_objects.rules import Rule def parse_rule_from_single_line(text): diff --git a/wafl/parsing/rules_parser.py b/wafl/parsing/rules_parser.py index 70d3b5f1..bb813e93 100644 --- a/wafl/parsing/rules_parser.py +++ b/wafl/parsing/rules_parser.py @@ -1,7 +1,7 @@ import yaml -from wafl.dataclasses.facts import Fact -from wafl.dataclasses.rules import Rule +from wafl.data_objects.facts import Fact +from wafl.data_objects.rules import Rule from wafl.simple_text_processing.deixis import from_user_to_bot diff --git a/wafl/readers/base_reader.py b/wafl/readers/base_reader.py index ea995601..5f0aaef2 100644 --- a/wafl/readers/base_reader.py +++ b/wafl/readers/base_reader.py @@ -1,6 +1,6 @@ from typing import List -from wafl.dataclasses.facts import Fact +from wafl.data_objects.facts import Fact class BaseReader: diff --git a/wafl/readers/pdf_reader.py b/wafl/readers/pdf_reader.py index 4f610616..dc94f664 100644 --- a/wafl/readers/pdf_reader.py +++ b/wafl/readers/pdf_reader.py @@ -2,7 +2,7 @@ from logging import getLogger from typing import List -from wafl.dataclasses.facts import Fact +from wafl.data_objects.facts import Fact, Sources from wafl.readers.base_reader import BaseReader _logger = getLogger(__name__) @@ -20,6 +20,7 @@ def get_chunks(self, filename: str) -> List[Fact]: Fact( text=page.get_text(), metadata={"filename": filename, "page_number": i}, + source=Sources.FROM_TEXT, ) for i, page in enumerate(doc) ] diff --git a/wafl/readers/reader_factory.py b/wafl/readers/reader_factory.py index 14ccb70c..6fc33bfc 100644 --- a/wafl/readers/reader_factory.py +++ b/wafl/readers/reader_factory.py @@ -4,7 +4,7 @@ class ReaderFactory: _chunk_size = 10000 - _overlap = 100 + _overlap = 500 _extension_to_reader_dict = {".pdf": PdfReader, ".txt": TextReader} @staticmethod @@ -13,7 +13,4 @@ def get_reader(filename): if extension in filename.lower(): return reader(ReaderFactory._chunk_size, ReaderFactory._overlap) - ### add pdf reader - ### add metadata and show in the UI - return TextReader(ReaderFactory._chunk_size, ReaderFactory._overlap) diff --git a/wafl/readers/text_reader.py b/wafl/readers/text_reader.py index b22c4ffe..8457ee04 100644 --- a/wafl/readers/text_reader.py +++ b/wafl/readers/text_reader.py @@ -1,7 +1,7 @@ from logging import getLogger from typing import List -from wafl.dataclasses.facts import Fact +from wafl.data_objects.facts import Fact, Sources from wafl.readers.base_reader import BaseReader _logger = getLogger(__name__) @@ -20,6 +20,7 @@ def get_chunks(self, filename: str) -> List[Fact]: Fact( text=chunk, metadata={"filename": filename, "chunk_number": i}, + source=Sources.FROM_TEXT, ) for i, chunk in enumerate(chunks) ] diff --git a/wafl/retriever/dense_retriever.py b/wafl/retriever/dense_retriever.py index 4bf3bb8c..acf28876 100644 --- a/wafl/retriever/dense_retriever.py +++ b/wafl/retriever/dense_retriever.py @@ -2,7 +2,6 @@ import numpy as np from typing import List, Tuple -from gensim.models import KeyedVectors from wafl.connectors.factories.sentence_embedder_connector_factory import ( SentenceEmbedderConnectorFactory, ) @@ -15,21 +14,25 @@ class DenseRetriever(BaseRetriever): _threshold_length = 5 def __init__(self, model_name, config): - self._connector = SentenceEmbedderConnectorFactory.get_connector( - model_name, config - ) - self._embeddings_model = KeyedVectors(384) + self._connector = SentenceEmbedderConnectorFactory.get_connector(config) + self._matrix = np.zeros((0, 384)) + self._indices = [] async def add_text_and_index(self, text: str, index: str): embeddings = await self._get_embeddings_from_text(text) - self._embeddings_model.add_vectors([index], [embeddings]) - self._embeddings_model.fill_norms(force=True) + self._matrix = np.vstack([self._matrix, embeddings]) + self._indices.append(index) async def get_indices_and_scores_from_text( - self, text: str + self, text: str, topn: int = 5 ) -> List[Tuple[str, float]]: embeddings = await self._get_embeddings_from_text(text) - return self._embeddings_model.similar_by_vector(embeddings, topn=5) + scores = np.dot(self._matrix, embeddings) / ( + np.linalg.norm(self._matrix, axis=1) * np.linalg.norm(embeddings) + ) + indices_and_scores = list(zip(self._indices, scores)) + indices_and_scores.sort(key=lambda x: x[1], reverse=True) + return indices_and_scores[:topn] async def _get_embeddings_from_text(self, text: str) -> "numpy.array": return (await self._connector.predict(text))["embedding"] diff --git a/wafl/run.py b/wafl/run.py index b0397e84..4138ac48 100644 --- a/wafl/run.py +++ b/wafl/run.py @@ -4,6 +4,7 @@ from wafl.exceptions import CloseConversation from wafl.events.conversation_events import ConversationEvents from wafl.interface.command_line_interface import CommandLineInterface +from wafl.knowledge.indexing_implementation import load_knowledge from wafl.logger.local_file_logger import LocalFileLogger from wafl.testcases import ConversationTestCases from wafl.variables import get_variables @@ -17,6 +18,12 @@ def print_incipit(): print() +def load_indices(): + print("Loading knowledge indices...") + config = Configuration.load_local_config() + asyncio.run(load_knowledge(config, _logger)) + + def run_from_command_line(): interface = CommandLineInterface() config = Configuration.load_local_config() diff --git a/wafl/runners/run_from_audio.py b/wafl/runners/run_from_audio.py index 7b523687..7a1d8f35 100644 --- a/wafl/runners/run_from_audio.py +++ b/wafl/runners/run_from_audio.py @@ -1,6 +1,9 @@ +import asyncio + from wafl.config import Configuration from wafl.events.conversation_events import ConversationEvents from wafl.interface.voice_interface import VoiceInterface +from wafl.knowledge.indexing_implementation import load_knowledge from wafl.logger.local_file_logger import LocalFileLogger from wafl.handlers.conversation_handler import ConversationHandler from wafl.scheduler.scheduler import Scheduler @@ -10,6 +13,7 @@ def run_from_audio(): config = Configuration.load_local_config() + asyncio.run(load_knowledge(config, _logger)) interface = VoiceInterface(config) conversation_events = ConversationEvents( config=config, diff --git a/wafl/runners/run_web_and_audio_interface.py b/wafl/runners/run_web_and_audio_interface.py index 9d5f833c..177f4501 100644 --- a/wafl/runners/run_web_and_audio_interface.py +++ b/wafl/runners/run_web_and_audio_interface.py @@ -1,3 +1,4 @@ +import asyncio import random import sys import threading @@ -6,6 +7,7 @@ from wafl.interface.list_interface import ListInterface from wafl.interface.voice_interface import VoiceInterface +from wafl.knowledge.indexing_implementation import load_knowledge from wafl.scheduler.scheduler import Scheduler from wafl.handlers.conversation_handler import ConversationHandler from wafl.logger.local_file_logger import LocalFileLogger diff --git a/wafl/speaker/fairseq_speaker.py b/wafl/speaker/tts_speaker.py similarity index 69% rename from wafl/speaker/fairseq_speaker.py rename to wafl/speaker/tts_speaker.py index 4e775f14..99066349 100644 --- a/wafl/speaker/fairseq_speaker.py +++ b/wafl/speaker/tts_speaker.py @@ -8,15 +8,16 @@ from wafl.speaker.utils import convert_numbers_to_words -class FairSeqSpeaker(BaseSpeaker): +class TTSSpeaker(BaseSpeaker): def __init__(self, config): self._connector = SpeakerConnectorFactory.get_connector(config) self._p = pyaudio.PyAudio() self._input_chunk_size = 1024 - self._output_chunk_size = 16384 + self._output_chunk_size = 4096 self._volume_threshold = ( - config.get_value("listener_model")["listener_volume_threshold"] / 5e3 + config.get_value("listener_model")["listener_volume_threshold"] * 1e-4 ) + self._interruptible = config.get_value("listener_model")["interruptible"] async def speak(self, text): text = convert_numbers_to_words(text) @@ -32,11 +33,15 @@ async def speak(self, text): ) stream.start_stream() await asyncio.sleep(0.1) - for i in range(0, len(wav), self._output_chunk_size): - inp = stream.read(self._input_chunk_size) - if _rms(inp) > self._volume_threshold: - break - stream.write(wav[i : i + self._output_chunk_size]) + if self._interruptible: + for i in range(0, len(wav), self._output_chunk_size): + inp = stream.read(self._input_chunk_size) + if _rms(inp) > self._volume_threshold: + break + stream.write(wav[i : i + self._output_chunk_size]) + else: + stream.write(wav) + stream.stop_stream() stream.close() await asyncio.sleep(0.1) diff --git a/wafl/templates/config.json b/wafl/templates/config.json index 7af9893e..57657b63 100644 --- a/wafl/templates/config.json +++ b/wafl/templates/config.json @@ -9,24 +9,18 @@ "functions": "functions.py", "max_recursion": 2, "frontend_port": 8090, - "llm_model": { - "model_host": "localhost", - "model_port": 8080, + "backend": { + "host": "localhost", + "port": 8080, + "token": "secret" + }, + "generation_config": { "temperature": 0.4 }, "listener_model": { - "model_host": "localhost", - "model_port": 8080, "listener_hotword_logp": -8, "listener_volume_threshold": 0.6, - "listener_silence_timeout": 0.7 - }, - "speaker_model": { - "model_host": "localhost", - "model_port": 8080 - }, - "text_embedding_model": { - "model_host": "localhost", - "model_port": 8080 + "listener_silence_timeout": 0.7, + "interruptible": false } } diff --git a/wafl/templates/main.prompt b/wafl/templates/main.prompt index 07b45290..8cbb4d64 100644 --- a/wafl/templates/main.prompt +++ b/wafl/templates/main.prompt @@ -8,4 +8,6 @@ The rules that *must* be followed are: Create a plausible dialogue based on the aforementioned summary and rules. Do not repeat yourself. Be friendly but not too servile. -Follow the rules if present and they apply to the dialogue. Do not improvise if rules are present. \ No newline at end of file +Follow the rules if present and they apply to the dialogue. Do not improvise if rules are present. +The user query might be incomplete or ambiguous or ungrammatical. The bot *must* ask for clarification if needed. +The bot only answers if the query is clear and unambiguous. \ No newline at end of file diff --git a/wafl/testcases.py b/wafl/testcases.py index bcc49f4d..cb97af4f 100644 --- a/wafl/testcases.py +++ b/wafl/testcases.py @@ -1,4 +1,6 @@ from wafl.answerer.entailer import Entailer +from wafl.knowledge.indexing_implementation import load_knowledge + from wafl.simple_text_processing.deixis import from_user_to_bot, from_bot_to_user from wafl.exceptions import CloseConversation from wafl.events.conversation_events import ConversationEvents @@ -25,8 +27,10 @@ async def test_single_case(self, name): test_lines = self._testcase_data[name]["lines"] is_negated = self._testcase_data[name]["negated"] interface = DummyInterface(user_lines) - conversation_events = ConversationEvents(self._config, interface=interface) - await conversation_events._knowledge.initialize_retrievers() + knowledge = await load_knowledge(self._config) + conversation_events = ConversationEvents( + self._config, interface=interface, knowledge=knowledge + ) print(self.BLUE_COLOR_START + f"\nRunning test '{name}'." + self.COLOR_END) continue_conversations = True @@ -77,9 +81,7 @@ async def _lhs_is_similar_to(self, lhs, rhs, prior_dialogue): if lhs_name != rhs_name: return False - return await self._entailer.left_entails_right( - lhs, rhs, "\n".join(prior_dialogue) - ) + return await self._entailer.left_entails_right(lhs, rhs) def _apply_deixis(self, line): name = line.split(":")[0].strip() diff --git a/wafl/variables.py b/wafl/variables.py index 0d6490e8..b2e183e7 100644 --- a/wafl/variables.py +++ b/wafl/variables.py @@ -1,9 +1,9 @@ def get_variables(): return { - "version": "0.1.1", + "version": "0.1.3", } def is_supported(wafl_llm_version): - supported_versions = ["0.1.0"] + supported_versions = ["0.1.1"] return wafl_llm_version in supported_versions