Skip to content

Commit

Permalink
fix(community): Fix bugs in IBM community, provide some tests and fix…
Browse files Browse the repository at this point in the history
… existing tests (#7282)

Co-authored-by: jacoblee93 <[email protected]>
  • Loading branch information
FilipZmijewski and jacoblee93 authored Dec 3, 2024
1 parent e8822ad commit 782fe71
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 120 deletions.
72 changes: 59 additions & 13 deletions libs/langchain-community/src/chat_models/ibm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ import {
BaseLanguageModelInput,
FunctionDefinition,
StructuredOutputMethodOptions,
type BaseLanguageModelCallOptions,
} from "@langchain/core/language_models/base";
import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import {
BaseChatModel,
BaseChatModelCallOptions,
BindToolsInput,
LangSmithParams,
type BaseChatModelParams,
Expand All @@ -41,7 +41,6 @@ import {
TextChatResultChoice,
TextChatResultMessage,
TextChatToolCall,
TextChatToolChoiceTool,
TextChatUsage,
} from "@ibm-cloud/watsonx-ai/dist/watsonx-ai-ml/vml_v1.js";
import { WatsonXAI } from "@ibm-cloud/watsonx-ai";
Expand Down Expand Up @@ -80,14 +79,14 @@ export interface WatsonxDeltaStream {
}

export interface WatsonxCallParams
extends Partial<Omit<TextChatParams, "modelId">> {
extends Partial<Omit<TextChatParams, "modelId" | "toolChoice">> {
maxRetries?: number;
}
export interface WatsonxCallOptionsChat
extends Omit<BaseLanguageModelCallOptions, "stop">,
extends Omit<BaseChatModelCallOptions, "stop">,
WatsonxCallParams {
promptIndex?: number;
tool_choice?: TextChatToolChoiceTool;
tool_choice?: TextChatParameterTools | string | "auto" | "any";
}

type ChatWatsonxToolType = BindToolsInput | TextChatParameterTools;
Expand Down Expand Up @@ -309,6 +308,29 @@ function _convertDeltaToMessageChunk(
return null;
}

function _convertToolChoiceToWatsonxToolChoice(
toolChoice: TextChatParameterTools | string | "auto" | "any"
) {
if (typeof toolChoice === "string") {
if (toolChoice === "any" || toolChoice === "required") {
return { toolChoiceOption: "required" };
} else if (toolChoice === "auto" || toolChoice === "none") {
return { toolChoiceOption: toolChoice };
} else {
return {
toolChoice: {
type: "function",
function: { name: toolChoice },
},
};
}
} else if ("type" in toolChoice) return { toolChoice };
else
throw new Error(
`Unrecognized tool_choice type. Expected string or TextChatParameterTools. Recieved ${toolChoice}`
);
}

export class ChatWatsonx<
CallOptions extends WatsonxCallOptionsChat = WatsonxCallOptionsChat
>
Expand Down Expand Up @@ -459,7 +481,7 @@ export class ChatWatsonx<
}

invocationParams(options: this["ParsedCallOptions"]) {
return {
const params = {
maxTokens: options.maxTokens ?? this.maxTokens,
temperature: options?.temperature ?? this.temperature,
timeLimit: options?.timeLimit ?? this.timeLimit,
Expand All @@ -472,10 +494,12 @@ export class ChatWatsonx<
tools: options.tools
? _convertToolToWatsonxTool(options.tools)
: undefined,
toolChoice: options.tool_choice,
responseFormat: options.responseFormat,
toolChoiceOption: options.toolChoiceOption,
};
const toolChoiceResult = options.tool_choice
? _convertToolChoiceToWatsonxToolChoice(options.tool_choice)
: {};
return { ...params, ...toolChoiceResult };
}

override bindTools(
Expand Down Expand Up @@ -562,7 +586,7 @@ export class ChatWatsonx<
.map(([_, value]) => value);
return { generations, llmOutput: { tokenUsage } };
} else {
const params: Omit<TextChatParams, "messages"> = {
const params = {
...this.invocationParams(options),
...this.scopeId(),
};
Expand All @@ -576,7 +600,6 @@ export class ChatWatsonx<
messages: watsonxMessages,
});
const { result } = await this.completionWithRetry(callback, options);

const generations: ChatGeneration[] = [];
for (const part of result.choices) {
const generation: ChatGeneration = {
Expand Down Expand Up @@ -623,10 +646,13 @@ export class ChatWatsonx<
});
const stream = await this.completionWithRetry(callback, options);
let defaultRole;
let usage: TextChatUsage | undefined;
let currentCompletion = 0;
for await (const chunk of stream) {
if (options.signal?.aborted) {
throw new Error("AbortError");
}
if (chunk?.data?.usage) usage = chunk.data.usage;
const { data } = chunk;
const choice = data.choices[0] as TextChatResultChoice &
Record<"delta", TextChatResultMessage>;
Expand All @@ -638,7 +664,7 @@ export class ChatWatsonx<
if (!delta) {
continue;
}

currentCompletion = choice.index ?? 0;
const newTokenIndices = {
prompt: options.promptIndex ?? 0,
completion: choice.index ?? 0,
Expand Down Expand Up @@ -682,6 +708,26 @@ export class ChatWatsonx<
{ chunk: generationChunk }
);
}

const generationChunk = new ChatGenerationChunk({
message: new AIMessageChunk({
content: "",
response_metadata: {
usage,
},
usage_metadata: {
input_tokens: usage?.prompt_tokens ?? 0,
output_tokens: usage?.completion_tokens ?? 0,
total_tokens: usage?.total_tokens ?? 0,
},
}),
text: "",
generationInfo: {
prompt: options.promptIndex ?? 0,
completion: currentCompletion ?? 0,
},
});
yield generationChunk;
}

/** @ignore */
Expand Down Expand Up @@ -760,7 +806,7 @@ export class ChatWatsonx<
},
],
// Ideally that would be set to required but this is not supported yet
toolChoice: {
tool_choice: {
type: "function",
function: {
name: functionName,
Expand Down Expand Up @@ -796,7 +842,7 @@ export class ChatWatsonx<
},
],
// Ideally that would be set to required but this is not supported yet
toolChoice: {
tool_choice: {
type: "function",
function: {
name: functionName,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ describe("Tests for chat", () => {
controller.abort();
return res;
}).rejects.toThrow();
}, 5000);
});
});

describe("Test ChatWatsonx invoke and generate with stream mode", () => {
Expand Down Expand Up @@ -357,7 +357,7 @@ describe("Tests for chat", () => {
controller.abort();
return res;
}).rejects.toThrow();
}, 5000);
});
});

describe("Test ChatWatsonx stream", () => {
Expand Down Expand Up @@ -415,7 +415,7 @@ describe("Tests for chat", () => {
}
expect(hasEntered).toBe(true);
}).rejects.toThrow();
}, 5000);
});
test("Token count and response equality", async () => {
let generation = "";
const service = new ChatWatsonx({
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class ChatWatsonxStandardIntegrationTests extends ChatModelIntegrationTests<
chatModelHasToolCalling: true,
chatModelHasStructuredOutput: true,
constructorArgs: {
model: "mistralai/mistral-large",
model: "meta-llama/llama-3-1-70b-instruct",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
Expand Down
3 changes: 3 additions & 0 deletions libs/langchain-community/src/document_compressors/ibm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,9 @@ export class WatsonxRerank
...this.scopeId(),
inputs,
query,
parameters: {
truncate_input_tokens: this.truncateInputTokens,
},
})
);
const resultDocuments = result.results.map(({ index, score }) => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,25 @@ describe("Integration tests on WatsonxRerank", () => {
expect(typeof item.metadata.relevanceScore).toBe("number")
);
});

test("Basic call with truncation", async () => {
const instance = new WatsonxRerank({
model: "cross-encoder/ms-marco-minilm-l-12-v2",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
version: "2024-05-31",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
truncateInputTokens: 512,
});
const longerDocs: Document[] = docs.map((item) => ({
pageContent: item.pageContent.repeat(100),
metadata: {},
}));
const result = await instance.compressDocuments(longerDocs, query);
expect(result.length).toBe(docs.length);
result.forEach((item) =>
expect(typeof item.metadata.relevanceScore).toBe("number")
);
});
});

describe(".rerank() method", () => {
Expand All @@ -57,24 +76,42 @@ describe("Integration tests on WatsonxRerank", () => {
expect(item.input).toBeUndefined();
});
});
});
test("Basic call with options", async () => {
const instance = new WatsonxRerank({
model: "cross-encoder/ms-marco-minilm-l-12-v2",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
version: "2024-05-31",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
});
const result = await instance.rerank(docs, query, {
returnOptions: {
topN: 3,
inputs: true,
},
test("Basic call with options", async () => {
const instance = new WatsonxRerank({
model: "cross-encoder/ms-marco-minilm-l-12-v2",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
version: "2024-05-31",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
});
const result = await instance.rerank(docs, query, {
returnOptions: {
topN: 3,
inputs: true,
},
});
expect(result.length).toBe(3);
result.forEach((item) => {
expect(typeof item.relevanceScore).toBe("number");
expect(item.input).toBeDefined();
});
});
expect(result.length).toBe(3);
result.forEach((item) => {
expect(typeof item.relevanceScore).toBe("number");
expect(item.input).toBeDefined();
test("Basic call with truncation", async () => {
const instance = new WatsonxRerank({
model: "cross-encoder/ms-marco-minilm-l-12-v2",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
version: "2024-05-31",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
});
const longerDocs = docs.map((item) => ({
pageContent: item.pageContent.repeat(100),
}));
const result = await instance.rerank(longerDocs, query, {
truncateInputTokens: 512,
});
result.forEach((item) => {
expect(typeof item.relevanceScore).toBe("number");
expect(item.input).toBeUndefined();
});
});
});
});
Loading

0 comments on commit 782fe71

Please sign in to comment.