From 8eadded3571fe77aa77a211619d0a2c7d51976be Mon Sep 17 00:00:00 2001 From: Christopher Nathanael Date: Tue, 3 Dec 2024 17:17:37 -0500 Subject: [PATCH] feat(google-genai): Support Gemini system instructions (#7235) Co-authored-by: Gary Chen Co-authored-by: martinl498 Co-authored-by: Shannon Budiman Co-authored-by: Jacob Lee --- .../langchain-google-genai/src/chat_models.ts | 56 +++++- .../src/tests/chat_models.test.ts | 178 ++++++++++++++++++ .../src/utils/common.ts | 12 +- 3 files changed, 239 insertions(+), 7 deletions(-) diff --git a/libs/langchain-google-genai/src/chat_models.ts b/libs/langchain-google-genai/src/chat_models.ts index 6fc5433babe3..7b22c8a93ce9 100644 --- a/libs/langchain-google-genai/src/chat_models.ts +++ b/libs/langchain-google-genai/src/chat_models.ts @@ -180,6 +180,15 @@ export interface GoogleGenerativeAIChatInput * @default false */ json?: boolean; + + /** + * Whether or not model supports system instructions. + * The following models support system instructions: + * - All Gemini 1.5 Pro model versions + * - All Gemini 1.5 Flash model versions + * - Gemini 1.0 Pro version gemini-1.0-pro-002 + */ + convertSystemMessageToHumanContent?: boolean | undefined; } /** @@ -563,6 +572,8 @@ export class ChatGoogleGenerativeAI streamUsage = true; + convertSystemMessageToHumanContent: boolean | undefined; + private client: GenerativeModel; get _isMultimodalModel() { @@ -651,6 +662,29 @@ export class ChatGoogleGenerativeAI this.streamUsage = fields?.streamUsage ?? this.streamUsage; } + get useSystemInstruction(): boolean { + return typeof this.convertSystemMessageToHumanContent === "boolean" + ? !this.convertSystemMessageToHumanContent + : this.computeUseSystemInstruction; + } + + get computeUseSystemInstruction(): boolean { + // This works on models from April 2024 and later + // Vertex AI: gemini-1.5-pro and gemini-1.0-002 and later + // AI Studio: gemini-1.5-pro-latest + if (this.modelName === "gemini-1.0-pro-001") { + return false; + } else if (this.modelName.startsWith("gemini-pro-vision")) { + return false; + } else if (this.modelName.startsWith("gemini-1.0-pro-vision")) { + return false; + } else if (this.modelName === "gemini-pro") { + // on AI Studio gemini-pro is still pointing at gemini-1.0-pro-001 + return false; + } + return true; + } + getLsParams(options: this["ParsedCallOptions"]): LangSmithParams { return { ls_provider: "google_genai", @@ -706,8 +740,15 @@ export class ChatGoogleGenerativeAI ): Promise { const prompt = convertBaseMessagesToContent( messages, - this._isMultimodalModel + this._isMultimodalModel, + this.useSystemInstruction ); + let actualPrompt = prompt; + if (prompt[0].role === "system") { + const [systemInstruction] = prompt; + this.client.systemInstruction = systemInstruction; + actualPrompt = prompt.slice(1); + } const parameters = this.invocationParams(options); // Handle streaming @@ -734,7 +775,7 @@ export class ChatGoogleGenerativeAI const res = await this.completionWithRetry({ ...parameters, - contents: prompt, + contents: actualPrompt, }); let usageMetadata: UsageMetadata | undefined; @@ -770,12 +811,19 @@ export class ChatGoogleGenerativeAI ): AsyncGenerator { const prompt = convertBaseMessagesToContent( messages, - this._isMultimodalModel + this._isMultimodalModel, + this.useSystemInstruction ); + let actualPrompt = prompt; + if (prompt[0].role === "system") { + const [systemInstruction] = prompt; + this.client.systemInstruction = systemInstruction; + actualPrompt = prompt.slice(1); + } const parameters = this.invocationParams(options); const request = { ...parameters, - contents: prompt, + contents: actualPrompt, }; const stream = await this.caller.callWithOptions( { signal: options?.signal }, diff --git a/libs/langchain-google-genai/src/tests/chat_models.test.ts b/libs/langchain-google-genai/src/tests/chat_models.test.ts index 97015725fc7e..73cc321abd7c 100644 --- a/libs/langchain-google-genai/src/tests/chat_models.test.ts +++ b/libs/langchain-google-genai/src/tests/chat_models.test.ts @@ -253,3 +253,181 @@ test("convertBaseMessagesToContent correctly creates properly formatted content" }, ]); }); + +test("Input has single system message followed by one user message, convert system message is false", async () => { + const messages = [ + new SystemMessage("You are a helpful assistant"), + new HumanMessage("What's the weather like in new york?"), + ]; + const messagesAsGoogleContent = convertBaseMessagesToContent( + messages, + false, + false + ); + + expect(messagesAsGoogleContent).toEqual([ + { + role: "user", + parts: [ + { text: "You are a helpful assistant" }, + { text: "What's the weather like in new york?" }, + ], + }, + ]); +}); + +test("Input has a system message that is not the first message, convert system message is false", async () => { + const messages = [ + new HumanMessage("What's the weather like in new york?"), + new SystemMessage("You are a helpful assistant"), + ]; + expect(() => { + convertBaseMessagesToContent(messages, false, false); + }).toThrow("System message should be the first one"); +}); + +test("Input has multiple system messages, convert system message is false", async () => { + const messages = [ + new SystemMessage("You are a helpful assistant"), + new SystemMessage("You are not a helpful assistant"), + ]; + expect(() => { + convertBaseMessagesToContent(messages, false, false); + }).toThrow("System message should be the first one"); +}); + +test("Input has no system message and one user message, convert system message is false", async () => { + const messages = [new HumanMessage("What's the weather like in new york?")]; + const messagesAsGoogleContent = convertBaseMessagesToContent( + messages, + false, + false + ); + + expect(messagesAsGoogleContent).toEqual([ + { + role: "user", + parts: [{ text: "What's the weather like in new york?" }], + }, + ]); +}); + +test("Input has no system message and multiple user message, convert system message is false", async () => { + const messages = [ + new HumanMessage("What's the weather like in new york?"), + new HumanMessage("What's the weather like in toronto?"), + new HumanMessage("What's the weather like in los angeles?"), + ]; + const messagesAsGoogleContent = convertBaseMessagesToContent( + messages, + false, + false + ); + + expect(messagesAsGoogleContent).toEqual([ + { + role: "user", + parts: [{ text: "What's the weather like in new york?" }], + }, + { + role: "user", + parts: [{ text: "What's the weather like in toronto?" }], + }, + { + role: "user", + parts: [{ text: "What's the weather like in los angeles?" }], + }, + ]); +}); + +test("Input has single system message followed by one user message, convert system message is true", async () => { + const messages = [ + new SystemMessage("You are a helpful assistant"), + new HumanMessage("What's the weather like in new york?"), + ]; + + const messagesAsGoogleContent = convertBaseMessagesToContent( + messages, + false, + true + ); + + expect(messagesAsGoogleContent).toEqual([ + { + role: "system", + parts: [{ text: "You are a helpful assistant" }], + }, + { + role: "user", + parts: [{ text: "What's the weather like in new york?" }], + }, + ]); +}); + +test("Input has single system message that is not the first message, convert system message is true", async () => { + const messages = [ + new HumanMessage("What's the weather like in new york?"), + new SystemMessage("You are a helpful assistant"), + ]; + + expect(() => convertBaseMessagesToContent(messages, false, true)).toThrow( + "System message should be the first one" + ); +}); + +test("Input has multiple system message, convert system message is true", async () => { + const messages = [ + new SystemMessage("What's the weather like in new york?"), + new SystemMessage("You are a helpful assistant"), + ]; + + expect(() => convertBaseMessagesToContent(messages, false, true)).toThrow( + "System message should be the first one" + ); +}); + +test("Input has no system message and one user message, convert system message is true", async () => { + const messages = [new HumanMessage("What's the weather like in new york?")]; + + const messagesAsGoogleContent = convertBaseMessagesToContent( + messages, + false, + true + ); + + expect(messagesAsGoogleContent).toEqual([ + { + role: "user", + parts: [{ text: "What's the weather like in new york?" }], + }, + ]); +}); + +test("Input has no system message and multiple user messages, convert system message is true", async () => { + const messages = [ + new HumanMessage("What's the weather like in new york?"), + new HumanMessage("Will it rain today?"), + new HumanMessage("How about next week?"), + ]; + + const messagesAsGoogleContent = convertBaseMessagesToContent( + messages, + false, + true + ); + + expect(messagesAsGoogleContent).toEqual([ + { + role: "user", + parts: [{ text: "What's the weather like in new york?" }], + }, + { + role: "user", + parts: [{ text: "Will it rain today?" }], + }, + { + role: "user", + parts: [{ text: "How about next week?" }], + }, + ]); +}); diff --git a/libs/langchain-google-genai/src/utils/common.ts b/libs/langchain-google-genai/src/utils/common.ts index 2670760c7115..e2e07f3c61df 100644 --- a/libs/langchain-google-genai/src/utils/common.ts +++ b/libs/langchain-google-genai/src/utils/common.ts @@ -61,6 +61,7 @@ export function convertAuthorToRole( case "model": // getMessageAuthor returns message.name. code ex.: return message.name ?? type; return "model"; case "system": + return "system"; case "human": return "user"; case "tool": @@ -179,7 +180,8 @@ export function convertMessageContentToParts( export function convertBaseMessagesToContent( messages: BaseMessage[], - isMultimodalModel: boolean + isMultimodalModel: boolean, + convertSystemMessageToHumanContent: boolean = false ) { return messages.reduce<{ content: Content[]; @@ -223,7 +225,10 @@ export function convertBaseMessagesToContent( }; } let actualRole = role; - if (actualRole === "function") { + if ( + actualRole === "function" || + (actualRole === "system" && !convertSystemMessageToHumanContent) + ) { // GenerativeAI API will throw an error if the role is not "user" or "model." actualRole = "user"; } @@ -232,7 +237,8 @@ export function convertBaseMessagesToContent( parts, }; return { - mergeWithPreviousContent: author === "system", + mergeWithPreviousContent: + author === "system" && !convertSystemMessageToHumanContent, content: [...acc.content, content], }; },