From 14fa21035841be276b80994c800b08fbb9d9581f Mon Sep 17 00:00:00 2001 From: FilipZmijewski Date: Tue, 12 Nov 2024 00:42:07 +0100 Subject: [PATCH] fix(community): For IBM implementation rename variables, remove defaults, fix tests and minor docs fixes (#7129) Co-authored-by: Jacob Lee --- .../docs/integrations/chat/ibm.ipynb | 103 ++------------- .../docs/integrations/llms/ibm.ipynb | 16 +-- .../integrations/text_embedding/ibm.ipynb | 93 ++++---------- .../src/chat_models/ibm.ts | 111 +++++++--------- .../src/chat_models/tests/ibm.int.test.ts | 86 ++++++------- .../tests/ibm.standard.int.test.ts | 1 + .../chat_models/tests/ibm.standard.test.ts | 1 + .../src/chat_models/tests/ibm.test.ts | 11 +- .../langchain-community/src/embeddings/ibm.ts | 32 +++-- .../src/embeddings/tests/ibm.int.test.ts | 4 + .../src/embeddings/tests/ibm.test.ts | 12 +- libs/langchain-community/src/llms/ibm.ts | 119 +++++++++--------- .../src/llms/tests/ibm.int.test.ts | 59 +++++++-- .../src/llms/tests/ibm.test.ts | 41 +++--- libs/langchain-community/src/types/ibm.ts | 2 +- 15 files changed, 287 insertions(+), 404 deletions(-) diff --git a/docs/core_docs/docs/integrations/chat/ibm.ipynb b/docs/core_docs/docs/integrations/chat/ibm.ipynb index 46cb7bf92d74..c3f60d925f99 100644 --- a/docs/core_docs/docs/integrations/chat/ibm.ipynb +++ b/docs/core_docs/docs/integrations/chat/ibm.ipynb @@ -21,14 +21,15 @@ "source": [ "# IBM watsonx.ai\n", "\n", - "This will help you getting started with IBM watsonx.ai [chat models](/docs/concepts/chat_models). For detailed documentation of all `IBM watsonx.ai` features and configurations head to the [IBM watsonx.ai](https://api.js.langchain.com/classes/_langchain_community.chat_models_ibm.html).\n", + "This will help you getting started with IBM watsonx.ai [chat models](/docs/concepts/chat_models). For detailed documentation of all `IBM watsonx.ai` features and configurations head to the [IBM watsonx.ai](https://api.js.langchain.com/modules/_langchain_community.chat_models_ibm.html).\n", "\n", "## Overview\n", "### Integration details\n", "\n", "| Class | Package | Local | Serializable | [PY support](https://python.langchain.com/docs/integrations/chat/ibm_watsonx/) | Package downloads | Package latest |\n", "| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n", - "| [`ChatWatsonx`](https://api.js.langchain.com/classes/_langchain_community.chat_models_ibm.html) | [@langchain/community](https://api.js.langchain.com/modules/langchain_community_llms_ibm.html) | ❌ | ✅ | ✅ | ![NPM - Downloads](https://img.shields.io/npm/dm/@langchain/community?style=flat-square&label=%20&) | ![NPM - Version](https://img.shields.io/npm/v/@langchain/community?style=flat-square&label=%20&) |\n", + "| [`ChatWatsonx`](https://api.js.langchain.com/classes/_langchain_community.chat_models_ibm.ChatWatsonx.html) | [@langchain/community](https://www.npmjs.com/package/@langchain/community) | ❌ | ✅ | ✅ | ![NPM - Downloads](https://img.shields.io/npm/dm/@langchain/community?style=flat-square&label=%20&) | ![NPM - Version](https://img.shields.io/npm/v/@langchain/community?style=flat-square&label=%20&) |\n", + "\n", "### Model features\n", "\n", @@ -138,7 +139,7 @@ "\n", "\n", "\n", - " __package_name__ @langchain/core\n", + " @langchain/community @langchain/core\n", "\n", "\n", "```" @@ -340,7 +341,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 1, "id": "cd21e356", "metadata": {}, "outputs": [ @@ -357,97 +358,7 @@ " only\n", " natural\n", " satellite\n", - " and\n", - " the\n", - " fifth\n", - " largest\n", - " satellite\n", - " in\n", - " the\n", - " Solar\n", - " System\n", - ".\n", - " It\n", - " or\n", - "bits\n", - " Earth\n", - " every\n", - " \n", - "2\n", - "7\n", - ".\n", - "3\n", - " days\n", - " and\n", - " rot\n", - "ates\n", - " on\n", - " its\n", - " axis\n", - " in\n", - " the\n", - " same\n", - " amount\n", - " of\n", - " time\n", - ",\n", - " which\n", - " is\n", - " why\n", - " we\n", - " always\n", - " see\n", - " the\n", - " same\n", - " side\n", - " of\n", - " it\n", - ".\n", - " The\n", - " Moon\n", - "'\n", - "s\n", - " phases\n", - " change\n", - " as\n", - " it\n", - " or\n", - "bits\n", - " Earth\n", - ",\n", - " going\n", - " through\n", - " cycles\n", - " of\n", - " new\n", - ",\n", - " c\n", - "res\n", - "cent\n", - ",\n", - " half\n", - ",\n", - " g\n", - "ib\n", - "b\n", - "ous\n", - ",\n", - " and\n", - " full\n", - " phases\n", - ".\n", - " Its\n", - " gravity\n", - " influences\n", - " Earth\n", - "'\n", - "s\n", - " t\n", - "ides\n", - " and\n", - " stabil\n", - "izes\n", - " our\n" + " and\n" ] } ], @@ -589,4 +500,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/docs/core_docs/docs/integrations/llms/ibm.ipynb b/docs/core_docs/docs/integrations/llms/ibm.ipynb index 1644f7401724..7c47f382e59e 100644 --- a/docs/core_docs/docs/integrations/llms/ibm.ipynb +++ b/docs/core_docs/docs/integrations/llms/ibm.ipynb @@ -22,7 +22,7 @@ "# IBM watsonx.ai\n", "\n", "\n", - "This will help you get started with IBM [text completion models (LLMs)](/docs/concepts/text_llms) using LangChain. For detailed documentation on `IBM watsonx.ai` features and configuration options, please refer to the [IBM watsonx.ai](https://api.js.langchain.com/classes/_langchain_community.llms_ibm.html).\n", + "This will help you get started with IBM [text completion models (LLMs)](/docs/concepts/text_llms) using LangChain. For detailed documentation on `IBM watsonx.ai` features and configuration options, please refer to the [IBM watsonx.ai](https://api.js.langchain.com/modules/_langchain_community.llms_ibm.html).\n", "\n", "## Overview\n", "### Integration details\n", @@ -30,7 +30,7 @@ "\n", "| Class | Package | Local | Serializable | [PY support](https://python.langchain.com/docs/integrations/llms/ibm_watsonx/) | Package downloads | Package latest |\n", "| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n", - "| [`IBM watsonx.ai`](https://api.js.langchain.com/modules/_langchain_community.llms_ibm.html) | [@langchain/community](https://api.js.langchain.com/modules/langchain_community_llms_ibm.html) | ❌ | ✅ | ✅ | ![NPM - Downloads](https://img.shields.io/npm/dm/@langchain/community?style=flat-square&label=%20&) | ![NPM - Version](https://img.shields.io/npm/v/@langchain/community?style=flat-square&label=%20&) |\n", + "| [`WatsonxLLM`](https://api.js.langchain.com/classes/_langchain_community.llms_ibm.WatsonxLLM.html) | [@langchain/community](https://www.npmjs.com/package/@langchain/community) | ❌ | ✅ | ✅ | ![NPM - Downloads](https://img.shields.io/npm/dm/@langchain/community?style=flat-square&label=%20&) | ![NPM - Version](https://img.shields.io/npm/v/@langchain/community?style=flat-square&label=%20&) |\n", "\n", "## Setup\n", "\n", @@ -161,11 +161,11 @@ "\n", "const props = {\n", " decoding_method: \"sample\",\n", - " max_new_tokens: 100,\n", - " min_new_tokens: 1,\n", + " maxNewTokens: 100,\n", + " minNewTokens: 1,\n", " temperature: 0.5,\n", - " top_k: 50,\n", - " top_p: 1,\n", + " topK: 50,\n", + " topP: 1,\n", "};\n", "const instance = new WatsonxLLM({\n", " version: \"YYYY-MM-DD\",\n", @@ -298,7 +298,7 @@ "source": [ "const result2 = await instance.invoke(\"Print hello world.\", {\n", " parameters: {\n", - " max_new_tokens: 20,\n", + " maxNewTokens: 100,\n", " },\n", "});\n", "console.log(result2);" @@ -358,4 +358,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/docs/core_docs/docs/integrations/text_embedding/ibm.ipynb b/docs/core_docs/docs/integrations/text_embedding/ibm.ipynb index dfe43b07c462..bac03b424f48 100644 --- a/docs/core_docs/docs/integrations/text_embedding/ibm.ipynb +++ b/docs/core_docs/docs/integrations/text_embedding/ibm.ipynb @@ -22,7 +22,7 @@ "# IBM watsonx.ai\n", "\n", "\n", - "This will help you get started with IBM watsonx.ai [embedding models](/docs/concepts/embedding_models) using LangChain. For detailed documentation on `IBM watsonx.ai` features and configuration options, please refer to the [API reference](https://api.js.langchain.com/classes/_langchain_community.embeddings_ibm.html).\n", + "This will help you get started with IBM watsonx.ai [embedding models](/docs/concepts/embedding_models) using LangChain. For detailed documentation on `IBM watsonx.ai` features and configuration options, please refer to the [API reference](https://api.js.langchain.com/modules/_langchain_community.embeddings_ibm.html).\n", "\n", "## Overview\n", "### Integration details\n", @@ -30,7 +30,7 @@ "\n", "| Class | Package | Local | [Py support](https://python.langchain.com/docs/integrations/text_embedding/ibm_watsonx/) | Package downloads | Package latest |\n", "| :--- | :--- | :---: | :---: | :---: | :---: |\n", - "| [`IBM watsonx.ai`](https://api.js.langchain.com/classes/_langchain_community.embeddings_ibm.WatsonxEmbeddings.html) | [@langchain/community](https://api.js.langchain.com/modules/langchain_community_llms_ibm.html)| ❌ | ✅ | ![NPM - Downloads](https://img.shields.io/npm/dm/@langchain/community?style=flat-square&label=%20&) | ![NPM - Version](https://img.shields.io/npm/v/@langchain/community?style=flat-square&label=%20&) |\n", + "| [`WatsonxEmbeddings`](https://api.js.langchain.com/classes/_langchain_community.embeddings_ibm.WatsonxEmbeddings.html) | [@langchain/community](https://www.npmjs.com/package/@langchain/community)| ❌ | ✅ | ![NPM - Downloads](https://img.shields.io/npm/dm/@langchain/community?style=flat-square&label=%20&) | ![NPM - Version](https://img.shields.io/npm/v/@langchain/community?style=flat-square&label=%20&) |\n", "\n", "## Setup\n", "\n", @@ -163,7 +163,6 @@ " serviceUrl: process.env.API_URL,\n", " projectId: \"\",\n", " spaceId: \"\",\n", - " idOrName: \"\",\n", " model: \"\",\n", "});" ] @@ -175,7 +174,7 @@ "source": [ "Note:\n", "\n", - "- You must provide `spaceId`, `projectId` or `idOrName`(deployment id) in order to proceed.\n", + "- You must provide `spaceId` or `projectId` in order to proceed.\n", "- Depending on the region of your provisioned service instance, use correct serviceUrl." ] }, @@ -243,7 +242,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 1, "id": "0d2befcd", "metadata": {}, "outputs": [ @@ -252,33 +251,18 @@ "output_type": "stream", "text": [ "[\n", - " -0.017436018, -0.01469498, -0.015685871, -0.013543149, -0.0011519607,\n", - " -0.008123747, 0.015286108, -0.023845721, -0.02454774, 0.07235078,\n", - " -0.032333843, -0.0035843418, -0.015389036, 0.0455373, -0.021119863,\n", - " -0.022039745, 0.021746712, -0.017774817, -0.008232582, -0.036727764,\n", - " -0.015734928, 0.03606811, -0.005108186, -0.036052454, 0.024462992,\n", - " 0.02359307, 0.03273164, 0.009195497, -0.0077208397, -0.0127943,\n", - " -0.023869334, -0.029473905, -0.0080457395, -0.0021337876, 0.04949132,\n", - " 0.013950589, -0.010046689, 0.021029025, -0.031725302, 0.004251065,\n", - " -0.034171984, -0.03696642, -0.014253629, -0.017757406, -0.007531065,\n", - " 0.07187789, 0.009661725, 0.041889492, -0.04660478, 0.028036641,\n", - " 0.059334517, -0.04561291, 0.056029715, -0.00676024, 0.026493236,\n", - " 0.0116374, 0.050126843, -0.018036349, -0.013711887, 0.042252757,\n", - " -0.04453391, 0.04705777, -0.00044598224, -0.030227259, 0.029286578,\n", - " 0.0252211, 0.011694125, -0.031404093, 0.02951232, 0.08812359,\n", - " 0.023539362, -0.011082862, 0.008024676, 0.00084492035, -0.007984158,\n", - " -0.0005008702, -0.025189219, 0.021000557, -0.0065513053, 0.036524914,\n", - " 0.0015150858, -0.0042383806, 0.049065087, 0.000941666, 0.04447001,\n", - " 0.012942205, -0.078316726, -0.03004237, -0.025807172, -0.03446275,\n", - " -0.00932942, -0.044925686, 0.03190307, 0.010136769, -0.048854534,\n", - " 0.025738232, -0.017840309, 0.023738133, 0.014214792, 0.030452395\n", + " -0.017436018, -0.01469498,\n", + " -0.015685871, -0.013543149,\n", + " -0.0011519607, -0.008123747,\n", + " 0.015286108, -0.023845721,\n", + " -0.02454774, 0.07235078\n", "]\n" ] } ], "source": [ " const singleVector = await embeddings.embedQuery(text);\n", - " singleVector.slice(0, 100);" + " singleVector.slice(0, 10);" ] }, { @@ -302,48 +286,18 @@ "output_type": "stream", "text": [ "[\n", - " -0.017436024, -0.014695002, -0.01568589, -0.013543164, -0.001151976,\n", - " -0.008123703, 0.015286064, -0.023845702, -0.024547677, 0.07235076,\n", - " -0.032333862, -0.0035843418, -0.015389038, 0.045537304, -0.021119865,\n", - " -0.02203975, 0.021746716, -0.01777481, -0.008232588, -0.03672781,\n", - " -0.015734889, 0.036068108, -0.0051082, -0.036052432, 0.024462998,\n", - " 0.023593083, 0.03273162, 0.009195521, -0.007720828, -0.012794304,\n", - " -0.023869323, -0.029473891, -0.008045726, -0.002133793, 0.049491342,\n", - " 0.013950573, -0.010046691, 0.02102898, -0.03172528, 0.0042510596,\n", - " -0.034171965, -0.036966413, -0.014253668, -0.017757434, -0.007531062,\n", - " 0.07187787, 0.009661732, 0.041889492, -0.04660476, 0.028036654,\n", - " 0.059334517, -0.045612894, 0.056029722, -0.00676024, 0.026493296,\n", - " 0.0116374055, 0.050126873, -0.018036384, -0.013711868, 0.0422528,\n", - " -0.044533912, 0.047057763, -0.00044596897, -0.030227251, 0.029286569,\n", - " 0.025221113, 0.011694138, -0.03140413, 0.029512335, 0.08812357,\n", - " 0.023539348, -0.011082865, 0.008024677, 0.00084490055, -0.007984145,\n", - " -0.0005008745, -0.025189226, 0.021000564, -0.0065513197, 0.036524955,\n", - " 0.0015150585, -0.0042383634, 0.049065102, 0.000941638, 0.044469994,\n", - " 0.012942193, -0.078316696, -0.0300424, -0.025807157, -0.0344627,\n", - " -0.009329439, -0.04492573, 0.031903077, 0.010136808, -0.048854522,\n", - " 0.025738247, -0.01784033, 0.023738142, 0.014214801, 0.030452369\n", + " -0.017436024, -0.014695002,\n", + " -0.01568589, -0.013543164,\n", + " -0.001151976, -0.008123703,\n", + " 0.015286064, -0.023845702,\n", + " -0.024547677, 0.07235076\n", "]\n", "[\n", - " 0.03278884, -0.017893745, -0.0027520044, 0.016506646, 0.028271576,\n", - " -0.01284331, 0.014344065, -0.007968607, -0.03899479, 0.039327156,\n", - " -0.047726233, 0.009559004, -0.05302522, 0.011498492, -0.0055542476,\n", - " -0.0020940166, -0.029262392, -0.025919685, 0.024261741, -0.0010863725,\n", - " 0.0074619935, 0.014191284, -0.009054746, -0.038633537, 0.039744128,\n", - " 0.012625762, 0.030490868, 0.013526139, -0.024638629, -0.011268263,\n", - " -0.012759613, -0.04693565, -0.013087251, -0.01971696, 0.0125782555,\n", - " 0.024156926, -0.011638484, 0.017364893, -0.0405832, -0.0032466082,\n", - " -0.01611277, -0.022583133, 0.019492855, -0.03664484, -0.022627067,\n", - " 0.011026938, -0.014631298, 0.043255687, -0.029447634, 0.017212389,\n", - " 0.029366229, -0.041978795, 0.005347565, -0.0106230285, -0.008334342,\n", - " -0.008841154, 0.045096103, 0.03996879, -0.002039457, -0.0051824683,\n", - " -0.019464444, 0.092018366, -0.009283633, -0.020052811, 0.0043408144,\n", - " -0.029403884, 0.02587689, -0.027253918, 0.0159064, 0.0421537,\n", - " 0.05078811, -0.012380686, -0.018032575, 0.01711449, 0.03636163,\n", - " -0.014590949, -0.015076142, 0.00018201554, 0.002490666, 0.044776678,\n", - " 0.05301749, -0.007891316, 0.028668318, -0.0016632816, 0.04487743,\n", - " -0.032529455, -0.040372133, -0.020566158, -0.011109745, -0.01724949,\n", - " -0.0047519016, -0.041635286, 0.0068111843, 0.039498538, -0.02491227,\n", - " 0.016853934, -0.017926402, -0.006154979, 0.025893573, 0.015262395\n", + " 0.03278884, -0.017893745,\n", + " -0.0027520044, 0.016506646,\n", + " 0.028271576, -0.01284331,\n", + " 0.014344065, -0.007968607,\n", + " -0.03899479, 0.039327156\n", "]\n" ] } @@ -355,9 +309,8 @@ "\n", " const vectors = await embeddings.embedDocuments([text, text2]);\n", " \n", - " console.log(vectors[0].slice(0, 100));\n", - " console.log(vectors[1].slice(0, 100));\n", - " " + " console.log(vectors[0].slice(0, 10));\n", + " console.log(vectors[1].slice(0, 10));\n" ] }, { @@ -386,4 +339,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} \ No newline at end of file +} diff --git a/libs/langchain-community/src/chat_models/ibm.ts b/libs/langchain-community/src/chat_models/ibm.ts index dd468909a886..d4dc6a64ba28 100644 --- a/libs/langchain-community/src/chat_models/ibm.ts +++ b/libs/langchain-community/src/chat_models/ibm.ts @@ -33,7 +33,6 @@ import { } from "@langchain/core/outputs"; import { AsyncCaller } from "@langchain/core/utils/async_caller"; import { - TextChatConstants, TextChatMessagesTextChatMessageAssistant, TextChatParameterTools, TextChatParams, @@ -42,7 +41,6 @@ import { TextChatResultChoice, TextChatResultMessage, TextChatToolCall, - TextChatToolChoiceTool, TextChatUsage, } from "@ibm-cloud/watsonx-ai/dist/watsonx-ai-ml/vml_v1.js"; import { WatsonXAI } from "@ibm-cloud/watsonx-ai"; @@ -81,30 +79,8 @@ export interface WatsonxDeltaStream { } export interface WatsonxCallParams - extends Partial< - Omit< - TextChatParams, - | "toolChoiceOption" - | "toolChoice" - | "frequencyPenalty" - | "topLogprobs" - | "maxTokens" - | "presencePenalty" - | "responseFormat" - | "timeLimit" - | "modelId" - > - > { + extends Partial> { maxRetries?: number; - tool_choice?: TextChatToolChoiceTool; - tool_choice_option?: TextChatConstants.ToolChoiceOption | string; - frequency_penalty?: number; - top_logprobs?: number; - max_new_tokens?: number; - presence_penalty?: number; - top_p?: number; - time_limit?: number; - response_format?: TextChatResponseFormat; } export interface WatsonxCallOptionsChat extends Omit, @@ -114,12 +90,15 @@ export interface WatsonxCallOptionsChat type ChatWatsonxToolType = BindToolsInput | TextChatParameterTools; -export interface ChatWatsonxInput extends BaseChatModelParams, WatsonxParams { +export interface ChatWatsonxInput + extends BaseChatModelParams, + WatsonxParams, + WatsonxCallParams { streaming?: boolean; } -function _convertToValidToolId(modelId: string, tool_call_id: string) { - if (modelId.startsWith("mistralai")) +function _convertToValidToolId(model: string, tool_call_id: string) { + if (model.startsWith("mistralai")) return _convertToolCallIdToMistralCompatible(tool_call_id); else return tool_call_id; } @@ -144,7 +123,7 @@ function _convertToolToWatsonxTool( function _convertMessagesToWatsonxMessages( messages: BaseMessage[], - modelId: string + model: string ): TextChatResultMessage[] { const getRole = (role: MessageType) => { switch (role) { @@ -168,7 +147,7 @@ function _convertMessagesToWatsonxMessages( return message.tool_calls .map((toolCall) => ({ ...toolCall, - id: _convertToValidToolId(modelId, toolCall.id ?? ""), + id: _convertToValidToolId(model, toolCall.id ?? ""), })) .map(convertLangChainToolCallToOpenAI) as TextChatToolCall[]; } @@ -183,7 +162,7 @@ function _convertMessagesToWatsonxMessages( role: getRole(message._getType()), content, name: message.name, - tool_call_id: _convertToValidToolId(modelId, message.tool_call_id), + tool_call_id: _convertToValidToolId(model, message.tool_call_id), }; } @@ -246,7 +225,7 @@ function _watsonxResponseToChatMessage( function _convertDeltaToMessageChunk( delta: WatsonxDeltaStream, rawData: TextChatResponse, - modelId: string, + model: string, usage?: TextChatUsage, defaultRole?: TextChatMessagesTextChatMessageAssistant.Constants.Role ) { @@ -262,7 +241,7 @@ function _convertDeltaToMessageChunk( } => ({ ...toolCall, index, - id: _convertToValidToolId(modelId, toolCall.id), + id: _convertToValidToolId(model, toolCall.id), type: "function", }) ) @@ -315,7 +294,7 @@ function _convertDeltaToMessageChunk( return new ToolMessageChunk({ content, additional_kwargs, - tool_call_id: _convertToValidToolId(modelId, rawToolCalls?.[0].id), + tool_call_id: _convertToValidToolId(model, rawToolCalls?.[0].id), }); } else if (role === "function") { return new FunctionMessageChunk({ @@ -379,11 +358,11 @@ export class ChatWatsonx< }; } - model = "mistralai/mistral-large"; + model: string; version = "2024-05-31"; - max_new_tokens = 100; + maxTokens: number; maxRetries = 0; @@ -393,35 +372,31 @@ export class ChatWatsonx< projectId?: string; - frequency_penalty?: number; + frequencyPenalty?: number; logprobs?: boolean; - top_logprobs?: number; + topLogprobs?: number; n?: number; - presence_penalty?: number; + presencePenalty?: number; temperature?: number; - top_p?: number; + topP?: number; - time_limit?: number; + timeLimit?: number; maxConcurrency?: number; service: WatsonXAI; - response_format?: TextChatResponseFormat | string; + responseFormat?: TextChatResponseFormat; streaming: boolean; - constructor( - fields: ChatWatsonxInput & - WatsonxAuth & - Partial> - ) { + constructor(fields: ChatWatsonxInput & WatsonxAuth) { super(fields); if ( (fields.projectId && fields.spaceId) || @@ -432,20 +407,20 @@ export class ChatWatsonx< if (!fields.projectId && !fields.spaceId && !fields.idOrName) throw new Error( - "No id specified! At least ide of 1 type has to be specified" + "No id specified! At least id of 1 type has to be specified" ); this.projectId = fields?.projectId; this.spaceId = fields?.spaceId; this.temperature = fields?.temperature; this.maxRetries = fields?.maxRetries || this.maxRetries; this.maxConcurrency = fields?.maxConcurrency; - this.frequency_penalty = fields?.frequency_penalty; - this.top_logprobs = fields?.top_logprobs; - this.max_new_tokens = fields?.max_new_tokens ?? this.max_new_tokens; - this.presence_penalty = fields?.presence_penalty; - this.top_p = fields?.top_p; - this.time_limit = fields?.time_limit; - this.response_format = fields?.response_format ?? this.response_format; + this.frequencyPenalty = fields?.frequencyPenalty; + this.topLogprobs = fields?.topLogprobs; + this.maxTokens = fields?.maxTokens ?? this.maxTokens; + this.presencePenalty = fields?.presencePenalty; + this.topP = fields?.topP; + this.timeLimit = fields?.timeLimit; + this.responseFormat = fields?.responseFormat ?? this.responseFormat; this.serviceUrl = fields?.serviceUrl; this.streaming = fields?.streaming ?? this.streaming; this.n = fields?.n ?? this.n; @@ -483,21 +458,21 @@ export class ChatWatsonx< invocationParams(options: this["ParsedCallOptions"]) { return { - maxTokens: options.max_new_tokens ?? this.max_new_tokens, + maxTokens: options.maxTokens ?? this.maxTokens, temperature: options?.temperature ?? this.temperature, - timeLimit: options?.time_limit ?? this.time_limit, - topP: options?.top_p ?? this.top_p, - presencePenalty: options?.presence_penalty ?? this.presence_penalty, + timeLimit: options?.timeLimit ?? this.timeLimit, + topP: options?.topP ?? this.topP, + presencePenalty: options?.presencePenalty ?? this.presencePenalty, n: options?.n ?? this.n, - topLogprobs: options?.top_logprobs ?? this.top_logprobs, + topLogprobs: options?.topLogprobs ?? this.topLogprobs, logprobs: options?.logprobs ?? this?.logprobs, - frequencyPenalty: options?.frequency_penalty ?? this.frequency_penalty, + frequencyPenalty: options?.frequencyPenalty ?? this.frequencyPenalty, tools: options.tools ? _convertToolToWatsonxTool(options.tools) : undefined, - toolChoice: options.tool_choice, - responseFormat: options.response_format, - toolChoiceOption: options.tool_choice_option, + toolChoice: options.toolChoice, + responseFormat: options.responseFormat, + toolChoiceOption: options.toolChoiceOption, }; } @@ -556,7 +531,7 @@ export class ChatWatsonx< if (message?.usage_metadata) { const completion = chunk.generationInfo?.completion; if (tokenUsages[completion]) - tokenUsages[completion].output_tokens += + tokenUsages[completion].output_tokens = message.usage_metadata.output_tokens; else tokenUsages[completion] = message.usage_metadata; } @@ -759,7 +734,7 @@ export class ChatWatsonx< let llm: Runnable; if (method === "jsonMode") { const options = { - response_format: { type: "json_object" }, + responseFormat: { type: "json_object" }, } as Partial; llm = this.bind(options); @@ -783,7 +758,7 @@ export class ChatWatsonx< }, ], // Ideally that would be set to required but this is not supported yet - tool_choice: { + toolChoice: { type: "function", function: { name: functionName, @@ -819,7 +794,7 @@ export class ChatWatsonx< }, ], // Ideally that would be set to required but this is not supported yet - tool_choice: { + toolChoice: { type: "function", function: { name: functionName, diff --git a/libs/langchain-community/src/chat_models/tests/ibm.int.test.ts b/libs/langchain-community/src/chat_models/tests/ibm.int.test.ts index 85f33d733e6d..2f1d118d92a4 100644 --- a/libs/langchain-community/src/chat_models/tests/ibm.int.test.ts +++ b/libs/langchain-community/src/chat_models/tests/ibm.int.test.ts @@ -12,15 +12,13 @@ import { LLMResult } from "@langchain/core/outputs"; import { ChatPromptTemplate } from "@langchain/core/prompts"; import { tool } from "@langchain/core/tools"; import { NewTokenIndices } from "@langchain/core/callbacks/base"; -import * as fs from "node:fs/promises"; -import { fileURLToPath } from "node:url"; -import * as path from "node:path"; import { ChatWatsonx } from "../ibm.js"; describe("Tests for chat", () => { describe("Test ChatWatsonx invoke and generate", () => { test("Basic invoke", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -30,6 +28,7 @@ describe("Tests for chat", () => { }); test("Basic generate", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -40,6 +39,7 @@ describe("Tests for chat", () => { }); test("Invoke with system message", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -53,6 +53,7 @@ describe("Tests for chat", () => { }); test("Invoke with output parser", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -68,6 +69,7 @@ describe("Tests for chat", () => { }); test("Invoke with prompt", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -83,6 +85,7 @@ describe("Tests for chat", () => { }); test("Invoke with chat conversation", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -104,6 +107,7 @@ describe("Tests for chat", () => { totalTokens: 0, }; const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -120,6 +124,7 @@ describe("Tests for chat", () => { }); test("Timeout", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -132,6 +137,7 @@ describe("Tests for chat", () => { }, 5000); test("Controller options", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -150,6 +156,7 @@ describe("Tests for chat", () => { describe("Test ChatWatsonx invoke and generate with stream mode", () => { test("Basic invoke", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -160,6 +167,7 @@ describe("Tests for chat", () => { }); test("Basic generate", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -170,6 +178,7 @@ describe("Tests for chat", () => { }); test("Generate with n>1", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -198,11 +207,12 @@ describe("Tests for chat", () => { ]; let tokenUsed = 0; const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", n: 2, - max_new_tokens: 5, + maxTokens: 5, streaming: true, callbackManager: CallbackManager.fromHandlers({ async handleLLMEnd(output: LLMResult) { @@ -236,6 +246,7 @@ describe("Tests for chat", () => { }); test("Invoke with system message", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -249,6 +260,7 @@ describe("Tests for chat", () => { }); test("Invoke with output parser", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -264,6 +276,7 @@ describe("Tests for chat", () => { }); test("Invoke with prompt", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -279,6 +292,7 @@ describe("Tests for chat", () => { }); test("Invoke with chat conversation", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -300,6 +314,7 @@ describe("Tests for chat", () => { totalTokens: 0, }; const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -316,6 +331,7 @@ describe("Tests for chat", () => { }); test("Timeout", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -328,6 +344,7 @@ describe("Tests for chat", () => { }, 5000); test("Controller options", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -346,6 +363,7 @@ describe("Tests for chat", () => { describe("Test ChatWatsonx stream", () => { test("Basic stream", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -366,6 +384,7 @@ describe("Tests for chat", () => { }); test("Timeout", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -378,6 +397,7 @@ describe("Tests for chat", () => { }, 5000); test("Controller options", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -399,6 +419,7 @@ describe("Tests for chat", () => { test("Token count and response equality", async () => { let generation = ""; const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -426,21 +447,24 @@ describe("Tests for chat", () => { }); test("Token count usage_metadata", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", }); let res: AIMessageChunk | null = null; + let outputCount = 0; const stream = await service.stream("Why is the sky blue? Be concise."); for await (const chunk of stream) { res = chunk; + outputCount += 1; } expect(res?.usage_metadata).toBeDefined(); if (!res?.usage_metadata) { return; } expect(res.usage_metadata.input_tokens).toBeGreaterThan(1); - expect(res.usage_metadata.output_tokens).toBe(1); + expect(res.usage_metadata.output_tokens).toBe(outputCount); expect(res.usage_metadata.total_tokens).toBe( res.usage_metadata.input_tokens + res.usage_metadata.output_tokens ); @@ -450,6 +474,7 @@ describe("Tests for chat", () => { describe("Test tool usage", () => { test("Passing tool to chat model", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -503,6 +528,7 @@ describe("Tests for chat", () => { }); test("Passing tool to chat model extended", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -563,6 +589,7 @@ describe("Tests for chat", () => { }); test("Binding model-specific formats", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -603,6 +630,7 @@ describe("Tests for chat", () => { }); test("Passing tool to chat model", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -655,6 +683,7 @@ describe("Tests for chat", () => { describe("Test withStructuredOutput usage", () => { test("Schema with zod", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -677,6 +706,7 @@ describe("Tests for chat", () => { test("Schema with zod and stream", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -703,6 +733,7 @@ describe("Tests for chat", () => { }); test("Schema with object", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -729,6 +760,7 @@ describe("Tests for chat", () => { }); test("Schema with rawOutput", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -761,6 +793,7 @@ describe("Tests for chat", () => { }); test("Schema with zod and JSON mode", async () => { const service = new ChatWatsonx({ + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", @@ -798,47 +831,4 @@ describe("Tests for chat", () => { expect(typeof result.number2).toBe("number"); }); }); - - describe("Test image input", () => { - test("Image input", async () => { - const service = new ChatWatsonx({ - version: "2024-05-31", - serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", - model: "meta-llama/llama-3-2-11b-vision-instruct", - projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", - max_new_tokens: 100, - }); - const __filename = fileURLToPath(import.meta.url); - const __dirname = path.dirname(__filename); - const encodedString = await fs.readFile( - path.join(__dirname, "/data/hotdog.jpg") - ); - const question = "What is on the picture"; - const messages = [ - { - role: "user", - content: [ - { - type: "text", - text: question, - }, - { - type: "image_url", - image_url: { - url: - "data:image/jpeg;base64," + encodedString.toString("base64"), - }, - }, - ], - }, - ]; - const res = await service.stream(messages); - const chunks = []; - for await (const chunk of res) { - expect(chunk).toBeInstanceOf(AIMessageChunk); - chunks.push(chunk.content); - } - expect(typeof chunks.join("")).toBe("string"); - }); - }); }); diff --git a/libs/langchain-community/src/chat_models/tests/ibm.standard.int.test.ts b/libs/langchain-community/src/chat_models/tests/ibm.standard.int.test.ts index 03b8eb4b3351..545ed3c06fa9 100644 --- a/libs/langchain-community/src/chat_models/tests/ibm.standard.int.test.ts +++ b/libs/langchain-community/src/chat_models/tests/ibm.standard.int.test.ts @@ -26,6 +26,7 @@ class ChatWatsonxStandardIntegrationTests extends ChatModelIntegrationTests< chatModelHasToolCalling: true, chatModelHasStructuredOutput: true, constructorArgs: { + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", projectId: process.env.WATSONX_AI_PROJECT_ID ?? "testString", diff --git a/libs/langchain-community/src/chat_models/tests/ibm.standard.test.ts b/libs/langchain-community/src/chat_models/tests/ibm.standard.test.ts index 6c7ab7d5576a..da9a624209c9 100644 --- a/libs/langchain-community/src/chat_models/tests/ibm.standard.test.ts +++ b/libs/langchain-community/src/chat_models/tests/ibm.standard.test.ts @@ -24,6 +24,7 @@ class ChatWatsonxStandardTests extends ChatModelUnitTests< chatModelHasToolCalling: true, chatModelHasStructuredOutput: true, constructorArgs: { + model: "mistralai/mistral-large", watsonxAIApikey: "testString", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL ?? "testString", diff --git a/libs/langchain-community/src/chat_models/tests/ibm.test.ts b/libs/langchain-community/src/chat_models/tests/ibm.test.ts index 8e04c1c26c6b..f52a689f6755 100644 --- a/libs/langchain-community/src/chat_models/tests/ibm.test.ts +++ b/libs/langchain-community/src/chat_models/tests/ibm.test.ts @@ -52,6 +52,7 @@ describe("LLM unit tests", () => { test("Test basic properties after init", async () => { const testProps = { + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID || "testString", @@ -63,6 +64,7 @@ describe("LLM unit tests", () => { test("Test methods after init", () => { const testProps: ChatWatsonxInput = { + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID || "testString", @@ -83,10 +85,10 @@ describe("LLM unit tests", () => { serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID || "testString", model: "ibm/granite-13b-chat-v2", - max_new_tokens: 100, + maxTokens: 100, temperature: 0.1, - time_limit: 10000, - top_p: 1, + timeLimit: 10000, + topP: 1, maxRetries: 3, maxConcurrency: 3, }; @@ -99,6 +101,7 @@ describe("LLM unit tests", () => { describe("Negative tests", () => { test("Missing id", async () => { const testProps: ChatWatsonxInput = { + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, }; @@ -149,6 +152,7 @@ describe("LLM unit tests", () => { test("Passing more than one id", async () => { const testProps: ChatWatsonxInput = { + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID || "testString", @@ -165,6 +169,7 @@ describe("LLM unit tests", () => { test("Not existing property passed", async () => { const testProps = { + model: "mistralai/mistral-large", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID || "testString", diff --git a/libs/langchain-community/src/embeddings/ibm.ts b/libs/langchain-community/src/embeddings/ibm.ts index fc8fcaebb561..9ee0f39976f9 100644 --- a/libs/langchain-community/src/embeddings/ibm.ts +++ b/libs/langchain-community/src/embeddings/ibm.ts @@ -9,14 +9,20 @@ import { WatsonxAuth, WatsonxParams } from "../types/ibm.js"; import { authenticateAndSetInstance } from "../utils/ibm.js"; export interface WatsonxEmbeddingsParams - extends Omit, - Pick {} + extends Pick { + truncateInputTokens?: number; +} + +export interface WatsonxInputEmbeddings + extends Omit { + truncateInputTokens?: number; +} export class WatsonxEmbeddings extends Embeddings implements WatsonxEmbeddingsParams, WatsonxParams { - model = "ibm/slate-125m-english-rtrvr"; + model: string; serviceUrl: string; @@ -26,7 +32,7 @@ export class WatsonxEmbeddings projectId?: string; - truncate_input_tokens?: number; + truncateInputTokens?: number; maxRetries?: number; @@ -34,18 +40,18 @@ export class WatsonxEmbeddings private service: WatsonXAI; - constructor(fields: WatsonxEmbeddingsParams & WatsonxAuth & WatsonxParams) { + constructor(fields: WatsonxInputEmbeddings & WatsonxAuth) { const superProps = { maxConcurrency: 2, ...fields }; super(superProps); - this.model = fields?.model ? fields.model : this.model; + this.model = fields.model; this.version = fields.version; this.serviceUrl = fields.serviceUrl; - this.truncate_input_tokens = fields.truncate_input_tokens; + this.truncateInputTokens = fields.truncateInputTokens; this.maxConcurrency = fields.maxConcurrency; - this.maxRetries = fields.maxRetries; + this.maxRetries = fields.maxRetries ?? 0; if (fields.projectId && fields.spaceId) throw new Error("Maximum 1 id type can be specified per instance"); - else if (!fields.projectId && !fields.spaceId && !fields.idOrName) + else if (!fields.projectId && !fields.spaceId) throw new Error( "No id specified! At least id of 1 type has to be specified" ); @@ -77,13 +83,14 @@ export class WatsonxEmbeddings } scopeId() { - if (this.projectId) return { projectId: this.projectId }; - else return { spaceId: this.spaceId }; + if (this.projectId) + return { projectId: this.projectId, modelId: this.model }; + else return { spaceId: this.spaceId, modelId: this.model }; } invocationParams(): EmbeddingParameters { return { - truncate_input_tokens: this.truncate_input_tokens, + truncate_input_tokens: this.truncateInputTokens, }; } @@ -104,7 +111,6 @@ export class WatsonxEmbeddings private async embedSingleText(inputs: string[]) { const textEmbeddingParams: TextEmbeddingsParams = { inputs, - modelId: this.model, ...this.scopeId(), parameters: this.invocationParams(), }; diff --git a/libs/langchain-community/src/embeddings/tests/ibm.int.test.ts b/libs/langchain-community/src/embeddings/tests/ibm.int.test.ts index 9361a7915213..a774181d4b91 100644 --- a/libs/langchain-community/src/embeddings/tests/ibm.int.test.ts +++ b/libs/langchain-community/src/embeddings/tests/ibm.int.test.ts @@ -5,6 +5,7 @@ import { WatsonxEmbeddings } from "../ibm.js"; describe("Test embeddings", () => { test("embedQuery method", async () => { const embeddings = new WatsonxEmbeddings({ + model: "ibm/slate-125m-english-rtrvr", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, @@ -15,6 +16,7 @@ describe("Test embeddings", () => { test("embedDocuments", async () => { const embeddings = new WatsonxEmbeddings({ + model: "ibm/slate-125m-english-rtrvr", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, @@ -27,6 +29,7 @@ describe("Test embeddings", () => { test("Concurrency", async () => { const embeddings = new WatsonxEmbeddings({ + model: "ibm/slate-125m-english-rtrvr", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, @@ -50,6 +53,7 @@ describe("Test embeddings", () => { test("List models", async () => { const embeddings = new WatsonxEmbeddings({ + model: "ibm/slate-125m-english-rtrvr", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, diff --git a/libs/langchain-community/src/embeddings/tests/ibm.test.ts b/libs/langchain-community/src/embeddings/tests/ibm.test.ts index 05f033f6f1af..ad4196b3387d 100644 --- a/libs/langchain-community/src/embeddings/tests/ibm.test.ts +++ b/libs/langchain-community/src/embeddings/tests/ibm.test.ts @@ -1,7 +1,7 @@ /* eslint-disable no-process-env */ /* eslint-disable @typescript-eslint/no-explicit-any */ import { testProperties } from "../../llms/tests/ibm.test.js"; -import { WatsonxEmbeddings } from "../ibm.js"; +import { WatsonxEmbeddings, WatsonxInputEmbeddings } from "../ibm.js"; const fakeAuthProp = { watsonxAIAuthType: "iam", @@ -11,6 +11,7 @@ describe("Embeddings unit tests", () => { describe("Positive tests", () => { test("Basic properties", () => { const testProps = { + model: "ibm/slate-125m-english-rtrvr", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID || "testString", @@ -20,14 +21,14 @@ describe("Embeddings unit tests", () => { }); test("Basic properties", () => { - const testProps = { + const testProps: WatsonxInputEmbeddings = { + model: "ibm/slate-125m-english-rtrvr", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID || "testString", - truncate_input_tokens: 10, + truncateInputTokens: 10, maxConcurrency: 2, maxRetries: 2, - model: "ibm/slate-125m-english-rtrvr", }; const instance = new WatsonxEmbeddings({ ...testProps, ...fakeAuthProp }); @@ -38,6 +39,7 @@ describe("Embeddings unit tests", () => { describe("Negative tests", () => { test("Missing id", async () => { const testProps = { + model: "ibm/slate-125m-english-rtrvr", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, }; @@ -85,6 +87,7 @@ describe("Embeddings unit tests", () => { test("Passing more than one id", async () => { const testProps = { + model: "ibm/slate-125m-english-rtrvr", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID || "testString", @@ -101,6 +104,7 @@ describe("Embeddings unit tests", () => { test("Invalid properties", () => { const testProps = { + model: "ibm/slate-125m-english-rtrvr", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID || "testString", diff --git a/libs/langchain-community/src/llms/ibm.ts b/libs/langchain-community/src/llms/ibm.ts index 302275158d9c..a0e8a292f0bf 100644 --- a/libs/langchain-community/src/llms/ibm.ts +++ b/libs/langchain-community/src/llms/ibm.ts @@ -3,12 +3,8 @@ import { CallbackManagerForLLMRun } from "@langchain/core/callbacks/manager"; import { BaseLLM, BaseLLMParams } from "@langchain/core/language_models/llms"; import { WatsonXAI } from "@ibm-cloud/watsonx-ai"; import { - DeploymentsTextGenerationParams, - DeploymentsTextGenerationStreamParams, DeploymentTextGenProperties, ReturnOptionProperties, - TextGenerationParams, - TextGenerationStreamParams, TextGenLengthPenalty, TextGenParameters, TextTokenizationParams, @@ -34,25 +30,28 @@ import { * Input to LLM class. */ -export interface WatsonxCallOptionsLLM - extends BaseLanguageModelCallOptions, - Omit< - Partial< - TextGenerationParams & - TextGenerationStreamParams & - DeploymentsTextGenerationParams & - DeploymentsTextGenerationStreamParams - >, - "input" - > { +export interface WatsonxCallOptionsLLM extends BaseLanguageModelCallOptions { maxRetries?: number; + parameters?: Partial; + idOrName?: string; } -export interface WatsonxInputLLM - extends TextGenParameters, - WatsonxParams, - BaseLLMParams { +export interface WatsonxInputLLM extends WatsonxParams, BaseLLMParams { streaming?: boolean; + maxNewTokens?: number; + decodingMethod?: TextGenParameters.Constants.DecodingMethod | string; + lengthPenalty?: TextGenLengthPenalty; + minNewTokens?: number; + randomSeed?: number; + stopSequence?: string[]; + temperature?: number; + timeLimit?: number; + topK?: number; + topP?: number; + repetitionPenalty?: number; + truncateInpuTokens?: number; + returnOptions?: ReturnOptionProperties; + includeStopSequence?: boolean; } /** @@ -73,7 +72,7 @@ export class WatsonxLLM< streaming = false; - model = "ibm/granite-13b-chat-v2"; + model: string; maxRetries = 0; @@ -81,7 +80,7 @@ export class WatsonxLLM< serviceUrl: string; - max_new_tokens?: number; + maxNewTokens?: number; spaceId?: string; @@ -89,31 +88,31 @@ export class WatsonxLLM< idOrName?: string; - decoding_method?: TextGenParameters.Constants.DecodingMethod | string; + decodingMethod?: TextGenParameters.Constants.DecodingMethod | string; - length_penalty?: TextGenLengthPenalty; + lengthPenalty?: TextGenLengthPenalty; - min_new_tokens?: number; + minNewTokens?: number; - random_seed?: number; + randomSeed?: number; - stop_sequences?: string[]; + stopSequence?: string[]; temperature?: number; - time_limit?: number; + timeLimit?: number; - top_k?: number; + topK?: number; - top_p?: number; + topP?: number; - repetition_penalty?: number; + repetitionPenalty?: number; - truncate_input_tokens?: number; + truncateInpuTokens?: number; - return_options?: ReturnOptionProperties; + returnOptions?: ReturnOptionProperties; - include_stop_sequence?: boolean; + includeStopSequence?: boolean; maxConcurrency?: number; @@ -123,21 +122,21 @@ export class WatsonxLLM< super(fields); this.model = fields.model ?? this.model; this.version = fields.version; - this.max_new_tokens = fields.max_new_tokens ?? this.max_new_tokens; + this.maxNewTokens = fields.maxNewTokens ?? this.maxNewTokens; this.serviceUrl = fields.serviceUrl; - this.decoding_method = fields.decoding_method; - this.length_penalty = fields.length_penalty; - this.min_new_tokens = fields.min_new_tokens; - this.random_seed = fields.random_seed; - this.stop_sequences = fields.stop_sequences; + this.decodingMethod = fields.decodingMethod; + this.lengthPenalty = fields.lengthPenalty; + this.minNewTokens = fields.minNewTokens; + this.randomSeed = fields.randomSeed; + this.stopSequence = fields.stopSequence; this.temperature = fields.temperature; - this.time_limit = fields.time_limit; - this.top_k = fields.top_k; - this.top_p = fields.top_p; - this.repetition_penalty = fields.repetition_penalty; - this.truncate_input_tokens = fields.truncate_input_tokens; - this.return_options = fields.return_options; - this.include_stop_sequence = fields.include_stop_sequence; + this.timeLimit = fields.timeLimit; + this.topK = fields.topK; + this.topP = fields.topP; + this.repetitionPenalty = fields.repetitionPenalty; + this.truncateInpuTokens = fields.truncateInpuTokens; + this.returnOptions = fields.returnOptions; + this.includeStopSequence = fields.includeStopSequence; this.maxRetries = fields.maxRetries || this.maxRetries; this.maxConcurrency = fields.maxConcurrency; this.streaming = fields.streaming || this.streaming; @@ -150,7 +149,7 @@ export class WatsonxLLM< if (!fields.projectId && !fields.spaceId && !fields.idOrName) throw new Error( - "No id specified! At least ide of 1 type has to be specified" + "No id specified! At least id of 1 type has to be specified" ); this.projectId = fields?.projectId; this.spaceId = fields?.spaceId; @@ -216,23 +215,23 @@ export class WatsonxLLM< const { parameters } = options; return { - max_new_tokens: parameters?.max_new_tokens ?? this.max_new_tokens, - decoding_method: parameters?.decoding_method ?? this.decoding_method, - length_penalty: parameters?.length_penalty ?? this.length_penalty, - min_new_tokens: parameters?.min_new_tokens ?? this.min_new_tokens, - random_seed: parameters?.random_seed ?? this.random_seed, - stop_sequences: options?.stop ?? this.stop_sequences, + max_new_tokens: parameters?.maxNewTokens ?? this.maxNewTokens, + decoding_method: parameters?.decodingMethod ?? this.decodingMethod, + length_penalty: parameters?.lengthPenalty ?? this.lengthPenalty, + min_new_tokens: parameters?.minNewTokens ?? this.minNewTokens, + random_seed: parameters?.randomSeed ?? this.randomSeed, + stop_sequences: options?.stop ?? this.stopSequence, temperature: parameters?.temperature ?? this.temperature, - time_limit: parameters?.time_limit ?? this.time_limit, - top_k: parameters?.top_k ?? this.top_k, - top_p: parameters?.top_p ?? this.top_p, + time_limit: parameters?.timeLimit ?? this.timeLimit, + top_k: parameters?.topK ?? this.topK, + top_p: parameters?.topP ?? this.topP, repetition_penalty: - parameters?.repetition_penalty ?? this.repetition_penalty, + parameters?.repetitionPenalty ?? this.repetitionPenalty, truncate_input_tokens: - parameters?.truncate_input_tokens ?? this.truncate_input_tokens, - return_options: parameters?.return_options ?? this.return_options, + parameters?.truncateInpuTokens ?? this.truncateInpuTokens, + return_options: parameters?.returnOptions ?? this.returnOptions, include_stop_sequence: - parameters?.include_stop_sequence ?? this.include_stop_sequence, + parameters?.includeStopSequence ?? this.includeStopSequence, }; } diff --git a/libs/langchain-community/src/llms/tests/ibm.int.test.ts b/libs/langchain-community/src/llms/tests/ibm.int.test.ts index 236fd4950be8..dfeebedd39e2 100644 --- a/libs/langchain-community/src/llms/tests/ibm.int.test.ts +++ b/libs/langchain-community/src/llms/tests/ibm.int.test.ts @@ -11,6 +11,7 @@ describe("Text generation", () => { describe("Test invoke method", () => { test("Correct value", async () => { const watsonXInstance = new WatsonxLLM({ + model: "ibm/granite-13b-chat-v2", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, @@ -18,8 +19,21 @@ describe("Text generation", () => { await watsonXInstance.invoke("Hello world?"); }); + test("Overwritte params", async () => { + const watsonXInstance = new WatsonxLLM({ + model: "ibm/granite-13b-chat-v2", + version: "2024-05-31", + serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, + projectId: process.env.WATSONX_AI_PROJECT_ID, + }); + await watsonXInstance.invoke("Hello world?", { + parameters: { maxNewTokens: 10 }, + }); + }); + test("Invalid projectId", async () => { const watsonXInstance = new WatsonxLLM({ + model: "ibm/granite-13b-chat-v2", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: "Test wrong value", @@ -29,6 +43,7 @@ describe("Text generation", () => { test("Invalid credentials", async () => { const watsonXInstance = new WatsonxLLM({ + model: "ibm/granite-13b-chat-v2", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: "Test wrong value", @@ -41,6 +56,7 @@ describe("Text generation", () => { test("Wrong value", async () => { const watsonXInstance = new WatsonxLLM({ + model: "ibm/granite-13b-chat-v2", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, @@ -51,6 +67,7 @@ describe("Text generation", () => { test("Stop", async () => { const watsonXInstance = new WatsonxLLM({ + model: "ibm/granite-13b-chat-v2", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, @@ -62,10 +79,11 @@ describe("Text generation", () => { test("Stop with timeout", async () => { const watsonXInstance = new WatsonxLLM({ + model: "ibm/granite-13b-chat-v2", version: "2024-05-31", serviceUrl: "sdadasdas" as string, projectId: process.env.WATSONX_AI_PROJECT_ID, - max_new_tokens: 5, + maxNewTokens: 5, maxRetries: 3, }); @@ -76,10 +94,11 @@ describe("Text generation", () => { test("Signal in call options", async () => { const watsonXInstance = new WatsonxLLM({ + model: "ibm/granite-13b-chat-v2", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, - max_new_tokens: 5, + maxNewTokens: 5, maxRetries: 3, }); const controllerNoAbortion = new AbortController(); @@ -100,6 +119,7 @@ describe("Text generation", () => { test("Concurenccy", async () => { const model = new WatsonxLLM({ + model: "ibm/granite-13b-chat-v2", maxConcurrency: 1, version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, @@ -119,9 +139,10 @@ describe("Text generation", () => { input_token_count: 0, }; const model = new WatsonxLLM({ - maxConcurrency: 1, + model: "ibm/granite-13b-chat-v2", version: "2024-05-31", - max_new_tokens: 1, + maxNewTokens: 1, + maxConcurrency: 1, serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, callbacks: CallbackManager.fromHandlers({ @@ -150,10 +171,12 @@ describe("Text generation", () => { let streamedText = ""; let usedTokens = 0; const model = new WatsonxLLM({ + model: "ibm/granite-13b-chat-v2", + maxConcurrency: 1, version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, - max_new_tokens: 5, + maxNewTokens: 5, streaming: true, callbacks: CallbackManager.fromHandlers({ @@ -176,10 +199,11 @@ describe("Text generation", () => { describe("Test generate methods", () => { test("Basic usage", async () => { const model = new WatsonxLLM({ + model: "ibm/granite-13b-chat-v2", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, - max_new_tokens: 5, + maxNewTokens: 5, }); const res = await model.generate([ "Print hello world!", @@ -190,10 +214,11 @@ describe("Text generation", () => { test("Stop", async () => { const model = new WatsonxLLM({ + model: "ibm/granite-13b-chat-v2", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, - max_new_tokens: 100, + maxNewTokens: 100, }); const res = await model.generate( @@ -215,10 +240,11 @@ describe("Text generation", () => { const nrNewTokens = [0, 0, 0]; const completions = ["", "", ""]; const model = new WatsonxLLM({ + model: "ibm/granite-13b-chat-v2", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, - max_new_tokens: 5, + maxNewTokens: 5, streaming: true, callbacks: CallbackManager.fromHandlers({ async handleLLMNewToken(token: string, idx) { @@ -245,10 +271,11 @@ describe("Text generation", () => { test("Prompt value", async () => { const model = new WatsonxLLM({ + model: "ibm/granite-13b-chat-v2", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, - max_new_tokens: 5, + maxNewTokens: 5, }); const res = await model.generatePrompt([ new StringPromptValue("Print hello world!"), @@ -264,10 +291,11 @@ describe("Text generation", () => { let countedTokens = 0; let streamedText = ""; const model = new WatsonxLLM({ + model: "ibm/granite-13b-chat-v2", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, - max_new_tokens: 100, + maxNewTokens: 100, callbacks: CallbackManager.fromHandlers({ async handleLLMNewToken(token: string) { countedTokens += 1; @@ -286,10 +314,11 @@ describe("Text generation", () => { test("Stop", async () => { const model = new WatsonxLLM({ + model: "ibm/granite-13b-chat-v2", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, - max_new_tokens: 100, + maxNewTokens: 100, }); const stream = await model.stream("Print hello world!", { @@ -304,10 +333,11 @@ describe("Text generation", () => { test("Timeout", async () => { const model = new WatsonxLLM({ + model: "ibm/granite-13b-chat-v2", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, - max_new_tokens: 1000, + maxNewTokens: 1000, }); await expect(async () => { const stream = await model.stream( @@ -325,10 +355,11 @@ describe("Text generation", () => { test("Signal in call options", async () => { const model = new WatsonxLLM({ + model: "ibm/granite-13b-chat-v2", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, - max_new_tokens: 1000, + maxNewTokens: 1000, }); const controller = new AbortController(); await expect(async () => { @@ -354,6 +385,7 @@ describe("Text generation", () => { describe("Test getNumToken method", () => { test("Passing correct value", async () => { const testProps: WatsonxInputLLM = { + model: "ibm/granite-13b-chat-v2", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, @@ -371,6 +403,7 @@ describe("Text generation", () => { test("Passing wrong value", async () => { const testProps: WatsonxInputLLM = { + model: "ibm/granite-13b-chat-v2", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID, diff --git a/libs/langchain-community/src/llms/tests/ibm.test.ts b/libs/langchain-community/src/llms/tests/ibm.test.ts index 7dfaecd6361c..6237cb1d14c1 100644 --- a/libs/langchain-community/src/llms/tests/ibm.test.ts +++ b/libs/langchain-community/src/llms/tests/ibm.test.ts @@ -3,10 +3,7 @@ import WatsonxAiMlVml_v1 from "@ibm-cloud/watsonx-ai/dist/watsonx-ai-ml/vml_v1.js"; import { WatsonxLLM, WatsonxInputLLM } from "../ibm.js"; import { authenticateAndSetInstance } from "../../utils/ibm.js"; -import { - WatsonxEmbeddings, - WatsonxEmbeddingsParams, -} from "../../embeddings/ibm.js"; +import { WatsonxEmbeddings } from "../../embeddings/ibm.js"; const fakeAuthProp = { watsonxAIAuthType: "iam", @@ -38,7 +35,7 @@ export const testProperties = ( } }); }; - checkProperty(testProps, instance); + checkProperty(testProps, instance); if (notExTestProps) checkProperty(notExTestProps, instance, false); }; @@ -56,6 +53,7 @@ describe("LLM unit tests", () => { test("Test basic properties after init", async () => { const testProps = { + model: "ibm/granite-13b-chat-v2", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID || "testString", @@ -67,6 +65,7 @@ describe("LLM unit tests", () => { test("Test methods after init", () => { const testProps: WatsonxInputLLM = { + model: "ibm/granite-13b-chat-v2", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID || "testString", @@ -82,33 +81,32 @@ describe("LLM unit tests", () => { }); test("Test properties after init", async () => { - const testProps = { + const testProps: WatsonxInputLLM = { version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID || "testString", model: "ibm/granite-13b-chat-v2", - max_new_tokens: 100, - decoding_method: "sample", - length_penalty: { decay_factor: 1, start_index: 1 }, - min_new_tokens: 10, - random_seed: 1, - stop_sequences: ["hello"], + maxNewTokens: 100, + decodingMethod: "sample", + lengthPenalty: { decay_factor: 1, start_index: 1 }, + minNewTokens: 10, + randomSeed: 1, + stopSequence: ["hello"], temperature: 0.1, - time_limit: 10000, - top_k: 1, - top_p: 1, - repetition_penalty: 1, - truncate_input_tokens: 1, - return_options: { + timeLimit: 10000, + topK: 1, + topP: 1, + repetitionPenalty: 1, + truncateInpuTokens: 1, + returnOptions: { input_text: true, generated_tokens: true, input_tokens: true, token_logprobs: true, token_ranks: true, - top_n_tokens: 2, }, - include_stop_sequence: false, + includeStopSequence: false, maxRetries: 3, maxConcurrency: 3, }; @@ -121,6 +119,7 @@ describe("LLM unit tests", () => { describe("Negative tests", () => { test("Missing id", async () => { const testProps: WatsonxInputLLM = { + model: "ibm/granite-13b-chat-v2", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, }; @@ -171,6 +170,7 @@ describe("LLM unit tests", () => { test("Passing more than one id", async () => { const testProps: WatsonxInputLLM = { + model: "ibm/granite-13b-chat-v2", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID || "testString", @@ -187,6 +187,7 @@ describe("LLM unit tests", () => { test("Not existing property passed", async () => { const testProps = { + model: "ibm/granite-13b-chat-v2", version: "2024-05-31", serviceUrl: process.env.WATSONX_AI_SERVICE_URL as string, projectId: process.env.WATSONX_AI_PROJECT_ID || "testString", diff --git a/libs/langchain-community/src/types/ibm.ts b/libs/langchain-community/src/types/ibm.ts index cd11592ee48b..ee5db8532036 100644 --- a/libs/langchain-community/src/types/ibm.ts +++ b/libs/langchain-community/src/types/ibm.ts @@ -18,7 +18,7 @@ export interface WatsonxInit { } export interface WatsonxParams extends WatsonxInit { - model?: string; + model: string; spaceId?: string; projectId?: string; idOrName?: string;