-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(pinecone): Add support for Pinecone
/embed
endpoint (#7203)
Co-authored-by: jacoblee93 <[email protected]>
- Loading branch information
1 parent
ae80cbf
commit 762ed46
Showing
20 changed files
with
691 additions
and
82 deletions.
There are no files selected for viewing
344 changes: 344 additions & 0 deletions
344
docs/core_docs/docs/integrations/text_embedding/pinecone.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
import { PineconeEmbeddings } from "@langchain/pinecone"; | ||
|
||
export const run = async () => { | ||
const model = new PineconeEmbeddings(); | ||
console.log({ model }); // Prints out model metadata | ||
const res = await model.embedQuery( | ||
"What would be a good company name a company that makes colorful socks?" | ||
); | ||
console.log({ res }); | ||
}; | ||
|
||
await run(); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import { Pinecone, PineconeConfiguration } from "@pinecone-database/pinecone"; | ||
import { getEnvironmentVariable } from "@langchain/core/utils/env"; | ||
|
||
export function getPineconeClient(config?: PineconeConfiguration): Pinecone { | ||
if ( | ||
getEnvironmentVariable("PINECONE_API_KEY") === undefined || | ||
getEnvironmentVariable("PINECONE_API_KEY") === "" | ||
) { | ||
throw new Error("PINECONE_API_KEY must be set in environment"); | ||
} | ||
if (!config) { | ||
return new Pinecone(); | ||
} else { | ||
return new Pinecone(config); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,139 @@ | ||
/* eslint-disable arrow-body-style */ | ||
|
||
import { Embeddings, type EmbeddingsParams } from "@langchain/core/embeddings"; | ||
import { | ||
EmbeddingsList, | ||
Pinecone, | ||
PineconeConfiguration, | ||
} from "@pinecone-database/pinecone"; | ||
import { getPineconeClient } from "./client.js"; | ||
|
||
/* PineconeEmbeddingsParams holds the optional fields a user can pass to a Pinecone embedding model. | ||
* @param model - Model to use to generate embeddings. Default is "multilingual-e5-large". | ||
* @param params - Additional parameters to pass to the embedding model. Note: parameters are model-specific. Read | ||
* more about model-specific parameters in the [Pinecone | ||
* documentation](https://docs.pinecone.io/guides/inference/understanding-inference#model-specific-parameters). | ||
* */ | ||
export interface PineconeEmbeddingsParams extends EmbeddingsParams { | ||
model?: string; // Model to use to generate embeddings | ||
params?: Record<string, string>; // Additional parameters to pass to the embedding model | ||
} | ||
|
||
/* PineconeEmbeddings generates embeddings using the Pinecone Inference API. */ | ||
export class PineconeEmbeddings | ||
extends Embeddings | ||
implements PineconeEmbeddingsParams | ||
{ | ||
client: Pinecone; | ||
|
||
model: string; | ||
|
||
params: Record<string, string>; | ||
|
||
constructor( | ||
fields?: Partial<PineconeEmbeddingsParams> & Partial<PineconeConfiguration> | ||
) { | ||
const defaultFields = { maxRetries: 3, ...fields }; | ||
super(defaultFields); | ||
|
||
if (defaultFields.apiKey) { | ||
const config = { | ||
apiKey: defaultFields.apiKey, | ||
controllerHostUrl: defaultFields.controllerHostUrl, | ||
fetchApi: defaultFields.fetchApi, | ||
additionalHeaders: defaultFields.additionalHeaders, | ||
sourceTag: defaultFields.sourceTag, | ||
} as PineconeConfiguration; | ||
this.client = getPineconeClient(config); | ||
} else { | ||
this.client = getPineconeClient(); | ||
} | ||
|
||
if (!defaultFields.model) { | ||
this.model = "multilingual-e5-large"; | ||
} else { | ||
this.model = defaultFields.model; | ||
} | ||
|
||
const defaultParams = { inputType: "passage" }; | ||
|
||
if (defaultFields.params) { | ||
this.params = { ...defaultFields.params, ...defaultParams }; | ||
} else { | ||
this.params = defaultParams; | ||
} | ||
} | ||
|
||
/* Generate embeddings for a list of input strings using a specified embedding model. | ||
* | ||
* @param texts - List of input strings for which to generate embeddings. | ||
* */ | ||
async embedDocuments(texts: string[]): Promise<number[][]> { | ||
if (texts.length === 0) { | ||
throw new Error( | ||
"At least one document is required to generate embeddings" | ||
); | ||
} | ||
|
||
let embeddings; | ||
if (this.params) { | ||
embeddings = await this.caller.call(async () => { | ||
const result: EmbeddingsList = await this.client.inference.embed( | ||
this.model, | ||
texts, | ||
this.params | ||
); | ||
return result; | ||
}); | ||
} else { | ||
embeddings = await this.caller.call(async () => { | ||
const result: EmbeddingsList = await this.client.inference.embed( | ||
this.model, | ||
texts, | ||
{} | ||
); | ||
return result; | ||
}); | ||
} | ||
|
||
const embeddingsList: number[][] = []; | ||
|
||
for (let i = 0; i < embeddings.length; i += 1) { | ||
if (embeddings[i].values) { | ||
embeddingsList.push(embeddings[i].values as number[]); | ||
} | ||
} | ||
return embeddingsList; | ||
} | ||
|
||
/* Generate embeddings for a given query string using a specified embedding model. | ||
* @param text - Query string for which to generate embeddings. | ||
* */ | ||
async embedQuery(text: string): Promise<number[]> { | ||
// Change inputType to query-specific param for multilingual-e5-large embedding model | ||
this.params.inputType = "query"; | ||
|
||
if (!text) { | ||
throw new Error("No query passed for which to generate embeddings"); | ||
} | ||
let embeddings: EmbeddingsList; | ||
if (this.params) { | ||
embeddings = await this.caller.call(async () => { | ||
return await this.client.inference.embed( | ||
this.model, | ||
[text], | ||
this.params | ||
); | ||
}); | ||
} else { | ||
embeddings = await this.caller.call(async () => { | ||
return await this.client.inference.embed(this.model, [text], {}); | ||
}); | ||
} | ||
if (embeddings[0].values) { | ||
return embeddings[0].values as number[]; | ||
} else { | ||
return []; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
export * from "./vectorstores.js"; | ||
export * from "./translator.js"; | ||
export * from "./embeddings.js"; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import { Pinecone } from "@pinecone-database/pinecone"; | ||
import { getPineconeClient } from "../client.js"; | ||
|
||
describe("Tests for getPineconeClient", () => { | ||
test("Happy path for getPineconeClient with and without `config` obj passed", async () => { | ||
const client = getPineconeClient(); | ||
expect(client).toBeInstanceOf(Pinecone); | ||
expect(client).toHaveProperty("config"); // Config is always set to *at least* the user's api key | ||
|
||
const clientWithConfig = getPineconeClient({ | ||
// eslint-disable-next-line no-process-env | ||
apiKey: process.env.PINECONE_API_KEY!, | ||
additionalHeaders: { header: "value" }, | ||
}); | ||
expect(clientWithConfig).toBeInstanceOf(Pinecone); | ||
expect(client).toHaveProperty("config"); // Unfortunately cannot assert on contents of config b/c it's a private | ||
// attribute of the Pinecone class | ||
}); | ||
|
||
test("Unhappy path: expect getPineconeClient to throw error if reset PINECONE_API_KEY to empty string", async () => { | ||
// eslint-disable-next-line no-process-env | ||
const originalApiKey = process.env.PINECONE_API_KEY; | ||
try { | ||
// eslint-disable-next-line no-process-env | ||
process.env.PINECONE_API_KEY = ""; | ||
const errorThrown = async () => { | ||
getPineconeClient(); | ||
}; | ||
await expect(errorThrown).rejects.toThrow(Error); | ||
await expect(errorThrown).rejects.toThrow( | ||
"PINECONE_API_KEY must be set in environment" | ||
); | ||
} finally { | ||
// Restore the original value of PINECONE_API_KEY | ||
// eslint-disable-next-line no-process-env | ||
process.env.PINECONE_API_KEY = originalApiKey; | ||
} | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import { getPineconeClient } from "../client.js"; | ||
|
||
describe("Tests for getPineconeClient", () => { | ||
test("Confirm getPineconeClient throws error when PINECONE_API_KEY is not set", async () => { | ||
/* eslint-disable-next-line no-process-env */ | ||
process.env.PINECONE_API_KEY = ""; | ||
const errorThrown = async () => { | ||
getPineconeClient(); | ||
}; | ||
await expect(errorThrown).rejects.toThrow(Error); | ||
await expect(errorThrown).rejects.toThrow( | ||
"PINECONE_API_KEY must be set in environment" | ||
); | ||
}); | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import { PineconeEmbeddings } from "../embeddings.js"; | ||
|
||
describe("Integration tests for Pinecone embeddings", () => { | ||
test("Happy path: defaults for both embedDocuments and embedQuery", async () => { | ||
const model = new PineconeEmbeddings(); | ||
expect(model.model).toBe("multilingual-e5-large"); | ||
expect(model.params).toEqual({ inputType: "passage" }); | ||
|
||
const docs = ["hello", "world"]; | ||
const embeddings = await model.embedDocuments(docs); | ||
expect(embeddings.length).toBe(docs.length); | ||
|
||
const query = "hello"; | ||
const queryEmbedding = await model.embedQuery(query); | ||
expect(queryEmbedding.length).toBeGreaterThan(0); | ||
}); | ||
|
||
test("Happy path: custom `params` obj passed to embedDocuments and embedQuery", async () => { | ||
const model = new PineconeEmbeddings({ | ||
params: { customParam: "value" }, | ||
}); | ||
expect(model.model).toBe("multilingual-e5-large"); | ||
expect(model.params).toEqual({ | ||
inputType: "passage", | ||
customParam: "value", | ||
}); | ||
|
||
const docs = ["hello", "world"]; | ||
const embeddings = await model.embedDocuments(docs); | ||
expect(embeddings.length).toBe(docs.length); | ||
expect(embeddings[0].length).toBe(1024); // Assert correct dims on random doc | ||
expect(model.model).toBe("multilingual-e5-large"); | ||
expect(model.params).toEqual({ | ||
inputType: "passage", // Maintain default inputType for docs | ||
customParam: "value", | ||
}); | ||
|
||
const query = "hello"; | ||
const queryEmbedding = await model.embedQuery(query); | ||
expect(model.model).toBe("multilingual-e5-large"); | ||
expect(queryEmbedding.length).toBe(1024); | ||
expect(model.params).toEqual({ | ||
inputType: "query", // Change inputType for query | ||
customParam: "value", | ||
}); | ||
}); | ||
|
||
test("Unhappy path: embedDocuments and embedQuery throw when empty objs are passed", async () => { | ||
const model = new PineconeEmbeddings(); | ||
await expect(model.embedDocuments([])).rejects.toThrow(); | ||
await expect(model.embedQuery("")).rejects.toThrow(); | ||
}); | ||
|
||
test("Unhappy path: PineconeEmbeddings throws when invalid model is passed", async () => { | ||
const model = new PineconeEmbeddings({ model: "invalid-model" }); | ||
await expect(model.embedDocuments([])).rejects.toThrow(); | ||
await expect(model.embedQuery("")).rejects.toThrow(); | ||
}); | ||
}); |
Oops, something went wrong.