Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(community): Add rerank solution to existing IBM community implementation #7200

Merged
merged 8 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
405 changes: 405 additions & 0 deletions docs/core_docs/docs/integrations/retrievers/ibm.ipynb

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions libs/langchain-community/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,10 @@ retrievers/dria.cjs
retrievers/dria.js
retrievers/dria.d.ts
retrievers/dria.d.cts
retrievers/ibm.cjs
retrievers/ibm.js
retrievers/ibm.d.ts
retrievers/ibm.d.cts
retrievers/metal.cjs
retrievers/metal.js
retrievers/metal.d.ts
Expand Down
2 changes: 2 additions & 0 deletions libs/langchain-community/langchain.config.js
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ export const config = {
"retrievers/chaindesk": "retrievers/chaindesk",
"retrievers/databerry": "retrievers/databerry",
"retrievers/dria": "retrievers/dria",
"retrievers/ibm": "retrievers/ibm",
"retrievers/metal": "retrievers/metal",
"retrievers/remote": "retrievers/remote/index",
"retrievers/supabase": "retrievers/supabase",
Expand Down Expand Up @@ -428,6 +429,7 @@ export const config = {
"retrievers/amazon_kendra",
"retrievers/amazon_knowledge_base",
"retrievers/dria",
"retrievers/ibm",
"retrievers/metal",
"retrievers/supabase",
"retrievers/vectara_summary",
Expand Down
13 changes: 13 additions & 0 deletions libs/langchain-community/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -2121,6 +2121,15 @@
"import": "./retrievers/dria.js",
"require": "./retrievers/dria.cjs"
},
"./retrievers/ibm": {
"types": {
"import": "./retrievers/ibm.d.ts",
"require": "./retrievers/ibm.d.cts",
"default": "./retrievers/ibm.d.ts"
},
"import": "./retrievers/ibm.js",
"require": "./retrievers/ibm.cjs"
},
"./retrievers/metal": {
"types": {
"import": "./retrievers/metal.d.ts",
Expand Down Expand Up @@ -3698,6 +3707,10 @@
"retrievers/dria.js",
"retrievers/dria.d.ts",
"retrievers/dria.d.cts",
"retrievers/ibm.cjs",
"retrievers/ibm.js",
"retrievers/ibm.d.ts",
"retrievers/ibm.d.cts",
"retrievers/metal.cjs",
"retrievers/metal.js",
"retrievers/metal.d.ts",
Expand Down
4 changes: 3 additions & 1 deletion libs/langchain-community/src/chat_models/ibm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ import {
TextChatResultChoice,
TextChatResultMessage,
TextChatToolCall,
TextChatToolChoiceTool,
TextChatUsage,
} from "@ibm-cloud/watsonx-ai/dist/watsonx-ai-ml/vml_v1.js";
import { WatsonXAI } from "@ibm-cloud/watsonx-ai";
Expand Down Expand Up @@ -86,6 +87,7 @@ export interface WatsonxCallOptionsChat
extends Omit<BaseLanguageModelCallOptions, "stop">,
WatsonxCallParams {
promptIndex?: number;
tool_choice?: TextChatToolChoiceTool;
}

type ChatWatsonxToolType = BindToolsInput | TextChatParameterTools;
Expand Down Expand Up @@ -470,7 +472,7 @@ export class ChatWatsonx<
tools: options.tools
? _convertToolToWatsonxTool(options.tools)
: undefined,
toolChoice: options.toolChoice,
toolChoice: options.tool_choice,
responseFormat: options.responseFormat,
toolChoiceOption: options.toolChoiceOption,
};
Expand Down
2 changes: 1 addition & 1 deletion libs/langchain-community/src/llms/tests/ibm.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ export const testProperties = (
checkProperty<typeof notExTestProps>(notExTestProps, instance, false);
};

describe("LLM unit tests", () => {
describe("Rerank unit tests", () => {
FilipZmijewski marked this conversation as resolved.
Show resolved Hide resolved
describe("Positive tests", () => {
test("Test authentication function", () => {
const instance = authenticateAndSetInstance({
Expand Down
1 change: 1 addition & 0 deletions libs/langchain-community/src/load/import_constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ export const optionalImportEntrypoints: string[] = [
"langchain_community/retrievers/amazon_kendra",
"langchain_community/retrievers/amazon_knowledge_base",
"langchain_community/retrievers/dria",
"langchain_community/retrievers/ibm",
"langchain_community/retrievers/metal",
"langchain_community/retrievers/supabase",
"langchain_community/retrievers/vectara_summary",
Expand Down
168 changes: 168 additions & 0 deletions libs/langchain-community/src/retrievers/ibm.ts
FilipZmijewski marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import { DocumentInterface } from "@langchain/core/documents";
import { BaseDocumentCompressor } from "@langchain/core/retrievers/document_compressors";
import { WatsonXAI } from "@ibm-cloud/watsonx-ai";
import { AsyncCaller } from "@langchain/core/utils/async_caller";
import { WatsonxAuth, WatsonxParams } from "../types/ibm.js";
import { authenticateAndSetInstance } from "../utils/ibm.js";

export interface WatsonxInputRerank extends Omit<WatsonxParams, "idOrName"> {
truncateInputTokens?: number;
returnOptions?: {
topN?: number;
inputs?: boolean;
};
}
export class WatsonxRerank
extends BaseDocumentCompressor
implements WatsonxInputRerank
{
maxRetries = 0;

version = "2024-05-31";

truncateInputTokens?: number | undefined;

returnOptions?:
| { topN?: number; inputs?: boolean; query?: boolean }
| undefined;

model: string;

spaceId?: string | undefined;

projectId?: string | undefined;

maxConcurrency?: number | undefined;

serviceUrl: string;

service: WatsonXAI;

constructor(fields: WatsonxInputRerank & WatsonxAuth) {
super();
FilipZmijewski marked this conversation as resolved.
Show resolved Hide resolved
if (fields.projectId && fields.spaceId)
throw new Error("Maximum 1 id type can be specified per instance");

if (!fields.projectId && !fields.spaceId)
throw new Error(
"No id specified! At least id of 1 type has to be specified"
);
this.model = fields.model;
this.serviceUrl = fields.serviceUrl;
this.version = fields.version;
this.projectId = fields?.projectId;
this.spaceId = fields?.spaceId;
this.maxRetries = fields.maxRetries ?? this.maxRetries;
this.maxConcurrency = fields.maxConcurrency;
this.truncateInputTokens = fields.truncateInputTokens;
this.returnOptions = fields.returnOptions;

const {
watsonxAIApikey,
watsonxAIAuthType,
watsonxAIBearerToken,
watsonxAIUsername,
watsonxAIPassword,
watsonxAIUrl,
version,
serviceUrl,
} = fields;

const auth = authenticateAndSetInstance({
watsonxAIApikey,
watsonxAIAuthType,
watsonxAIBearerToken,
watsonxAIUsername,
watsonxAIPassword,
watsonxAIUrl,
version,
serviceUrl,
});
if (auth) this.service = auth;
else throw new Error("You have not provided one type of authentication");
}

scopeId() {
if (this.projectId)
return { projectId: this.projectId, modelId: this.model };
else return { spaceId: this.spaceId, modelId: this.model };
}

invocationParams(options?: Partial<WatsonxInputRerank>) {
return {
truncate_input_tokens:
options?.truncateInputTokens ?? this.truncateInputTokens,
return_options: {
top_n: options?.returnOptions?.topN ?? this.returnOptions?.topN,
inputs: options?.returnOptions?.inputs ?? this.returnOptions?.inputs,
},
};
}

async compressDocuments(
documents: DocumentInterface[],
query: string
): Promise<DocumentInterface[]> {
const caller = new AsyncCaller({
maxConcurrency: this.maxConcurrency,
maxRetries: this.maxRetries,
});
const inputs = documents.map((document) => ({
text: document.pageContent,
}));
const { result } = await caller.call(() =>
this.service.textRerank({
...this.scopeId(),
inputs,
query,
})
);
const resultDocuments = result.results.map(({ index, score }) => {
const rankedDocument = documents[index];
rankedDocument.metadata.relevanceScore = score;
return rankedDocument;
});
return resultDocuments;
}

async rerank(
documents: Array<
DocumentInterface | string | Record<"pageContent", string>
>,
query: string,
options?: Partial<WatsonxInputRerank>
): Promise<Array<{ index: number; relevanceScore: number; input?: string }>> {
const inputs = documents.map((document) => {
if (typeof document === "string") {
return { text: document };
}
return { text: document.pageContent };
});

const caller = new AsyncCaller({
maxConcurrency: this.maxConcurrency,
maxRetries: this.maxRetries,
});
const { result } = await caller.call(() =>
this.service.textRerank({
...this.scopeId(),
inputs,
query,
parameters: this.invocationParams(options),
})
);
const response = result.results.map((document) => {
return document?.input
? {
index: document.index,
relevanceScore: document.score,
input: document?.input,
}
: {
index: document.index,
relevanceScore: document.score,
};
});
return response;
}
}
80 changes: 80 additions & 0 deletions libs/langchain-community/src/retrievers/tests/ibm.int.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/* eslint-disable no-process-env */
import { Document } from "@langchain/core/documents";
import { WatsonxRerank } from "../ibm.js";

const query = "What is the capital of the United States?";
const docs = [
new Document({
pageContent:
"Carson City is the capital city of the American state of Nevada. At the 2010 United States Census, Carson City had a population of 55,274.",
}),
new Document({
pageContent:
"The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that are a political division controlled by the United States. Its capital is Saipan.",
}),
new Document({
pageContent:
"Charlotte Amalie is the capital and largest city of the United States Virgin Islands. It has about 20,000 people. The city is on the island of Saint Thomas.",
}),
new Document({
pageContent:
"Washington, D.C. (also known as simply Washington or D.C., and officially as the District of Columbia) is the capital of the United States. It is a federal district. The President of the USA and many major national government offices are in the territory. This makes it the political center of the United States of America.",
}),
new Document({
pageContent:
"Capital punishment (the death penalty) has existed in the United States since before the United States was a country. As of 2017, capital punishment is legal in 30 of the 50 states. The federal government (including the United States military) also uses capital punishment.",
}),
];
describe("Integration tests on WatsonxRerank", () => {
describe(".compressDocuments() method", () => {
test("Basic call", async () => {
const instance = new WatsonxRerank({
model: "cross-encoder/ms-marco-minilm-l-12-v2",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
version: "2024-05-31",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
});
const result = await instance.compressDocuments(docs, query);
expect(result.length).toBe(docs.length);
result.forEach((item) =>
expect(typeof item.metadata.relevanceScore).toBe("number")
);
});
});

describe(".rerank() method", () => {
test("Basic call", async () => {
const instance = new WatsonxRerank({
model: "cross-encoder/ms-marco-minilm-l-12-v2",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
version: "2024-05-31",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
});
const result = await instance.rerank(docs, query);
expect(result.length).toBe(docs.length);
result.forEach((item) => {
expect(typeof item.relevanceScore).toBe("number");
expect(item.input).toBeUndefined();
});
});
});
test("Basic call with options", async () => {
const instance = new WatsonxRerank({
model: "cross-encoder/ms-marco-minilm-l-12-v2",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
version: "2024-05-31",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
});
const result = await instance.rerank(docs, query, {
returnOptions: {
topN: 3,
inputs: true,
},
});
expect(result.length).toBe(3);
result.forEach((item) => {
expect(typeof item.relevanceScore).toBe("number");
expect(item.input).toBeDefined();
});
});
});
Loading