diff --git a/documentation/source/configuration.rst b/documentation/source/configuration.rst
index f67dfbe7..b098092a 100644
--- a/documentation/source/configuration.rst
+++ b/documentation/source/configuration.rst
@@ -7,37 +7,33 @@ A typical configuration file looks like this:
.. code-block:: text
- {
- "waking_up_word": "computer",
- "waking_up_sound": true,
- "deactivate_sound": true,
- "rules": "rules.yaml",
- "index": "indices.yaml",
- "cache_filename": "knowledge_cache",
- "prompt_filename": "main.prompt",
- "functions": "functions.py",
- "max_recursion": 2,
- "llm_model": {
- "model_host": "localhost",
- "model_port": 8080,
- "temperature": 0.4
- },
- "listener_model": {
- "model_host": "localhost",
- "model_port": 8080,
- "listener_hotword_logp": -8,
- "listener_volume_threshold": 0.6,
- "listener_silence_timeout": 0.7
- },
- "speaker_model": {
- "model_host": "localhost",
- "model_port": 8080
- },
- "text_embedding_model": {
- "model_host": "localhost",
- "model_port": 8080
- }
- }
+{
+ "waking_up_word": "computer",
+ "waking_up_sound": true,
+ "deactivate_sound": true,
+ "rules": "rules.yaml",
+ "index": "indices.yaml",
+ "cache_filename": "knowledge_cache",
+ "prompt_filename": "main.prompt",
+ "functions": "functions.py",
+ "max_recursion": 2,
+ "frontend_port": 8090,
+ "backend": {
+ "host": "localhost",
+ "port": 8080,
+ "token": "secret"
+ },
+ "generation_config": {
+ "temperature": 0.4
+ },
+ "listener_model": {
+ "listener_hotword_logp": -8,
+ "listener_volume_threshold": 0.6,
+ "listener_silence_timeout": 0.7,
+ "interruptible": true
+ }
+}
+
@@ -59,20 +55,10 @@ These settings regulate the following:
* "frontend_port" is the port where the web frontend is running. The default is 8090.
- * "llm_model" is the configuration to connect to wafl-llm in the backend. The default url is "localhost:8080". The "temperature" parameter is used to set the temperature for the LLM model. The default is 0.4.
-
- * "listener_model" is the configuration to connect to the listener model in the backend. The default is "localhost:8080".
-
- - The listener model is used to detect the wake-up word.
- The similarity threshold for the detection can be set with the "listener_hotword_logp" parameter.
-
- - The "listener_volume_threshold" parameter is used to set the volume threshold for any conversation.
- Any word uttered with a volume below this threshold is ignored.
-
- - The "listener_silence_timeout" parameter is used to set the silence timeout for any conversation.
- If no word is uttered for a time longer than this timeout, the conversation is considered finished.
-
- * "speaker_model" is the configuration to connect to the speaker model in the backend. The default is "localhost:8080".
+ * "backend" is the configuration related to the backend. The default is "localhost:8080".
- * "text_embedding_model" is the configuration to connect to the text embedding model in the backend. The default is "localhost:8080".
+ * "generation_config" is the configuration related to the generation of the response. The default is "temperature: 0.4".
+ * "listener_model" is the configuration related to the listener model.
+ These items determine the thresholds for hotword detection, volume threshold, silence timeout, and whether the listener is interruptible.
+ The default is "listener_hotword_logp: -8", "listener_volume_threshold: 0.6", "listener_silence_timeout: 0.7", "interruptible: true".
diff --git a/requirements.txt b/requirements.txt
index 19f7062d..7bba6a9b 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,7 +1,6 @@
flask[async]==3.0.3
flask-cors==4.0.1
nltk==3.8.1
-gensim==4.3.3
sklearn==0.0
python-Levenshtein==0.25.1
fuzzywuzzy==0.18.0
diff --git a/setup.py b/setup.py
index 3b239ac4..c926112f 100644
--- a/setup.py
+++ b/setup.py
@@ -20,9 +20,8 @@
"wafl.connectors.clients",
"wafl.connectors.factories",
"wafl.connectors.remote",
- "wafl.dataclasses",
+ "wafl.data_objects",
"wafl.events",
- "wafl.extractors",
"wafl.handlers",
"wafl.inference",
"wafl.interface",
@@ -49,7 +48,6 @@
"flask[async]==3.0.3",
"flask-cors==4.0.1",
"nltk==3.8.1",
- "gensim==4.3.3",
"sklearn==0.0",
"python-Levenshtein==0.25.1",
"fuzzywuzzy==0.18.0",
diff --git a/tests/config.json b/tests/config.json
index 121d77f0..af40d7a7 100644
--- a/tests/config.json
+++ b/tests/config.json
@@ -8,24 +8,19 @@
"prompt_filename": "main.prompt",
"functions": "functions.py",
"max_recursion": 2,
- "llm_model": {
- "model_host": "localhost",
- "model_port": 8080,
+ "frontend_port": 8090,
+ "backend": {
+ "host": "aragorn",
+ "port": 8080,
+ "token": "secret"
+ },
+ "generation_config": {
"temperature": 0.4
},
"listener_model": {
- "model_host": "localhost",
- "model_port": 8080,
"listener_hotword_logp": -8,
"listener_volume_threshold": 0.6,
- "listener_silence_timeout": 0.7
- },
- "speaker_model": {
- "model_host": "localhost",
- "model_port": 8080
- },
- "text_embedding_model": {
- "model_host": "localhost",
- "model_port": 8080
+ "listener_silence_timeout": 0.7,
+ "interruptible": true
}
}
diff --git a/tests/main.prompt b/tests/main.prompt
index 07b45290..8cbb4d64 100644
--- a/tests/main.prompt
+++ b/tests/main.prompt
@@ -8,4 +8,6 @@ The rules that *must* be followed are:
Create a plausible dialogue based on the aforementioned summary and rules.
Do not repeat yourself. Be friendly but not too servile.
-Follow the rules if present and they apply to the dialogue. Do not improvise if rules are present.
\ No newline at end of file
+Follow the rules if present and they apply to the dialogue. Do not improvise if rules are present.
+The user query might be incomplete or ambiguous or ungrammatical. The bot *must* ask for clarification if needed.
+The bot only answers if the query is clear and unambiguous.
\ No newline at end of file
diff --git a/tests/test_connection.py b/tests/test_connection.py
index 317a91fd..112a34c3 100644
--- a/tests/test_connection.py
+++ b/tests/test_connection.py
@@ -12,7 +12,7 @@
class TestConnection(TestCase):
def test__connection_to_generative_model_can_generate_text(self):
config = Configuration.load_local_config()
- connector = RemoteLLMConnector(config.get_value("llm_model"))
+ connector = RemoteLLMConnector(config)
prediction = asyncio.run(
connector.predict(
PromptCreator.create_from_one_instruction(
@@ -25,7 +25,7 @@ def test__connection_to_generative_model_can_generate_text(self):
def test__connection_to_generative_model_can_generate_text_within_tags(self):
config = Configuration.load_local_config()
- connector = RemoteLLMConnector(config.get_value("llm_model"))
+ connector = RemoteLLMConnector(config)
connector._num_prediction_tokens = 200
text = 'Generate a full paragraph based on this chapter title " The First Contact". The theme of the paragraph is space opera. Include the characters "Alberto" and "Maria". Write at least three sentences.'
prompt = f"""
@@ -43,7 +43,7 @@ def test__connection_to_generative_model_can_generate_text_within_tags(self):
def test__connection_to_generative_model_can_generate_a_python_list(self):
config = Configuration.load_local_config()
- connector = RemoteLLMConnector(config.get_value("llm_model"))
+ connector = RemoteLLMConnector(config)
connector._num_prediction_tokens = 200
prompt = "Generate a Python list of 4 chapters names for a space opera book. The output needs to be a python list of strings: "
prediction = asyncio.run(
diff --git a/tests/test_entailer.py b/tests/test_entailer.py
new file mode 100644
index 00000000..6ccc7088
--- /dev/null
+++ b/tests/test_entailer.py
@@ -0,0 +1,34 @@
+import asyncio
+import os
+
+from unittest import TestCase
+from wafl.config import Configuration
+from wafl.connectors.remote.remote_entailer_connector import RemoteEntailerConnector
+from wafl.connectors.clients.entailer_client import EntailerClient
+
+_path = os.path.dirname(__file__)
+
+
+class TestConnection(TestCase):
+ def test__entailer_connector(self):
+ config = Configuration.load_local_config()
+ connector = RemoteEntailerConnector(config)
+ prediction = asyncio.run(
+ connector.predict(
+ "The first contact is a romance novel set in the middle ages.",
+ "The first contact is a science fiction novel about the first contact between humans and aliens.",
+ )
+ )
+ assert prediction["score"] < 0.5
+
+ def test__entailment_client(self):
+
+ config = Configuration.load_local_config()
+ client = EntailerClient(config)
+ prediction = asyncio.run(
+ client.get_entailment_score(
+ "The first contact is a romance novel set in the middle ages.",
+ "The first contact is a science fiction novel about the first contact between humans and aliens.",
+ )
+ )
+ assert prediction < 0.5
diff --git a/tests/test_indexing.py b/tests/test_indexing.py
index 1e64b4a4..0ea51557 100644
--- a/tests/test_indexing.py
+++ b/tests/test_indexing.py
@@ -5,7 +5,7 @@
from unittest import TestCase
from wafl.config import Configuration
-from wafl.dataclasses.dataclasses import Query
+from wafl.data_objects.dataclasses import Query
from wafl.knowledge.indexing_implementation import add_to_index, load_knowledge
_path = os.path.dirname(__file__)
diff --git a/tests/test_speaker.py b/tests/test_speaker.py
index 35fe57fc..79db6a9a 100644
--- a/tests/test_speaker.py
+++ b/tests/test_speaker.py
@@ -4,7 +4,7 @@
from unittest import TestCase
from wafl.config import Configuration
-from wafl.speaker.fairseq_speaker import FairSeqSpeaker
+from wafl.speaker.tts_speaker import TTSSpeaker
from wafl.speaker.soundfile_speaker import SoundFileSpeaker
_wafl_greetings = """
@@ -17,24 +17,30 @@
class TestSpeaker(TestCase):
def test_voice(self):
config = Configuration.load_local_config()
- speaker = FairSeqSpeaker(config)
+ speaker = TTSSpeaker(config)
text = "Hello world"
asyncio.run(speaker.speak(text))
def test_long_text(self):
config = Configuration.load_local_config()
- speaker = FairSeqSpeaker(config)
+ speaker = TTSSpeaker(config)
text = (
"Shall I compare you to a summer's day? Thou art more lovely and temperate."
)
asyncio.run(speaker.speak(text))
- def test_number_pronunciation(self):
+ def test_number_pronunciation1(self):
config = Configuration.load_local_config()
- speaker = FairSeqSpeaker(config)
+ speaker = TTSSpeaker(config)
text = "The time is 54 past 8"
asyncio.run(speaker.speak(text))
+ def test_number_pronunciation2(self):
+ config = Configuration.load_local_config()
+ speaker = TTSSpeaker(config)
+ text = "The time is 8 54"
+ asyncio.run(speaker.speak(text))
+
def test_on_sound(self):
speaker = SoundFileSpeaker()
speaker.speak(os.path.join(_path, "../wafl/sounds/activation.wav"))
diff --git a/tests/test_voice.py b/tests/test_voice.py
index 6ecd785b..7b90ef83 100644
--- a/tests/test_voice.py
+++ b/tests/test_voice.py
@@ -15,7 +15,7 @@
rules:
- the user's name is Jane:
- - write "I hear you"
+ - reply with "I hear you" and nothing else
""".strip()
_path = os.path.dirname(__file__)
@@ -23,7 +23,7 @@
class TestVoice(TestCase):
def test__activation(self):
- interface = DummyInterface(to_utter=["computer", "my name is Jane"])
+ interface = DummyInterface(to_utter=["computer my name is Jane"])
config = Configuration.load_local_config()
config.set_value("rules", _wafl_example)
conversation_events = ConversationEvents(config=config, interface=interface)
diff --git a/todo.txt b/todo.txt
index fa7be712..db4f8d22 100644
--- a/todo.txt
+++ b/todo.txt
@@ -1,4 +1,42 @@
-* why do I need to re-initialise the retrievers after unpickling the knowledge?
+* user prior text and response from failed substitutions between [] instead of just iteration number (line 89, dialogue_answerer.py)
+* remove dead code
+* re-fine-tune phi to get better performance
+
+* delete rules and memory from discourse_answerer
+
+* This is wrong - from wafl_ll
+ <|end|><|assistant|><|user|> Hi!<|end|><|assistant|>
+ The user is sandwiched between the assistant. It should be:
+ <|end|><|assistant|> Hi!<|end|><|user|>
+
+/* make interruptible speech optional
+
+
+* use entailment score to flag a rule for execution before the answer.
+* get all model list from wafl_llm backend. Only specify the connection port and host in wafl
+
+* the answer from the indexed files should be directed from a rule.
+ - facts and rules should live at the highest level of the retrieval
+
+
+/* apply entailer to rule retrieval:
+/ if more than one rule is retrieved, then the one
+/ that is entailed by the query should be chosen
+
+
+
+/* Add tqdm to indexing.
+/* Make it index when wafl start first, not at the first use/login
+
+/* The prior items with timestamps might not be necessary.
+/ - Just implement a queue with a fixed size
+
+* add entailer to wafl_llm
+
+
+/* why do I need to re-initialise the retrievers after unpickling the knowledge?
+ - maybe you should save the retrievers in the knowledge object separately?
+ - It was gensim that was not serializable. Took it out
/* knowledge cache does not cache the rules or facts
diff --git a/wafl/answerer/answerer_implementation.py b/wafl/answerer/answerer_implementation.py
index 3c8b36e8..6d01549b 100644
--- a/wafl/answerer/answerer_implementation.py
+++ b/wafl/answerer/answerer_implementation.py
@@ -1,10 +1,9 @@
import re
-import traceback
from typing import List, Tuple
-
+from wafl.answerer.entailer import Entailer
from wafl.exceptions import CloseConversation
-from wafl.dataclasses.facts import Fact
+from wafl.data_objects.facts import Fact, Sources
from wafl.interface.conversation import Conversation, Utterance
@@ -53,8 +52,11 @@ async def substitute_memory_in_answer_and_get_memories_if_present(
async def execute_results_in_answer(answer_text: str, module, functions) -> str:
+ if "" in answer_text and "" not in answer_text:
+ answer_text += ""
+
matches = re.finditer(
- r"(.*?)|(.*?\))$",
+ r"(.*?)",
answer_text,
re.DOTALL | re.MULTILINE,
)
@@ -65,14 +67,6 @@ async def execute_results_in_answer(answer_text: str, module, functions) -> str:
result = await _run_code(to_execute, module, functions)
answer_text = answer_text.replace(match.group(0), result)
- matches = re.finditer(r"(.*?\))$", answer_text, re.DOTALL | re.MULTILINE)
- for match in matches:
- to_execute = match.group(1)
- if not to_execute:
- continue
- result = await _run_code(to_execute, module, functions)
- answer_text = answer_text.replace(match.group(0), result)
-
return answer_text
@@ -104,8 +98,7 @@ async def _run_code(to_execute: str, module, functions) -> str:
result = (
f"Error while executing\n\n```python\n{to_execute}\n```\n\n{str(e)}"
)
- traceback.print_exc()
- break
+ raise RuntimeError(result)
if not result:
result = f"\n```python\n{to_execute}\n```"
@@ -113,22 +106,32 @@ async def _run_code(to_execute: str, module, functions) -> str:
return result
-def get_text_from_facts_and_thresholds(
+def create_memory_from_fact_list(facts: List[Fact], max_num_facts: int) -> str:
+ text_fact_list = [
+ "\n\n- " + "- " + fact.text + "
"
+ for fact in facts
+ if fact.source == Sources.FROM_TEXT
+ ][:max_num_facts]
+ rule_fact_list = [
+ "\n\n- " + "- " + fact.text + "
"
+ for fact in facts
+ if fact.source in [None, Sources.FROM_RULES]
+ ]
+ return "".join(text_fact_list + rule_fact_list)
+
+
+def get_facts_with_metadata_from_facts_and_thresholds(
facts_and_thresholds: List[Tuple[Fact, float]], memory: str
) -> List[str]:
- text_list = []
+ fact_list = []
for item in facts_and_thresholds:
if item[0].text not in memory:
- text = item[0].text
+ new_fact = item[0].copy()
if item[0].metadata:
- text = (
- f"Metadata for the following text: {str(item[0].metadata)}"
- + "\n"
- + text
- )
- text_list.append(text)
+ new_fact.text = new_fact.text
+ fact_list.append(new_fact)
- return text_list
+ return fact_list
def add_dummy_utterances_to_continue_generation(
@@ -150,3 +153,21 @@ def add_dummy_utterances_to_continue_generation(
def add_memories_to_facts(facts: str, memories: List[str]) -> str:
return facts + "\n" + "\n".join(memories)
+
+
+async def select_best_rules_using_entailer(
+ conversation: Conversation,
+ rules_as_strings: List[str],
+ entailer: Entailer,
+ num_rules: int,
+) -> List[str]:
+ query_text = conversation.get_last_speaker_utterance("user")
+ ### Sort rules by score
+ scores = []
+ for rule in rules_as_strings:
+ score = await entailer.get_score(query_text, rule)
+ scores.append(score)
+ rules_as_strings = sorted(
+ rules_as_strings, key=lambda x: scores[rules_as_strings.index(x)], reverse=True
+ )
+ return rules_as_strings[:num_rules]
diff --git a/wafl/answerer/base_answerer.py b/wafl/answerer/base_answerer.py
deleted file mode 100644
index a999dd0f..00000000
--- a/wafl/answerer/base_answerer.py
+++ /dev/null
@@ -1,3 +0,0 @@
-class BaseAnswerer:
- async def answer(self, query_text: str) -> "Answer":
- raise NotImplementedError
diff --git a/wafl/answerer/dialogue_answerer.py b/wafl/answerer/dialogue_answerer.py
index f12be579..9447070f 100644
--- a/wafl/answerer/dialogue_answerer.py
+++ b/wafl/answerer/dialogue_answerer.py
@@ -1,33 +1,37 @@
from importlib import import_module
from inspect import getmembers, isfunction
-from typing import List, Tuple
+from typing import List
+
+from wafl.answerer.entailer import Entailer
from wafl.answerer.answerer_implementation import (
substitute_memory_in_answer_and_get_memories_if_present,
create_one_liner,
- get_text_from_facts_and_thresholds,
+ get_facts_with_metadata_from_facts_and_thresholds,
add_dummy_utterances_to_continue_generation,
add_memories_to_facts,
execute_results_in_answer,
+ create_memory_from_fact_list,
+ select_best_rules_using_entailer,
)
-from wafl.answerer.base_answerer import BaseAnswerer
from wafl.answerer.rule_maker import RuleMaker
from wafl.connectors.clients.llm_chat_client import LLMChatClient
-from wafl.dataclasses.dataclasses import Query, Answer
-from wafl.interface.conversation import Conversation
+from wafl.data_objects.dataclasses import Query, Answer
+from wafl.interface.conversation import Conversation, Utterance
from wafl.simple_text_processing.questions import is_question
-class DialogueAnswerer(BaseAnswerer):
+class DialogueAnswerer:
def __init__(self, config, knowledge, interface, code_path, logger):
self._threshold_for_facts = 0.85
self._client = LLMChatClient(config)
+ self._entailer = Entailer(config)
self._knowledge = knowledge
self._logger = logger
self._interface = interface
self._max_num_past_utterances = 5
- self._max_num_past_utterances_for_facts = 5
- self._max_num_past_utterances_for_rules = 2
- self._prior_facts_with_timestamp = []
+ self._max_num_facts = 5
+ self._max_num_rules = 2
+ self._prior_facts = []
self._init_python_module(code_path.replace(".py", ""))
self._prior_rules = []
self._max_predictions = 3
@@ -38,6 +42,10 @@ def __init__(self, config, knowledge, interface, code_path, logger):
max_num_rules=1,
)
+ def reset(self):
+ self._prior_facts = []
+ self._prior_rules = []
+
async def answer(self, query_text: str) -> Answer:
if self._logger:
self._logger.write(f"Dialogue Answerer: the query is {query_text}")
@@ -45,35 +53,50 @@ async def answer(self, query_text: str) -> Answer:
conversation = self._interface.get_utterances_list_with_timestamp().get_last_n(
self._max_num_past_utterances
)
+ conversation.add_utterance(Utterance(speaker="user", text=query_text))
rules_text = await self._get_relevant_rules(conversation)
if not conversation:
conversation = create_one_liner(query_text)
- conversational_timestamp = len(conversation)
- facts = await self._get_relevant_facts(
+ memory = await self._get_relevant_facts(
query,
has_prior_rules=bool(rules_text),
- conversational_timestamp=conversational_timestamp,
)
final_answer_text = ""
- for _ in range(self._max_predictions):
- original_answer_text = await self._client.get_answer(
- text=facts,
- rules_text=rules_text,
- dialogue=conversation,
- )
- await self._interface.add_fact(f"The bot predicts: {original_answer_text}")
- answer_text, memories = await self._apply_substitutions(
- original_answer_text
- )
-
- final_answer_text += answer_text
+ is_finished = False
+ for num_attempts in range(self._max_predictions):
+ try:
+ original_answer_text = await self._client.get_answer(
+ text=memory,
+ rules_text=rules_text,
+ dialogue=conversation,
+ )
+ await self._interface.add_fact(
+ f"The bot predicts: {original_answer_text}"
+ )
+ answer_text, memories = await self._apply_substitutions(
+ original_answer_text
+ )
- if not memories:
- break
+ final_answer_text += answer_text
+ if not memories:
+ is_finished = True
+ break
+ facts = add_memories_to_facts(facts, memories)
+ add_dummy_utterances_to_continue_generation(conversation, answer_text)
+
+ except RuntimeError as e:
+ if self._logger:
+ self._logger.write(f"Error in generating answer: {e}")
+ conversation.add_utterance(
+ Utterance(
+ speaker="bot",
+ text=f"[when using the answer {original_answer_text} the system says {e}]\n",
+ )
+ )
- facts = add_memories_to_facts(facts, memories)
- add_dummy_utterances_to_continue_generation(conversation, answer_text)
+ if not is_finished:
+ final_answer_text += "I was unable to generate a full answer. Please see the logs for more information."
if self._logger:
self._logger.write(
@@ -82,22 +105,19 @@ async def answer(self, query_text: str) -> Answer:
return Answer.create_from_text(final_answer_text)
- async def _get_relevant_facts(
- self, query: Query, has_prior_rules: bool, conversational_timestamp: int
- ) -> str:
- memory = "\n".join([item[0] for item in self._prior_facts_with_timestamp])
- self._prior_facts_with_timestamp = self._get_prior_facts_with_timestamp(
- conversational_timestamp
- )
+ async def _get_relevant_facts(self, query: Query, has_prior_rules: bool) -> str:
+ memory = create_memory_from_fact_list(self._prior_facts, self._max_num_facts)
facts_and_thresholds = await self._knowledge.ask_for_facts_with_threshold(
query, is_from_user=True, threshold=self._threshold_for_facts
)
if facts_and_thresholds:
- facts = get_text_from_facts_and_thresholds(facts_and_thresholds, memory)
- self._prior_facts_with_timestamp.extend(
- (item, conversational_timestamp) for item in facts
+ facts = get_facts_with_metadata_from_facts_and_thresholds(
+ facts_and_thresholds, memory
+ )
+ self._prior_facts.extend(facts)
+ memory = create_memory_from_fact_list(
+ self._prior_facts, self._max_num_facts
)
- memory = "\n".join([item[0] for item in self._prior_facts_with_timestamp])
await self._interface.add_fact(f"The bot remembers the facts:\n{memory}")
else:
@@ -110,11 +130,14 @@ async def _get_relevant_facts(
return memory
async def _get_relevant_rules(self, conversation: Conversation) -> List[str]:
- rules = await self._rule_creator.create_from_query(conversation)
- for rule in rules:
+ rules_as_strings = await self._rule_creator.create_from_query(conversation)
+ rules_as_strings = await select_best_rules_using_entailer(
+ conversation, rules_as_strings, self._entailer, num_rules=1
+ )
+ for rule in rules_as_strings:
if rule not in self._prior_rules:
self._prior_rules.insert(0, rule)
- self._prior_rules = self._prior_rules[: self._max_num_past_utterances_for_rules]
+ self._prior_rules = self._prior_rules[: self._max_num_rules]
return self._prior_rules
def _init_python_module(self, module_name):
@@ -129,13 +152,3 @@ async def _apply_substitutions(self, original_answer_text):
self._functions,
)
)
-
- def _get_prior_facts_with_timestamp(
- self, conversational_timestamp: int
- ) -> List[Tuple[str, int]]:
- return [
- item
- for item in self._prior_facts_with_timestamp
- if item[1]
- > conversational_timestamp - self._max_num_past_utterances_for_facts
- ]
diff --git a/wafl/answerer/entailer.py b/wafl/answerer/entailer.py
index 54e4e3e2..3f3c2ab9 100644
--- a/wafl/answerer/entailer.py
+++ b/wafl/answerer/entailer.py
@@ -1,41 +1,14 @@
-import os
-import textwrap
-
-from wafl.connectors.factories.llm_connector_factory import LLMConnectorFactory
-from wafl.connectors.prompt_template import PromptTemplate
-from wafl.interface.conversation import Utterance, Conversation
-
-_path = os.path.dirname(__file__)
+from wafl.connectors.clients.entailer_client import EntailerClient
class Entailer:
def __init__(self, config):
- self._connector = LLMConnectorFactory.get_connector(config)
+ self.entailer_client = EntailerClient(config)
self._config = config
- async def left_entails_right(self, lhs: str, rhs: str, dialogue) -> str:
- prompt = await self._get_answer_prompt(lhs, rhs, dialogue)
- result = await self._connector.generate(prompt)
- result = self._clean_result(result)
- return result == "yes"
-
- async def _get_answer_prompt(self, lhs, rhs, dialogue):
- return PromptTemplate(
- system_prompt="",
- conversation=self._get_dialogue_prompt(lhs, rhs, dialogue),
- )
-
- def _clean_result(self, result):
- result = result.replace("", "")
- result = result.split("\n")[0]
- result = result.strip()
- return result.lower()
+ async def left_entails_right(self, lhs: str, rhs: str) -> bool:
+ prediction = await self.entailer_client.get_entailment_score(lhs, rhs)
+ return prediction > 0.5
- def _get_dialogue_prompt(self, dialogue, lhs, rhs):
- text = f"""
-Your task is to determine whether two sentences are similar.
-1) {lhs.lower()}
-2) {rhs.lower()}
-Please answer "yes" if the two sentences are similar or "no" if not:
- """.strip()
- return Conversation([Utterance(speaker="user", text=text)])
+ async def get_score(self, lhs: str, rhs: str) -> float:
+ return await self.entailer_client.get_entailment_score(lhs, rhs)
diff --git a/wafl/answerer/rule_maker.py b/wafl/answerer/rule_maker.py
index 115dfcfc..7454fe73 100644
--- a/wafl/answerer/rule_maker.py
+++ b/wafl/answerer/rule_maker.py
@@ -1,7 +1,7 @@
from typing import List
-from wafl.dataclasses.dataclasses import Query
-from wafl.dataclasses.rules import Rule
+from wafl.data_objects.dataclasses import Query
+from wafl.data_objects.rules import Rule
class RuleMaker:
diff --git a/wafl/changelog.txt b/wafl/changelog.txt
new file mode 100644
index 00000000..bfd13914
--- /dev/null
+++ b/wafl/changelog.txt
@@ -0,0 +1,7 @@
+- version 0.1.3
+* added multi-threaded support for multiple files indexing
+* TODO: ADD support for multiple knowledge bases.
+ It needs to index the rules and the files separately!
+* the interface should show where the facts come from in the web interface
+* add support for wafl studio where you can concatenate actions (and create corresponding yaml files)
+* use <> tags for contactenation
\ No newline at end of file
diff --git a/wafl/command_line.py b/wafl/command_line.py
index 4ebb6643..114f86be 100644
--- a/wafl/command_line.py
+++ b/wafl/command_line.py
@@ -9,6 +9,7 @@
run_testcases,
print_incipit,
download_models,
+ load_indices,
)
from wafl.runners.run_from_actions import run_action
@@ -52,26 +53,31 @@ def process_cli():
elif command == "run":
from wafl.runners.run_web_and_audio_interface import run_app
+ load_indices()
run_app()
remove_preprocessed("/")
elif command == "run-cli":
+ load_indices()
run_from_command_line()
remove_preprocessed("/")
elif command == "run-audio":
from wafl.runners.run_from_audio import run_from_audio
+ load_indices()
run_from_audio()
remove_preprocessed("/")
elif command == "run-server":
from wafl.runners.run_web_interface import run_server_only_app
+ load_indices()
run_server_only_app()
remove_preprocessed("/")
elif command == "run-tests":
+ load_indices()
run_testcases()
remove_preprocessed("/")
diff --git a/wafl/config.py b/wafl/config.py
index 45c00c8e..cb5afcb7 100644
--- a/wafl/config.py
+++ b/wafl/config.py
@@ -2,6 +2,10 @@
import os
import shutil
+from wafl.connectors.remote.remote_configuration_connector import (
+ RemoteConfigurationConnector,
+)
+
_path = os.path.dirname(__file__)
@@ -21,6 +25,10 @@ def __init__(self, filename):
with open(filename) as file:
self._data = json.load(file)
+ self._remote_config = RemoteConfigurationConnector(
+ self._data["backend"]["host"], self._data["backend"]["port"]
+ )
+
def get_value(self, key):
if key in self._data:
return self._data[key]
diff --git a/wafl/connectors/clients/clients_implementation.py b/wafl/connectors/clients/clients_implementation.py
deleted file mode 100644
index ec463c56..00000000
--- a/wafl/connectors/clients/clients_implementation.py
+++ /dev/null
@@ -1,19 +0,0 @@
-import csv
-import os
-import joblib
-
-from wafl.knowledge.single_file_knowledge import SingleFileKnowledge
-
-_path = os.path.dirname(__file__)
-
-
-async def load_knowledge_from_file(filename, config):
- items_list = []
- with open(os.path.join(_path, "../../data/", filename + ".csv")) as file:
- csvreader = csv.reader(file)
- for row in csvreader:
- items_list.append(row[0].strip())
-
- knowledge = await SingleFileKnowledge.create_from_list(items_list, config)
- joblib.dump(knowledge, os.path.join(_path, f"../../data/{filename}.knowledge"))
- return knowledge
diff --git a/wafl/connectors/clients/entailer_client.py b/wafl/connectors/clients/entailer_client.py
new file mode 100644
index 00000000..bdde7dc1
--- /dev/null
+++ b/wafl/connectors/clients/entailer_client.py
@@ -0,0 +1,19 @@
+import os
+
+from wafl.connectors.factories.entailer_connector_factory import (
+ EntailerConnectorFactory,
+)
+
+_path = os.path.dirname(__file__)
+
+
+class EntailerClient:
+ def __init__(self, config):
+ self._connector = EntailerConnectorFactory.get_connector(config)
+ self._config = config
+
+ async def get_entailment_score(self, lhs: str, rhs: str) -> float:
+ prediction = await self._connector.predict(lhs, rhs)
+ if "score" not in prediction:
+ raise ValueError("The Entailment prediction does not contain a score.")
+ return prediction["score"]
diff --git a/wafl/connectors/clients/information_client.py b/wafl/connectors/clients/information_client.py
index 772afb00..533fc902 100644
--- a/wafl/connectors/clients/information_client.py
+++ b/wafl/connectors/clients/information_client.py
@@ -1,9 +1,6 @@
import os
-import textwrap
-from typing import List
from wafl.connectors.factories.llm_connector_factory import LLMConnectorFactory
-from wafl.connectors.prompt_template import PromptTemplate
_path = os.path.dirname(__file__)
diff --git a/wafl/connectors/factories/entailer_connector_factory.py b/wafl/connectors/factories/entailer_connector_factory.py
new file mode 100644
index 00000000..3de3cdff
--- /dev/null
+++ b/wafl/connectors/factories/entailer_connector_factory.py
@@ -0,0 +1,8 @@
+from wafl.config import Configuration
+from wafl.connectors.remote.remote_entailer_connector import RemoteEntailerConnector
+
+
+class EntailerConnectorFactory:
+ @staticmethod
+ def get_connector(config: Configuration):
+ return RemoteEntailerConnector(config)
diff --git a/wafl/connectors/factories/llm_connector_factory.py b/wafl/connectors/factories/llm_connector_factory.py
index 2d1c8714..ea82f837 100644
--- a/wafl/connectors/factories/llm_connector_factory.py
+++ b/wafl/connectors/factories/llm_connector_factory.py
@@ -1,7 +1,8 @@
+from wafl.config import Configuration
from wafl.connectors.remote.remote_llm_connector import RemoteLLMConnector
class LLMConnectorFactory:
@staticmethod
- def get_connector(config):
- return RemoteLLMConnector(config.get_value("llm_model"))
+ def get_connector(config: Configuration):
+ return RemoteLLMConnector(config)
diff --git a/wafl/connectors/factories/sentence_embedder_connector_factory.py b/wafl/connectors/factories/sentence_embedder_connector_factory.py
index 7ab1a708..2db98989 100644
--- a/wafl/connectors/factories/sentence_embedder_connector_factory.py
+++ b/wafl/connectors/factories/sentence_embedder_connector_factory.py
@@ -1,3 +1,4 @@
+from wafl.config import Configuration
from wafl.connectors.remote.remote_sentence_embedder_connector import (
RemoteSentenceEmbedderConnector,
)
@@ -5,5 +6,5 @@
class SentenceEmbedderConnectorFactory:
@staticmethod
- def get_connector(model_name, config):
- return RemoteSentenceEmbedderConnector(config.get_value(model_name))
+ def get_connector(config: Configuration):
+ return RemoteSentenceEmbedderConnector(config)
diff --git a/wafl/connectors/factories/speaker_connector_factory.py b/wafl/connectors/factories/speaker_connector_factory.py
index 578497eb..da11134b 100644
--- a/wafl/connectors/factories/speaker_connector_factory.py
+++ b/wafl/connectors/factories/speaker_connector_factory.py
@@ -1,7 +1,8 @@
+from wafl.config import Configuration
from wafl.connectors.remote.remote_speaker_connector import RemoteSpeakerConnector
class SpeakerConnectorFactory:
@staticmethod
- def get_connector(config):
- return RemoteSpeakerConnector(config.get_value("speaker_model"))
+ def get_connector(config: Configuration):
+ return RemoteSpeakerConnector(config)
diff --git a/wafl/connectors/factories/whisper_connector_factory.py b/wafl/connectors/factories/whisper_connector_factory.py
index 5b4e1c2f..d8a7068c 100644
--- a/wafl/connectors/factories/whisper_connector_factory.py
+++ b/wafl/connectors/factories/whisper_connector_factory.py
@@ -1,7 +1,8 @@
+from wafl.config import Configuration
from wafl.connectors.remote.remote_whisper_connector import RemoteWhisperConnector
class WhisperConnectorFactory:
@staticmethod
- def get_connector(config):
- return RemoteWhisperConnector(config.get_value("listener_model"))
+ def get_connector(config: Configuration):
+ return RemoteWhisperConnector(config)
diff --git a/wafl/connectors/remote/remote_configuration_connector.py b/wafl/connectors/remote/remote_configuration_connector.py
new file mode 100644
index 00000000..3590b5d5
--- /dev/null
+++ b/wafl/connectors/remote/remote_configuration_connector.py
@@ -0,0 +1,61 @@
+import aiohttp
+import asyncio
+import json
+
+from typing import Dict
+from wafl.variables import get_variables
+
+
+class RemoteConfigurationConnector:
+ _max_tries = 3
+
+ def __init__(self, host: str, port: int):
+ self._server_url = f"https://{host}:{port}/predictions/configuration"
+ try:
+ loop = asyncio.get_running_loop()
+
+ except RuntimeError:
+ loop = None
+
+ if (not loop or (loop and not loop.is_running())) and not asyncio.run(
+ self.check_connection()
+ ):
+ raise RuntimeError(
+ "Cannot connect a running Configuration handler. Is WAFL-LLM running?"
+ )
+
+ async def predict(self) -> Dict[str, str]:
+ payload = {"version": get_variables()["version"]}
+ for _ in range(self._max_tries):
+ async with aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(ssl=False)
+ ) as session:
+ async with session.post(self._server_url, json=payload) as response:
+ data = await response.text()
+ prediction = json.loads(data)
+ listener_model = prediction["listener_model"]
+ speaker_model = prediction["speaker_model"]
+ text_embedding_model = prediction["text_embedding_model"]
+ entailer_model = prediction["entailer_model"]
+ llm_model = prediction["llm_model"]
+ return {
+ "listener_model": listener_model,
+ "speaker_model": speaker_model,
+ "text_embedding_model": text_embedding_model,
+ "entailer_model": entailer_model,
+ "llm_model": llm_model,
+ }
+
+ return {}
+
+ async def check_connection(self) -> bool:
+ try:
+ async with aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(ssl=False)
+ ) as session:
+ payload = {"version": get_variables()["version"]}
+ async with session.post(self._server_url, json=payload) as response:
+ return response.status == 200
+
+ except Exception:
+ return False
diff --git a/wafl/connectors/remote/remote_entailer_connector.py b/wafl/connectors/remote/remote_entailer_connector.py
new file mode 100644
index 00000000..50430cf9
--- /dev/null
+++ b/wafl/connectors/remote/remote_entailer_connector.py
@@ -0,0 +1,60 @@
+import aiohttp
+import asyncio
+import json
+from typing import Dict
+
+from wafl.config import Configuration
+
+
+class RemoteEntailerConnector:
+ _max_tries = 3
+
+ def __init__(self, config: Configuration):
+ host = config.get_value("backend")["host"]
+ port = config.get_value("backend")["port"]
+
+ self._server_url = f"https://{host}:" f"{port}/predictions/entailer"
+ try:
+ loop = asyncio.get_running_loop()
+
+ except RuntimeError:
+ loop = None
+
+ if (not loop or (loop and not loop.is_running())) and not asyncio.run(
+ self.check_connection()
+ ):
+ raise RuntimeError("Cannot connect a running Entailment Model.")
+
+ async def predict(self, lhs: str, rhs: str) -> Dict[str, float]:
+ payload = {"lhs": lhs, "rhs": rhs}
+ for _ in range(self._max_tries):
+ async with aiohttp.ClientSession(
+ connector=aiohttp.TCPConnector(ssl=False)
+ ) as session:
+ async with session.post(self._server_url, json=payload) as response:
+ data = await response.text()
+ prediction = json.loads(data)
+ if "score" in prediction:
+ score = prediction["score"]
+ return {"score": float(score)}
+ return {"score": -1.0}
+
+ return {"score": -1.0}
+
+ async def check_connection(self):
+ payload = {"lhs": "test", "rhs": "test"}
+ try:
+ async with aiohttp.ClientSession(
+ conn_timeout=3, connector=aiohttp.TCPConnector(ssl=False)
+ ) as session:
+ async with session.post(self._server_url, json=payload) as response:
+ await response.text()
+ return True
+
+ except aiohttp.client.InvalidURL:
+ print()
+ print("Is the entailer server running?")
+ print("Please run 'bash start-llm.sh' (see docs for explanation).")
+ print()
+
+ return False
diff --git a/wafl/connectors/remote/remote_llm_connector.py b/wafl/connectors/remote/remote_llm_connector.py
index a2b3d7f7..01b8bcd0 100644
--- a/wafl/connectors/remote/remote_llm_connector.py
+++ b/wafl/connectors/remote/remote_llm_connector.py
@@ -3,6 +3,7 @@
import aiohttp
import asyncio
+from wafl.config import Configuration
from wafl.connectors.base_llm_connector import BaseLLMConnector
from wafl.connectors.prompt_template import PromptTemplate
from wafl.variables import is_supported
@@ -11,14 +12,14 @@
class RemoteLLMConnector(BaseLLMConnector):
_max_tries = 3
_max_reply_length = 1024
- _num_prediction_tokens = 200
+ _num_prediction_tokens = 1024
_cache = {}
- def __init__(self, config, last_strings=None, num_replicas=1):
+ def __init__(self, config: Configuration, last_strings=None, num_replicas=1):
super().__init__(last_strings)
- host = config["model_host"]
- port = config["model_port"]
- self._default_temperature = config["temperature"]
+ host = config.get_value("backend")["host"]
+ port = config.get_value("backend")["port"]
+ self._default_temperature = config.get_value("generation_config")["temperature"]
self._server_url = f"https://{host}:{port}/predictions/bot"
self._num_replicas = num_replicas
diff --git a/wafl/connectors/remote/remote_sentence_embedder_connector.py b/wafl/connectors/remote/remote_sentence_embedder_connector.py
index ae71e6fe..9b75b2f0 100644
--- a/wafl/connectors/remote/remote_sentence_embedder_connector.py
+++ b/wafl/connectors/remote/remote_sentence_embedder_connector.py
@@ -10,8 +10,8 @@ class RemoteSentenceEmbedderConnector:
_max_tries = 3
def __init__(self, config):
- host = config["model_host"]
- port = config["model_port"]
+ host = config.get_value("backend")["host"]
+ port = config.get_value("backend")["port"]
self._server_url = f"https://{host}:" f"{port}/predictions/sentence_embedder"
try:
diff --git a/wafl/connectors/remote/remote_speaker_connector.py b/wafl/connectors/remote/remote_speaker_connector.py
index 3d1226b4..a0402fa8 100644
--- a/wafl/connectors/remote/remote_speaker_connector.py
+++ b/wafl/connectors/remote/remote_speaker_connector.py
@@ -4,15 +4,16 @@
import json
from typing import Dict
+from wafl.config import Configuration
class RemoteSpeakerConnector:
_max_tries = 3
- def __init__(self, config):
+ def __init__(self, config: Configuration):
self._server_url = (
- f"https://{config['model_host']}:"
- f"{config['model_port']}/predictions/speaker"
+ f"https://{config.get_value('backend')['host']}:"
+ f"{config.get_value('backend')['port']}/predictions/speaker"
)
try:
loop = asyncio.get_running_loop()
diff --git a/wafl/connectors/remote/remote_whisper_connector.py b/wafl/connectors/remote/remote_whisper_connector.py
index d9498a83..8632a62f 100644
--- a/wafl/connectors/remote/remote_whisper_connector.py
+++ b/wafl/connectors/remote/remote_whisper_connector.py
@@ -4,14 +4,16 @@
from typing import Dict
+from wafl.config import Configuration
+
class RemoteWhisperConnector:
_max_tries = 3
- def __init__(self, config):
+ def __init__(self, config: Configuration):
self._server_url = (
- f"https://{config['model_host']}:"
- f"{config['model_port']}/predictions/whisper"
+ f"https://{config.get_value('backend')['host']}:"
+ f"{config.get_value('backend')['port']}/predictions/whisper"
)
try:
loop = asyncio.get_running_loop()
diff --git a/wafl/dataclasses/__init__.py b/wafl/data_objects/__init__.py
similarity index 100%
rename from wafl/dataclasses/__init__.py
rename to wafl/data_objects/__init__.py
diff --git a/wafl/dataclasses/dataclasses.py b/wafl/data_objects/dataclasses.py
similarity index 100%
rename from wafl/dataclasses/dataclasses.py
rename to wafl/data_objects/dataclasses.py
diff --git a/wafl/data_objects/facts.py b/wafl/data_objects/facts.py
new file mode 100644
index 00000000..88926ab0
--- /dev/null
+++ b/wafl/data_objects/facts.py
@@ -0,0 +1,33 @@
+from dataclasses import dataclass
+from enum import Enum
+from typing import Union
+
+
+class Sources(Enum):
+ FROM_TEXT = 1
+ FROM_RULES = 2
+
+
+@dataclass
+class Fact:
+ text: Union[str, dict]
+ is_question: bool = False
+ variable: str = None
+ is_interruption: bool = False
+ destination: str = None
+ metadata: Union[str, dict] = None
+ source: Sources = Sources.FROM_RULES
+
+ def toJSON(self):
+ return str(self)
+
+ def copy(self):
+ return Fact(
+ self.text,
+ self.is_question,
+ self.variable,
+ self.is_interruption,
+ self.destination,
+ self.metadata,
+ self.source,
+ )
diff --git a/wafl/dataclasses/rules.py b/wafl/data_objects/rules.py
similarity index 100%
rename from wafl/dataclasses/rules.py
rename to wafl/data_objects/rules.py
diff --git a/wafl/dataclasses/facts.py b/wafl/dataclasses/facts.py
deleted file mode 100644
index 0445adff..00000000
--- a/wafl/dataclasses/facts.py
+++ /dev/null
@@ -1,16 +0,0 @@
-from dataclasses import dataclass
-from typing import Union
-
-
-@dataclass
-class Fact:
- text: Union[str, dict]
- is_question: bool = False
- variable: str = None
- is_interruption: bool = False
- source: str = None
- destination: str = None
- metadata: Union[str, dict] = None
-
- def toJSON(self):
- return str(self)
diff --git a/wafl/events/conversation_events.py b/wafl/events/conversation_events.py
index d83d52ce..4c0da959 100644
--- a/wafl/events/conversation_events.py
+++ b/wafl/events/conversation_events.py
@@ -19,6 +19,7 @@ def __init__(
config: "Configuration",
interface: "BaseInterface",
logger=None,
+ knowledge=None,
):
self._config = config
try:
@@ -29,6 +30,8 @@ def __init__(
if not loop or not loop.is_running():
self._knowledge = asyncio.run(load_knowledge(config, logger))
+ else:
+ self._knowledge = knowledge
self._answerer = create_answerer(config, self._knowledge, interface, logger)
self._answerer._client._connector._cache = {}
@@ -38,6 +41,9 @@ def __init__(
if logger:
self._logger.set_depth(0)
+ def reset(self):
+ self._answerer.reset()
+
async def output(self, text: str):
await self._interface.output(text)
@@ -74,6 +80,9 @@ async def _process_query(self, text: str):
):
await self._interface.output("I don't know what to reply")
+ if not text_is_question and not self._interface.get_utterances_list():
+ await self._interface.output("I don't know what to reply")
+
if (
not text_is_question
and answer.is_true()
diff --git a/wafl/extractors/__init__.py b/wafl/extractors/__init__.py
deleted file mode 100644
index ea96f559..00000000
--- a/wafl/extractors/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-print
diff --git a/wafl/extractors/utils.py b/wafl/extractors/utils.py
deleted file mode 100644
index 4401c22f..00000000
--- a/wafl/extractors/utils.py
+++ /dev/null
@@ -1,23 +0,0 @@
-import re
-
-
-def get_answer_from_text(text):
- _claim_yn = "The claim makes sense:\n"
- pos = text.find(_claim_yn) + len(_claim_yn)
- return text[pos]
-
-
-def get_text_up_to_question(text):
- _claim_yn = "The claim makes sense:\n"
- return text[: text.find(_claim_yn) + len(_claim_yn)]
-
-
-def get_function_description(text):
- if "<" not in text:
- return ""
-
- return re.sub(r".*<(.*)>$", r"\1", text, re.MULTILINE).strip()
-
-
-def get_code(text):
- return re.sub(r"(.*)<.*>$", r"\1", text, re.MULTILINE).strip()
diff --git a/wafl/handlers/conversation_handler.py b/wafl/handlers/conversation_handler.py
index ccb4d1a9..18f1e7af 100644
--- a/wafl/handlers/conversation_handler.py
+++ b/wafl/handlers/conversation_handler.py
@@ -66,6 +66,7 @@ async def _main_loop(self):
and interactions == 1
):
self._interface.deactivate()
+ self._conversation_events.reset()
num_misses = 0
if (
diff --git a/wafl/inference/utils.py b/wafl/inference/utils.py
index 2f25bda1..d270814b 100644
--- a/wafl/inference/utils.py
+++ b/wafl/inference/utils.py
@@ -2,7 +2,7 @@
from typing import List, Dict, Tuple, Any
from fuzzywuzzy import process
-from wafl.dataclasses.dataclasses import Answer
+from wafl.data_objects.dataclasses import Answer
from wafl.simple_text_processing.normalize import normalized
from wafl.simple_text_processing.questions import is_question
diff --git a/wafl/interface/conversation.py b/wafl/interface/conversation.py
index 68687eb1..c65f1f5b 100644
--- a/wafl/interface/conversation.py
+++ b/wafl/interface/conversation.py
@@ -111,6 +111,15 @@ def get_last_speaker_utterances(self, speaker: str, n: int) -> List[str]:
if utterance.speaker == speaker
][-n:]
+ def get_last_speaker_utterance(self, speaker: str) -> str:
+ if not self.utterances:
+ return ""
+
+ for utterance in reversed(self.utterances):
+ if utterance.speaker == speaker:
+ return utterance.text
+ return ""
+
def get_first_timestamp(self) -> float:
return self.utterances[0].timestamp if self.utterances else None
diff --git a/wafl/interface/voice_interface.py b/wafl/interface/voice_interface.py
index 17ac6d5f..0b8a7821 100644
--- a/wafl/interface/voice_interface.py
+++ b/wafl/interface/voice_interface.py
@@ -1,4 +1,3 @@
-import asyncio
import os
import random
import re
@@ -7,7 +6,7 @@
from wafl.interface.base_interface import BaseInterface
from wafl.interface.utils import not_good_enough
from wafl.listener.whisper_listener import WhisperListener
-from wafl.speaker.fairseq_speaker import FairSeqSpeaker
+from wafl.speaker.tts_speaker import TTSSpeaker
from wafl.speaker.soundfile_speaker import SoundFileSpeaker
_path = os.path.dirname(__file__)
@@ -27,7 +26,7 @@ def __init__(self, config):
self._deactivation_sound_filename = self.__get_deactivation_sound_from_config(
config
)
- self._speaker = FairSeqSpeaker(config)
+ self._speaker = TTSSpeaker(config)
self._listener = WhisperListener(config)
self._listener.set_timeout(
config.get_value("listener_model")["listener_silence_timeout"]
@@ -77,7 +76,7 @@ async def input(self) -> str:
text = text.lower().capitalize()
print(COLOR_START + "user> " + text + COLOR_END)
utterance = remove_text_between_brackets(text)
- if utterance.strip():
+ if utterance.strip() and self._is_listening:
self._insert_utterance(speaker="user", text=text)
return remove_unclear(text)
diff --git a/wafl/knowledge/indexing_implementation.py b/wafl/knowledge/indexing_implementation.py
index a780c7de..d4c3a8f8 100644
--- a/wafl/knowledge/indexing_implementation.py
+++ b/wafl/knowledge/indexing_implementation.py
@@ -1,24 +1,50 @@
+import asyncio
import os
-
import joblib
import yaml
+import threading
+from tqdm import tqdm
from wafl.config import Configuration
from wafl.knowledge.single_file_knowledge import SingleFileKnowledge
from wafl.readers.reader_factory import ReaderFactory
+async def add_file_to_knowledge(knowledge, filename):
+ reader = ReaderFactory.get_reader(filename)
+ for chunk in reader.get_chunks(filename):
+ await knowledge.add_fact(chunk)
+
+
async def _add_indices_to_knowledge(knowledge, text):
indices = yaml.safe_load(text)
if "paths" not in indices or not indices["paths"]:
return knowledge
for path in indices["paths"]:
- for root, _, files in os.walk(path):
- for file in files:
- reader = ReaderFactory.get_reader(file)
- for chunk in reader.get_chunks(os.path.join(root, file)):
- await knowledge.add_fact(chunk)
+ print(f"Indexing path: {path}")
+ file_count = sum(len(files) for _, _, files in os.walk(path))
+ with tqdm(total=file_count) as pbar:
+ for root, _, files in os.walk(path):
+ threads = []
+ for file in files:
+ threads.append(
+ threading.Thread(
+ target=asyncio.run,
+ args=(
+ add_file_to_knowledge(
+ knowledge, os.path.join(root, file)
+ ),
+ ),
+ )
+ )
+ num_threads = min(10, len(threads))
+ for i in range(0, len(threads), num_threads):
+ for thread in threads[i : i + num_threads]:
+ thread.start()
+ for thread in threads[i : i + num_threads]:
+ thread.join()
+ pbar.update(num_threads)
return knowledge
@@ -27,10 +53,12 @@ async def load_knowledge(config, logger=None):
if ".yaml" in config.get_value("rules") and not any(
item in config.get_value("rules") for item in [" ", "\n"]
):
+ rules_filename = config.get_value("rules")
with open(config.get_value("rules")) as file:
rules_txt = file.read()
else:
+ rules_filename = None
rules_txt = config.get_value("rules")
index_filename = config.get_value("index")
@@ -41,17 +69,18 @@ async def load_knowledge(config, logger=None):
cache_filename = config.get_value("cache_filename")
if os.path.exists(cache_filename):
- knowledge = joblib.load(cache_filename)
- if knowledge.hash == hash(rules_txt) and os.path.getmtime(
- cache_filename
- ) > os.path.getmtime(index_filename):
- await knowledge.initialize_retrievers()
+ if (
+ rules_filename
+ and os.path.getmtime(cache_filename) > os.path.getmtime(rules_filename)
+ and os.path.getmtime(cache_filename) > os.path.getmtime(index_filename)
+ ):
+ knowledge = joblib.load(cache_filename)
return knowledge
knowledge = SingleFileKnowledge(config, rules_txt, logger=logger)
knowledge = await _add_indices_to_knowledge(knowledge, index_txt)
- joblib.dump(knowledge, config.get_value("cache_filename"))
await knowledge.initialize_retrievers()
+ joblib.dump(knowledge, config.get_value("cache_filename"))
return knowledge
diff --git a/wafl/knowledge/single_file_knowledge.py b/wafl/knowledge/single_file_knowledge.py
index 8c9c5a25..2fc15c39 100644
--- a/wafl/knowledge/single_file_knowledge.py
+++ b/wafl/knowledge/single_file_knowledge.py
@@ -4,9 +4,10 @@
from typing import List
import nltk
+from tqdm import tqdm
from wafl.config import Configuration
-from wafl.dataclasses.facts import Fact
+from wafl.data_objects.facts import Fact
from wafl.knowledge.base_knowledge import BaseKnowledge
from wafl.knowledge.utils import (
text_is_exact_string,
@@ -169,7 +170,8 @@ def get_facts_and_rule_as_text(self):
return text
async def initialize_retrievers(self):
- for index, fact in self._facts_dict.items():
+ print("Initializing fact retrievers")
+ for index, fact in tqdm(self._facts_dict.items()):
if text_is_exact_string(fact.text):
continue
@@ -181,7 +183,8 @@ async def initialize_retrievers(self):
clean_text_for_retrieval(fact.text), index
)
- for index, rule in self._rules_dict.items():
+ print("Initializing rule retrievers")
+ for index, rule in tqdm(self._rules_dict.items()):
if text_is_exact_string(rule.effect.text):
continue
@@ -189,10 +192,6 @@ async def initialize_retrievers(self):
clean_text_for_retrieval(rule.effect.text), index
)
- for index, rule in self._rules_dict.items():
- if not text_is_exact_string(rule.effect.text):
- continue
-
await self._rules_string_retriever.add_text_and_index(
rule.effect.text, index
)
diff --git a/wafl/parsing/line_rules_parser.py b/wafl/parsing/line_rules_parser.py
index 73371f3e..b6bfa3ee 100644
--- a/wafl/parsing/line_rules_parser.py
+++ b/wafl/parsing/line_rules_parser.py
@@ -1,6 +1,6 @@
from wafl.simple_text_processing.questions import is_question
-from wafl.dataclasses.facts import Fact
-from wafl.dataclasses.rules import Rule
+from wafl.data_objects.facts import Fact
+from wafl.data_objects.rules import Rule
def parse_rule_from_single_line(text):
diff --git a/wafl/parsing/rules_parser.py b/wafl/parsing/rules_parser.py
index 70d3b5f1..bb813e93 100644
--- a/wafl/parsing/rules_parser.py
+++ b/wafl/parsing/rules_parser.py
@@ -1,7 +1,7 @@
import yaml
-from wafl.dataclasses.facts import Fact
-from wafl.dataclasses.rules import Rule
+from wafl.data_objects.facts import Fact
+from wafl.data_objects.rules import Rule
from wafl.simple_text_processing.deixis import from_user_to_bot
diff --git a/wafl/readers/base_reader.py b/wafl/readers/base_reader.py
index ea995601..5f0aaef2 100644
--- a/wafl/readers/base_reader.py
+++ b/wafl/readers/base_reader.py
@@ -1,6 +1,6 @@
from typing import List
-from wafl.dataclasses.facts import Fact
+from wafl.data_objects.facts import Fact
class BaseReader:
diff --git a/wafl/readers/pdf_reader.py b/wafl/readers/pdf_reader.py
index 4f610616..dc94f664 100644
--- a/wafl/readers/pdf_reader.py
+++ b/wafl/readers/pdf_reader.py
@@ -2,7 +2,7 @@
from logging import getLogger
from typing import List
-from wafl.dataclasses.facts import Fact
+from wafl.data_objects.facts import Fact, Sources
from wafl.readers.base_reader import BaseReader
_logger = getLogger(__name__)
@@ -20,6 +20,7 @@ def get_chunks(self, filename: str) -> List[Fact]:
Fact(
text=page.get_text(),
metadata={"filename": filename, "page_number": i},
+ source=Sources.FROM_TEXT,
)
for i, page in enumerate(doc)
]
diff --git a/wafl/readers/reader_factory.py b/wafl/readers/reader_factory.py
index 14ccb70c..6fc33bfc 100644
--- a/wafl/readers/reader_factory.py
+++ b/wafl/readers/reader_factory.py
@@ -4,7 +4,7 @@
class ReaderFactory:
_chunk_size = 10000
- _overlap = 100
+ _overlap = 500
_extension_to_reader_dict = {".pdf": PdfReader, ".txt": TextReader}
@staticmethod
@@ -13,7 +13,4 @@ def get_reader(filename):
if extension in filename.lower():
return reader(ReaderFactory._chunk_size, ReaderFactory._overlap)
- ### add pdf reader
- ### add metadata and show in the UI
-
return TextReader(ReaderFactory._chunk_size, ReaderFactory._overlap)
diff --git a/wafl/readers/text_reader.py b/wafl/readers/text_reader.py
index b22c4ffe..8457ee04 100644
--- a/wafl/readers/text_reader.py
+++ b/wafl/readers/text_reader.py
@@ -1,7 +1,7 @@
from logging import getLogger
from typing import List
-from wafl.dataclasses.facts import Fact
+from wafl.data_objects.facts import Fact, Sources
from wafl.readers.base_reader import BaseReader
_logger = getLogger(__name__)
@@ -20,6 +20,7 @@ def get_chunks(self, filename: str) -> List[Fact]:
Fact(
text=chunk,
metadata={"filename": filename, "chunk_number": i},
+ source=Sources.FROM_TEXT,
)
for i, chunk in enumerate(chunks)
]
diff --git a/wafl/retriever/dense_retriever.py b/wafl/retriever/dense_retriever.py
index 4bf3bb8c..acf28876 100644
--- a/wafl/retriever/dense_retriever.py
+++ b/wafl/retriever/dense_retriever.py
@@ -2,7 +2,6 @@
import numpy as np
from typing import List, Tuple
-from gensim.models import KeyedVectors
from wafl.connectors.factories.sentence_embedder_connector_factory import (
SentenceEmbedderConnectorFactory,
)
@@ -15,21 +14,25 @@ class DenseRetriever(BaseRetriever):
_threshold_length = 5
def __init__(self, model_name, config):
- self._connector = SentenceEmbedderConnectorFactory.get_connector(
- model_name, config
- )
- self._embeddings_model = KeyedVectors(384)
+ self._connector = SentenceEmbedderConnectorFactory.get_connector(config)
+ self._matrix = np.zeros((0, 384))
+ self._indices = []
async def add_text_and_index(self, text: str, index: str):
embeddings = await self._get_embeddings_from_text(text)
- self._embeddings_model.add_vectors([index], [embeddings])
- self._embeddings_model.fill_norms(force=True)
+ self._matrix = np.vstack([self._matrix, embeddings])
+ self._indices.append(index)
async def get_indices_and_scores_from_text(
- self, text: str
+ self, text: str, topn: int = 5
) -> List[Tuple[str, float]]:
embeddings = await self._get_embeddings_from_text(text)
- return self._embeddings_model.similar_by_vector(embeddings, topn=5)
+ scores = np.dot(self._matrix, embeddings) / (
+ np.linalg.norm(self._matrix, axis=1) * np.linalg.norm(embeddings)
+ )
+ indices_and_scores = list(zip(self._indices, scores))
+ indices_and_scores.sort(key=lambda x: x[1], reverse=True)
+ return indices_and_scores[:topn]
async def _get_embeddings_from_text(self, text: str) -> "numpy.array":
return (await self._connector.predict(text))["embedding"]
diff --git a/wafl/run.py b/wafl/run.py
index b0397e84..4138ac48 100644
--- a/wafl/run.py
+++ b/wafl/run.py
@@ -4,6 +4,7 @@
from wafl.exceptions import CloseConversation
from wafl.events.conversation_events import ConversationEvents
from wafl.interface.command_line_interface import CommandLineInterface
+from wafl.knowledge.indexing_implementation import load_knowledge
from wafl.logger.local_file_logger import LocalFileLogger
from wafl.testcases import ConversationTestCases
from wafl.variables import get_variables
@@ -17,6 +18,12 @@ def print_incipit():
print()
+def load_indices():
+ print("Loading knowledge indices...")
+ config = Configuration.load_local_config()
+ asyncio.run(load_knowledge(config, _logger))
+
+
def run_from_command_line():
interface = CommandLineInterface()
config = Configuration.load_local_config()
diff --git a/wafl/runners/run_from_audio.py b/wafl/runners/run_from_audio.py
index 7b523687..7a1d8f35 100644
--- a/wafl/runners/run_from_audio.py
+++ b/wafl/runners/run_from_audio.py
@@ -1,6 +1,9 @@
+import asyncio
+
from wafl.config import Configuration
from wafl.events.conversation_events import ConversationEvents
from wafl.interface.voice_interface import VoiceInterface
+from wafl.knowledge.indexing_implementation import load_knowledge
from wafl.logger.local_file_logger import LocalFileLogger
from wafl.handlers.conversation_handler import ConversationHandler
from wafl.scheduler.scheduler import Scheduler
@@ -10,6 +13,7 @@
def run_from_audio():
config = Configuration.load_local_config()
+ asyncio.run(load_knowledge(config, _logger))
interface = VoiceInterface(config)
conversation_events = ConversationEvents(
config=config,
diff --git a/wafl/runners/run_web_and_audio_interface.py b/wafl/runners/run_web_and_audio_interface.py
index 9d5f833c..177f4501 100644
--- a/wafl/runners/run_web_and_audio_interface.py
+++ b/wafl/runners/run_web_and_audio_interface.py
@@ -1,3 +1,4 @@
+import asyncio
import random
import sys
import threading
@@ -6,6 +7,7 @@
from wafl.interface.list_interface import ListInterface
from wafl.interface.voice_interface import VoiceInterface
+from wafl.knowledge.indexing_implementation import load_knowledge
from wafl.scheduler.scheduler import Scheduler
from wafl.handlers.conversation_handler import ConversationHandler
from wafl.logger.local_file_logger import LocalFileLogger
diff --git a/wafl/speaker/fairseq_speaker.py b/wafl/speaker/tts_speaker.py
similarity index 69%
rename from wafl/speaker/fairseq_speaker.py
rename to wafl/speaker/tts_speaker.py
index 4e775f14..99066349 100644
--- a/wafl/speaker/fairseq_speaker.py
+++ b/wafl/speaker/tts_speaker.py
@@ -8,15 +8,16 @@
from wafl.speaker.utils import convert_numbers_to_words
-class FairSeqSpeaker(BaseSpeaker):
+class TTSSpeaker(BaseSpeaker):
def __init__(self, config):
self._connector = SpeakerConnectorFactory.get_connector(config)
self._p = pyaudio.PyAudio()
self._input_chunk_size = 1024
- self._output_chunk_size = 16384
+ self._output_chunk_size = 4096
self._volume_threshold = (
- config.get_value("listener_model")["listener_volume_threshold"] / 5e3
+ config.get_value("listener_model")["listener_volume_threshold"] * 1e-4
)
+ self._interruptible = config.get_value("listener_model")["interruptible"]
async def speak(self, text):
text = convert_numbers_to_words(text)
@@ -32,11 +33,15 @@ async def speak(self, text):
)
stream.start_stream()
await asyncio.sleep(0.1)
- for i in range(0, len(wav), self._output_chunk_size):
- inp = stream.read(self._input_chunk_size)
- if _rms(inp) > self._volume_threshold:
- break
- stream.write(wav[i : i + self._output_chunk_size])
+ if self._interruptible:
+ for i in range(0, len(wav), self._output_chunk_size):
+ inp = stream.read(self._input_chunk_size)
+ if _rms(inp) > self._volume_threshold:
+ break
+ stream.write(wav[i : i + self._output_chunk_size])
+ else:
+ stream.write(wav)
+
stream.stop_stream()
stream.close()
await asyncio.sleep(0.1)
diff --git a/wafl/templates/config.json b/wafl/templates/config.json
index 7af9893e..57657b63 100644
--- a/wafl/templates/config.json
+++ b/wafl/templates/config.json
@@ -9,24 +9,18 @@
"functions": "functions.py",
"max_recursion": 2,
"frontend_port": 8090,
- "llm_model": {
- "model_host": "localhost",
- "model_port": 8080,
+ "backend": {
+ "host": "localhost",
+ "port": 8080,
+ "token": "secret"
+ },
+ "generation_config": {
"temperature": 0.4
},
"listener_model": {
- "model_host": "localhost",
- "model_port": 8080,
"listener_hotword_logp": -8,
"listener_volume_threshold": 0.6,
- "listener_silence_timeout": 0.7
- },
- "speaker_model": {
- "model_host": "localhost",
- "model_port": 8080
- },
- "text_embedding_model": {
- "model_host": "localhost",
- "model_port": 8080
+ "listener_silence_timeout": 0.7,
+ "interruptible": false
}
}
diff --git a/wafl/templates/main.prompt b/wafl/templates/main.prompt
index 07b45290..8cbb4d64 100644
--- a/wafl/templates/main.prompt
+++ b/wafl/templates/main.prompt
@@ -8,4 +8,6 @@ The rules that *must* be followed are:
Create a plausible dialogue based on the aforementioned summary and rules.
Do not repeat yourself. Be friendly but not too servile.
-Follow the rules if present and they apply to the dialogue. Do not improvise if rules are present.
\ No newline at end of file
+Follow the rules if present and they apply to the dialogue. Do not improvise if rules are present.
+The user query might be incomplete or ambiguous or ungrammatical. The bot *must* ask for clarification if needed.
+The bot only answers if the query is clear and unambiguous.
\ No newline at end of file
diff --git a/wafl/testcases.py b/wafl/testcases.py
index bcc49f4d..cb97af4f 100644
--- a/wafl/testcases.py
+++ b/wafl/testcases.py
@@ -1,4 +1,6 @@
from wafl.answerer.entailer import Entailer
+from wafl.knowledge.indexing_implementation import load_knowledge
+
from wafl.simple_text_processing.deixis import from_user_to_bot, from_bot_to_user
from wafl.exceptions import CloseConversation
from wafl.events.conversation_events import ConversationEvents
@@ -25,8 +27,10 @@ async def test_single_case(self, name):
test_lines = self._testcase_data[name]["lines"]
is_negated = self._testcase_data[name]["negated"]
interface = DummyInterface(user_lines)
- conversation_events = ConversationEvents(self._config, interface=interface)
- await conversation_events._knowledge.initialize_retrievers()
+ knowledge = await load_knowledge(self._config)
+ conversation_events = ConversationEvents(
+ self._config, interface=interface, knowledge=knowledge
+ )
print(self.BLUE_COLOR_START + f"\nRunning test '{name}'." + self.COLOR_END)
continue_conversations = True
@@ -77,9 +81,7 @@ async def _lhs_is_similar_to(self, lhs, rhs, prior_dialogue):
if lhs_name != rhs_name:
return False
- return await self._entailer.left_entails_right(
- lhs, rhs, "\n".join(prior_dialogue)
- )
+ return await self._entailer.left_entails_right(lhs, rhs)
def _apply_deixis(self, line):
name = line.split(":")[0].strip()
diff --git a/wafl/variables.py b/wafl/variables.py
index 0d6490e8..b2e183e7 100644
--- a/wafl/variables.py
+++ b/wafl/variables.py
@@ -1,9 +1,9 @@
def get_variables():
return {
- "version": "0.1.1",
+ "version": "0.1.3",
}
def is_supported(wafl_llm_version):
- supported_versions = ["0.1.0"]
+ supported_versions = ["0.1.1"]
return wafl_llm_version in supported_versions