Skip to content

Commit

Permalink
feat(openai): Alias system messages as developer messages for o1, add…
Browse files Browse the repository at this point in the history
… reasoning_effort param (#7398)
  • Loading branch information
jacoblee93 authored Dec 19, 2024
1 parent a47583a commit c8a1cdf
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 14 deletions.
2 changes: 1 addition & 1 deletion libs/langchain-openai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
"license": "MIT",
"dependencies": {
"js-tiktoken": "^1.0.12",
"openai": "^4.71.0",
"openai": "^4.77.0",
"zod": "^3.22.4",
"zod-to-json-schema": "^3.22.3"
},
Expand Down
44 changes: 39 additions & 5 deletions libs/langchain-openai/src/chat_models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,13 @@ interface OpenAILLMOutput {
}

// TODO import from SDK when available
type OpenAIRoleEnum = "system" | "assistant" | "user" | "function" | "tool";
type OpenAIRoleEnum =
| "system"
| "developer"
| "assistant"
| "user"
| "function"
| "tool";

type OpenAICompletionParam =
OpenAIClient.Chat.Completions.ChatCompletionMessageParam;
Expand All @@ -105,6 +111,7 @@ type OpenAIFnCallOption = OpenAIClient.Chat.ChatCompletionFunctionCallOption;
function extractGenericMessageCustomRole(message: ChatMessage) {
if (
message.role !== "system" &&
message.role !== "developer" &&
message.role !== "assistant" &&
message.role !== "user" &&
message.role !== "function" &&
Expand Down Expand Up @@ -249,6 +256,14 @@ function _convertDeltaToMessageChunk(
});
} else if (role === "system") {
return new SystemMessageChunk({ content, response_metadata });
} else if (role === "developer") {
return new SystemMessageChunk({
content,
response_metadata,
additional_kwargs: {
__openai_role__: "developer",
},
});
} else if (role === "function") {
return new FunctionMessageChunk({
content,
Expand All @@ -270,13 +285,18 @@ function _convertDeltaToMessageChunk(

// Used in LangSmith, export is important here
export function _convertMessagesToOpenAIParams(
messages: BaseMessage[]
messages: BaseMessage[],
model?: string
): OpenAICompletionParam[] {
// TODO: Function messages do not support array content, fix cast
return messages.flatMap((message) => {
let role = messageToOpenAIRole(message);
if (role === "system" && model?.startsWith("o1")) {
role = "developer";
}
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const completionParam: Record<string, any> = {
role: messageToOpenAIRole(message),
role,
content: message.content,
};
if (message.name != null) {
Expand Down Expand Up @@ -428,6 +448,12 @@ export interface ChatOpenAICallOptions
* [Learn more](https://platform.openai.com/docs/guides/latency-optimization#use-predicted-outputs).
*/
prediction?: OpenAIClient.ChatCompletionPredictionContent;

/**
* Constrains effort on reasoning for reasoning models. Currently supported values are low, medium, and high.
* Reducing reasoning effort can result in faster responses and fewer tokens used on reasoning in a response.
*/
reasoning_effort?: OpenAIClient.Chat.ChatCompletionReasoningEffort;
}

export interface ChatOpenAIFields
Expand Down Expand Up @@ -994,6 +1020,7 @@ export class ChatOpenAI<
"promptIndex",
"response_format",
"seed",
"reasoning_effort",
];
}

Expand Down Expand Up @@ -1092,6 +1119,8 @@ export class ChatOpenAI<

modalities?: Array<OpenAIClient.Chat.ChatCompletionModality>;

reasoningEffort?: OpenAIClient.Chat.ChatCompletionReasoningEffort;

constructor(
fields?: ChatOpenAIFields,
/** @deprecated */
Expand Down Expand Up @@ -1162,6 +1191,7 @@ export class ChatOpenAI<
this.__includeRawResponse = fields?.__includeRawResponse;
this.audio = fields?.audio;
this.modalities = fields?.modalities;
this.reasoningEffort = fields?.reasoningEffort;

if (this.azureOpenAIApiKey || this.azureADTokenProvider) {
if (
Expand Down Expand Up @@ -1337,6 +1367,10 @@ export class ChatOpenAI<
if (options?.prediction !== undefined) {
params.prediction = options.prediction;
}
const reasoningEffort = options?.reasoning_effort ?? this.reasoningEffort;
if (reasoningEffort !== undefined) {
params.reasoning_effort = reasoningEffort;
}
return params;
}

Expand All @@ -1360,7 +1394,7 @@ export class ChatOpenAI<
runManager?: CallbackManagerForLLMRun
): AsyncGenerator<ChatGenerationChunk> {
const messagesMapped: OpenAICompletionParam[] =
_convertMessagesToOpenAIParams(messages);
_convertMessagesToOpenAIParams(messages, this.model);
const params = {
...this.invocationParams(options, {
streaming: true,
Expand Down Expand Up @@ -1489,7 +1523,7 @@ export class ChatOpenAI<
const usageMetadata = {} as UsageMetadata;
const params = this.invocationParams(options);
const messagesMapped: OpenAICompletionParam[] =
_convertMessagesToOpenAIParams(messages);
_convertMessagesToOpenAIParams(messages, this.model);

if (params.stream) {
const stream = this._streamResponseChunks(messages, options, runManager);
Expand Down
21 changes: 18 additions & 3 deletions libs/langchain-openai/src/tests/chat_models.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1157,8 +1157,6 @@ describe("Audio output", () => {
content: [userInput],
}),
]);
// console.log("userInputRes.content", userInputRes.content);
// console.log("userInputRes.additional_kwargs.audio", userInputRes.additional_kwargs.audio);
expect(userInputRes.additional_kwargs.audio).toBeTruthy();
expect(
(userInputRes.additional_kwargs.audio as Record<string, any>).transcript
Expand Down Expand Up @@ -1191,6 +1189,23 @@ test("Can stream o1 requests", async () => {
expect(finalMsg.content.length).toBeGreaterThanOrEqual(1);
}

// A
expect(numChunks).toBeGreaterThan(3);
});

test("Allows developer messages with o1", async () => {
const model = new ChatOpenAI({
model: "o1",
reasoningEffort: "low",
});
const res = await model.invoke([
{
role: "developer",
content: `Always respond only with the word "testing"`,
},
{
role: "user",
content: "hi",
},
]);
expect(res.content).toEqual("testing");
});
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,35 @@ test("withStructuredOutput zod schema function calling", async () => {
expect("number2" in result).toBe(true);
});

test("withStructuredOutput with o1", async () => {
const model = new ChatOpenAI({
model: "o1",
});

const calculatorSchema = z.object({
operation: z.enum(["add", "subtract", "multiply", "divide"]),
number1: z.number(),
number2: z.number(),
});
const modelWithStructuredOutput = model.withStructuredOutput(
calculatorSchema,
{
name: "calculator",
}
);

const prompt = ChatPromptTemplate.fromMessages([
["developer", "You are VERY bad at math and must always use a calculator."],
["human", "Please help me!! What is 2 + 2?"],
]);
const chain = prompt.pipe(modelWithStructuredOutput);
const result = await chain.invoke({});
// console.log(result);
expect("operation" in result).toBe(true);
expect("number1" in result).toBe(true);
expect("number2" in result).toBe(true);
});

test("withStructuredOutput zod schema streaming", async () => {
const model = new ChatOpenAI({
temperature: 0,
Expand Down
6 changes: 6 additions & 0 deletions libs/langchain-openai/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,12 @@ export interface OpenAIChatInput extends OpenAIBaseInput {
* [Learn more](https://platform.openai.com/docs/guides/audio).
*/
audio?: OpenAIClient.Chat.ChatCompletionAudioParam;

/**
* Constrains effort on reasoning for reasoning models. Currently supported values are low, medium, and high.
* Reducing reasoning effort can result in faster responses and fewer tokens used on reasoning in a response.
*/
reasoningEffort?: OpenAIClient.ChatCompletionReasoningEffort;
}

export declare interface AzureOpenAIInput {
Expand Down
10 changes: 5 additions & 5 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -12904,7 +12904,7 @@ __metadata:
jest: ^29.5.0
jest-environment-node: ^29.6.4
js-tiktoken: ^1.0.12
openai: ^4.71.0
openai: ^4.77.0
prettier: ^2.8.3
release-it: ^17.6.0
rimraf: ^5.0.1
Expand Down Expand Up @@ -36216,9 +36216,9 @@ __metadata:
languageName: node
linkType: hard

"openai@npm:^4.71.0":
version: 4.71.0
resolution: "openai@npm:4.71.0"
"openai@npm:^4.77.0":
version: 4.77.0
resolution: "openai@npm:4.77.0"
dependencies:
"@types/node": ^18.11.18
"@types/node-fetch": ^2.6.4
Expand All @@ -36234,7 +36234,7 @@ __metadata:
optional: true
bin:
openai: bin/cli
checksum: ba4b3772e806c59b1ea1235a40486392c797906e45dd97914f2cd819b4be2996e207c7b7c67d43236692300354f4e9ffa8ebfca6e97d3555655ebf0f3f01e3f2
checksum: e311130e3b35a7dc924e7125cca3246b7ac958b3c451f2f4ef2cae72144c82429f9c6db7bf67059fc9e10f3911087ebf8a9a4b2919ac915235ed3c324897b146
languageName: node
linkType: hard

Expand Down

0 comments on commit c8a1cdf

Please sign in to comment.