Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] add lancedb as memory #262

Merged
merged 4 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions mle/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

import mle
from mle.server import app
from mle.utils import Memory
import mle.workflow as workflow
from mle.utils.system import (
get_config,
Expand Down Expand Up @@ -278,9 +277,6 @@ def new(name):
'integration': {},
}, outfile, default_flow_style=False)

# init the memory
Memory(project_dir)


@cli.command()
@click.option('--reset', is_flag=True, help='Reset the integration')
Expand Down
149 changes: 147 additions & 2 deletions mle/utils/memory.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import uuid
import os.path
from datetime import datetime
from typing import List, Dict
from typing import List, Dict, Optional

import lancedb
from lancedb.embeddings import get_registry

import chromadb
from chromadb.utils import embedding_functions
Expand All @@ -11,7 +14,8 @@
chromadb.logger.setLevel(chromadb.logging.ERROR)


class Memory:
class ChromaDBMemory:

def __init__(
self,
project_path: str,
Expand Down Expand Up @@ -152,3 +156,144 @@ def reset(self):
Notice: You may need to set the environment variable `ALLOW_RESET` to `TRUE` to enable this function.
"""
self.client.reset()


class LanceDBMemory:

def __init__(
self,
project_path: str,
):
"""
Memory: A base class for memory and external knowledge management.
Args:
project_path: the path to store the data.
"""
self.db_name = '.mle'
self.table_name = 'memory'
self.client = lancedb.connect(uri=self.db_name)

config = get_config(project_path)
if config["platform"] == "OpenAI":
self.text_embedding = get_registry().get("openai").create(api_key=config["api_key"])
else:
raise NotImplementedError

def add(
self,
texts: List[str],
metadata: Optional[List[Dict]] = None,
table_name: Optional[str] = None,
ids: Optional[List[str]] = None,
) -> List[str]:
"""
Adds a list of text items to the specified memory table in the database.

Args:
texts (List[str]): A list of text strings to be added.
metadata (Optional[List[Dict]]): A list of metadata to be added.
table_name (Optional[str]): The name of the table to add data to. Defaults to self.table_name.
ids (Optional[List[str]]): A list of unique IDs for the text items.
If not provided, random UUIDs are generated.

Returns:
List[str]: A list of IDs associated with the added text items.
"""
if isinstance(texts, str):
texts = (texts, )

if metadata is None:
metadata = [None, ] * len(texts)
elif isinstance(metadata, dict):
metadata = (metadata, )
else:
assert len(texts) == len(metadata)

embeds = self.text_embedding.compute_source_embeddings(texts)

table_name = table_name or self.table_name
ids = ids or [str(uuid.uuid4()) for _ in range(len(texts))]

data = [
{
"vector": embed,
"text": text,
"id": idx,
"metadata": meta,
} for idx, text, embed, meta in zip(ids, texts, embeds, metadata)
]

if table_name not in self.client.table_names():
self.client.create_table(table_name, data=data)
else:
self.client.open_table(table_name).add(data=data)

return ids

def query(self, query_texts: List[str], table_name: Optional[str] = None, n_results: int = 5) -> List[List[dict]]:
"""
Queries the specified memory table for similar text embeddings.

Args:
query_texts (List[str]): A list of query text strings.
table_name (Optional[str]): The name of the table to query. Defaults to self.table_name.
n_results (int): The maximum number of results to retrieve per query. Default is 5.

Returns:
List[List[dict]]: A list of results for each query text, each result being a dictionary with
keys such as "vector", "text", and "id".
"""
table_name = table_name or self.table_name
table = self.client.open_table(table_name)
query_embeds = self.text_embedding.compute_source_embeddings(query_texts)

results = [table.search(query).limit(n_results).to_list() for query in query_embeds]
return results

def delete(self, record_id: str, table_name: Optional[str] = None) -> bool:
"""
Deletes a record from the specified memory table.

Args:
record_id (str): The ID of the record to delete.
table_name (Optional[str]): The name of the table to delete the record from. Defaults to self.table_name.

Returns:
bool: True if the deletion was successful, False otherwise.
"""
table_name = table_name or self.table_name
table = self.client.open_table(table_name)
return table.delete(f"id = '{record_id}'")

def drop(self, table_name: Optional[str] = None) -> bool:
"""
Drops (deletes) the specified memory table.

Args:
table_name (Optional[str]): The name of the table to delete. Defaults to self.table_name.

Returns:
bool: True if the table was successfully dropped, False otherwise.
"""
table_name = table_name or self.table_name
return self.client.drop_table(table_name)

def count(self, table_name: Optional[str] = None) -> int:
"""
Counts the number of records in the specified memory table.

Args:
table_name (Optional[str]): The name of the table to count records in. Defaults to self.table_name.

Returns:
int: The number of records in the table.
"""
table_name = table_name or self.table_name
table = self.client.open_table(table_name)
return table.count_rows()

def reset(self) -> None:
"""
Resets the memory by dropping the default memory table.
"""
self.drop()
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ bottleneck~=1.4.0
google-api-python-client~=2.143.0
google-auth-httplib2~=0.2.0
google-auth-oauthlib~=1.2.1
lancedb~=0.15.0
Loading