generated from langchain-ai/memory-template
-
Notifications
You must be signed in to change notification settings - Fork 12
/
test_graph.py
55 lines (46 loc) · 1.97 KB
/
test_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from typing import List
import langsmith as ls
import pytest
from langgraph.checkpoint.memory import MemorySaver
from langgraph.store.memory import InMemoryStore
from memory_agent.graph import builder
@pytest.mark.asyncio
@ls.unit
@pytest.mark.parametrize(
"conversation",
[
["My name is Alice and I love pizza. Remember this."],
[
"Hi, I'm Bob and I enjoy playing tennis. Remember this.",
"Yes, I also have a pet dog named Max.",
"Max is a golden retriever and he's 5 years old. Please remember this too.",
],
[
"Hello, I'm Charlie. I work as a software engineer and I'm passionate about AI. Remember this.",
"I specialize in machine learning algorithms and I'm currently working on a project involving natural language processing.",
"My main goal is to improve sentiment analysis accuracy in multi-lingual texts. It's challenging but exciting.",
"We've made some progress using transformer models, but we're still working on handling context and idioms across languages.",
"Chinese and English have been the most challenging pair so far due to their vast differences in structure and cultural contexts.",
],
],
ids=["short", "medium", "long"],
)
async def test_memory_storage(conversation: List[str]):
mem_store = InMemoryStore()
graph = builder.compile(store=mem_store, checkpointer=MemorySaver())
user_id = "test-user"
config = {
"configurable": {},
"user_id": user_id,
}
for content in conversation:
await graph.ainvoke(
{"messages": [("user", content)]},
{**config, "thread_id": "thread"},
)
namespace = ("memories", user_id)
memories = mem_store.search(namespace)
ls.expect(len(memories)).to_be_greater_than(0)
bad_namespace = ("memories", "wrong-user")
bad_memories = mem_store.search(bad_namespace)
ls.expect(len(bad_memories)).to_equal(0)