Skip to content

Commit

Permalink
Merge pull request #16 from jxnl/main-class-update
Browse files Browse the repository at this point in the history
Implementation alternative
  • Loading branch information
jxnl authored Jan 2, 2024
2 parents 9e81b2a + 174e923 commit 3893e1a
Show file tree
Hide file tree
Showing 5 changed files with 297 additions and 179 deletions.
2 changes: 0 additions & 2 deletions examples/extract_user/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,9 @@ const client = Instructor({
mode: "FUNCTIONS"
})

//@ts-expect-error these types wont work since were using a proxy and just returning the OAI instance type
const user: User = await client.chat.completions.create({
messages: [{ role: "user", content: "Jason Liu is 30 years old" }],
model: "gpt-3.5-turbo",
//@ts-expect-error same as above
response_model: UserSchema,
max_retries: 3
})
Expand Down
337 changes: 162 additions & 175 deletions src/instructor.ts
Original file line number Diff line number Diff line change
@@ -1,202 +1,189 @@
import assert from "assert"
import OpenAI from "openai"
import {
ChatCompletion,
ChatCompletionCreateParams,
ChatCompletionMessage
} from "openai/resources/index.mjs"
import { ZodSchema } from "zod"
import { JsonSchema7Type, zodToJsonSchema } from "zod-to-json-schema"
OAIBuildFunctionParams,
OAIBuildMessageBasedParams,
OAIBuildToolFunctionParams
} from "@/oai/params"
import {
OAIResponseFnArgsParser,
OAIResponseJSONStringParser,
OAIResponseToolArgsParser
} from "@/oai/parser"
import OpenAI from "openai"
import { ChatCompletion, ChatCompletionCreateParamsNonStreaming } from "openai/resources/index.mjs"
import { ZodObject } from "zod"
import zodToJsonSchema from "zod-to-json-schema"

import { MODE } from "@/constants/modes"

export class OpenAISchema {
private response_model: ReturnType<typeof zodToJsonSchema>
constructor(public zod_schema: ZodSchema) {
this.response_model = zodToJsonSchema(zod_schema)
}

get definitions() {
return this.response_model["definitions"]
}

get properties() {
return this.response_model["properties"]
}
const MODE_TO_PARSER = {
[MODE.FUNCTIONS]: OAIResponseFnArgsParser,
[MODE.TOOLS]: OAIResponseToolArgsParser,
[MODE.JSON]: OAIResponseJSONStringParser,
[MODE.MD_JSON]: OAIResponseJSONStringParser,
[MODE.JSON_SCHEMA]: OAIResponseJSONStringParser
}

get openai_schema() {
return {
name: this.response_model["title"] || "schema",
description:
this.response_model["description"] ||
`Correctly extracted \`${
this.response_model["title"] || "schema"
}\` with all the required parameters with correct types`,
parameters: Object.keys(this.response_model).reduce(
(acc, curr) => {
if (
curr.startsWith("$") ||
["title", "description", "additionalProperties"].includes(curr)
)
return acc
acc[curr] = this.response_model[curr]
return acc
},
{} as {
[key: string]: object | JsonSchema7Type
}
)
}
}
const MODE_TO_PARAMS = {
[MODE.FUNCTIONS]: OAIBuildFunctionParams,
[MODE.TOOLS]: OAIBuildToolFunctionParams,
[MODE.JSON]: OAIBuildMessageBasedParams,
[MODE.MD_JSON]: OAIBuildMessageBasedParams,
[MODE.JSON_SCHEMA]: OAIBuildMessageBasedParams
}

type PatchedChatCompletionCreateParams = ChatCompletionCreateParams & {
response_model?: ZodSchema | OpenAISchema
type PatchedChatCompletionCreateParams = ChatCompletionCreateParamsNonStreaming & {
//eslint-disable-next-line @typescript-eslint/no-explicit-any
response_model?: ZodObject<any>
max_retries?: number
}

function handleResponseModel(
response_model: ZodSchema | OpenAISchema,
args: PatchedChatCompletionCreateParams[],
mode: MODE = "FUNCTIONS"
): [OpenAISchema, PatchedChatCompletionCreateParams[], MODE] {
if (!(response_model instanceof OpenAISchema)) {
response_model = new OpenAISchema(response_model)
class Instructor {
readonly client: OpenAI
readonly mode: MODE

/**
* Creates an instance of the `Instructor` class.
* @param {OpenAI} client - The OpenAI client.
* @param {string} mode - The mode of operation.
*/
constructor({ client, mode }: { client: OpenAI; mode: MODE }) {
this.client = client
this.mode = mode
}

if (mode === MODE.FUNCTIONS) {
args[0].functions = [response_model.openai_schema]
args[0].function_call = { name: response_model.openai_schema.name }
} else if (mode === MODE.TOOLS) {
args[0].tools = [{ type: "function", function: response_model.openai_schema }]
args[0].tool_choice = {
type: "function",
function: { name: response_model.openai_schema.name }
/**
* Handles chat completion with retries.
* @param {PatchedChatCompletionCreateParams} params - The parameters for chat completion.
* @returns {Promise<any>} The response from the chat completion.
*/
chatCompletion = async ({ max_retries = 3, ...params }: PatchedChatCompletionCreateParams) => {
let attempts = 0
let validationIssues = []
let lastMessage = null

const completionParams = this.buildChatCompletionParams(params)

const makeCompletionCall = async () => {
let resolvedParams = completionParams

try {
if (validationIssues.length > 0) {
resolvedParams = {
...completionParams,
messages: [
...completionParams.messages,
...(lastMessage ? [lastMessage] : []),
{
role: "system",
content: `Your last response had the following validation issues, please try again: ${validationIssues.join(
", "
)}`
}
]
}
}

const completion = await this.client.chat.completions.create(resolvedParams)
const response = this.parseOAIResponse(completion)

return response
} catch (error) {
throw error
}
}
} else if ([MODE.JSON, MODE.MD_JSON, MODE.JSON_SCHEMA].includes(mode)) {
let message: string = `As a genius expert, your task is to understand the content and provide the parsed objects in json that match the following json_schema: \n${JSON.stringify(
response_model.properties
)}`
if (response_model["definitions"]) {
message += `Here are some more definitions to adhere to: \n${JSON.stringify(
response_model.definitions
)}`

const makeCompletionCallWithRetries = async () => {
try {
const data = await makeCompletionCall()
const validation = params.response_model.safeParse(data)

if (!validation.success) {
if ("error" in validation) {
lastMessage = {
role: "assistant",
content: JSON.stringify(data)
}

validationIssues = validation.error.issues.map(issue => issue.message)
throw validation.error
} else {
throw new Error("Validation failed.")
}
}

return data
} catch (error) {
if (attempts < max_retries) {
attempts++
return await makeCompletionCallWithRetries()
} else {
throw error
}
}
}
if (mode === MODE.JSON) {
args[0].response_format = { type: "json_object" }
} else if (mode == MODE.JSON_SCHEMA) {
args[0].response_format = { type: "json_object" }
} else if (mode === MODE.MD_JSON) {
args[0].messages.push({
role: "assistant",
content: "```json"
})
args[0].stop = "```"

return await makeCompletionCallWithRetries()
}

/**
* Builds the chat completion parameters.
* @param {PatchedChatCompletionCreateParams} params - The parameters for chat completion.
* @returns {ChatCompletionCreateParamsNonStreaming} The chat completion parameters.
*/
private buildChatCompletionParams = ({
response_model,
...params
}: PatchedChatCompletionCreateParams): ChatCompletionCreateParamsNonStreaming => {
const jsonSchema = zodToJsonSchema(response_model, "response_model")

const definition = {
name: "response_model",
...jsonSchema.definitions.response_model
}
if (args[0].messages[0].role != "system") {
args[0].messages.unshift({ role: "system", content: message })
} else {
args[0].messages[0].content += `\n${message}`

const paramsForMode = MODE_TO_PARAMS[this.mode](definition, params, this.mode)

return {
stream: false,
...paramsForMode
}
} else {
console.error("unknown mode", mode)
}
return [response_model, args, mode]
}

function processResponse(
response: OpenAI.Chat.Completions.ChatCompletion,
response_model: OpenAISchema,
mode: MODE = "FUNCTIONS"
) {
const message = response.choices[0].message
if (mode === MODE.FUNCTIONS) {
assert.equal(
message.function_call!.name,
response_model.openai_schema.name,
"Function name does not match"
)
return response_model.zod_schema.parse(JSON.parse(message.function_call!.arguments))
} else if (mode === MODE.TOOLS) {
const tool_call = message.tool_calls![0]
assert.equal(
tool_call.function.name,
response_model.openai_schema.name,
"Tool name does not match"
)
return response_model.zod_schema.parse(JSON.parse(tool_call.function.arguments))
} else if ([MODE.JSON, MODE.MD_JSON, MODE.JSON_SCHEMA].includes(mode)) {
return response_model.zod_schema.parse(JSON.parse(message.content!))
} else {
console.error("unknown mode", mode)
}
}
/**
* Parses the OAI response.
* @param {ChatCompletion} response - The response from the chat completion.
* @returns {any} The parsed response.
*/
private parseOAIResponse = (response: ChatCompletion) => {
const parser = MODE_TO_PARSER[this.mode]

function dumpMessage(message: ChatCompletionMessage) {
const ret: ChatCompletionMessage = {
role: message.role,
content: message.content || ""
return parser(response)
}
if (message.tool_calls) {
ret["content"] += JSON.stringify(message.tool_calls)
}
if (message.function_call) {
ret["content"] += JSON.stringify(message.function_call)

/**
* Public chat interface.
*/
public chat = {
completions: {
create: this.chatCompletion
}
}
return ret
}

const patch = ({
client,
mode
}: {
client: OpenAI
response_model?: ZodSchema | OpenAISchema
max_retries?: number
mode?: MODE
}): OpenAI => {
client.chat.completions.create = new Proxy(client.chat.completions.create, {
async apply(target, ctx, args: PatchedChatCompletionCreateParams[]) {
const max_retries = args[0].max_retries || 1
let retries = 0,
response: ChatCompletion | undefined = undefined,
response_model = args[0].response_model
;[response_model, args, mode] = handleResponseModel(response_model!, args, mode)

delete args[0].response_model
delete args[0].max_retries

while (retries < max_retries) {
try {
response = (await target.apply(
ctx,
args as [PatchedChatCompletionCreateParams]
)) as ChatCompletion
return processResponse(response, response_model as OpenAISchema, mode)
} catch (error) {
console.error(error.errors || error)
if (!response) {
break
}
args[0].messages.push(dumpMessage(response.choices[0].message))
args[0].messages.push({
role: "user",
content: `Recall the function correctly, fix the errors, exceptions found\n${error}`
})
if (mode == MODE.MD_JSON) {
args[0].messages.push({ role: "assistant", content: "```json" })
}
retries++
if (retries > max_retries) {
throw error
}
} finally {
response = undefined
}
type OAIClientExtended = OpenAI & Instructor

export default function (args: { client: OpenAI; mode: MODE }): OAIClientExtended {
const instructor = new Instructor(args)

const instructorWithProxy = new Proxy(instructor, {
get: (target, prop, receiver) => {
if (prop in target) {
return Reflect.get(target, prop, receiver)
}

return Reflect.get(target.client, prop, receiver)
}
})
return client
}

export default patch
return instructorWithProxy as OAIClientExtended
}
Loading

0 comments on commit 3893e1a

Please sign in to comment.