Skip to content

Commit

Permalink
Merge pull request #262 from leeeizhang/lei/lancedb
Browse files Browse the repository at this point in the history
[MRG] add lancedb as memory
  • Loading branch information
huangyz0918 authored Nov 15, 2024
2 parents 98f504a + 298a659 commit ce39ddf
Show file tree
Hide file tree
Showing 3 changed files with 148 additions and 6 deletions.
4 changes: 0 additions & 4 deletions mle/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import mle
from mle.server import app
from mle.utils import Memory
import mle.workflow as workflow
from mle.utils.system import (
get_config,
Expand Down Expand Up @@ -278,9 +277,6 @@ def new(name):
'integration': {},
}, outfile, default_flow_style=False)

# init the memory
Memory(project_dir)


@cli.command()
@click.option('--reset', is_flag=True, help='Reset the integration')
Expand Down
149 changes: 147 additions & 2 deletions mle/utils/memory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import uuid
import os.path
from datetime import datetime
from typing import List, Dict
from typing import List, Dict, Optional

import lancedb
from lancedb.embeddings import get_registry

import chromadb
from chromadb.utils import embedding_functions
Expand All @@ -11,7 +14,8 @@
chromadb.logger.setLevel(chromadb.logging.ERROR)


class Memory:
class ChromaDBMemory:

def __init__(
self,
project_path: str,
Expand Down Expand Up @@ -152,3 +156,144 @@ def reset(self):
Notice: You may need to set the environment variable `ALLOW_RESET` to `TRUE` to enable this function.
"""
self.client.reset()


class LanceDBMemory:

def __init__(
self,
project_path: str,
):
"""
Memory: A base class for memory and external knowledge management.
Args:
project_path: the path to store the data.
"""
self.db_name = '.mle'
self.table_name = 'memory'
self.client = lancedb.connect(uri=self.db_name)

config = get_config(project_path)
if config["platform"] == "OpenAI":
self.text_embedding = get_registry().get("openai").create(api_key=config["api_key"])
else:
raise NotImplementedError

def add(
self,
texts: List[str],
metadata: Optional[List[Dict]] = None,
table_name: Optional[str] = None,
ids: Optional[List[str]] = None,
) -> List[str]:
"""
Adds a list of text items to the specified memory table in the database.
Args:
texts (List[str]): A list of text strings to be added.
metadata (Optional[List[Dict]]): A list of metadata to be added.
table_name (Optional[str]): The name of the table to add data to. Defaults to self.table_name.
ids (Optional[List[str]]): A list of unique IDs for the text items.
If not provided, random UUIDs are generated.
Returns:
List[str]: A list of IDs associated with the added text items.
"""
if isinstance(texts, str):
texts = (texts, )

if metadata is None:
metadata = [None, ] * len(texts)
elif isinstance(metadata, dict):
metadata = (metadata, )
else:
assert len(texts) == len(metadata)

embeds = self.text_embedding.compute_source_embeddings(texts)

table_name = table_name or self.table_name
ids = ids or [str(uuid.uuid4()) for _ in range(len(texts))]

data = [
{
"vector": embed,
"text": text,
"id": idx,
"metadata": meta,
} for idx, text, embed, meta in zip(ids, texts, embeds, metadata)
]

if table_name not in self.client.table_names():
self.client.create_table(table_name, data=data)
else:
self.client.open_table(table_name).add(data=data)

return ids

def query(self, query_texts: List[str], table_name: Optional[str] = None, n_results: int = 5) -> List[List[dict]]:
"""
Queries the specified memory table for similar text embeddings.
Args:
query_texts (List[str]): A list of query text strings.
table_name (Optional[str]): The name of the table to query. Defaults to self.table_name.
n_results (int): The maximum number of results to retrieve per query. Default is 5.
Returns:
List[List[dict]]: A list of results for each query text, each result being a dictionary with
keys such as "vector", "text", and "id".
"""
table_name = table_name or self.table_name
table = self.client.open_table(table_name)
query_embeds = self.text_embedding.compute_source_embeddings(query_texts)

results = [table.search(query).limit(n_results).to_list() for query in query_embeds]
return results

def delete(self, record_id: str, table_name: Optional[str] = None) -> bool:
"""
Deletes a record from the specified memory table.
Args:
record_id (str): The ID of the record to delete.
table_name (Optional[str]): The name of the table to delete the record from. Defaults to self.table_name.
Returns:
bool: True if the deletion was successful, False otherwise.
"""
table_name = table_name or self.table_name
table = self.client.open_table(table_name)
return table.delete(f"id = '{record_id}'")

def drop(self, table_name: Optional[str] = None) -> bool:
"""
Drops (deletes) the specified memory table.
Args:
table_name (Optional[str]): The name of the table to delete. Defaults to self.table_name.
Returns:
bool: True if the table was successfully dropped, False otherwise.
"""
table_name = table_name or self.table_name
return self.client.drop_table(table_name)

def count(self, table_name: Optional[str] = None) -> int:
"""
Counts the number of records in the specified memory table.
Args:
table_name (Optional[str]): The name of the table to count records in. Defaults to self.table_name.
Returns:
int: The number of records in the table.
"""
table_name = table_name or self.table_name
table = self.client.open_table(table_name)
return table.count_rows()

def reset(self) -> None:
"""
Resets the memory by dropping the default memory table.
"""
self.drop()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ bottleneck~=1.4.0
google-api-python-client~=2.143.0
google-auth-httplib2~=0.2.0
google-auth-oauthlib~=1.2.1
lancedb~=0.15.0

0 comments on commit ce39ddf

Please sign in to comment.