From b2afdf160f538d10edb42de34094e172df1cf7c0 Mon Sep 17 00:00:00 2001 From: Bernard F Faucher Date: Fri, 13 Dec 2024 23:24:31 -0500 Subject: [PATCH] feat(community): Neo4j chat message history (#7331) Co-authored-by: jacoblee93 --- libs/langchain-community/.gitignore | 4 + libs/langchain-community/langchain.config.js | 2 + libs/langchain-community/package.json | 13 ++ .../src/load/import_constants.ts | 1 + .../src/stores/message/neo4j.ts | 160 ++++++++++++++++++ .../src/stores/tests/neo4j.int.test.ts | 138 +++++++++++++++ test-int-deps-docker-compose.yml | 15 +- 7 files changed, 332 insertions(+), 1 deletion(-) create mode 100644 libs/langchain-community/src/stores/message/neo4j.ts create mode 100644 libs/langchain-community/src/stores/tests/neo4j.int.test.ts diff --git a/libs/langchain-community/.gitignore b/libs/langchain-community/.gitignore index 8dc708cbe23e..b0ce628119ae 100644 --- a/libs/langchain-community/.gitignore +++ b/libs/langchain-community/.gitignore @@ -802,6 +802,10 @@ stores/message/mongodb.cjs stores/message/mongodb.js stores/message/mongodb.d.ts stores/message/mongodb.d.cts +stores/message/neo4j.cjs +stores/message/neo4j.js +stores/message/neo4j.d.ts +stores/message/neo4j.d.cts stores/message/planetscale.cjs stores/message/planetscale.js stores/message/planetscale.d.ts diff --git a/libs/langchain-community/langchain.config.js b/libs/langchain-community/langchain.config.js index 547960384372..cbabc1c68841 100644 --- a/libs/langchain-community/langchain.config.js +++ b/libs/langchain-community/langchain.config.js @@ -250,6 +250,7 @@ export const config = { "stores/message/ioredis": "stores/message/ioredis", "stores/message/momento": "stores/message/momento", "stores/message/mongodb": "stores/message/mongodb", + "stores/message/neo4j": "stores/message/neo4j", "stores/message/planetscale": "stores/message/planetscale", "stores/message/postgres": "stores/message/postgres", "stores/message/redis": "stores/message/redis", @@ -473,6 +474,7 @@ export const config = { "stores/message/ipfs_datastore", "stores/message/momento", "stores/message/mongodb", + "stores/message/neo4j", "stores/message/planetscale", "stores/message/postgres", "stores/message/redis", diff --git a/libs/langchain-community/package.json b/libs/langchain-community/package.json index 108a51c174ce..31da4591d64c 100644 --- a/libs/langchain-community/package.json +++ b/libs/langchain-community/package.json @@ -2521,6 +2521,15 @@ "import": "./stores/message/mongodb.js", "require": "./stores/message/mongodb.cjs" }, + "./stores/message/neo4j": { + "types": { + "import": "./stores/message/neo4j.d.ts", + "require": "./stores/message/neo4j.d.cts", + "default": "./stores/message/neo4j.d.ts" + }, + "import": "./stores/message/neo4j.js", + "require": "./stores/message/neo4j.cjs" + }, "./stores/message/planetscale": { "types": { "import": "./stores/message/planetscale.d.ts", @@ -3932,6 +3941,10 @@ "stores/message/mongodb.js", "stores/message/mongodb.d.ts", "stores/message/mongodb.d.cts", + "stores/message/neo4j.cjs", + "stores/message/neo4j.js", + "stores/message/neo4j.d.ts", + "stores/message/neo4j.d.cts", "stores/message/planetscale.cjs", "stores/message/planetscale.js", "stores/message/planetscale.d.ts", diff --git a/libs/langchain-community/src/load/import_constants.ts b/libs/langchain-community/src/load/import_constants.ts index 722dd82e678b..5930f82690db 100644 --- a/libs/langchain-community/src/load/import_constants.ts +++ b/libs/langchain-community/src/load/import_constants.ts @@ -130,6 +130,7 @@ export const optionalImportEntrypoints: string[] = [ "langchain_community/stores/message/ioredis", "langchain_community/stores/message/momento", "langchain_community/stores/message/mongodb", + "langchain_community/stores/message/neo4j", "langchain_community/stores/message/planetscale", "langchain_community/stores/message/postgres", "langchain_community/stores/message/redis", diff --git a/libs/langchain-community/src/stores/message/neo4j.ts b/libs/langchain-community/src/stores/message/neo4j.ts new file mode 100644 index 000000000000..a5f132900470 --- /dev/null +++ b/libs/langchain-community/src/stores/message/neo4j.ts @@ -0,0 +1,160 @@ +import neo4j, { Driver, Record, auth } from "neo4j-driver"; +import { v4 as uuidv4 } from "uuid"; +import { BaseListChatMessageHistory } from "@langchain/core/chat_history"; +import { + BaseMessage, + mapStoredMessagesToChatMessages, +} from "@langchain/core/messages"; + +export type Neo4jChatMessageHistoryConfigInput = { + sessionId?: string | number; + sessionNodeLabel?: string; + messageNodeLabel?: string; + url: string; + username: string; + password: string; + windowSize?: number; +}; + +const defaultConfig = { + sessionNodeLabel: "ChatSession", + messageNodeLabel: "ChatMessage", + windowSize: 3, +}; + +export class Neo4jChatMessageHistory extends BaseListChatMessageHistory { + lc_namespace: string[] = ["langchain", "stores", "message", "neo4j"]; + + sessionId: string | number; + + sessionNodeLabel: string; + + messageNodeLabel: string; + + windowSize: number; + + private driver: Driver; + + constructor({ + sessionId = uuidv4(), + sessionNodeLabel = defaultConfig.sessionNodeLabel, + messageNodeLabel = defaultConfig.messageNodeLabel, + url, + username, + password, + windowSize = defaultConfig.windowSize, + }: Neo4jChatMessageHistoryConfigInput) { + super(); + + this.sessionId = sessionId; + this.sessionNodeLabel = sessionNodeLabel; + this.messageNodeLabel = messageNodeLabel; + this.windowSize = windowSize; + + if (url && username && password) { + try { + this.driver = neo4j.driver(url, auth.basic(username, password)); + } catch (e: any) { + throw new Error( + `Could not create a Neo4j driver instance. Please check the connection details.\nCause: ${e.message}` + ); + } + } else { + throw new Error("Neo4j connection details not provided."); + } + } + + static async initialize( + props: Neo4jChatMessageHistoryConfigInput + ): Promise { + const instance = new Neo4jChatMessageHistory(props); + + try { + await instance.verifyConnectivity(); + } catch (e: any) { + throw new Error( + `Could not verify connection to the Neo4j database.\nCause: ${e.message}` + ); + } + + return instance; + } + + async verifyConnectivity() { + const connectivity = await this.driver.getServerInfo(); + return connectivity; + } + + async getMessages(): Promise { + const getMessagesCypherQuery = ` + MERGE (chatSession:${this.sessionNodeLabel} {id: $sessionId}) + WITH chatSession + MATCH (chatSession)-[:LAST_MESSAGE]->(lastMessage) + MATCH p=(lastMessage)<-[:NEXT*0..${this.windowSize * 2 - 1}]-() + WITH p, length(p) AS length + ORDER BY length DESC LIMIT 1 + UNWIND reverse(nodes(p)) AS node + RETURN {data:{content: node.content}, type:node.type} AS result + `; + + try { + const { records } = await this.driver.executeQuery( + getMessagesCypherQuery, + { + sessionId: this.sessionId, + } + ); + const results = records.map((record: Record) => record.get("result")); + + return mapStoredMessagesToChatMessages(results); + } catch (e: any) { + throw new Error(`Ohno! Couldn't get messages.\nCause: ${e.message}`); + } + } + + async addMessage(message: BaseMessage): Promise { + const addMessageCypherQuery = ` + MERGE (chatSession:${this.sessionNodeLabel} {id: $sessionId}) + WITH chatSession + OPTIONAL MATCH (chatSession)-[lastMessageRel:LAST_MESSAGE]->(lastMessage) + CREATE (chatSession)-[:LAST_MESSAGE]->(newLastMessage:${this.messageNodeLabel}) + SET newLastMessage += {type:$type, content:$content} + WITH newLastMessage, lastMessageRel, lastMessage + WHERE lastMessage IS NOT NULL + CREATE (lastMessage)-[:NEXT]->(newLastMessage) + DELETE lastMessageRel + `; + + try { + await this.driver.executeQuery(addMessageCypherQuery, { + sessionId: this.sessionId, + type: message.getType(), + content: message.content, + }); + } catch (e: any) { + throw new Error(`Ohno! Couldn't add message.\nCause: ${e.message}`); + } + } + + async clear() { + const clearMessagesCypherQuery = ` + MATCH p=(chatSession:${this.sessionNodeLabel} {id: $sessionId})-[:LAST_MESSAGE]->(lastMessage)<-[:NEXT*0..]-() + UNWIND nodes(p) as node + DETACH DELETE node + `; + + try { + await this.driver.executeQuery(clearMessagesCypherQuery, { + sessionId: this.sessionId, + }); + } catch (e: any) { + throw new Error( + `Ohno! Couldn't clear chat history.\nCause: ${e.message}` + ); + } + } + + async close() { + await this.driver.close(); + } +} diff --git a/libs/langchain-community/src/stores/tests/neo4j.int.test.ts b/libs/langchain-community/src/stores/tests/neo4j.int.test.ts new file mode 100644 index 000000000000..2f6c17d01ed6 --- /dev/null +++ b/libs/langchain-community/src/stores/tests/neo4j.int.test.ts @@ -0,0 +1,138 @@ +import { describe, it, expect, beforeEach, afterEach } from "@jest/globals"; +import { HumanMessage, AIMessage } from "@langchain/core/messages"; +import neo4j from "neo4j-driver"; +import { Neo4jChatMessageHistory } from "../message/neo4j.js"; + +const goodConfig = { + url: "bolt://host.docker.internal:7687", + username: "neo4j", + password: "langchain", +}; + +describe("The Neo4jChatMessageHistory class", () => { + describe("Test suite", () => { + it("Runs at all", () => { + expect(true).toEqual(true); + }); + }); + + describe("Class instantiation", () => { + it("Requires a url, username and password, throwing an error if not provided", async () => { + const badConfig = {}; + await expect( + // @ts-expect-error Bad config + Neo4jChatMessageHistory.initialize(badConfig) + ).rejects.toThrow(neo4j.Neo4jError); + }); + + it("Creates a class instance from - at minimum - a url, username and password", async () => { + const instance = await Neo4jChatMessageHistory.initialize(goodConfig); + expect(instance).toBeInstanceOf(Neo4jChatMessageHistory); + await instance.close(); + }); + + it("Class instances have expected, configurable fields, and sensible defaults", async () => { + const instance = await Neo4jChatMessageHistory.initialize(goodConfig); + + expect(instance.sessionId).toBeDefined(); + expect(instance.sessionNodeLabel).toEqual("ChatSession"); + expect(instance.windowSize).toEqual(3); + expect(instance.messageNodeLabel).toEqual("ChatMessage"); + + const secondInstance = await Neo4jChatMessageHistory.initialize({ + ...goodConfig, + sessionId: "Shibboleet", + sessionNodeLabel: "Conversation", + messageNodeLabel: "Communication", + windowSize: 4, + }); + + expect(secondInstance.sessionId).toBeDefined(); + expect(secondInstance.sessionId).toEqual("Shibboleet"); + expect(instance.sessionId).not.toEqual(secondInstance.sessionId); + expect(secondInstance.sessionNodeLabel).toEqual("Conversation"); + expect(secondInstance.messageNodeLabel).toEqual("Communication"); + expect(secondInstance.windowSize).toEqual(4); + + await instance.close(); + await secondInstance.close(); + }); + }); + + describe("Core functionality", () => { + let instance: undefined | Neo4jChatMessageHistory; + + beforeEach(async () => { + instance = await Neo4jChatMessageHistory.initialize(goodConfig); + }); + + afterEach(async () => { + await instance?.clear(); + await instance?.close(); + }); + + it("Connects verifiably to the underlying Neo4j database", async () => { + const connected = await instance?.verifyConnectivity(); + expect(connected).toBeDefined(); + }); + + it("getMessages()", async () => { + let results = await instance?.getMessages(); + expect(results).toEqual([]); + const messages = [ + new HumanMessage( + "My first name is a random set of numbers and letters" + ), + new AIMessage("And other alphanumerics that changes hourly forever"), + new HumanMessage( + "My last name, a thousand vowels fading down a sinkhole to a susurrus" + ), + new AIMessage("It couldn't just be John Doe or Bingo"), + new HumanMessage( + "My address, a made-up language written out in living glyphs" + ), + new AIMessage("Lifted from demonic literature and religious text"), + new HumanMessage("Telephone: uncovered by purveyors of the ouija"), + new AIMessage("When checked against the CBGB women's room graffiti"), + new HumanMessage("My social: a sudoku"), + new AIMessage("My age is obscure"), + ]; + await instance?.addMessages(messages); + results = (await instance?.getMessages()) || []; + const windowSize = instance?.windowSize || 0; + expect(results.length).toEqual(windowSize * 2); + expect(results).toEqual(messages.slice(windowSize * -2)); + }); + + it("addMessage()", async () => { + const messages = [ + new HumanMessage("99 Bottles of beer on the wall, 99 bottles of beer!"), + new AIMessage( + "Take one down, pass it around, 98 bottles of beer on the wall." + ), + new HumanMessage("How many bottles of beer are currently on the wall?"), + new AIMessage("There are currently 98 bottles of beer on the wall."), + ]; + for (const message of messages) { + await instance?.addMessage(message); + } + const results = await instance?.getMessages(); + expect(results).toEqual(messages); + }); + + it("clear()", async () => { + const messages = [ + new AIMessage("I'm not your enemy."), + new HumanMessage("That sounds like something that my enemy would say."), + new AIMessage("You're being difficult."), + new HumanMessage("I'm being guarded."), + ]; + await instance?.addMessages(messages); + let results = await instance?.getMessages(); + expect(results).toEqual(messages); + await instance?.clear(); + results = await instance?.getMessages(); + expect(results).toEqual([]); + }); + }); +}); diff --git a/test-int-deps-docker-compose.yml b/test-int-deps-docker-compose.yml index e14c4d65779a..2c875f6f5221 100644 --- a/test-int-deps-docker-compose.yml +++ b/test-int-deps-docker-compose.yml @@ -36,4 +36,17 @@ services: qdrant: image: qdrant/qdrant:v1.9.1 ports: - - 6333:6333 \ No newline at end of file + - 6333:6333 + neo4j: + image: neo4j:latest + volumes: + - $HOME/neo4j/logs:/var/lib/neo4j/logs + - $HOME/neo4j/config:/var/lib/neo4j/config + - $HOME/neo4j/data:/var/lib/neo4j/data + - $HOME/neo4j/plugins:/var/lib/neo4j/plugins + environment: + - NEO4J_dbms_security_auth__enabled=false + ports: + - "7474:7474" + - "7687:7687" + restart: always \ No newline at end of file