Skip to content

Commit

Permalink
feat(community): add filters to LibSQLVectorStore (langchain-ai#7209)
Browse files Browse the repository at this point in the history
Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
2 people authored and FilipZmijewski committed Nov 27, 2024
1 parent 9b2de04 commit d6a0663
Show file tree
Hide file tree
Showing 3 changed files with 176 additions and 41 deletions.
61 changes: 61 additions & 0 deletions libs/langchain-community/src/utils/sqlite_where_builder.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import { InStatement, InValue } from "@libsql/client";

export type WhereCondition<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
Metadata extends Record<string, any> = Record<string, any>
> = {
[Key in keyof Metadata]:
| {
operator: "=" | ">" | "<" | ">=" | "<=" | "<>" | "LIKE";
value: InValue;
}
| {
operator: "IN";
value: InValue[];
};
};

type WhereInStatement = Exclude<InStatement, string>;

export class SqliteWhereBuilder {
private conditions: WhereCondition;

constructor(conditions: WhereCondition) {
this.conditions = conditions;
}

buildWhereClause(): WhereInStatement {
const sqlParts: string[] = [];
const args: Record<string, InValue> = {};

for (const [column, condition] of Object.entries(this.conditions)) {
const { operator, value } = condition;

if (operator === "IN") {
const placeholders = value
.map((_, index) => `:${column}${index}`)
.join(", ");
sqlParts.push(
`json_extract(metadata, '$.${column}') IN (${placeholders})`
);

const values = value.reduce(
(previousValue: Record<string, InValue>, currentValue, index) => {
return { ...previousValue, [`${column}${index}`]: currentValue };
},
{}
);

Object.assign(args, values);
} else {
sqlParts.push(
`json_extract(metadata, '$.${column}') ${operator} :${column}`
);
args[column] = value;
}
}

const sql = sqlParts.length ? `${sqlParts.join(" AND ")}` : "";
return { sql, args };
}
}
38 changes: 32 additions & 6 deletions libs/langchain-community/src/vectorstores/libsql.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@ import { Document } from "@langchain/core/documents";
import type { EmbeddingsInterface } from "@langchain/core/embeddings";
import { VectorStore } from "@langchain/core/vectorstores";
import type { Client, InStatement } from "@libsql/client";
import {
SqliteWhereBuilder,
WhereCondition,
} from "../utils/sqlite_where_builder.js";

// eslint-disable-next-line @typescript-eslint/no-explicit-any
type MetadataDefault = Record<string, any>;
Expand All @@ -24,7 +28,7 @@ export interface LibSQLVectorStoreArgs {
export class LibSQLVectorStore<
Metadata extends MetadataDefault = MetadataDefault
> extends VectorStore {
declare FilterType: (doc: Document<Metadata>) => boolean;
declare FilterType: string | InStatement | WhereCondition<Metadata>;

private db;

Expand Down Expand Up @@ -111,9 +115,8 @@ export class LibSQLVectorStore<
*/
async similaritySearchVectorWithScore(
query: number[],
k: number
// filter is currently unused
// filter?: this["FilterType"]
k: number,
filter?: this["FilterType"]
): Promise<[Document<Metadata>, number][]> {
// Potential SQL injection risk if query vector is not properly sanitized.
if (!query.every((num) => typeof num === "number" && !Number.isNaN(num))) {
Expand All @@ -122,12 +125,35 @@ export class LibSQLVectorStore<

const queryVector = `[${query.join(",")}]`;

const sql: InStatement = {
const sql = {
sql: `SELECT ${this.table}.rowid as id, ${this.table}.content, ${this.table}.metadata, vector_distance_cos(${this.table}.${this.column}, vector(:queryVector)) AS distance
FROM vector_top_k('idx_${this.table}_${this.column}', vector(:queryVector), CAST(:k AS INTEGER)) as top_k
JOIN ${this.table} ON top_k.rowid = ${this.table}.rowid`,
args: { queryVector, k },
};
} satisfies InStatement;

// Filter is a raw sql where clause, so append it to the join
if (typeof filter === "string") {
sql.sql += ` AND ${filter}`;
} else if (typeof filter === "object") {
// Filter is an in statement.
if ("sql" in filter) {
sql.sql += ` AND ${filter.sql}`;
sql.args = {
...filter.args,
...sql.args,
};
} else {
const builder = new SqliteWhereBuilder(filter);
const where = builder.buildWhereClause();

sql.sql += ` AND ${where.sql}`;
sql.args = {
...where.args,
...sql.args,
};
}
}

const results = await this.db.execute(sql);

Expand Down
118 changes: 83 additions & 35 deletions libs/langchain-community/src/vectorstores/tests/libsql.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,10 @@ describe("LibSQLVectorStore (local)", () => {
const store = new LibSQLVectorStore(embeddings, config);

const ids = await store.addDocuments([
{
new Document({
pageContent: "hello",
metadata: { a: 1 },
},
}),
]);

expect(ids).toHaveLength(1);
Expand Down Expand Up @@ -117,10 +117,10 @@ describe("LibSQLVectorStore (local)", () => {
const store = new LibSQLVectorStore(embeddings, config);

const ids = await store.addDocuments([
{
new Document({
pageContent: "hello world",
metadata: { a: 1 },
},
}),
]);

expect(ids).toHaveLength(1);
Expand Down Expand Up @@ -154,18 +154,15 @@ describe("LibSQLVectorStore (local)", () => {
const store = new LibSQLVectorStore(embeddings, config);

const ids = await store.addDocuments([
{
new Document({
pageContent: "the quick brown fox",
metadata: { a: 1 },
},
{
}),
new Document({
pageContent: "jumped over the lazy dog",
metadata: { a: 2 },
},
{
}),
new Document({
pageContent: "hello world",
metadata: { a: 3 },
},
}),
]);

expect(ids).toHaveLength(3);
Expand All @@ -186,7 +183,7 @@ describe("LibSQLVectorStore (local)", () => {
).toBe(true);
});

test("a document can be deleted by id", async () => {
test("a similarity search with a filter can be performed", async () => {
await client.batch([
`DROP TABLE IF EXISTS vectors;`,
`CREATE TABLE IF NOT EXISTS vectors (
Expand All @@ -201,18 +198,72 @@ describe("LibSQLVectorStore (local)", () => {
const store = new LibSQLVectorStore(embeddings, config);

const ids = await store.addDocuments([
{
new Document({
pageContent: "the quick brown fox",
metadata: { a: 1 },
metadata: {
label: "1",
},
}),
new Document({
pageContent: "jumped over the lazy dog",
metadata: {
label: "2",
},
}),
new Document({
pageContent: "hello world",
metadata: {
label: "1",
},
}),
]);

expect(ids).toHaveLength(3);
expect(ids.every((id) => typeof id === "string")).toBe(true);

const results = await store.similaritySearch("the quick brown dog", 10, {
label: {
operator: "=",
value: "1",
},
{
});

expect(results).toHaveLength(2);
expect(results.map((result) => result.pageContent)).toEqual([
"the quick brown fox",
"hello world",
]);
expect(
results.map((result) => result.id).every((id) => typeof id === "string")
).toBe(true);
});

test("a document can be deleted by id", async () => {
await client.batch([
`DROP TABLE IF EXISTS vectors;`,
`CREATE TABLE IF NOT EXISTS vectors (
content TEXT,
metadata JSON,
embedding F32_BLOB(1024)
);`,
`CREATE INDEX IF NOT EXISTS idx_vectors_embedding
ON vectors (libsql_vector_idx(embedding));`,
]);

const store = new LibSQLVectorStore(embeddings, config);

const ids = await store.addDocuments([
new Document({
pageContent: "the quick brown fox",
}),
new Document({
pageContent: "jumped over the lazy dog",
metadata: { a: 2 },
},
{
}),
new Document({
pageContent: "hello world",
metadata: { a: 3 },
},
}),
]);

expect(ids).toHaveLength(3);
Expand Down Expand Up @@ -247,18 +298,15 @@ describe("LibSQLVectorStore (local)", () => {
const store = new LibSQLVectorStore(embeddings, config);

const ids = await store.addDocuments([
{
new Document({
pageContent: "the quick brown fox",
metadata: { a: 1 },
},
{
}),
new Document({
pageContent: "jumped over the lazy dog",
metadata: { a: 2 },
},
{
}),
new Document({
pageContent: "hello world",
metadata: { a: 3 },
},
}),
]);

expect(ids).toHaveLength(3);
Expand Down Expand Up @@ -289,18 +337,18 @@ describe("LibSQLVectorStore (local)", () => {
const store = new LibSQLVectorStore(embeddings, config);

const ids = await store.addDocuments([
{
new Document({
pageContent: "the quick brown fox",
metadata: { a: 1 },
},
{
}),
new Document({
pageContent: "jumped over the lazy dog",
metadata: { a: 2 },
},
{
}),
new Document({
pageContent: "hello world",
metadata: { a: 3 },
},
}),
]);

expect(ids).toHaveLength(3);
Expand Down

0 comments on commit d6a0663

Please sign in to comment.