Skip to content

Commit

Permalink
feat(community): Neo4j chat message history (#7331)
Browse files Browse the repository at this point in the history
Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
BernardFaucher and jacoblee93 authored Dec 14, 2024
1 parent a84856e commit b2afdf1
Show file tree
Hide file tree
Showing 7 changed files with 332 additions and 1 deletion.
4 changes: 4 additions & 0 deletions libs/langchain-community/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,10 @@ stores/message/mongodb.cjs
stores/message/mongodb.js
stores/message/mongodb.d.ts
stores/message/mongodb.d.cts
stores/message/neo4j.cjs
stores/message/neo4j.js
stores/message/neo4j.d.ts
stores/message/neo4j.d.cts
stores/message/planetscale.cjs
stores/message/planetscale.js
stores/message/planetscale.d.ts
Expand Down
2 changes: 2 additions & 0 deletions libs/langchain-community/langchain.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ export const config = {
"stores/message/ioredis": "stores/message/ioredis",
"stores/message/momento": "stores/message/momento",
"stores/message/mongodb": "stores/message/mongodb",
"stores/message/neo4j": "stores/message/neo4j",
"stores/message/planetscale": "stores/message/planetscale",
"stores/message/postgres": "stores/message/postgres",
"stores/message/redis": "stores/message/redis",
Expand Down Expand Up @@ -473,6 +474,7 @@ export const config = {
"stores/message/ipfs_datastore",
"stores/message/momento",
"stores/message/mongodb",
"stores/message/neo4j",
"stores/message/planetscale",
"stores/message/postgres",
"stores/message/redis",
Expand Down
13 changes: 13 additions & 0 deletions libs/langchain-community/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -2521,6 +2521,15 @@
"import": "./stores/message/mongodb.js",
"require": "./stores/message/mongodb.cjs"
},
"./stores/message/neo4j": {
"types": {
"import": "./stores/message/neo4j.d.ts",
"require": "./stores/message/neo4j.d.cts",
"default": "./stores/message/neo4j.d.ts"
},
"import": "./stores/message/neo4j.js",
"require": "./stores/message/neo4j.cjs"
},
"./stores/message/planetscale": {
"types": {
"import": "./stores/message/planetscale.d.ts",
Expand Down Expand Up @@ -3932,6 +3941,10 @@
"stores/message/mongodb.js",
"stores/message/mongodb.d.ts",
"stores/message/mongodb.d.cts",
"stores/message/neo4j.cjs",
"stores/message/neo4j.js",
"stores/message/neo4j.d.ts",
"stores/message/neo4j.d.cts",
"stores/message/planetscale.cjs",
"stores/message/planetscale.js",
"stores/message/planetscale.d.ts",
Expand Down
1 change: 1 addition & 0 deletions libs/langchain-community/src/load/import_constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ export const optionalImportEntrypoints: string[] = [
"langchain_community/stores/message/ioredis",
"langchain_community/stores/message/momento",
"langchain_community/stores/message/mongodb",
"langchain_community/stores/message/neo4j",
"langchain_community/stores/message/planetscale",
"langchain_community/stores/message/postgres",
"langchain_community/stores/message/redis",
Expand Down
160 changes: 160 additions & 0 deletions libs/langchain-community/src/stores/message/neo4j.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import neo4j, { Driver, Record, auth } from "neo4j-driver";
import { v4 as uuidv4 } from "uuid";
import { BaseListChatMessageHistory } from "@langchain/core/chat_history";
import {
BaseMessage,
mapStoredMessagesToChatMessages,
} from "@langchain/core/messages";

export type Neo4jChatMessageHistoryConfigInput = {
sessionId?: string | number;
sessionNodeLabel?: string;
messageNodeLabel?: string;
url: string;
username: string;
password: string;
windowSize?: number;
};

const defaultConfig = {
sessionNodeLabel: "ChatSession",
messageNodeLabel: "ChatMessage",
windowSize: 3,
};

export class Neo4jChatMessageHistory extends BaseListChatMessageHistory {
lc_namespace: string[] = ["langchain", "stores", "message", "neo4j"];

sessionId: string | number;

sessionNodeLabel: string;

messageNodeLabel: string;

windowSize: number;

private driver: Driver;

constructor({
sessionId = uuidv4(),
sessionNodeLabel = defaultConfig.sessionNodeLabel,
messageNodeLabel = defaultConfig.messageNodeLabel,
url,
username,
password,
windowSize = defaultConfig.windowSize,
}: Neo4jChatMessageHistoryConfigInput) {
super();

this.sessionId = sessionId;
this.sessionNodeLabel = sessionNodeLabel;
this.messageNodeLabel = messageNodeLabel;
this.windowSize = windowSize;

if (url && username && password) {
try {
this.driver = neo4j.driver(url, auth.basic(username, password));
} catch (e: any) {
throw new Error(
`Could not create a Neo4j driver instance. Please check the connection details.\nCause: ${e.message}`
);
}
} else {
throw new Error("Neo4j connection details not provided.");
}
}

static async initialize(
props: Neo4jChatMessageHistoryConfigInput
): Promise<Neo4jChatMessageHistory> {
const instance = new Neo4jChatMessageHistory(props);

try {
await instance.verifyConnectivity();
} catch (e: any) {
throw new Error(
`Could not verify connection to the Neo4j database.\nCause: ${e.message}`
);
}

return instance;
}

async verifyConnectivity() {
const connectivity = await this.driver.getServerInfo();
return connectivity;
}

async getMessages(): Promise<BaseMessage[]> {
const getMessagesCypherQuery = `
MERGE (chatSession:${this.sessionNodeLabel} {id: $sessionId})
WITH chatSession
MATCH (chatSession)-[:LAST_MESSAGE]->(lastMessage)
MATCH p=(lastMessage)<-[:NEXT*0..${this.windowSize * 2 - 1}]-()
WITH p, length(p) AS length
ORDER BY length DESC LIMIT 1
UNWIND reverse(nodes(p)) AS node
RETURN {data:{content: node.content}, type:node.type} AS result
`;

try {
const { records } = await this.driver.executeQuery(
getMessagesCypherQuery,
{
sessionId: this.sessionId,
}
);
const results = records.map((record: Record) => record.get("result"));

return mapStoredMessagesToChatMessages(results);
} catch (e: any) {
throw new Error(`Ohno! Couldn't get messages.\nCause: ${e.message}`);
}
}

async addMessage(message: BaseMessage): Promise<void> {
const addMessageCypherQuery = `
MERGE (chatSession:${this.sessionNodeLabel} {id: $sessionId})
WITH chatSession
OPTIONAL MATCH (chatSession)-[lastMessageRel:LAST_MESSAGE]->(lastMessage)
CREATE (chatSession)-[:LAST_MESSAGE]->(newLastMessage:${this.messageNodeLabel})
SET newLastMessage += {type:$type, content:$content}
WITH newLastMessage, lastMessageRel, lastMessage
WHERE lastMessage IS NOT NULL
CREATE (lastMessage)-[:NEXT]->(newLastMessage)
DELETE lastMessageRel
`;

try {
await this.driver.executeQuery(addMessageCypherQuery, {
sessionId: this.sessionId,
type: message.getType(),
content: message.content,
});
} catch (e: any) {
throw new Error(`Ohno! Couldn't add message.\nCause: ${e.message}`);
}
}

async clear() {
const clearMessagesCypherQuery = `
MATCH p=(chatSession:${this.sessionNodeLabel} {id: $sessionId})-[:LAST_MESSAGE]->(lastMessage)<-[:NEXT*0..]-()
UNWIND nodes(p) as node
DETACH DELETE node
`;

try {
await this.driver.executeQuery(clearMessagesCypherQuery, {
sessionId: this.sessionId,
});
} catch (e: any) {
throw new Error(
`Ohno! Couldn't clear chat history.\nCause: ${e.message}`
);
}
}

async close() {
await this.driver.close();
}
}
138 changes: 138 additions & 0 deletions libs/langchain-community/src/stores/tests/neo4j.int.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
import { describe, it, expect, beforeEach, afterEach } from "@jest/globals";
import { HumanMessage, AIMessage } from "@langchain/core/messages";
import neo4j from "neo4j-driver";
import { Neo4jChatMessageHistory } from "../message/neo4j.js";

const goodConfig = {
url: "bolt://host.docker.internal:7687",
username: "neo4j",
password: "langchain",
};

describe("The Neo4jChatMessageHistory class", () => {
describe("Test suite", () => {
it("Runs at all", () => {
expect(true).toEqual(true);
});
});

describe("Class instantiation", () => {
it("Requires a url, username and password, throwing an error if not provided", async () => {
const badConfig = {};
await expect(
// @ts-expect-error Bad config
Neo4jChatMessageHistory.initialize(badConfig)
).rejects.toThrow(neo4j.Neo4jError);
});

it("Creates a class instance from - at minimum - a url, username and password", async () => {
const instance = await Neo4jChatMessageHistory.initialize(goodConfig);
expect(instance).toBeInstanceOf(Neo4jChatMessageHistory);
await instance.close();
});

it("Class instances have expected, configurable fields, and sensible defaults", async () => {
const instance = await Neo4jChatMessageHistory.initialize(goodConfig);

expect(instance.sessionId).toBeDefined();
expect(instance.sessionNodeLabel).toEqual("ChatSession");
expect(instance.windowSize).toEqual(3);
expect(instance.messageNodeLabel).toEqual("ChatMessage");

const secondInstance = await Neo4jChatMessageHistory.initialize({
...goodConfig,
sessionId: "Shibboleet",
sessionNodeLabel: "Conversation",
messageNodeLabel: "Communication",
windowSize: 4,
});

expect(secondInstance.sessionId).toBeDefined();
expect(secondInstance.sessionId).toEqual("Shibboleet");
expect(instance.sessionId).not.toEqual(secondInstance.sessionId);
expect(secondInstance.sessionNodeLabel).toEqual("Conversation");
expect(secondInstance.messageNodeLabel).toEqual("Communication");
expect(secondInstance.windowSize).toEqual(4);

await instance.close();
await secondInstance.close();
});
});

describe("Core functionality", () => {
let instance: undefined | Neo4jChatMessageHistory;

beforeEach(async () => {
instance = await Neo4jChatMessageHistory.initialize(goodConfig);
});

afterEach(async () => {
await instance?.clear();
await instance?.close();
});

it("Connects verifiably to the underlying Neo4j database", async () => {
const connected = await instance?.verifyConnectivity();
expect(connected).toBeDefined();
});

it("getMessages()", async () => {
let results = await instance?.getMessages();
expect(results).toEqual([]);
const messages = [
new HumanMessage(
"My first name is a random set of numbers and letters"
),
new AIMessage("And other alphanumerics that changes hourly forever"),
new HumanMessage(
"My last name, a thousand vowels fading down a sinkhole to a susurrus"
),
new AIMessage("It couldn't just be John Doe or Bingo"),
new HumanMessage(
"My address, a made-up language written out in living glyphs"
),
new AIMessage("Lifted from demonic literature and religious text"),
new HumanMessage("Telephone: uncovered by purveyors of the ouija"),
new AIMessage("When checked against the CBGB women's room graffiti"),
new HumanMessage("My social: a sudoku"),
new AIMessage("My age is obscure"),
];
await instance?.addMessages(messages);
results = (await instance?.getMessages()) || [];
const windowSize = instance?.windowSize || 0;
expect(results.length).toEqual(windowSize * 2);
expect(results).toEqual(messages.slice(windowSize * -2));
});

it("addMessage()", async () => {
const messages = [
new HumanMessage("99 Bottles of beer on the wall, 99 bottles of beer!"),
new AIMessage(
"Take one down, pass it around, 98 bottles of beer on the wall."
),
new HumanMessage("How many bottles of beer are currently on the wall?"),
new AIMessage("There are currently 98 bottles of beer on the wall."),
];
for (const message of messages) {
await instance?.addMessage(message);
}
const results = await instance?.getMessages();
expect(results).toEqual(messages);
});

it("clear()", async () => {
const messages = [
new AIMessage("I'm not your enemy."),
new HumanMessage("That sounds like something that my enemy would say."),
new AIMessage("You're being difficult."),
new HumanMessage("I'm being guarded."),
];
await instance?.addMessages(messages);
let results = await instance?.getMessages();
expect(results).toEqual(messages);
await instance?.clear();
results = await instance?.getMessages();
expect(results).toEqual([]);
});
});
});
15 changes: 14 additions & 1 deletion test-int-deps-docker-compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,17 @@ services:
qdrant:
image: qdrant/qdrant:v1.9.1
ports:
- 6333:6333
- 6333:6333
neo4j:
image: neo4j:latest
volumes:
- $HOME/neo4j/logs:/var/lib/neo4j/logs
- $HOME/neo4j/config:/var/lib/neo4j/config
- $HOME/neo4j/data:/var/lib/neo4j/data
- $HOME/neo4j/plugins:/var/lib/neo4j/plugins
environment:
- NEO4J_dbms_security_auth__enabled=false
ports:
- "7474:7474"
- "7687:7687"
restart: always

0 comments on commit b2afdf1

Please sign in to comment.