-
Notifications
You must be signed in to change notification settings - Fork 84
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add generic JS wrapClient method (#485)
CC @dqbd @hinthornw
- Loading branch information
1 parent
8a67e50
commit 5d93731
Showing
4 changed files
with
174 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -138,4 +138,4 @@ | |
}, | ||
"./package.json": "./package.json" | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import { jest } from "@jest/globals"; | ||
import { OpenAI } from "openai"; | ||
import { wrapSDK } from "../wrappers.js"; | ||
import { Client } from "../client.js"; | ||
|
||
test.concurrent("chat.completions", async () => { | ||
const client = new Client(); | ||
const callSpy = jest | ||
// eslint-disable-next-line @typescript-eslint/no-explicit-any | ||
.spyOn((client as any).caller, "call") | ||
.mockResolvedValue({ ok: true, text: () => "" }); | ||
|
||
const originalClient = new OpenAI(); | ||
const patchedClient = wrapSDK(new OpenAI(), { client }); | ||
|
||
// invoke | ||
const original = await originalClient.chat.completions.create({ | ||
messages: [{ role: "user", content: `Say 'foo'` }], | ||
temperature: 0, | ||
seed: 42, | ||
model: "gpt-3.5-turbo", | ||
}); | ||
|
||
const patched = await patchedClient.chat.completions.create({ | ||
messages: [{ role: "user", content: `Say 'foo'` }], | ||
temperature: 0, | ||
seed: 42, | ||
model: "gpt-3.5-turbo", | ||
}); | ||
|
||
expect(patched.choices).toEqual(original.choices); | ||
|
||
// stream | ||
const originalStream = await originalClient.chat.completions.create({ | ||
messages: [{ role: "user", content: `Say 'foo'` }], | ||
temperature: 0, | ||
seed: 42, | ||
model: "gpt-3.5-turbo", | ||
stream: true, | ||
}); | ||
|
||
const originalChoices = []; | ||
for await (const chunk of originalStream) { | ||
originalChoices.push(chunk.choices); | ||
} | ||
|
||
const patchedStream = await patchedClient.chat.completions.create({ | ||
messages: [{ role: "user", content: `Say 'foo'` }], | ||
temperature: 0, | ||
seed: 42, | ||
model: "gpt-3.5-turbo", | ||
stream: true, | ||
}); | ||
|
||
const patchedChoices = []; | ||
for await (const chunk of patchedStream) { | ||
patchedChoices.push(chunk.choices); | ||
} | ||
|
||
expect(patchedChoices).toEqual(originalChoices); | ||
for (const call of callSpy.mock.calls) { | ||
// eslint-disable-next-line @typescript-eslint/no-explicit-any | ||
expect((call[2] as any)["method"]).toBe("POST"); | ||
} | ||
}); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,22 +1,92 @@ | ||
import type { OpenAI } from "openai"; | ||
import type { Client } from "./index.js"; | ||
import { traceable } from "./traceable.js"; | ||
|
||
export const wrapOpenAI = ( | ||
openai: OpenAI, | ||
type OpenAIType = { | ||
chat: { | ||
completions: { | ||
create: (...args: any[]) => any; | ||
}; | ||
}; | ||
completions: { | ||
create: (...args: any[]) => any; | ||
}; | ||
}; | ||
|
||
/** | ||
* Wraps an OpenAI client's completion methods, enabling automatic LangSmith | ||
* tracing. Method signatures are unchanged. | ||
* @param openai An OpenAI client instance. | ||
* @param options LangSmith options. | ||
* @returns | ||
*/ | ||
export const wrapOpenAI = <T extends OpenAIType>( | ||
openai: T, | ||
options?: { client?: Client } | ||
): OpenAI => { | ||
// @ts-expect-error Promise<APIPromise<...>> != APIPromise<...> | ||
): T => { | ||
openai.chat.completions.create = traceable( | ||
openai.chat.completions.create.bind(openai.chat.completions), | ||
Object.assign({ name: "ChatOpenAI", run_type: "llm" }, options?.client) | ||
); | ||
|
||
// @ts-expect-error Promise<APIPromise<...>> != APIPromise<...> | ||
openai.completions.create = traceable( | ||
openai.completions.create.bind(openai.completions), | ||
Object.assign({ name: "OpenAI", run_type: "llm" }, options?.client) | ||
); | ||
|
||
return openai; | ||
}; | ||
|
||
const _wrapClient = <T extends object>( | ||
sdk: T, | ||
runName: string, | ||
options?: { client?: Client } | ||
): T => { | ||
return new Proxy(sdk, { | ||
get(target, propKey, receiver) { | ||
const originalValue = target[propKey as keyof T]; | ||
if (typeof originalValue === "function") { | ||
return traceable( | ||
originalValue.bind(target), | ||
Object.assign( | ||
{ name: [runName, propKey.toString()].join("."), run_type: "llm" }, | ||
options?.client | ||
) | ||
); | ||
} else if ( | ||
originalValue != null && | ||
!Array.isArray(originalValue) && | ||
// eslint-disable-next-line no-instanceof/no-instanceof | ||
!(originalValue instanceof Date) && | ||
typeof originalValue === "object" | ||
) { | ||
return _wrapClient( | ||
originalValue, | ||
[runName, propKey.toString()].join("."), | ||
options | ||
); | ||
} else { | ||
return Reflect.get(target, propKey, receiver); | ||
} | ||
}, | ||
}); | ||
}; | ||
|
||
/** | ||
* Wrap an arbitrary SDK, enabling automatic LangSmith tracing. | ||
* Method signatures are unchanged. | ||
* | ||
* Note that this will wrap and trace ALL SDK methods, not just | ||
* LLM completion methods. If the passed SDK contains other methods, | ||
* we recommend using the wrapped instance for LLM calls only. | ||
* @param sdk An arbitrary SDK instance. | ||
* @param options LangSmith options. | ||
* @returns | ||
*/ | ||
export const wrapSDK = <T extends object>( | ||
sdk: T, | ||
options?: { client?: Client; runName?: string } | ||
): T => { | ||
return _wrapClient(sdk, options?.runName ?? sdk.constructor?.name, { | ||
client: options?.client, | ||
}); | ||
}; |