diff --git a/docs/core_docs/docs/integrations/chat/ibm.ipynb b/docs/core_docs/docs/integrations/chat/ibm.ipynb
index 46cb7bf92d74..c3f60d925f99 100644
--- a/docs/core_docs/docs/integrations/chat/ibm.ipynb
+++ b/docs/core_docs/docs/integrations/chat/ibm.ipynb
@@ -21,14 +21,15 @@
"source": [
"# IBM watsonx.ai\n",
"\n",
- "This will help you getting started with IBM watsonx.ai [chat models](/docs/concepts/chat_models). For detailed documentation of all `IBM watsonx.ai` features and configurations head to the [IBM watsonx.ai](https://api.js.langchain.com/classes/_langchain_community.chat_models_ibm.html).\n",
+ "This will help you getting started with IBM watsonx.ai [chat models](/docs/concepts/chat_models). For detailed documentation of all `IBM watsonx.ai` features and configurations head to the [IBM watsonx.ai](https://api.js.langchain.com/modules/_langchain_community.chat_models_ibm.html).\n",
"\n",
"## Overview\n",
"### Integration details\n",
"\n",
"| Class | Package | Local | Serializable | [PY support](https://python.langchain.com/docs/integrations/chat/ibm_watsonx/) | Package downloads | Package latest |\n",
"| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n",
- "| [`ChatWatsonx`](https://api.js.langchain.com/classes/_langchain_community.chat_models_ibm.html) | [@langchain/community](https://api.js.langchain.com/modules/langchain_community_llms_ibm.html) | ❌ | ✅ | ✅ | ![NPM - Downloads](https://img.shields.io/npm/dm/@langchain/community?style=flat-square&label=%20&) | ![NPM - Version](https://img.shields.io/npm/v/@langchain/community?style=flat-square&label=%20&) |\n",
+ "| [`ChatWatsonx`](https://api.js.langchain.com/classes/_langchain_community.chat_models_ibm.ChatWatsonx.html) | [@langchain/community](https://www.npmjs.com/package/@langchain/community) | ❌ | ✅ | ✅ | ![NPM - Downloads](https://img.shields.io/npm/dm/@langchain/community?style=flat-square&label=%20&) | ![NPM - Version](https://img.shields.io/npm/v/@langchain/community?style=flat-square&label=%20&) |\n",
+
"\n",
"### Model features\n",
"\n",
@@ -138,7 +139,7 @@
"\n",
"\n",
"\n",
- " __package_name__ @langchain/core\n",
+ " @langchain/community @langchain/core\n",
"\n",
"\n",
"```"
@@ -340,7 +341,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 1,
"id": "cd21e356",
"metadata": {},
"outputs": [
@@ -357,97 +358,7 @@
" only\n",
" natural\n",
" satellite\n",
- " and\n",
- " the\n",
- " fifth\n",
- " largest\n",
- " satellite\n",
- " in\n",
- " the\n",
- " Solar\n",
- " System\n",
- ".\n",
- " It\n",
- " or\n",
- "bits\n",
- " Earth\n",
- " every\n",
- " \n",
- "2\n",
- "7\n",
- ".\n",
- "3\n",
- " days\n",
- " and\n",
- " rot\n",
- "ates\n",
- " on\n",
- " its\n",
- " axis\n",
- " in\n",
- " the\n",
- " same\n",
- " amount\n",
- " of\n",
- " time\n",
- ",\n",
- " which\n",
- " is\n",
- " why\n",
- " we\n",
- " always\n",
- " see\n",
- " the\n",
- " same\n",
- " side\n",
- " of\n",
- " it\n",
- ".\n",
- " The\n",
- " Moon\n",
- "'\n",
- "s\n",
- " phases\n",
- " change\n",
- " as\n",
- " it\n",
- " or\n",
- "bits\n",
- " Earth\n",
- ",\n",
- " going\n",
- " through\n",
- " cycles\n",
- " of\n",
- " new\n",
- ",\n",
- " c\n",
- "res\n",
- "cent\n",
- ",\n",
- " half\n",
- ",\n",
- " g\n",
- "ib\n",
- "b\n",
- "ous\n",
- ",\n",
- " and\n",
- " full\n",
- " phases\n",
- ".\n",
- " Its\n",
- " gravity\n",
- " influences\n",
- " Earth\n",
- "'\n",
- "s\n",
- " t\n",
- "ides\n",
- " and\n",
- " stabil\n",
- "izes\n",
- " our\n"
+ " and\n"
]
}
],
@@ -589,4 +500,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
-}
\ No newline at end of file
+}
diff --git a/docs/core_docs/docs/integrations/llms/ibm.ipynb b/docs/core_docs/docs/integrations/llms/ibm.ipynb
index 1644f7401724..7c47f382e59e 100644
--- a/docs/core_docs/docs/integrations/llms/ibm.ipynb
+++ b/docs/core_docs/docs/integrations/llms/ibm.ipynb
@@ -22,7 +22,7 @@
"# IBM watsonx.ai\n",
"\n",
"\n",
- "This will help you get started with IBM [text completion models (LLMs)](/docs/concepts/text_llms) using LangChain. For detailed documentation on `IBM watsonx.ai` features and configuration options, please refer to the [IBM watsonx.ai](https://api.js.langchain.com/classes/_langchain_community.llms_ibm.html).\n",
+ "This will help you get started with IBM [text completion models (LLMs)](/docs/concepts/text_llms) using LangChain. For detailed documentation on `IBM watsonx.ai` features and configuration options, please refer to the [IBM watsonx.ai](https://api.js.langchain.com/modules/_langchain_community.llms_ibm.html).\n",
"\n",
"## Overview\n",
"### Integration details\n",
@@ -30,7 +30,7 @@
"\n",
"| Class | Package | Local | Serializable | [PY support](https://python.langchain.com/docs/integrations/llms/ibm_watsonx/) | Package downloads | Package latest |\n",
"| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n",
- "| [`IBM watsonx.ai`](https://api.js.langchain.com/modules/_langchain_community.llms_ibm.html) | [@langchain/community](https://api.js.langchain.com/modules/langchain_community_llms_ibm.html) | ❌ | ✅ | ✅ | ![NPM - Downloads](https://img.shields.io/npm/dm/@langchain/community?style=flat-square&label=%20&) | ![NPM - Version](https://img.shields.io/npm/v/@langchain/community?style=flat-square&label=%20&) |\n",
+ "| [`WatsonxLLM`](https://api.js.langchain.com/classes/_langchain_community.llms_ibm.WatsonxLLM.html) | [@langchain/community](https://www.npmjs.com/package/@langchain/community) | ❌ | ✅ | ✅ | ![NPM - Downloads](https://img.shields.io/npm/dm/@langchain/community?style=flat-square&label=%20&) | ![NPM - Version](https://img.shields.io/npm/v/@langchain/community?style=flat-square&label=%20&) |\n",
"\n",
"## Setup\n",
"\n",
@@ -161,11 +161,11 @@
"\n",
"const props = {\n",
" decoding_method: \"sample\",\n",
- " max_new_tokens: 100,\n",
- " min_new_tokens: 1,\n",
+ " maxNewTokens: 100,\n",
+ " minNewTokens: 1,\n",
" temperature: 0.5,\n",
- " top_k: 50,\n",
- " top_p: 1,\n",
+ " topK: 50,\n",
+ " topP: 1,\n",
"};\n",
"const instance = new WatsonxLLM({\n",
" version: \"YYYY-MM-DD\",\n",
@@ -298,7 +298,7 @@
"source": [
"const result2 = await instance.invoke(\"Print hello world.\", {\n",
" parameters: {\n",
- " max_new_tokens: 20,\n",
+ " maxNewTokens: 100,\n",
" },\n",
"});\n",
"console.log(result2);"
@@ -358,4 +358,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
-}
\ No newline at end of file
+}
diff --git a/docs/core_docs/docs/integrations/text_embedding/ibm.ipynb b/docs/core_docs/docs/integrations/text_embedding/ibm.ipynb
index dfe43b07c462..bac03b424f48 100644
--- a/docs/core_docs/docs/integrations/text_embedding/ibm.ipynb
+++ b/docs/core_docs/docs/integrations/text_embedding/ibm.ipynb
@@ -22,7 +22,7 @@
"# IBM watsonx.ai\n",
"\n",
"\n",
- "This will help you get started with IBM watsonx.ai [embedding models](/docs/concepts/embedding_models) using LangChain. For detailed documentation on `IBM watsonx.ai` features and configuration options, please refer to the [API reference](https://api.js.langchain.com/classes/_langchain_community.embeddings_ibm.html).\n",
+ "This will help you get started with IBM watsonx.ai [embedding models](/docs/concepts/embedding_models) using LangChain. For detailed documentation on `IBM watsonx.ai` features and configuration options, please refer to the [API reference](https://api.js.langchain.com/modules/_langchain_community.embeddings_ibm.html).\n",
"\n",
"## Overview\n",
"### Integration details\n",
@@ -30,7 +30,7 @@
"\n",
"| Class | Package | Local | [Py support](https://python.langchain.com/docs/integrations/text_embedding/ibm_watsonx/) | Package downloads | Package latest |\n",
"| :--- | :--- | :---: | :---: | :---: | :---: |\n",
- "| [`IBM watsonx.ai`](https://api.js.langchain.com/classes/_langchain_community.embeddings_ibm.WatsonxEmbeddings.html) | [@langchain/community](https://api.js.langchain.com/modules/langchain_community_llms_ibm.html)| ❌ | ✅ | ![NPM - Downloads](https://img.shields.io/npm/dm/@langchain/community?style=flat-square&label=%20&) | ![NPM - Version](https://img.shields.io/npm/v/@langchain/community?style=flat-square&label=%20&) |\n",
+ "| [`WatsonxEmbeddings`](https://api.js.langchain.com/classes/_langchain_community.embeddings_ibm.WatsonxEmbeddings.html) | [@langchain/community](https://www.npmjs.com/package/@langchain/community)| ❌ | ✅ | ![NPM - Downloads](https://img.shields.io/npm/dm/@langchain/community?style=flat-square&label=%20&) | ![NPM - Version](https://img.shields.io/npm/v/@langchain/community?style=flat-square&label=%20&) |\n",
"\n",
"## Setup\n",
"\n",
@@ -163,7 +163,6 @@
" serviceUrl: process.env.API_URL,\n",
" projectId: \"\",\n",
" spaceId: \"\",\n",
- " idOrName: \"\",\n",
" model: \"\",\n",
"});"
]
@@ -175,7 +174,7 @@
"source": [
"Note:\n",
"\n",
- "- You must provide `spaceId`, `projectId` or `idOrName`(deployment id) in order to proceed.\n",
+ "- You must provide `spaceId` or `projectId` in order to proceed.\n",
"- Depending on the region of your provisioned service instance, use correct serviceUrl."
]
},
@@ -243,7 +242,7 @@
},
{
"cell_type": "code",
- "execution_count": 6,
+ "execution_count": 1,
"id": "0d2befcd",
"metadata": {},
"outputs": [
@@ -252,33 +251,18 @@
"output_type": "stream",
"text": [
"[\n",
- " -0.017436018, -0.01469498, -0.015685871, -0.013543149, -0.0011519607,\n",
- " -0.008123747, 0.015286108, -0.023845721, -0.02454774, 0.07235078,\n",
- " -0.032333843, -0.0035843418, -0.015389036, 0.0455373, -0.021119863,\n",
- " -0.022039745, 0.021746712, -0.017774817, -0.008232582, -0.036727764,\n",
- " -0.015734928, 0.03606811, -0.005108186, -0.036052454, 0.024462992,\n",
- " 0.02359307, 0.03273164, 0.009195497, -0.0077208397, -0.0127943,\n",
- " -0.023869334, -0.029473905, -0.0080457395, -0.0021337876, 0.04949132,\n",
- " 0.013950589, -0.010046689, 0.021029025, -0.031725302, 0.004251065,\n",
- " -0.034171984, -0.03696642, -0.014253629, -0.017757406, -0.007531065,\n",
- " 0.07187789, 0.009661725, 0.041889492, -0.04660478, 0.028036641,\n",
- " 0.059334517, -0.04561291, 0.056029715, -0.00676024, 0.026493236,\n",
- " 0.0116374, 0.050126843, -0.018036349, -0.013711887, 0.042252757,\n",
- " -0.04453391, 0.04705777, -0.00044598224, -0.030227259, 0.029286578,\n",
- " 0.0252211, 0.011694125, -0.031404093, 0.02951232, 0.08812359,\n",
- " 0.023539362, -0.011082862, 0.008024676, 0.00084492035, -0.007984158,\n",
- " -0.0005008702, -0.025189219, 0.021000557, -0.0065513053, 0.036524914,\n",
- " 0.0015150858, -0.0042383806, 0.049065087, 0.000941666, 0.04447001,\n",
- " 0.012942205, -0.078316726, -0.03004237, -0.025807172, -0.03446275,\n",
- " -0.00932942, -0.044925686, 0.03190307, 0.010136769, -0.048854534,\n",
- " 0.025738232, -0.017840309, 0.023738133, 0.014214792, 0.030452395\n",
+ " -0.017436018, -0.01469498,\n",
+ " -0.015685871, -0.013543149,\n",
+ " -0.0011519607, -0.008123747,\n",
+ " 0.015286108, -0.023845721,\n",
+ " -0.02454774, 0.07235078\n",
"]\n"
]
}
],
"source": [
" const singleVector = await embeddings.embedQuery(text);\n",
- " singleVector.slice(0, 100);"
+ " singleVector.slice(0, 10);"
]
},
{
@@ -302,48 +286,18 @@
"output_type": "stream",
"text": [
"[\n",
- " -0.017436024, -0.014695002, -0.01568589, -0.013543164, -0.001151976,\n",
- " -0.008123703, 0.015286064, -0.023845702, -0.024547677, 0.07235076,\n",
- " -0.032333862, -0.0035843418, -0.015389038, 0.045537304, -0.021119865,\n",
- " -0.02203975, 0.021746716, -0.01777481, -0.008232588, -0.03672781,\n",
- " -0.015734889, 0.036068108, -0.0051082, -0.036052432, 0.024462998,\n",
- " 0.023593083, 0.03273162, 0.009195521, -0.007720828, -0.012794304,\n",
- " -0.023869323, -0.029473891, -0.008045726, -0.002133793, 0.049491342,\n",
- " 0.013950573, -0.010046691, 0.02102898, -0.03172528, 0.0042510596,\n",
- " -0.034171965, -0.036966413, -0.014253668, -0.017757434, -0.007531062,\n",
- " 0.07187787, 0.009661732, 0.041889492, -0.04660476, 0.028036654,\n",
- " 0.059334517, -0.045612894, 0.056029722, -0.00676024, 0.026493296,\n",
- " 0.0116374055, 0.050126873, -0.018036384, -0.013711868, 0.0422528,\n",
- " -0.044533912, 0.047057763, -0.00044596897, -0.030227251, 0.029286569,\n",
- " 0.025221113, 0.011694138, -0.03140413, 0.029512335, 0.08812357,\n",
- " 0.023539348, -0.011082865, 0.008024677, 0.00084490055, -0.007984145,\n",
- " -0.0005008745, -0.025189226, 0.021000564, -0.0065513197, 0.036524955,\n",
- " 0.0015150585, -0.0042383634, 0.049065102, 0.000941638, 0.044469994,\n",
- " 0.012942193, -0.078316696, -0.0300424, -0.025807157, -0.0344627,\n",
- " -0.009329439, -0.04492573, 0.031903077, 0.010136808, -0.048854522,\n",
- " 0.025738247, -0.01784033, 0.023738142, 0.014214801, 0.030452369\n",
+ " -0.017436024, -0.014695002,\n",
+ " -0.01568589, -0.013543164,\n",
+ " -0.001151976, -0.008123703,\n",
+ " 0.015286064, -0.023845702,\n",
+ " -0.024547677, 0.07235076\n",
"]\n",
"[\n",
- " 0.03278884, -0.017893745, -0.0027520044, 0.016506646, 0.028271576,\n",
- " -0.01284331, 0.014344065, -0.007968607, -0.03899479, 0.039327156,\n",
- " -0.047726233, 0.009559004, -0.05302522, 0.011498492, -0.0055542476,\n",
- " -0.0020940166, -0.029262392, -0.025919685, 0.024261741, -0.0010863725,\n",
- " 0.0074619935, 0.014191284, -0.009054746, -0.038633537, 0.039744128,\n",
- " 0.012625762, 0.030490868, 0.013526139, -0.024638629, -0.011268263,\n",
- " -0.012759613, -0.04693565, -0.013087251, -0.01971696, 0.0125782555,\n",
- " 0.024156926, -0.011638484, 0.017364893, -0.0405832, -0.0032466082,\n",
- " -0.01611277, -0.022583133, 0.019492855, -0.03664484, -0.022627067,\n",
- " 0.011026938, -0.014631298, 0.043255687, -0.029447634, 0.017212389,\n",
- " 0.029366229, -0.041978795, 0.005347565, -0.0106230285, -0.008334342,\n",
- " -0.008841154, 0.045096103, 0.03996879, -0.002039457, -0.0051824683,\n",
- " -0.019464444, 0.092018366, -0.009283633, -0.020052811, 0.0043408144,\n",
- " -0.029403884, 0.02587689, -0.027253918, 0.0159064, 0.0421537,\n",
- " 0.05078811, -0.012380686, -0.018032575, 0.01711449, 0.03636163,\n",
- " -0.014590949, -0.015076142, 0.00018201554, 0.002490666, 0.044776678,\n",
- " 0.05301749, -0.007891316, 0.028668318, -0.0016632816, 0.04487743,\n",
- " -0.032529455, -0.040372133, -0.020566158, -0.011109745, -0.01724949,\n",
- " -0.0047519016, -0.041635286, 0.0068111843, 0.039498538, -0.02491227,\n",
- " 0.016853934, -0.017926402, -0.006154979, 0.025893573, 0.015262395\n",
+ " 0.03278884, -0.017893745,\n",
+ " -0.0027520044, 0.016506646,\n",
+ " 0.028271576, -0.01284331,\n",
+ " 0.014344065, -0.007968607,\n",
+ " -0.03899479, 0.039327156\n",
"]\n"
]
}
@@ -355,9 +309,8 @@
"\n",
" const vectors = await embeddings.embedDocuments([text, text2]);\n",
" \n",
- " console.log(vectors[0].slice(0, 100));\n",
- " console.log(vectors[1].slice(0, 100));\n",
- " "
+ " console.log(vectors[0].slice(0, 10));\n",
+ " console.log(vectors[1].slice(0, 10));\n"
]
},
{
@@ -386,4 +339,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
-}
\ No newline at end of file
+}
diff --git a/libs/langchain-community/src/chat_models/ibm.ts b/libs/langchain-community/src/chat_models/ibm.ts
index dd468909a886..d4dc6a64ba28 100644
--- a/libs/langchain-community/src/chat_models/ibm.ts
+++ b/libs/langchain-community/src/chat_models/ibm.ts
@@ -33,7 +33,6 @@ import {
} from "@langchain/core/outputs";
import { AsyncCaller } from "@langchain/core/utils/async_caller";
import {
- TextChatConstants,
TextChatMessagesTextChatMessageAssistant,
TextChatParameterTools,
TextChatParams,
@@ -42,7 +41,6 @@ import {
TextChatResultChoice,
TextChatResultMessage,
TextChatToolCall,
- TextChatToolChoiceTool,
TextChatUsage,
} from "@ibm-cloud/watsonx-ai/dist/watsonx-ai-ml/vml_v1.js";
import { WatsonXAI } from "@ibm-cloud/watsonx-ai";
@@ -81,30 +79,8 @@ export interface WatsonxDeltaStream {
}
export interface WatsonxCallParams
- extends Partial<
- Omit<
- TextChatParams,
- | "toolChoiceOption"
- | "toolChoice"
- | "frequencyPenalty"
- | "topLogprobs"
- | "maxTokens"
- | "presencePenalty"
- | "responseFormat"
- | "timeLimit"
- | "modelId"
- >
- > {
+ extends Partial> {
maxRetries?: number;
- tool_choice?: TextChatToolChoiceTool;
- tool_choice_option?: TextChatConstants.ToolChoiceOption | string;
- frequency_penalty?: number;
- top_logprobs?: number;
- max_new_tokens?: number;
- presence_penalty?: number;
- top_p?: number;
- time_limit?: number;
- response_format?: TextChatResponseFormat;
}
export interface WatsonxCallOptionsChat
extends Omit,
@@ -114,12 +90,15 @@ export interface WatsonxCallOptionsChat
type ChatWatsonxToolType = BindToolsInput | TextChatParameterTools;
-export interface ChatWatsonxInput extends BaseChatModelParams, WatsonxParams {
+export interface ChatWatsonxInput
+ extends BaseChatModelParams,
+ WatsonxParams,
+ WatsonxCallParams {
streaming?: boolean;
}
-function _convertToValidToolId(modelId: string, tool_call_id: string) {
- if (modelId.startsWith("mistralai"))
+function _convertToValidToolId(model: string, tool_call_id: string) {
+ if (model.startsWith("mistralai"))
return _convertToolCallIdToMistralCompatible(tool_call_id);
else return tool_call_id;
}
@@ -144,7 +123,7 @@ function _convertToolToWatsonxTool(
function _convertMessagesToWatsonxMessages(
messages: BaseMessage[],
- modelId: string
+ model: string
): TextChatResultMessage[] {
const getRole = (role: MessageType) => {
switch (role) {
@@ -168,7 +147,7 @@ function _convertMessagesToWatsonxMessages(
return message.tool_calls
.map((toolCall) => ({
...toolCall,
- id: _convertToValidToolId(modelId, toolCall.id ?? ""),
+ id: _convertToValidToolId(model, toolCall.id ?? ""),
}))
.map(convertLangChainToolCallToOpenAI) as TextChatToolCall[];
}
@@ -183,7 +162,7 @@ function _convertMessagesToWatsonxMessages(
role: getRole(message._getType()),
content,
name: message.name,
- tool_call_id: _convertToValidToolId(modelId, message.tool_call_id),
+ tool_call_id: _convertToValidToolId(model, message.tool_call_id),
};
}
@@ -246,7 +225,7 @@ function _watsonxResponseToChatMessage(
function _convertDeltaToMessageChunk(
delta: WatsonxDeltaStream,
rawData: TextChatResponse,
- modelId: string,
+ model: string,
usage?: TextChatUsage,
defaultRole?: TextChatMessagesTextChatMessageAssistant.Constants.Role
) {
@@ -262,7 +241,7 @@ function _convertDeltaToMessageChunk(
} => ({
...toolCall,
index,
- id: _convertToValidToolId(modelId, toolCall.id),
+ id: _convertToValidToolId(model, toolCall.id),
type: "function",
})
)
@@ -315,7 +294,7 @@ function _convertDeltaToMessageChunk(
return new ToolMessageChunk({
content,
additional_kwargs,
- tool_call_id: _convertToValidToolId(modelId, rawToolCalls?.[0].id),
+ tool_call_id: _convertToValidToolId(model, rawToolCalls?.[0].id),
});
} else if (role === "function") {
return new FunctionMessageChunk({
@@ -379,11 +358,11 @@ export class ChatWatsonx<
};
}
- model = "mistralai/mistral-large";
+ model: string;
version = "2024-05-31";
- max_new_tokens = 100;
+ maxTokens: number;
maxRetries = 0;
@@ -393,35 +372,31 @@ export class ChatWatsonx<
projectId?: string;
- frequency_penalty?: number;
+ frequencyPenalty?: number;
logprobs?: boolean;
- top_logprobs?: number;
+ topLogprobs?: number;
n?: number;
- presence_penalty?: number;
+ presencePenalty?: number;
temperature?: number;
- top_p?: number;
+ topP?: number;
- time_limit?: number;
+ timeLimit?: number;
maxConcurrency?: number;
service: WatsonXAI;
- response_format?: TextChatResponseFormat | string;
+ responseFormat?: TextChatResponseFormat;
streaming: boolean;
- constructor(
- fields: ChatWatsonxInput &
- WatsonxAuth &
- Partial>
- ) {
+ constructor(fields: ChatWatsonxInput & WatsonxAuth) {
super(fields);
if (
(fields.projectId && fields.spaceId) ||
@@ -432,20 +407,20 @@ export class ChatWatsonx<
if (!fields.projectId && !fields.spaceId && !fields.idOrName)
throw new Error(
- "No id specified! At least ide of 1 type has to be specified"
+ "No id specified! At least id of 1 type has to be specified"
);
this.projectId = fields?.projectId;
this.spaceId = fields?.spaceId;
this.temperature = fields?.temperature;
this.maxRetries = fields?.maxRetries || this.maxRetries;
this.maxConcurrency = fields?.maxConcurrency;
- this.frequency_penalty = fields?.frequency_penalty;
- this.top_logprobs = fields?.top_logprobs;
- this.max_new_tokens = fields?.max_new_tokens ?? this.max_new_tokens;
- this.presence_penalty = fields?.presence_penalty;
- this.top_p = fields?.top_p;
- this.time_limit = fields?.time_limit;
- this.response_format = fields?.response_format ?? this.response_format;
+ this.frequencyPenalty = fields?.frequencyPenalty;
+ this.topLogprobs = fields?.topLogprobs;
+ this.maxTokens = fields?.maxTokens ?? this.maxTokens;
+ this.presencePenalty = fields?.presencePenalty;
+ this.topP = fields?.topP;
+ this.timeLimit = fields?.timeLimit;
+ this.responseFormat = fields?.responseFormat ?? this.responseFormat;
this.serviceUrl = fields?.serviceUrl;
this.streaming = fields?.streaming ?? this.streaming;
this.n = fields?.n ?? this.n;
@@ -483,21 +458,21 @@ export class ChatWatsonx<
invocationParams(options: this["ParsedCallOptions"]) {
return {
- maxTokens: options.max_new_tokens ?? this.max_new_tokens,
+ maxTokens: options.maxTokens ?? this.maxTokens,
temperature: options?.temperature ?? this.temperature,
- timeLimit: options?.time_limit ?? this.time_limit,
- topP: options?.top_p ?? this.top_p,
- presencePenalty: options?.presence_penalty ?? this.presence_penalty,
+ timeLimit: options?.timeLimit ?? this.timeLimit,
+ topP: options?.topP ?? this.topP,
+ presencePenalty: options?.presencePenalty ?? this.presencePenalty,
n: options?.n ?? this.n,
- topLogprobs: options?.top_logprobs ?? this.top_logprobs,
+ topLogprobs: options?.topLogprobs ?? this.topLogprobs,
logprobs: options?.logprobs ?? this?.logprobs,
- frequencyPenalty: options?.frequency_penalty ?? this.frequency_penalty,
+ frequencyPenalty: options?.frequencyPenalty ?? this.frequencyPenalty,
tools: options.tools
? _convertToolToWatsonxTool(options.tools)
: undefined,
- toolChoice: options.tool_choice,
- responseFormat: options.response_format,
- toolChoiceOption: options.tool_choice_option,
+ toolChoice: options.toolChoice,
+ responseFormat: options.responseFormat,
+ toolChoiceOption: options.toolChoiceOption,
};
}
@@ -556,7 +531,7 @@ export class ChatWatsonx<
if (message?.usage_metadata) {
const completion = chunk.generationInfo?.completion;
if (tokenUsages[completion])
- tokenUsages[completion].output_tokens +=
+ tokenUsages[completion].output_tokens =
message.usage_metadata.output_tokens;
else tokenUsages[completion] = message.usage_metadata;
}
@@ -759,7 +734,7 @@ export class ChatWatsonx<
let llm: Runnable;
if (method === "jsonMode") {
const options = {
- response_format: { type: "json_object" },
+ responseFormat: { type: "json_object" },
} as Partial;
llm = this.bind(options);
@@ -783,7 +758,7 @@ export class ChatWatsonx<
},
],
// Ideally that would be set to required but this is not supported yet
- tool_choice: {
+ toolChoice: {
type: "function",
function: {
name: functionName,
@@ -819,7 +794,7 @@ export class ChatWatsonx<
},
],
// Ideally that would be set to required but this is not supported yet
- tool_choice: {
+ toolChoice: {
type: "function",
function: {
name: functionName,
diff --git a/libs/langchain-community/src/chat_models/tests/ibm.int.test.ts b/libs/langchain-community/src/chat_models/tests/ibm.int.test.ts
index 85f33d733e6d..2f1d118d92a4 100644
--- a/libs/langchain-community/src/chat_models/tests/ibm.int.test.ts
+++ b/libs/langchain-community/src/chat_models/tests/ibm.int.test.ts
@@ -12,15 +12,13 @@ import { LLMResult } from "@langchain/core/outputs";
import { ChatPromptTemplate } from "@langchain/core/prompts";
import { tool } from "@langchain/core/tools";
import { NewTokenIndices } from "@langchain/core/callbacks/base";
-import * as fs from "node:fs/promises";
-import { fileURLToPath } from "node:url";
-import * as path from "node:path";
import { ChatWatsonx } from "../ibm.js";
describe("Tests for chat", () => {
describe("Test ChatWatsonx invoke and generate", () => {
test("Basic invoke", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -30,6 +28,7 @@ describe("Tests for chat", () => {
});
test("Basic generate", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -40,6 +39,7 @@ describe("Tests for chat", () => {
});
test("Invoke with system message", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -53,6 +53,7 @@ describe("Tests for chat", () => {
});
test("Invoke with output parser", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -68,6 +69,7 @@ describe("Tests for chat", () => {
});
test("Invoke with prompt", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -83,6 +85,7 @@ describe("Tests for chat", () => {
});
test("Invoke with chat conversation", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -104,6 +107,7 @@ describe("Tests for chat", () => {
totalTokens: 0,
};
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -120,6 +124,7 @@ describe("Tests for chat", () => {
});
test("Timeout", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -132,6 +137,7 @@ describe("Tests for chat", () => {
}, 5000);
test("Controller options", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -150,6 +156,7 @@ describe("Tests for chat", () => {
describe("Test ChatWatsonx invoke and generate with stream mode", () => {
test("Basic invoke", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -160,6 +167,7 @@ describe("Tests for chat", () => {
});
test("Basic generate", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -170,6 +178,7 @@ describe("Tests for chat", () => {
});
test("Generate with n>1", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -198,11 +207,12 @@ describe("Tests for chat", () => {
];
let tokenUsed = 0;
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
n: 2,
- max_new_tokens: 5,
+ maxTokens: 5,
streaming: true,
callbackManager: CallbackManager.fromHandlers({
async handleLLMEnd(output: LLMResult) {
@@ -236,6 +246,7 @@ describe("Tests for chat", () => {
});
test("Invoke with system message", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -249,6 +260,7 @@ describe("Tests for chat", () => {
});
test("Invoke with output parser", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -264,6 +276,7 @@ describe("Tests for chat", () => {
});
test("Invoke with prompt", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -279,6 +292,7 @@ describe("Tests for chat", () => {
});
test("Invoke with chat conversation", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -300,6 +314,7 @@ describe("Tests for chat", () => {
totalTokens: 0,
};
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -316,6 +331,7 @@ describe("Tests for chat", () => {
});
test("Timeout", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -328,6 +344,7 @@ describe("Tests for chat", () => {
}, 5000);
test("Controller options", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -346,6 +363,7 @@ describe("Tests for chat", () => {
describe("Test ChatWatsonx stream", () => {
test("Basic stream", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -366,6 +384,7 @@ describe("Tests for chat", () => {
});
test("Timeout", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -378,6 +397,7 @@ describe("Tests for chat", () => {
}, 5000);
test("Controller options", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -399,6 +419,7 @@ describe("Tests for chat", () => {
test("Token count and response equality", async () => {
let generation = "";
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -426,21 +447,24 @@ describe("Tests for chat", () => {
});
test("Token count usage_metadata", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
});
let res: AIMessageChunk | null = null;
+ let outputCount = 0;
const stream = await service.stream("Why is the sky blue? Be concise.");
for await (const chunk of stream) {
res = chunk;
+ outputCount += 1;
}
expect(res?.usage_metadata).toBeDefined();
if (!res?.usage_metadata) {
return;
}
expect(res.usage_metadata.input_tokens).toBeGreaterThan(1);
- expect(res.usage_metadata.output_tokens).toBe(1);
+ expect(res.usage_metadata.output_tokens).toBe(outputCount);
expect(res.usage_metadata.total_tokens).toBe(
res.usage_metadata.input_tokens + res.usage_metadata.output_tokens
);
@@ -450,6 +474,7 @@ describe("Tests for chat", () => {
describe("Test tool usage", () => {
test("Passing tool to chat model", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -503,6 +528,7 @@ describe("Tests for chat", () => {
});
test("Passing tool to chat model extended", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -563,6 +589,7 @@ describe("Tests for chat", () => {
});
test("Binding model-specific formats", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -603,6 +630,7 @@ describe("Tests for chat", () => {
});
test("Passing tool to chat model", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -655,6 +683,7 @@ describe("Tests for chat", () => {
describe("Test withStructuredOutput usage", () => {
test("Schema with zod", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -677,6 +706,7 @@ describe("Tests for chat", () => {
test("Schema with zod and stream", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -703,6 +733,7 @@ describe("Tests for chat", () => {
});
test("Schema with object", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -729,6 +760,7 @@ describe("Tests for chat", () => {
});
test("Schema with rawOutput", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -761,6 +793,7 @@ describe("Tests for chat", () => {
});
test("Schema with zod and JSON mode", async () => {
const service = new ChatWatsonx({
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
@@ -798,47 +831,4 @@ describe("Tests for chat", () => {
expect(typeof result.number2).toBe("number");
});
});
-
- describe("Test image input", () => {
- test("Image input", async () => {
- const service = new ChatWatsonx({
- version: "2024-05-31",
- serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
- model: "meta-llama/llama-3-2-11b-vision-instruct",
- projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
- max_new_tokens: 100,
- });
- const __filename = fileURLToPath(import.meta.url);
- const __dirname = path.dirname(__filename);
- const encodedString = await fs.readFile(
- path.join(__dirname, "/data/hotdog.jpg")
- );
- const question = "What is on the picture";
- const messages = [
- {
- role: "user",
- content: [
- {
- type: "text",
- text: question,
- },
- {
- type: "image_url",
- image_url: {
- url:
- "data:image/jpeg;base64," + encodedString.toString("base64"),
- },
- },
- ],
- },
- ];
- const res = await service.stream(messages);
- const chunks = [];
- for await (const chunk of res) {
- expect(chunk).toBeInstanceOf(AIMessageChunk);
- chunks.push(chunk.content);
- }
- expect(typeof chunks.join("")).toBe("string");
- });
- });
});
diff --git a/libs/langchain-community/src/chat_models/tests/ibm.standard.int.test.ts b/libs/langchain-community/src/chat_models/tests/ibm.standard.int.test.ts
index 03b8eb4b3351..545ed3c06fa9 100644
--- a/libs/langchain-community/src/chat_models/tests/ibm.standard.int.test.ts
+++ b/libs/langchain-community/src/chat_models/tests/ibm.standard.int.test.ts
@@ -26,6 +26,7 @@ class ChatWatsonxStandardIntegrationTests extends ChatModelIntegrationTests<
chatModelHasToolCalling: true,
chatModelHasStructuredOutput: true,
constructorArgs: {
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString",
diff --git a/libs/langchain-community/src/chat_models/tests/ibm.standard.test.ts b/libs/langchain-community/src/chat_models/tests/ibm.standard.test.ts
index 6c7ab7d5576a..da9a624209c9 100644
--- a/libs/langchain-community/src/chat_models/tests/ibm.standard.test.ts
+++ b/libs/langchain-community/src/chat_models/tests/ibm.standard.test.ts
@@ -24,6 +24,7 @@ class ChatWatsonxStandardTests extends ChatModelUnitTests<
chatModelHasToolCalling: true,
chatModelHasStructuredOutput: true,
constructorArgs: {
+ model: "mistralai/mistral-large",
watsonxAIApikey: "testString",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString",
diff --git a/libs/langchain-community/src/chat_models/tests/ibm.test.ts b/libs/langchain-community/src/chat_models/tests/ibm.test.ts
index 8e04c1c26c6b..f52a689f6755 100644
--- a/libs/langchain-community/src/chat_models/tests/ibm.test.ts
+++ b/libs/langchain-community/src/chat_models/tests/ibm.test.ts
@@ -52,6 +52,7 @@ describe("LLM unit tests", () => {
test("Test basic properties after init", async () => {
const testProps = {
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID || "testString",
@@ -63,6 +64,7 @@ describe("LLM unit tests", () => {
test("Test methods after init", () => {
const testProps: ChatWatsonxInput = {
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID || "testString",
@@ -83,10 +85,10 @@ describe("LLM unit tests", () => {
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID || "testString",
model: "ibm/granite-13b-chat-v2",
- max_new_tokens: 100,
+ maxTokens: 100,
temperature: 0.1,
- time_limit: 10000,
- top_p: 1,
+ timeLimit: 10000,
+ topP: 1,
maxRetries: 3,
maxConcurrency: 3,
};
@@ -99,6 +101,7 @@ describe("LLM unit tests", () => {
describe("Negative tests", () => {
test("Missing id", async () => {
const testProps: ChatWatsonxInput = {
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
};
@@ -149,6 +152,7 @@ describe("LLM unit tests", () => {
test("Passing more than one id", async () => {
const testProps: ChatWatsonxInput = {
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID || "testString",
@@ -165,6 +169,7 @@ describe("LLM unit tests", () => {
test("Not existing property passed", async () => {
const testProps = {
+ model: "mistralai/mistral-large",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID || "testString",
diff --git a/libs/langchain-community/src/embeddings/ibm.ts b/libs/langchain-community/src/embeddings/ibm.ts
index fc8fcaebb561..9ee0f39976f9 100644
--- a/libs/langchain-community/src/embeddings/ibm.ts
+++ b/libs/langchain-community/src/embeddings/ibm.ts
@@ -9,14 +9,20 @@ import { WatsonxAuth, WatsonxParams } from "../types/ibm.js";
import { authenticateAndSetInstance } from "../utils/ibm.js";
export interface WatsonxEmbeddingsParams
- extends Omit,
- Pick {}
+ extends Pick {
+ truncateInputTokens?: number;
+}
+
+export interface WatsonxInputEmbeddings
+ extends Omit {
+ truncateInputTokens?: number;
+}
export class WatsonxEmbeddings
extends Embeddings
implements WatsonxEmbeddingsParams, WatsonxParams
{
- model = "ibm/slate-125m-english-rtrvr";
+ model: string;
serviceUrl: string;
@@ -26,7 +32,7 @@ export class WatsonxEmbeddings
projectId?: string;
- truncate_input_tokens?: number;
+ truncateInputTokens?: number;
maxRetries?: number;
@@ -34,18 +40,18 @@ export class WatsonxEmbeddings
private service: WatsonXAI;
- constructor(fields: WatsonxEmbeddingsParams & WatsonxAuth & WatsonxParams) {
+ constructor(fields: WatsonxInputEmbeddings & WatsonxAuth) {
const superProps = { maxConcurrency: 2, ...fields };
super(superProps);
- this.model = fields?.model ? fields.model : this.model;
+ this.model = fields.model;
this.version = fields.version;
this.serviceUrl = fields.serviceUrl;
- this.truncate_input_tokens = fields.truncate_input_tokens;
+ this.truncateInputTokens = fields.truncateInputTokens;
this.maxConcurrency = fields.maxConcurrency;
- this.maxRetries = fields.maxRetries;
+ this.maxRetries = fields.maxRetries ?? 0;
if (fields.projectId && fields.spaceId)
throw new Error("Maximum 1 id type can be specified per instance");
- else if (!fields.projectId && !fields.spaceId && !fields.idOrName)
+ else if (!fields.projectId && !fields.spaceId)
throw new Error(
"No id specified! At least id of 1 type has to be specified"
);
@@ -77,13 +83,14 @@ export class WatsonxEmbeddings
}
scopeId() {
- if (this.projectId) return { projectId: this.projectId };
- else return { spaceId: this.spaceId };
+ if (this.projectId)
+ return { projectId: this.projectId, modelId: this.model };
+ else return { spaceId: this.spaceId, modelId: this.model };
}
invocationParams(): EmbeddingParameters {
return {
- truncate_input_tokens: this.truncate_input_tokens,
+ truncate_input_tokens: this.truncateInputTokens,
};
}
@@ -104,7 +111,6 @@ export class WatsonxEmbeddings
private async embedSingleText(inputs: string[]) {
const textEmbeddingParams: TextEmbeddingsParams = {
inputs,
- modelId: this.model,
...this.scopeId(),
parameters: this.invocationParams(),
};
diff --git a/libs/langchain-community/src/embeddings/tests/ibm.int.test.ts b/libs/langchain-community/src/embeddings/tests/ibm.int.test.ts
index 9361a7915213..a774181d4b91 100644
--- a/libs/langchain-community/src/embeddings/tests/ibm.int.test.ts
+++ b/libs/langchain-community/src/embeddings/tests/ibm.int.test.ts
@@ -5,6 +5,7 @@ import { WatsonxEmbeddings } from "../ibm.js";
describe("Test embeddings", () => {
test("embedQuery method", async () => {
const embeddings = new WatsonxEmbeddings({
+ model: "ibm/slate-125m-english-rtrvr",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
@@ -15,6 +16,7 @@ describe("Test embeddings", () => {
test("embedDocuments", async () => {
const embeddings = new WatsonxEmbeddings({
+ model: "ibm/slate-125m-english-rtrvr",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
@@ -27,6 +29,7 @@ describe("Test embeddings", () => {
test("Concurrency", async () => {
const embeddings = new WatsonxEmbeddings({
+ model: "ibm/slate-125m-english-rtrvr",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
@@ -50,6 +53,7 @@ describe("Test embeddings", () => {
test("List models", async () => {
const embeddings = new WatsonxEmbeddings({
+ model: "ibm/slate-125m-english-rtrvr",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
diff --git a/libs/langchain-community/src/embeddings/tests/ibm.test.ts b/libs/langchain-community/src/embeddings/tests/ibm.test.ts
index 05f033f6f1af..ad4196b3387d 100644
--- a/libs/langchain-community/src/embeddings/tests/ibm.test.ts
+++ b/libs/langchain-community/src/embeddings/tests/ibm.test.ts
@@ -1,7 +1,7 @@
/* eslint-disable no-process-env */
/* eslint-disable @typescript-eslint/no-explicit-any */
import { testProperties } from "../../llms/tests/ibm.test.js";
-import { WatsonxEmbeddings } from "../ibm.js";
+import { WatsonxEmbeddings, WatsonxInputEmbeddings } from "../ibm.js";
const fakeAuthProp = {
watsonxAIAuthType: "iam",
@@ -11,6 +11,7 @@ describe("Embeddings unit tests", () => {
describe("Positive tests", () => {
test("Basic properties", () => {
const testProps = {
+ model: "ibm/slate-125m-english-rtrvr",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID || "testString",
@@ -20,14 +21,14 @@ describe("Embeddings unit tests", () => {
});
test("Basic properties", () => {
- const testProps = {
+ const testProps: WatsonxInputEmbeddings = {
+ model: "ibm/slate-125m-english-rtrvr",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID || "testString",
- truncate_input_tokens: 10,
+ truncateInputTokens: 10,
maxConcurrency: 2,
maxRetries: 2,
- model: "ibm/slate-125m-english-rtrvr",
};
const instance = new WatsonxEmbeddings({ ...testProps, ...fakeAuthProp });
@@ -38,6 +39,7 @@ describe("Embeddings unit tests", () => {
describe("Negative tests", () => {
test("Missing id", async () => {
const testProps = {
+ model: "ibm/slate-125m-english-rtrvr",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
};
@@ -85,6 +87,7 @@ describe("Embeddings unit tests", () => {
test("Passing more than one id", async () => {
const testProps = {
+ model: "ibm/slate-125m-english-rtrvr",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID || "testString",
@@ -101,6 +104,7 @@ describe("Embeddings unit tests", () => {
test("Invalid properties", () => {
const testProps = {
+ model: "ibm/slate-125m-english-rtrvr",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID || "testString",
diff --git a/libs/langchain-community/src/llms/ibm.ts b/libs/langchain-community/src/llms/ibm.ts
index 302275158d9c..a0e8a292f0bf 100644
--- a/libs/langchain-community/src/llms/ibm.ts
+++ b/libs/langchain-community/src/llms/ibm.ts
@@ -3,12 +3,8 @@ import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager";
import { BaseLLM, BaseLLMParams } from "@langchain/core/language_models/llms";
import { WatsonXAI } from "@ibm-cloud/watsonx-ai";
import {
- DeploymentsTextGenerationParams,
- DeploymentsTextGenerationStreamParams,
DeploymentTextGenProperties,
ReturnOptionProperties,
- TextGenerationParams,
- TextGenerationStreamParams,
TextGenLengthPenalty,
TextGenParameters,
TextTokenizationParams,
@@ -34,25 +30,28 @@ import {
* Input to LLM class.
*/
-export interface WatsonxCallOptionsLLM
- extends BaseLanguageModelCallOptions,
- Omit<
- Partial<
- TextGenerationParams &
- TextGenerationStreamParams &
- DeploymentsTextGenerationParams &
- DeploymentsTextGenerationStreamParams
- >,
- "input"
- > {
+export interface WatsonxCallOptionsLLM extends BaseLanguageModelCallOptions {
maxRetries?: number;
+ parameters?: Partial;
+ idOrName?: string;
}
-export interface WatsonxInputLLM
- extends TextGenParameters,
- WatsonxParams,
- BaseLLMParams {
+export interface WatsonxInputLLM extends WatsonxParams, BaseLLMParams {
streaming?: boolean;
+ maxNewTokens?: number;
+ decodingMethod?: TextGenParameters.Constants.DecodingMethod | string;
+ lengthPenalty?: TextGenLengthPenalty;
+ minNewTokens?: number;
+ randomSeed?: number;
+ stopSequence?: string[];
+ temperature?: number;
+ timeLimit?: number;
+ topK?: number;
+ topP?: number;
+ repetitionPenalty?: number;
+ truncateInpuTokens?: number;
+ returnOptions?: ReturnOptionProperties;
+ includeStopSequence?: boolean;
}
/**
@@ -73,7 +72,7 @@ export class WatsonxLLM<
streaming = false;
- model = "ibm/granite-13b-chat-v2";
+ model: string;
maxRetries = 0;
@@ -81,7 +80,7 @@ export class WatsonxLLM<
serviceUrl: string;
- max_new_tokens?: number;
+ maxNewTokens?: number;
spaceId?: string;
@@ -89,31 +88,31 @@ export class WatsonxLLM<
idOrName?: string;
- decoding_method?: TextGenParameters.Constants.DecodingMethod | string;
+ decodingMethod?: TextGenParameters.Constants.DecodingMethod | string;
- length_penalty?: TextGenLengthPenalty;
+ lengthPenalty?: TextGenLengthPenalty;
- min_new_tokens?: number;
+ minNewTokens?: number;
- random_seed?: number;
+ randomSeed?: number;
- stop_sequences?: string[];
+ stopSequence?: string[];
temperature?: number;
- time_limit?: number;
+ timeLimit?: number;
- top_k?: number;
+ topK?: number;
- top_p?: number;
+ topP?: number;
- repetition_penalty?: number;
+ repetitionPenalty?: number;
- truncate_input_tokens?: number;
+ truncateInpuTokens?: number;
- return_options?: ReturnOptionProperties;
+ returnOptions?: ReturnOptionProperties;
- include_stop_sequence?: boolean;
+ includeStopSequence?: boolean;
maxConcurrency?: number;
@@ -123,21 +122,21 @@ export class WatsonxLLM<
super(fields);
this.model = fields.model ?? this.model;
this.version = fields.version;
- this.max_new_tokens = fields.max_new_tokens ?? this.max_new_tokens;
+ this.maxNewTokens = fields.maxNewTokens ?? this.maxNewTokens;
this.serviceUrl = fields.serviceUrl;
- this.decoding_method = fields.decoding_method;
- this.length_penalty = fields.length_penalty;
- this.min_new_tokens = fields.min_new_tokens;
- this.random_seed = fields.random_seed;
- this.stop_sequences = fields.stop_sequences;
+ this.decodingMethod = fields.decodingMethod;
+ this.lengthPenalty = fields.lengthPenalty;
+ this.minNewTokens = fields.minNewTokens;
+ this.randomSeed = fields.randomSeed;
+ this.stopSequence = fields.stopSequence;
this.temperature = fields.temperature;
- this.time_limit = fields.time_limit;
- this.top_k = fields.top_k;
- this.top_p = fields.top_p;
- this.repetition_penalty = fields.repetition_penalty;
- this.truncate_input_tokens = fields.truncate_input_tokens;
- this.return_options = fields.return_options;
- this.include_stop_sequence = fields.include_stop_sequence;
+ this.timeLimit = fields.timeLimit;
+ this.topK = fields.topK;
+ this.topP = fields.topP;
+ this.repetitionPenalty = fields.repetitionPenalty;
+ this.truncateInpuTokens = fields.truncateInpuTokens;
+ this.returnOptions = fields.returnOptions;
+ this.includeStopSequence = fields.includeStopSequence;
this.maxRetries = fields.maxRetries || this.maxRetries;
this.maxConcurrency = fields.maxConcurrency;
this.streaming = fields.streaming || this.streaming;
@@ -150,7 +149,7 @@ export class WatsonxLLM<
if (!fields.projectId && !fields.spaceId && !fields.idOrName)
throw new Error(
- "No id specified! At least ide of 1 type has to be specified"
+ "No id specified! At least id of 1 type has to be specified"
);
this.projectId = fields?.projectId;
this.spaceId = fields?.spaceId;
@@ -216,23 +215,23 @@ export class WatsonxLLM<
const { parameters } = options;
return {
- max_new_tokens: parameters?.max_new_tokens ?? this.max_new_tokens,
- decoding_method: parameters?.decoding_method ?? this.decoding_method,
- length_penalty: parameters?.length_penalty ?? this.length_penalty,
- min_new_tokens: parameters?.min_new_tokens ?? this.min_new_tokens,
- random_seed: parameters?.random_seed ?? this.random_seed,
- stop_sequences: options?.stop ?? this.stop_sequences,
+ max_new_tokens: parameters?.maxNewTokens ?? this.maxNewTokens,
+ decoding_method: parameters?.decodingMethod ?? this.decodingMethod,
+ length_penalty: parameters?.lengthPenalty ?? this.lengthPenalty,
+ min_new_tokens: parameters?.minNewTokens ?? this.minNewTokens,
+ random_seed: parameters?.randomSeed ?? this.randomSeed,
+ stop_sequences: options?.stop ?? this.stopSequence,
temperature: parameters?.temperature ?? this.temperature,
- time_limit: parameters?.time_limit ?? this.time_limit,
- top_k: parameters?.top_k ?? this.top_k,
- top_p: parameters?.top_p ?? this.top_p,
+ time_limit: parameters?.timeLimit ?? this.timeLimit,
+ top_k: parameters?.topK ?? this.topK,
+ top_p: parameters?.topP ?? this.topP,
repetition_penalty:
- parameters?.repetition_penalty ?? this.repetition_penalty,
+ parameters?.repetitionPenalty ?? this.repetitionPenalty,
truncate_input_tokens:
- parameters?.truncate_input_tokens ?? this.truncate_input_tokens,
- return_options: parameters?.return_options ?? this.return_options,
+ parameters?.truncateInpuTokens ?? this.truncateInpuTokens,
+ return_options: parameters?.returnOptions ?? this.returnOptions,
include_stop_sequence:
- parameters?.include_stop_sequence ?? this.include_stop_sequence,
+ parameters?.includeStopSequence ?? this.includeStopSequence,
};
}
diff --git a/libs/langchain-community/src/llms/tests/ibm.int.test.ts b/libs/langchain-community/src/llms/tests/ibm.int.test.ts
index 236fd4950be8..dfeebedd39e2 100644
--- a/libs/langchain-community/src/llms/tests/ibm.int.test.ts
+++ b/libs/langchain-community/src/llms/tests/ibm.int.test.ts
@@ -11,6 +11,7 @@ describe("Text generation", () => {
describe("Test invoke method", () => {
test("Correct value", async () => {
const watsonXInstance = new WatsonxLLM({
+ model: "ibm/granite-13b-chat-v2",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
@@ -18,8 +19,21 @@ describe("Text generation", () => {
await watsonXInstance.invoke("Hello world?");
});
+ test("Overwritte params", async () => {
+ const watsonXInstance = new WatsonxLLM({
+ model: "ibm/granite-13b-chat-v2",
+ version: "2024-05-31",
+ serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
+ projectId: process.env.WATSONX_AI_PROJECT_ID,
+ });
+ await watsonXInstance.invoke("Hello world?", {
+ parameters: { maxNewTokens: 10 },
+ });
+ });
+
test("Invalid projectId", async () => {
const watsonXInstance = new WatsonxLLM({
+ model: "ibm/granite-13b-chat-v2",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: "Test wrong value",
@@ -29,6 +43,7 @@ describe("Text generation", () => {
test("Invalid credentials", async () => {
const watsonXInstance = new WatsonxLLM({
+ model: "ibm/granite-13b-chat-v2",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: "Test wrong value",
@@ -41,6 +56,7 @@ describe("Text generation", () => {
test("Wrong value", async () => {
const watsonXInstance = new WatsonxLLM({
+ model: "ibm/granite-13b-chat-v2",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
@@ -51,6 +67,7 @@ describe("Text generation", () => {
test("Stop", async () => {
const watsonXInstance = new WatsonxLLM({
+ model: "ibm/granite-13b-chat-v2",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
@@ -62,10 +79,11 @@ describe("Text generation", () => {
test("Stop with timeout", async () => {
const watsonXInstance = new WatsonxLLM({
+ model: "ibm/granite-13b-chat-v2",
version: "2024-05-31",
serviceUrl: "sdadasdas" as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
- max_new_tokens: 5,
+ maxNewTokens: 5,
maxRetries: 3,
});
@@ -76,10 +94,11 @@ describe("Text generation", () => {
test("Signal in call options", async () => {
const watsonXInstance = new WatsonxLLM({
+ model: "ibm/granite-13b-chat-v2",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
- max_new_tokens: 5,
+ maxNewTokens: 5,
maxRetries: 3,
});
const controllerNoAbortion = new AbortController();
@@ -100,6 +119,7 @@ describe("Text generation", () => {
test("Concurenccy", async () => {
const model = new WatsonxLLM({
+ model: "ibm/granite-13b-chat-v2",
maxConcurrency: 1,
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
@@ -119,9 +139,10 @@ describe("Text generation", () => {
input_token_count: 0,
};
const model = new WatsonxLLM({
- maxConcurrency: 1,
+ model: "ibm/granite-13b-chat-v2",
version: "2024-05-31",
- max_new_tokens: 1,
+ maxNewTokens: 1,
+ maxConcurrency: 1,
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
callbacks: CallbackManager.fromHandlers({
@@ -150,10 +171,12 @@ describe("Text generation", () => {
let streamedText = "";
let usedTokens = 0;
const model = new WatsonxLLM({
+ model: "ibm/granite-13b-chat-v2",
+ maxConcurrency: 1,
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
- max_new_tokens: 5,
+ maxNewTokens: 5,
streaming: true,
callbacks: CallbackManager.fromHandlers({
@@ -176,10 +199,11 @@ describe("Text generation", () => {
describe("Test generate methods", () => {
test("Basic usage", async () => {
const model = new WatsonxLLM({
+ model: "ibm/granite-13b-chat-v2",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
- max_new_tokens: 5,
+ maxNewTokens: 5,
});
const res = await model.generate([
"Print hello world!",
@@ -190,10 +214,11 @@ describe("Text generation", () => {
test("Stop", async () => {
const model = new WatsonxLLM({
+ model: "ibm/granite-13b-chat-v2",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
- max_new_tokens: 100,
+ maxNewTokens: 100,
});
const res = await model.generate(
@@ -215,10 +240,11 @@ describe("Text generation", () => {
const nrNewTokens = [0, 0, 0];
const completions = ["", "", ""];
const model = new WatsonxLLM({
+ model: "ibm/granite-13b-chat-v2",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
- max_new_tokens: 5,
+ maxNewTokens: 5,
streaming: true,
callbacks: CallbackManager.fromHandlers({
async handleLLMNewToken(token: string, idx) {
@@ -245,10 +271,11 @@ describe("Text generation", () => {
test("Prompt value", async () => {
const model = new WatsonxLLM({
+ model: "ibm/granite-13b-chat-v2",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
- max_new_tokens: 5,
+ maxNewTokens: 5,
});
const res = await model.generatePrompt([
new StringPromptValue("Print hello world!"),
@@ -264,10 +291,11 @@ describe("Text generation", () => {
let countedTokens = 0;
let streamedText = "";
const model = new WatsonxLLM({
+ model: "ibm/granite-13b-chat-v2",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
- max_new_tokens: 100,
+ maxNewTokens: 100,
callbacks: CallbackManager.fromHandlers({
async handleLLMNewToken(token: string) {
countedTokens += 1;
@@ -286,10 +314,11 @@ describe("Text generation", () => {
test("Stop", async () => {
const model = new WatsonxLLM({
+ model: "ibm/granite-13b-chat-v2",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
- max_new_tokens: 100,
+ maxNewTokens: 100,
});
const stream = await model.stream("Print hello world!", {
@@ -304,10 +333,11 @@ describe("Text generation", () => {
test("Timeout", async () => {
const model = new WatsonxLLM({
+ model: "ibm/granite-13b-chat-v2",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
- max_new_tokens: 1000,
+ maxNewTokens: 1000,
});
await expect(async () => {
const stream = await model.stream(
@@ -325,10 +355,11 @@ describe("Text generation", () => {
test("Signal in call options", async () => {
const model = new WatsonxLLM({
+ model: "ibm/granite-13b-chat-v2",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
- max_new_tokens: 1000,
+ maxNewTokens: 1000,
});
const controller = new AbortController();
await expect(async () => {
@@ -354,6 +385,7 @@ describe("Text generation", () => {
describe("Test getNumToken method", () => {
test("Passing correct value", async () => {
const testProps: WatsonxInputLLM = {
+ model: "ibm/granite-13b-chat-v2",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
@@ -371,6 +403,7 @@ describe("Text generation", () => {
test("Passing wrong value", async () => {
const testProps: WatsonxInputLLM = {
+ model: "ibm/granite-13b-chat-v2",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID,
diff --git a/libs/langchain-community/src/llms/tests/ibm.test.ts b/libs/langchain-community/src/llms/tests/ibm.test.ts
index 7dfaecd6361c..6237cb1d14c1 100644
--- a/libs/langchain-community/src/llms/tests/ibm.test.ts
+++ b/libs/langchain-community/src/llms/tests/ibm.test.ts
@@ -3,10 +3,7 @@
import WatsonxAiMlVml_v1 from "@ibm-cloud/watsonx-ai/dist/watsonx-ai-ml/vml_v1.js";
import { WatsonxLLM, WatsonxInputLLM } from "../ibm.js";
import { authenticateAndSetInstance } from "../../utils/ibm.js";
-import {
- WatsonxEmbeddings,
- WatsonxEmbeddingsParams,
-} from "../../embeddings/ibm.js";
+import { WatsonxEmbeddings } from "../../embeddings/ibm.js";
const fakeAuthProp = {
watsonxAIAuthType: "iam",
@@ -38,7 +35,7 @@ export const testProperties = (
}
});
};
- checkProperty(testProps, instance);
+ checkProperty(testProps, instance);
if (notExTestProps)
checkProperty(notExTestProps, instance, false);
};
@@ -56,6 +53,7 @@ describe("LLM unit tests", () => {
test("Test basic properties after init", async () => {
const testProps = {
+ model: "ibm/granite-13b-chat-v2",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID || "testString",
@@ -67,6 +65,7 @@ describe("LLM unit tests", () => {
test("Test methods after init", () => {
const testProps: WatsonxInputLLM = {
+ model: "ibm/granite-13b-chat-v2",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID || "testString",
@@ -82,33 +81,32 @@ describe("LLM unit tests", () => {
});
test("Test properties after init", async () => {
- const testProps = {
+ const testProps: WatsonxInputLLM = {
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID || "testString",
model: "ibm/granite-13b-chat-v2",
- max_new_tokens: 100,
- decoding_method: "sample",
- length_penalty: { decay_factor: 1, start_index: 1 },
- min_new_tokens: 10,
- random_seed: 1,
- stop_sequences: ["hello"],
+ maxNewTokens: 100,
+ decodingMethod: "sample",
+ lengthPenalty: { decay_factor: 1, start_index: 1 },
+ minNewTokens: 10,
+ randomSeed: 1,
+ stopSequence: ["hello"],
temperature: 0.1,
- time_limit: 10000,
- top_k: 1,
- top_p: 1,
- repetition_penalty: 1,
- truncate_input_tokens: 1,
- return_options: {
+ timeLimit: 10000,
+ topK: 1,
+ topP: 1,
+ repetitionPenalty: 1,
+ truncateInpuTokens: 1,
+ returnOptions: {
input_text: true,
generated_tokens: true,
input_tokens: true,
token_logprobs: true,
token_ranks: true,
-
top_n_tokens: 2,
},
- include_stop_sequence: false,
+ includeStopSequence: false,
maxRetries: 3,
maxConcurrency: 3,
};
@@ -121,6 +119,7 @@ describe("LLM unit tests", () => {
describe("Negative tests", () => {
test("Missing id", async () => {
const testProps: WatsonxInputLLM = {
+ model: "ibm/granite-13b-chat-v2",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
};
@@ -171,6 +170,7 @@ describe("LLM unit tests", () => {
test("Passing more than one id", async () => {
const testProps: WatsonxInputLLM = {
+ model: "ibm/granite-13b-chat-v2",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID || "testString",
@@ -187,6 +187,7 @@ describe("LLM unit tests", () => {
test("Not existing property passed", async () => {
const testProps = {
+ model: "ibm/granite-13b-chat-v2",
version: "2024-05-31",
serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string,
projectId: process.env.WATSONX_AI_PROJECT_ID || "testString",
diff --git a/libs/langchain-community/src/types/ibm.ts b/libs/langchain-community/src/types/ibm.ts
index cd11592ee48b..ee5db8532036 100644
--- a/libs/langchain-community/src/types/ibm.ts
+++ b/libs/langchain-community/src/types/ibm.ts
@@ -18,7 +18,7 @@ export interface WatsonxInit {
}
export interface WatsonxParams extends WatsonxInit {
- model?: string;
+ model: string;
spaceId?: string;
projectId?: string;
idOrName?: string;