From 0079bf88e313499388bfd85ade0c0e7e91573a05 Mon Sep 17 00:00:00 2001 From: Aloha <13401668+a1ooha@users.noreply.github.com> Date: Sat, 29 Jun 2024 02:43:06 +0800 Subject: [PATCH] core[patch]: support image url prompt with mustache (#5916) * fix: support image url prompt with mustache * feat: use unescaped HTML --------- Co-authored-by: Brace Sproul --- langchain-core/src/prompts/chat.ts | 2 ++ langchain-core/src/prompts/image.ts | 14 ++++---- langchain-core/src/prompts/template.ts | 4 +++ .../src/prompts/tests/chat.mustache.test.ts | 32 +++++++++++++++++++ 4 files changed, 44 insertions(+), 8 deletions(-) diff --git a/langchain-core/src/prompts/chat.ts b/langchain-core/src/prompts/chat.ts index 66c5d6c61839..32048df3262f 100644 --- a/langchain-core/src/prompts/chat.ts +++ b/langchain-core/src/prompts/chat.ts @@ -532,6 +532,7 @@ class _StringImageMessagePromptTemplate< imgTemplateObject = new ImagePromptTemplate({ template: imgTemplate, inputVariables, + templateFormat: additionalOptions?.templateFormat, }); } else if (typeof imgTemplate === "object") { if ("url" in imgTemplate) { @@ -551,6 +552,7 @@ class _StringImageMessagePromptTemplate< imgTemplateObject = new ImagePromptTemplate({ template: imgTemplate, inputVariables, + templateFormat: additionalOptions?.templateFormat, }); } else { throw new Error("Invalid image template"); diff --git a/langchain-core/src/prompts/image.ts b/langchain-core/src/prompts/image.ts index e18337303b67..397b36bb0867 100644 --- a/langchain-core/src/prompts/image.ts +++ b/langchain-core/src/prompts/image.ts @@ -6,7 +6,11 @@ import { BasePromptTemplateInput, TypedPromptInputValues, } from "./base.js"; -import { TemplateFormat, checkValidTemplate } from "./template.js"; +import { + TemplateFormat, + checkValidTemplate, + renderTemplate, +} from "./template.js"; /** * Inputs to create a {@link ImagePromptTemplate} @@ -125,13 +129,7 @@ export class ImagePromptTemplate< const formatted: Record = {}; for (const [key, value] of Object.entries(this.template)) { if (typeof value === "string") { - formatted[key] = value.replace(/{([^{}]*)}/g, (match, group) => { - const replacement = values[group]; - return typeof replacement === "string" || - typeof replacement === "number" - ? String(replacement) - : match; - }); + formatted[key] = renderTemplate(value, this.templateFormat, values); } else { formatted[key] = value; } diff --git a/langchain-core/src/prompts/template.ts b/langchain-core/src/prompts/template.ts index 206ea8003745..245641ccbe1d 100644 --- a/langchain-core/src/prompts/template.ts +++ b/langchain-core/src/prompts/template.ts @@ -2,6 +2,10 @@ import mustache from "mustache"; import { MessageContent } from "../messages/index.js"; import type { InputValues } from "../utils/types/index.js"; +// Use unescaped HTML +// https://github.com/janl/mustache.js?tab=readme-ov-file#variables +mustache.escape = (text) => text; + /** * Type that specifies the format of a template. */ diff --git a/langchain-core/src/prompts/tests/chat.mustache.test.ts b/langchain-core/src/prompts/tests/chat.mustache.test.ts index ecac9e954634..f315a0824240 100644 --- a/langchain-core/src/prompts/tests/chat.mustache.test.ts +++ b/langchain-core/src/prompts/tests/chat.mustache.test.ts @@ -91,6 +91,24 @@ test("Mustache template with image and chat prompts inside one template (fromMes } ); + const messages = await template.formatMessages({ + name: "Bob", + image_url: "https://foo.com/bar.png", + other_var: "bar", + }); + + expect(messages).toEqual([ + new HumanMessage({ + content: [ + { type: "image_url", image_url: { url: "https://foo.com/bar.png" } }, + { type: "text", text: "bar" }, + ], + }), + new HumanMessage({ + content: "hello Bob", + }), + ]); + expect(template.inputVariables.sort()).toEqual([ "image_url", "name", @@ -115,5 +133,19 @@ test("Mustache image template with nested URL and chat prompts HumanMessagePromp } ); + const messages = await template.formatMessages({ + name: "Bob", + image_url: "https://foo.com/bar.png", + }); + + expect(messages).toEqual([ + new HumanMessage({ + content: [ + { type: "text", text: "Bob" }, + { type: "image_url", image_url: { url: "https://foo.com/bar.png" } }, + ], + }), + ]); + expect(template.inputVariables.sort()).toEqual(["image_url", "name"]); });