diff --git a/libs/langchain-community/src/indexes/postgres.ts b/libs/langchain-community/src/indexes/postgres.ts index 75382415a844..2777aefabaee 100644 --- a/libs/langchain-community/src/indexes/postgres.ts +++ b/libs/langchain-community/src/indexes/postgres.ts @@ -7,6 +7,7 @@ import { export type PostgresRecordManagerOptions = { postgresConnectionOptions: PoolConfig; + pool?: Pool; tableName?: string; schema?: string; }; @@ -23,9 +24,9 @@ export class PostgresRecordManager implements RecordManagerInterface { finalTableName: string; constructor(namespace: string, config: PostgresRecordManagerOptions) { - const { postgresConnectionOptions, tableName } = config; + const { postgresConnectionOptions, tableName, pool } = config; this.namespace = namespace; - this.pool = new pg.Pool(postgresConnectionOptions); + this.pool = pool || new pg.Pool(postgresConnectionOptions); this.tableName = tableName || "upsertion_records"; this.finalTableName = config.schema ? `"${config.schema}"."${tableName}"` diff --git a/libs/langchain-community/src/indexes/tests/postgres.int.test.ts b/libs/langchain-community/src/indexes/tests/postgres.int.test.ts index b3d06be77702..79564ced15bd 100644 --- a/libs/langchain-community/src/indexes/tests/postgres.int.test.ts +++ b/libs/langchain-community/src/indexes/tests/postgres.int.test.ts @@ -1,5 +1,5 @@ import { describe, expect, test, jest } from "@jest/globals"; -import { PoolConfig } from "pg"; +import pg, { PoolConfig } from "pg"; import { PostgresRecordManager, PostgresRecordManagerOptions, @@ -36,6 +36,16 @@ describe.skip("PostgresRecordManager", () => { await recordManager.end(); }); + test("Test provided postgres pool instance", async () => { + const pool = new pg.Pool(config.postgresConnectionOptions); + const providedPoolRecordManager = new PostgresRecordManager("test", { + ...config, + pool, + }); + + expect(providedPoolRecordManager.pool).toBe(pool); + }); + test("Test explicit schema definition", async () => { // configure explicit schema with record manager config.schema = "newSchema";