Skip to content

Commit

Permalink
support metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
leeeizhang committed Nov 12, 2024
1 parent d454c18 commit 298a659
Showing 1 changed file with 26 additions and 4 deletions.
30 changes: 26 additions & 4 deletions mle/utils/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,12 +179,19 @@ def __init__(
else:
raise NotImplementedError

def add(self, texts: List[str], table_name: Optional[str] = None, ids: Optional[List[str]] = None) -> List[str]:
def add(
self,
texts: List[str],
metadata: Optional[List[Dict]] = None,
table_name: Optional[str] = None,
ids: Optional[List[str]] = None,
) -> List[str]:
"""
Adds a list of text items to the specified memory table in the database.
Args:
texts (List[str]): A list of text strings to be added.
metadata (Optional[List[Dict]]): A list of metadata to be added.
table_name (Optional[str]): The name of the table to add data to. Defaults to self.table_name.
ids (Optional[List[str]]): A list of unique IDs for the text items.
If not provided, random UUIDs are generated.
Expand All @@ -193,14 +200,29 @@ def add(self, texts: List[str], table_name: Optional[str] = None, ids: Optional[
List[str]: A list of IDs associated with the added text items.
"""
if isinstance(texts, str):
texts = [texts]
texts = (texts, )

if metadata is None:
metadata = [None, ] * len(texts)
elif isinstance(metadata, dict):
metadata = (metadata, )
else:
assert len(texts) == len(metadata)

embeds = self.text_embedding.compute_source_embeddings(texts)

table_name = table_name or self.table_name
ids = ids or [str(uuid.uuid4()) for _ in range(len(texts))]

data = [{"vector": embed, "text": text, "id": idx} for idx, text, embed in zip(ids, texts, embeds)]

data = [
{
"vector": embed,
"text": text,
"id": idx,
"metadata": meta,
} for idx, text, embed, meta in zip(ids, texts, embeds, metadata)
]

if table_name not in self.client.table_names():
self.client.create_table(table_name, data=data)
else:
Expand Down

0 comments on commit 298a659

Please sign in to comment.