-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(community): Neo4j chat message history (#7331)
Co-authored-by: jacoblee93 <[email protected]>
- Loading branch information
1 parent
a84856e
commit b2afdf1
Showing
7 changed files
with
332 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,160 @@ | ||
import neo4j, { Driver, Record, auth } from "neo4j-driver"; | ||
import { v4 as uuidv4 } from "uuid"; | ||
import { BaseListChatMessageHistory } from "@langchain/core/chat_history"; | ||
import { | ||
BaseMessage, | ||
mapStoredMessagesToChatMessages, | ||
} from "@langchain/core/messages"; | ||
|
||
export type Neo4jChatMessageHistoryConfigInput = { | ||
sessionId?: string | number; | ||
sessionNodeLabel?: string; | ||
messageNodeLabel?: string; | ||
url: string; | ||
username: string; | ||
password: string; | ||
windowSize?: number; | ||
}; | ||
|
||
const defaultConfig = { | ||
sessionNodeLabel: "ChatSession", | ||
messageNodeLabel: "ChatMessage", | ||
windowSize: 3, | ||
}; | ||
|
||
export class Neo4jChatMessageHistory extends BaseListChatMessageHistory { | ||
lc_namespace: string[] = ["langchain", "stores", "message", "neo4j"]; | ||
|
||
sessionId: string | number; | ||
|
||
sessionNodeLabel: string; | ||
|
||
messageNodeLabel: string; | ||
|
||
windowSize: number; | ||
|
||
private driver: Driver; | ||
|
||
constructor({ | ||
sessionId = uuidv4(), | ||
sessionNodeLabel = defaultConfig.sessionNodeLabel, | ||
messageNodeLabel = defaultConfig.messageNodeLabel, | ||
url, | ||
username, | ||
password, | ||
windowSize = defaultConfig.windowSize, | ||
}: Neo4jChatMessageHistoryConfigInput) { | ||
super(); | ||
|
||
this.sessionId = sessionId; | ||
this.sessionNodeLabel = sessionNodeLabel; | ||
this.messageNodeLabel = messageNodeLabel; | ||
this.windowSize = windowSize; | ||
|
||
if (url && username && password) { | ||
try { | ||
this.driver = neo4j.driver(url, auth.basic(username, password)); | ||
} catch (e: any) { | ||
throw new Error( | ||
`Could not create a Neo4j driver instance. Please check the connection details.\nCause: ${e.message}` | ||
); | ||
} | ||
} else { | ||
throw new Error("Neo4j connection details not provided."); | ||
} | ||
} | ||
|
||
static async initialize( | ||
props: Neo4jChatMessageHistoryConfigInput | ||
): Promise<Neo4jChatMessageHistory> { | ||
const instance = new Neo4jChatMessageHistory(props); | ||
|
||
try { | ||
await instance.verifyConnectivity(); | ||
} catch (e: any) { | ||
throw new Error( | ||
`Could not verify connection to the Neo4j database.\nCause: ${e.message}` | ||
); | ||
} | ||
|
||
return instance; | ||
} | ||
|
||
async verifyConnectivity() { | ||
const connectivity = await this.driver.getServerInfo(); | ||
return connectivity; | ||
} | ||
|
||
async getMessages(): Promise<BaseMessage[]> { | ||
const getMessagesCypherQuery = ` | ||
MERGE (chatSession:${this.sessionNodeLabel} {id: $sessionId}) | ||
WITH chatSession | ||
MATCH (chatSession)-[:LAST_MESSAGE]->(lastMessage) | ||
MATCH p=(lastMessage)<-[:NEXT*0..${this.windowSize * 2 - 1}]-() | ||
WITH p, length(p) AS length | ||
ORDER BY length DESC LIMIT 1 | ||
UNWIND reverse(nodes(p)) AS node | ||
RETURN {data:{content: node.content}, type:node.type} AS result | ||
`; | ||
|
||
try { | ||
const { records } = await this.driver.executeQuery( | ||
getMessagesCypherQuery, | ||
{ | ||
sessionId: this.sessionId, | ||
} | ||
); | ||
const results = records.map((record: Record) => record.get("result")); | ||
|
||
return mapStoredMessagesToChatMessages(results); | ||
} catch (e: any) { | ||
throw new Error(`Ohno! Couldn't get messages.\nCause: ${e.message}`); | ||
} | ||
} | ||
|
||
async addMessage(message: BaseMessage): Promise<void> { | ||
const addMessageCypherQuery = ` | ||
MERGE (chatSession:${this.sessionNodeLabel} {id: $sessionId}) | ||
WITH chatSession | ||
OPTIONAL MATCH (chatSession)-[lastMessageRel:LAST_MESSAGE]->(lastMessage) | ||
CREATE (chatSession)-[:LAST_MESSAGE]->(newLastMessage:${this.messageNodeLabel}) | ||
SET newLastMessage += {type:$type, content:$content} | ||
WITH newLastMessage, lastMessageRel, lastMessage | ||
WHERE lastMessage IS NOT NULL | ||
CREATE (lastMessage)-[:NEXT]->(newLastMessage) | ||
DELETE lastMessageRel | ||
`; | ||
|
||
try { | ||
await this.driver.executeQuery(addMessageCypherQuery, { | ||
sessionId: this.sessionId, | ||
type: message.getType(), | ||
content: message.content, | ||
}); | ||
} catch (e: any) { | ||
throw new Error(`Ohno! Couldn't add message.\nCause: ${e.message}`); | ||
} | ||
} | ||
|
||
async clear() { | ||
const clearMessagesCypherQuery = ` | ||
MATCH p=(chatSession:${this.sessionNodeLabel} {id: $sessionId})-[:LAST_MESSAGE]->(lastMessage)<-[:NEXT*0..]-() | ||
UNWIND nodes(p) as node | ||
DETACH DELETE node | ||
`; | ||
|
||
try { | ||
await this.driver.executeQuery(clearMessagesCypherQuery, { | ||
sessionId: this.sessionId, | ||
}); | ||
} catch (e: any) { | ||
throw new Error( | ||
`Ohno! Couldn't clear chat history.\nCause: ${e.message}` | ||
); | ||
} | ||
} | ||
|
||
async close() { | ||
await this.driver.close(); | ||
} | ||
} |
138 changes: 138 additions & 0 deletions
138
libs/langchain-community/src/stores/tests/neo4j.int.test.ts
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
import { describe, it, expect, beforeEach, afterEach } from "@jest/globals"; | ||
import { HumanMessage, AIMessage } from "@langchain/core/messages"; | ||
import neo4j from "neo4j-driver"; | ||
import { Neo4jChatMessageHistory } from "../message/neo4j.js"; | ||
|
||
const goodConfig = { | ||
url: "bolt://host.docker.internal:7687", | ||
username: "neo4j", | ||
password: "langchain", | ||
}; | ||
|
||
describe("The Neo4jChatMessageHistory class", () => { | ||
describe("Test suite", () => { | ||
it("Runs at all", () => { | ||
expect(true).toEqual(true); | ||
}); | ||
}); | ||
|
||
describe("Class instantiation", () => { | ||
it("Requires a url, username and password, throwing an error if not provided", async () => { | ||
const badConfig = {}; | ||
await expect( | ||
// @ts-expect-error Bad config | ||
Neo4jChatMessageHistory.initialize(badConfig) | ||
).rejects.toThrow(neo4j.Neo4jError); | ||
}); | ||
|
||
it("Creates a class instance from - at minimum - a url, username and password", async () => { | ||
const instance = await Neo4jChatMessageHistory.initialize(goodConfig); | ||
expect(instance).toBeInstanceOf(Neo4jChatMessageHistory); | ||
await instance.close(); | ||
}); | ||
|
||
it("Class instances have expected, configurable fields, and sensible defaults", async () => { | ||
const instance = await Neo4jChatMessageHistory.initialize(goodConfig); | ||
|
||
expect(instance.sessionId).toBeDefined(); | ||
expect(instance.sessionNodeLabel).toEqual("ChatSession"); | ||
expect(instance.windowSize).toEqual(3); | ||
expect(instance.messageNodeLabel).toEqual("ChatMessage"); | ||
|
||
const secondInstance = await Neo4jChatMessageHistory.initialize({ | ||
...goodConfig, | ||
sessionId: "Shibboleet", | ||
sessionNodeLabel: "Conversation", | ||
messageNodeLabel: "Communication", | ||
windowSize: 4, | ||
}); | ||
|
||
expect(secondInstance.sessionId).toBeDefined(); | ||
expect(secondInstance.sessionId).toEqual("Shibboleet"); | ||
expect(instance.sessionId).not.toEqual(secondInstance.sessionId); | ||
expect(secondInstance.sessionNodeLabel).toEqual("Conversation"); | ||
expect(secondInstance.messageNodeLabel).toEqual("Communication"); | ||
expect(secondInstance.windowSize).toEqual(4); | ||
|
||
await instance.close(); | ||
await secondInstance.close(); | ||
}); | ||
}); | ||
|
||
describe("Core functionality", () => { | ||
let instance: undefined | Neo4jChatMessageHistory; | ||
|
||
beforeEach(async () => { | ||
instance = await Neo4jChatMessageHistory.initialize(goodConfig); | ||
}); | ||
|
||
afterEach(async () => { | ||
await instance?.clear(); | ||
await instance?.close(); | ||
}); | ||
|
||
it("Connects verifiably to the underlying Neo4j database", async () => { | ||
const connected = await instance?.verifyConnectivity(); | ||
expect(connected).toBeDefined(); | ||
}); | ||
|
||
it("getMessages()", async () => { | ||
let results = await instance?.getMessages(); | ||
expect(results).toEqual([]); | ||
const messages = [ | ||
new HumanMessage( | ||
"My first name is a random set of numbers and letters" | ||
), | ||
new AIMessage("And other alphanumerics that changes hourly forever"), | ||
new HumanMessage( | ||
"My last name, a thousand vowels fading down a sinkhole to a susurrus" | ||
), | ||
new AIMessage("It couldn't just be John Doe or Bingo"), | ||
new HumanMessage( | ||
"My address, a made-up language written out in living glyphs" | ||
), | ||
new AIMessage("Lifted from demonic literature and religious text"), | ||
new HumanMessage("Telephone: uncovered by purveyors of the ouija"), | ||
new AIMessage("When checked against the CBGB women's room graffiti"), | ||
new HumanMessage("My social: a sudoku"), | ||
new AIMessage("My age is obscure"), | ||
]; | ||
await instance?.addMessages(messages); | ||
results = (await instance?.getMessages()) || []; | ||
const windowSize = instance?.windowSize || 0; | ||
expect(results.length).toEqual(windowSize * 2); | ||
expect(results).toEqual(messages.slice(windowSize * -2)); | ||
}); | ||
|
||
it("addMessage()", async () => { | ||
const messages = [ | ||
new HumanMessage("99 Bottles of beer on the wall, 99 bottles of beer!"), | ||
new AIMessage( | ||
"Take one down, pass it around, 98 bottles of beer on the wall." | ||
), | ||
new HumanMessage("How many bottles of beer are currently on the wall?"), | ||
new AIMessage("There are currently 98 bottles of beer on the wall."), | ||
]; | ||
for (const message of messages) { | ||
await instance?.addMessage(message); | ||
} | ||
const results = await instance?.getMessages(); | ||
expect(results).toEqual(messages); | ||
}); | ||
|
||
it("clear()", async () => { | ||
const messages = [ | ||
new AIMessage("I'm not your enemy."), | ||
new HumanMessage("That sounds like something that my enemy would say."), | ||
new AIMessage("You're being difficult."), | ||
new HumanMessage("I'm being guarded."), | ||
]; | ||
await instance?.addMessages(messages); | ||
let results = await instance?.getMessages(); | ||
expect(results).toEqual(messages); | ||
await instance?.clear(); | ||
results = await instance?.getMessages(); | ||
expect(results).toEqual([]); | ||
}); | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters