Skip to content

Commit

Permalink
fix(langchain): Fix serialization for initChatModel (#7222)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 authored Nov 17, 2024
1 parent d420b71 commit cf672b5
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 21 deletions.
44 changes: 40 additions & 4 deletions langchain/src/chat_models/tests/universal.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { ChatPromptTemplate, PromptTemplate } from "@langchain/core/prompts";
import { RunLogPatch, StreamEvent } from "@langchain/core/tracers/log_stream";
import { AIMessageChunk } from "@langchain/core/messages";
import { concat } from "@langchain/core/utils/stream";
import { awaitAllCallbacks } from "@langchain/core/callbacks/promises";
import { AgentExecutor, createReactAgent } from "../../agents/index.js";
import { pull } from "../../hub.js";
import { initChatModel } from "../universal.js";
Expand All @@ -32,7 +33,7 @@ const googleApiKey = process.env.GOOGLE_API_KEY;
process.env.GOOGLE_API_KEY = "";

test("Initialize non-configurable models", async () => {
const gpt4 = await initChatModel("gpt-4", {
const gpt4 = await initChatModel("gpt-4o-mini", {
modelProvider: "openai",
temperature: 0.25, // Funky temperature to verify it's being set properly.
apiKey: openAIApiKey,
Expand Down Expand Up @@ -67,7 +68,7 @@ test("Create a partially configurable model with no default model", async () =>

const gpt4Result = await configurableModel.invoke("what's your name", {
configurable: {
model: "gpt-4",
model: "gpt-4o-mini",
apiKey: openAIApiKey,
},
});
Expand All @@ -85,7 +86,7 @@ test("Create a partially configurable model with no default model", async () =>
});

test("Create a fully configurable model with a default model and a config prefix", async () => {
const configurableModelWithDefault = await initChatModel("gpt-4", {
const configurableModelWithDefault = await initChatModel("gpt-4o-mini", {
modelProvider: "openai",
configurableFields: "any",
configPrefix: "foo",
Expand Down Expand Up @@ -155,7 +156,7 @@ test("Bind tools to a configurable model", async () => {
}
);

const configurableModel = await initChatModel("gpt-4", {
const configurableModel = await initChatModel("gpt-4o-mini", {
configurableFields: ["model", "modelProvider", "apiKey"],
temperature: 0,
});
Expand Down Expand Up @@ -602,3 +603,38 @@ describe("Can call base runnable methods", () => {
expect(result.tool_calls?.[0].name).toBe("GetWeather");
});
});

describe("Serialization", () => {
it("does not contain additional fields", async () => {
const gpt4 = await initChatModel("gpt-4o-mini", {
modelProvider: "openai",
temperature: 0.25, // Funky temperature to verify it's being set properly.
apiKey: openAIApiKey,
});
let serializedRepresentation;
const res = await gpt4.invoke("foo", {
callbacks: [
{
handleChatModelStart(llm) {
serializedRepresentation = llm;
},
},
],
configurable: { extra: "bar" },
});
await awaitAllCallbacks();
expect(res).toBeDefined();
const { ChatOpenAI } = await import("@langchain/openai");
expect(serializedRepresentation).toEqual(
JSON.parse(
JSON.stringify(
new ChatOpenAI({
model: "gpt-4o-mini",
temperature: 0.25,
apiKey: openAIApiKey,
})
)
)
);
});
});
40 changes: 23 additions & 17 deletions langchain/src/chat_models/universal.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,50 +73,51 @@ async function _initChatModelHelper(
`Unable to infer model provider for { model: ${model} }, please specify modelProvider directly.`
);
}
const { modelProvider: _unused, ...passedParams } = params;

try {
switch (modelProviderCopy) {
case "openai": {
const { ChatOpenAI } = await import("@langchain/openai");
return new ChatOpenAI({ model, ...params });
return new ChatOpenAI({ model, ...passedParams });
}
case "anthropic": {
const { ChatAnthropic } = await import("@langchain/anthropic");
return new ChatAnthropic({ model, ...params });
return new ChatAnthropic({ model, ...passedParams });
}
case "azure_openai": {
const { AzureChatOpenAI } = await import("@langchain/openai");
return new AzureChatOpenAI({ model, ...params });
return new AzureChatOpenAI({ model, ...passedParams });
}
case "cohere": {
const { ChatCohere } = await import("@langchain/cohere");
return new ChatCohere({ model, ...params });
return new ChatCohere({ model, ...passedParams });
}
case "google-vertexai": {
const { ChatVertexAI } = await import("@langchain/google-vertexai");
return new ChatVertexAI({ model, ...params });
return new ChatVertexAI({ model, ...passedParams });
}
case "google-genai": {
const { ChatGoogleGenerativeAI } = await import(
"@langchain/google-genai"
);
return new ChatGoogleGenerativeAI({ model, ...params });
return new ChatGoogleGenerativeAI({ model, ...passedParams });
}
case "ollama": {
const { ChatOllama } = await import("@langchain/ollama");
return new ChatOllama({ model, ...params });
return new ChatOllama({ model, ...passedParams });
}
case "mistralai": {
const { ChatMistralAI } = await import("@langchain/mistralai");
return new ChatMistralAI({ model, ...params });
return new ChatMistralAI({ model, ...passedParams });
}
case "groq": {
const { ChatGroq } = await import("@langchain/groq");
return new ChatGroq({ model, ...params });
return new ChatGroq({ model, ...passedParams });
}
case "bedrock": {
const { ChatBedrockConverse } = await import("@langchain/aws");
return new ChatBedrockConverse({ model, ...params });
return new ChatBedrockConverse({ model, ...passedParams });
}
case "fireworks": {
const { ChatFireworks } = await import(
Expand All @@ -127,7 +128,7 @@ async function _initChatModelHelper(
// @ts-ignore - Can not install as a proper dependency due to circular dependency
"@langchain/community/chat_models/fireworks"
);
return new ChatFireworks({ model, ...params });
return new ChatFireworks({ model, ...passedParams });
}
case "together": {
const { ChatTogetherAI } = await import(
Expand All @@ -138,7 +139,7 @@ async function _initChatModelHelper(
// @ts-ignore - Can not install as a proper dependency due to circular dependency
"@langchain/community/chat_models/togetherai"
);
return new ChatTogetherAI({ model, ...params });
return new ChatTogetherAI({ model, ...passedParams });
}
default: {
const supported = _SUPPORTED_PROVIDERS.join(", ");
Expand Down Expand Up @@ -247,7 +248,10 @@ class _ConfigurableModel<
if (fields.configurableFields === "any") {
this._configurableFields = "any";
} else {
this._configurableFields = fields.configurableFields ?? "any";
this._configurableFields = fields.configurableFields ?? [
"model",
"modelProvider",
];
}

if (fields.configPrefix) {
Expand Down Expand Up @@ -786,12 +790,14 @@ export async function initChatModel<
configPrefix: "",
...(fields ?? {}),
};
let configurableFieldsCopy = configurableFields;
let configurableFieldsCopy = Array.isArray(configurableFields)
? [...configurableFields]
: configurableFields;

if (!model && !configurableFieldsCopy) {
if (!model && configurableFieldsCopy === undefined) {
configurableFieldsCopy = ["model", "modelProvider"];
}
if (configPrefix && !configurableFieldsCopy) {
if (configPrefix && configurableFieldsCopy === undefined) {
console.warn(
`{ configPrefix: ${configPrefix} } has been set but no fields are configurable. Set ` +
`{ configurableFields: [...] } to specify the model params that are ` +
Expand All @@ -802,7 +808,7 @@ export async function initChatModel<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const paramsCopy: Record<string, any> = { ...params };

if (!configurableFieldsCopy) {
if (configurableFieldsCopy === undefined) {
return new _ConfigurableModel<RunInput, CallOptions>({
defaultConfig: {
...paramsCopy,
Expand Down

0 comments on commit cf672b5

Please sign in to comment.