From 101421e956ac8ccdbd52d8f57a61943d3e2736e6 Mon Sep 17 00:00:00 2001 From: "leizhang.real@gmail.com" Date: Mon, 11 Nov 2024 14:23:49 +0000 Subject: [PATCH 1/4] add lancedb as memory --- mle/cli.py | 4 -- mle/utils/memory.py | 93 ++++++++++++++++++++++++++++++++++++++++++++- requirements.txt | 1 + 3 files changed, 93 insertions(+), 5 deletions(-) diff --git a/mle/cli.py b/mle/cli.py index 98b650e..8f640b3 100644 --- a/mle/cli.py +++ b/mle/cli.py @@ -10,7 +10,6 @@ import mle from mle.server import app -from mle.utils import Memory import mle.workflow as workflow from mle.utils.system import ( get_config, @@ -278,9 +277,6 @@ def new(name): 'integration': {}, }, outfile, default_flow_style=False) - # init the memory - Memory(project_dir) - @cli.command() @click.option('--reset', is_flag=True, help='Reset the integration') diff --git a/mle/utils/memory.py b/mle/utils/memory.py index eef2fcb..5516650 100644 --- a/mle/utils/memory.py +++ b/mle/utils/memory.py @@ -3,6 +3,9 @@ from datetime import datetime from typing import List, Dict +import lancedb +from lancedb.embeddings import get_registry + import chromadb from chromadb.utils import embedding_functions @@ -11,7 +14,8 @@ chromadb.logger.setLevel(chromadb.logging.ERROR) -class Memory: +class ChromaDBMemory: + def __init__( self, project_path: str, @@ -152,3 +156,90 @@ def reset(self): Notice: You may need to set the environment variable `ALLOW_RESET` to `TRUE` to enable this function. """ self.client.reset() + + +class LanceDBMemory: + + def __init__( + self, + project_path: str, + *, + embedding_model: str = "openai", + ): + """ + Memory: A base class for memory and external knowledge management. + Args: + project_path: the path to store the data. + embedding_model: the embedding model to use. + """ + self.db_name = '.mle' + self.table_name = 'memory' + self.client = lancedb.connect(uri=self.db_name) + + config = get_config(project_path) + if embedding_model == "openai": + self.text_embedding = get_registry().get(embedding_model).create(api_key=config["api_key"]) + + def add(self, texts: List[str], table_name: str = None, ids: List[str] = None) -> List[str]: + """ + Adds a list of text items to the specified memory table in the database. + + Args: + texts (List[str]): A list of text strings to be added. + table_name (str, optional): The name of the table to add data to. Defaults to self.table_name. + ids (List[str], optional): A list of unique IDs for the text items. + If not provided, random UUIDs are generated. + + Returns: + List[str]: A list of IDs associated with the added text items. + """ + if isinstance(texts, str): + texts = (texts, ) + embeds = self.text_embedding.compute_source_embeddings(texts) + + if table_name is None: + table_name = self.table_name + + if ids is None: + ids = [str(uuid.uuid4()) for _ in range(len(texts))] + + if table_name not in self.client.table_names(): + self.client.create_table(table_name, data=[ + {"vector": embed, "text": text, "id": idx} + for idx, text, embed in zip(ids, texts, embeds) + ]) + else: + self.client.open_table(table_name).add( + data=[ + {"vector": embed, "text": text, "id": idx} + for idx, text, embed in zip(ids, texts, embeds) + ] + ) + + return ids + + def query(self, query_texts: List[str], table_name: str = None, n_results: int = 5) -> List[List[dict]]: + """ + Queries the specified memory table for similar text embeddings. + + Args: + query_texts (List[str]): A list of query text strings. + table_name (str, optional): The name of the table to query. Defaults to self.table_name. + n_results (int, optional): The maximum number of results to retrieve per query. Default is 5. + + Returns: + List[List[dict]]: A list of results for each query text, each result being a dictionary with + keys such as "vector", "text", and "id". + """ + if table_name is None: + table_name = self.table_name + table = self.client.open_table(table_name) + + query_embeds = self.text_embedding.compute_source_embeddings(query_texts) + + results = [] + for query in query_embeds: + result = table.search(query).limit(n_results).to_list() + results.append(result) + + return results diff --git a/requirements.txt b/requirements.txt index a9aaf16..d0a5d1f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -21,3 +21,4 @@ bottleneck~=1.4.0 google-api-python-client~=2.143.0 google-auth-httplib2~=0.2.0 google-auth-oauthlib~=1.2.1 +lancedb~=0.15.0 From 47d3b0ba9e613d15dcf1b32e12c7f6f41a36f066 Mon Sep 17 00:00:00 2001 From: "leizhang.real@gmail.com" Date: Tue, 12 Nov 2024 14:35:43 +0000 Subject: [PATCH 2/4] complete the apis --- mle/utils/memory.py | 104 +++++++++++++++++++++++++++++--------------- 1 file changed, 68 insertions(+), 36 deletions(-) diff --git a/mle/utils/memory.py b/mle/utils/memory.py index 5516650..c77d1cb 100644 --- a/mle/utils/memory.py +++ b/mle/utils/memory.py @@ -1,7 +1,7 @@ import uuid import os.path from datetime import datetime -from typing import List, Dict +from typing import List, Dict, Optional import lancedb from lancedb.embeddings import get_registry @@ -163,83 +163,115 @@ class LanceDBMemory: def __init__( self, project_path: str, - *, - embedding_model: str = "openai", ): """ Memory: A base class for memory and external knowledge management. Args: project_path: the path to store the data. - embedding_model: the embedding model to use. """ self.db_name = '.mle' self.table_name = 'memory' self.client = lancedb.connect(uri=self.db_name) config = get_config(project_path) - if embedding_model == "openai": - self.text_embedding = get_registry().get(embedding_model).create(api_key=config["api_key"]) + if config["platform"] == "OpenAI": + self.text_embedding = get_registry().get("openai").create(api_key=config["api_key"]) + else: + raise NotImplementedError - def add(self, texts: List[str], table_name: str = None, ids: List[str] = None) -> List[str]: + def add(self, texts: List[str], table_name: Optional[str] = None, ids: Optional[List[str]] = None) -> List[str]: """ Adds a list of text items to the specified memory table in the database. Args: texts (List[str]): A list of text strings to be added. - table_name (str, optional): The name of the table to add data to. Defaults to self.table_name. - ids (List[str], optional): A list of unique IDs for the text items. + table_name (Optional[str]): The name of the table to add data to. Defaults to self.table_name. + ids (Optional[List[str]]): A list of unique IDs for the text items. If not provided, random UUIDs are generated. Returns: List[str]: A list of IDs associated with the added text items. """ if isinstance(texts, str): - texts = (texts, ) + texts = [texts] embeds = self.text_embedding.compute_source_embeddings(texts) - if table_name is None: - table_name = self.table_name - - if ids is None: - ids = [str(uuid.uuid4()) for _ in range(len(texts))] + table_name = table_name or self.table_name + ids = ids or [str(uuid.uuid4()) for _ in range(len(texts))] + data = [{"vector": embed, "text": text, "id": idx} for idx, text, embed in zip(ids, texts, embeds)] + if table_name not in self.client.table_names(): - self.client.create_table(table_name, data=[ - {"vector": embed, "text": text, "id": idx} - for idx, text, embed in zip(ids, texts, embeds) - ]) + self.client.create_table(table_name, data=data) else: - self.client.open_table(table_name).add( - data=[ - {"vector": embed, "text": text, "id": idx} - for idx, text, embed in zip(ids, texts, embeds) - ] - ) + self.client.open_table(table_name).add(data=data) return ids - def query(self, query_texts: List[str], table_name: str = None, n_results: int = 5) -> List[List[dict]]: + def query(self, query_texts: List[str], table_name: Optional[str] = None, n_results: int = 5) -> List[List[dict]]: """ Queries the specified memory table for similar text embeddings. Args: query_texts (List[str]): A list of query text strings. - table_name (str, optional): The name of the table to query. Defaults to self.table_name. - n_results (int, optional): The maximum number of results to retrieve per query. Default is 5. + table_name (Optional[str]): The name of the table to query. Defaults to self.table_name. + n_results (int): The maximum number of results to retrieve per query. Default is 5. Returns: List[List[dict]]: A list of results for each query text, each result being a dictionary with keys such as "vector", "text", and "id". """ - if table_name is None: - table_name = self.table_name + table_name = table_name or self.table_name table = self.client.open_table(table_name) - query_embeds = self.text_embedding.compute_source_embeddings(query_texts) - results = [] - for query in query_embeds: - result = table.search(query).limit(n_results).to_list() - results.append(result) - + results = [table.search(query).limit(n_results).to_list() for query in query_embeds] return results + + def delete(self, record_id: str, table_name: Optional[str] = None) -> bool: + """ + Deletes a record from the specified memory table. + + Args: + record_id (str): The ID of the record to delete. + table_name (Optional[str]): The name of the table to delete the record from. Defaults to self.table_name. + + Returns: + bool: True if the deletion was successful, False otherwise. + """ + table_name = table_name or self.table_name + table = self.client.open_table(table_name) + return table.delete(f"id = '{record_id}'") + + def drop(self, table_name: Optional[str] = None) -> bool: + """ + Drops (deletes) the specified memory table. + + Args: + table_name (Optional[str]): The name of the table to delete. Defaults to self.table_name. + + Returns: + bool: True if the table was successfully dropped, False otherwise. + """ + table_name = table_name or self.table_name + return self.client.drop_table(table_name) + + def count(self, table_name: Optional[str] = None) -> int: + """ + Counts the number of records in the specified memory table. + + Args: + table_name (Optional[str]): The name of the table to count records in. Defaults to self.table_name. + + Returns: + int: The number of records in the table. + """ + table_name = table_name or self.table_name + table = self.client.open_table(table_name) + return table.count_rows() + + def reset(self) -> None: + """ + Resets the memory by dropping the default memory table. + """ + self.drop() From d454c188388ce77440bd96b4fba584009279d647 Mon Sep 17 00:00:00 2001 From: "leizhang.real@gmail.com" Date: Tue, 12 Nov 2024 14:37:58 +0000 Subject: [PATCH 3/4] lint check --- mle/utils/memory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mle/utils/memory.py b/mle/utils/memory.py index c77d1cb..c0a7e6c 100644 --- a/mle/utils/memory.py +++ b/mle/utils/memory.py @@ -218,7 +218,7 @@ def query(self, query_texts: List[str], table_name: Optional[str] = None, n_resu n_results (int): The maximum number of results to retrieve per query. Default is 5. Returns: - List[List[dict]]: A list of results for each query text, each result being a dictionary with + List[List[dict]]: A list of results for each query text, each result being a dictionary with keys such as "vector", "text", and "id". """ table_name = table_name or self.table_name From 298a659ca64239c619e374fdc213d1feb3d3b59f Mon Sep 17 00:00:00 2001 From: "leizhang.real@gmail.com" Date: Tue, 12 Nov 2024 14:48:39 +0000 Subject: [PATCH 4/4] support metadata --- mle/utils/memory.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/mle/utils/memory.py b/mle/utils/memory.py index c0a7e6c..f1ae313 100644 --- a/mle/utils/memory.py +++ b/mle/utils/memory.py @@ -179,12 +179,19 @@ def __init__( else: raise NotImplementedError - def add(self, texts: List[str], table_name: Optional[str] = None, ids: Optional[List[str]] = None) -> List[str]: + def add( + self, + texts: List[str], + metadata: Optional[List[Dict]] = None, + table_name: Optional[str] = None, + ids: Optional[List[str]] = None, + ) -> List[str]: """ Adds a list of text items to the specified memory table in the database. Args: texts (List[str]): A list of text strings to be added. + metadata (Optional[List[Dict]]): A list of metadata to be added. table_name (Optional[str]): The name of the table to add data to. Defaults to self.table_name. ids (Optional[List[str]]): A list of unique IDs for the text items. If not provided, random UUIDs are generated. @@ -193,14 +200,29 @@ def add(self, texts: List[str], table_name: Optional[str] = None, ids: Optional[ List[str]: A list of IDs associated with the added text items. """ if isinstance(texts, str): - texts = [texts] + texts = (texts, ) + + if metadata is None: + metadata = [None, ] * len(texts) + elif isinstance(metadata, dict): + metadata = (metadata, ) + else: + assert len(texts) == len(metadata) + embeds = self.text_embedding.compute_source_embeddings(texts) table_name = table_name or self.table_name ids = ids or [str(uuid.uuid4()) for _ in range(len(texts))] - data = [{"vector": embed, "text": text, "id": idx} for idx, text, embed in zip(ids, texts, embeds)] - + data = [ + { + "vector": embed, + "text": text, + "id": idx, + "metadata": meta, + } for idx, text, embed, meta in zip(ids, texts, embeds, metadata) + ] + if table_name not in self.client.table_names(): self.client.create_table(table_name, data=data) else: