Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(openai): Alias system messages as developer messages for o1, add reasoning_effort param #7398

Merged
merged 8 commits into from
Dec 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion libs/langchain-openai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"license": "MIT",
"dependencies": {
"js-tiktoken": "^1.0.12",
"openai": "^4.71.0",
"openai": "^4.77.0",
"zod": "^3.22.4",
"zod-to-json-schema": "^3.22.3"
},
Expand Down
44 changes: 39 additions & 5 deletions libs/langchain-openai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,13 @@ interface OpenAILLMOutput {
}

// TODO import from SDK when available
type OpenAIRoleEnum = "system" | "assistant" | "user" | "function" | "tool";
type OpenAIRoleEnum =
| "system"
| "developer"
| "assistant"
| "user"
| "function"
| "tool";

type OpenAICompletionParam =
OpenAIClient.Chat.Completions.ChatCompletionMessageParam;
Expand All @@ -105,6 +111,7 @@ type OpenAIFnCallOption = OpenAIClient.Chat.ChatCompletionFunctionCallOption;
function extractGenericMessageCustomRole(message: ChatMessage) {
if (
message.role !== "system" &&
message.role !== "developer" &&
message.role !== "assistant" &&
message.role !== "user" &&
message.role !== "function" &&
Expand Down Expand Up @@ -249,6 +256,14 @@ function _convertDeltaToMessageChunk(
});
} else if (role === "system") {
return new SystemMessageChunk({ content, response_metadata });
} else if (role === "developer") {
return new SystemMessageChunk({
content,
response_metadata,
additional_kwargs: {
__openai_role__: "developer",
},
});
} else if (role === "function") {
return new FunctionMessageChunk({
content,
Expand All @@ -270,13 +285,18 @@ function _convertDeltaToMessageChunk(

// Used in LangSmith, export is important here
export function _convertMessagesToOpenAIParams(
messages: BaseMessage[]
messages: BaseMessage[],
model?: string
): OpenAICompletionParam[] {
// TODO: Function messages do not support array content, fix cast
return messages.flatMap((message) => {
let role = messageToOpenAIRole(message);
if (role === "system" && model?.startsWith("o1")) {
role = "developer";
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const completionParam: Record<string, any> = {
role: messageToOpenAIRole(message),
role,
content: message.content,
};
if (message.name != null) {
Expand Down Expand Up @@ -428,6 +448,12 @@ export interface ChatOpenAICallOptions
* [Learn more](https://platform.openai.com/docs/guides/latency-optimization#use-predicted-outputs).
*/
prediction?: OpenAIClient.ChatCompletionPredictionContent;

/**
* Constrains effort on reasoning for reasoning models. Currently supported values are low, medium, and high.
* Reducing reasoning effort can result in faster responses and fewer tokens used on reasoning in a response.
*/
reasoning_effort?: OpenAIClient.Chat.ChatCompletionReasoningEffort;
}

export interface ChatOpenAIFields
Expand Down Expand Up @@ -994,6 +1020,7 @@ export class ChatOpenAI<
"promptIndex",
"response_format",
"seed",
"reasoning_effort",
];
}

Expand Down Expand Up @@ -1092,6 +1119,8 @@ export class ChatOpenAI<

modalities?: Array<OpenAIClient.Chat.ChatCompletionModality>;

reasoningEffort?: OpenAIClient.Chat.ChatCompletionReasoningEffort;

constructor(
fields?: ChatOpenAIFields,
/** @deprecated */
Expand Down Expand Up @@ -1162,6 +1191,7 @@ export class ChatOpenAI<
this.__includeRawResponse = fields?.__includeRawResponse;
this.audio = fields?.audio;
this.modalities = fields?.modalities;
this.reasoningEffort = fields?.reasoningEffort;

if (this.azureOpenAIApiKey || this.azureADTokenProvider) {
if (
Expand Down Expand Up @@ -1337,6 +1367,10 @@ export class ChatOpenAI<
if (options?.prediction !== undefined) {
params.prediction = options.prediction;
}
const reasoningEffort = options?.reasoning_effort ?? this.reasoningEffort;
if (reasoningEffort !== undefined) {
params.reasoning_effort = reasoningEffort;
}
return params;
}

Expand All @@ -1360,7 +1394,7 @@ export class ChatOpenAI<
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
const messagesMapped: OpenAICompletionParam[] =
_convertMessagesToOpenAIParams(messages);
_convertMessagesToOpenAIParams(messages, this.model);
const params = {
...this.invocationParams(options, {
streaming: true,
Expand Down Expand Up @@ -1489,7 +1523,7 @@ export class ChatOpenAI<
const usageMetadata = {} as UsageMetadata;
const params = this.invocationParams(options);
const messagesMapped: OpenAICompletionParam[] =
_convertMessagesToOpenAIParams(messages);
_convertMessagesToOpenAIParams(messages, this.model);

if (params.stream) {
const stream = this._streamResponseChunks(messages, options, runManager);
Expand Down
21 changes: 18 additions & 3 deletions libs/langchain-openai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1157,8 +1157,6 @@ describe("Audio output", () => {
content: [userInput],
}),
]);
// console.log("userInputRes.content", userInputRes.content);
// console.log("userInputRes.additional_kwargs.audio", userInputRes.additional_kwargs.audio);
expect(userInputRes.additional_kwargs.audio).toBeTruthy();
expect(
(userInputRes.additional_kwargs.audio as Record<string, any>).transcript
Expand Down Expand Up @@ -1191,6 +1189,23 @@ test("Can stream o1 requests", async () => {
expect(finalMsg.content.length).toBeGreaterThanOrEqual(1);
}

// A
expect(numChunks).toBeGreaterThan(3);
});

test("Allows developer messages with o1", async () => {
const model = new ChatOpenAI({
model: "o1",
reasoningEffort: "low",
});
const res = await model.invoke([
{
role: "developer",
content: `Always respond only with the word "testing"`,
},
{
role: "user",
content: "hi",
},
]);
expect(res.content).toEqual("testing");
});
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,35 @@ test("withStructuredOutput zod schema function calling", async () => {
expect("number2" in result).toBe(true);
});

test("withStructuredOutput with o1", async () => {
const model = new ChatOpenAI({
model: "o1",
});

const calculatorSchema = z.object({
operation: z.enum(["add", "subtract", "multiply", "divide"]),
number1: z.number(),
number2: z.number(),
});
const modelWithStructuredOutput = model.withStructuredOutput(
calculatorSchema,
{
name: "calculator",
}
);

const prompt = ChatPromptTemplate.fromMessages([
["developer", "You are VERY bad at math and must always use a calculator."],
["human", "Please help me!! What is 2 + 2?"],
]);
const chain = prompt.pipe(modelWithStructuredOutput);
const result = await chain.invoke({});
// console.log(result);
expect("operation" in result).toBe(true);
expect("number1" in result).toBe(true);
expect("number2" in result).toBe(true);
});

test("withStructuredOutput zod schema streaming", async () => {
const model = new ChatOpenAI({
temperature: 0,
Expand Down
6 changes: 6 additions & 0 deletions libs/langchain-openai/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ export interface OpenAIChatInput extends OpenAIBaseInput {
* [Learn more](https://platform.openai.com/docs/guides/audio).
*/
audio?: OpenAIClient.Chat.ChatCompletionAudioParam;

/**
* Constrains effort on reasoning for reasoning models. Currently supported values are low, medium, and high.
* Reducing reasoning effort can result in faster responses and fewer tokens used on reasoning in a response.
*/
reasoningEffort?: OpenAIClient.ChatCompletionReasoningEffort;
}

export declare interface AzureOpenAIInput {
Expand Down
10 changes: 5 additions & 5 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -12904,7 +12904,7 @@ __metadata:
jest: ^29.5.0
jest-environment-node: ^29.6.4
js-tiktoken: ^1.0.12
openai: ^4.71.0
openai: ^4.77.0
prettier: ^2.8.3
release-it: ^17.6.0
rimraf: ^5.0.1
Expand Down Expand Up @@ -36216,9 +36216,9 @@ __metadata:
languageName: node
linkType: hard

"openai@npm:^4.71.0":
version: 4.71.0
resolution: "openai@npm:4.71.0"
"openai@npm:^4.77.0":
version: 4.77.0
resolution: "openai@npm:4.77.0"
dependencies:
"@types/node": ^18.11.18
"@types/node-fetch": ^2.6.4
Expand All @@ -36234,7 +36234,7 @@ __metadata:
optional: true
bin:
openai: bin/cli
checksum: ba4b3772e806c59b1ea1235a40486392c797906e45dd97914f2cd819b4be2996e207c7b7c67d43236692300354f4e9ffa8ebfca6e97d3555655ebf0f3f01e3f2
checksum: e311130e3b35a7dc924e7125cca3246b7ac958b3c451f2f4ef2cae72144c82429f9c6db7bf67059fc9e10f3911087ebf8a9a4b2919ac915235ed3c324897b146
languageName: node
linkType: hard

Expand Down
Loading