Skip to content

Commit

Permalink
Add generic JS wrapClient method (#485)
Browse files Browse the repository at this point in the history
  • Loading branch information
jacoblee93 authored Mar 6, 2024
1 parent 8a67e50 commit 5d93731
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 7 deletions.
32 changes: 32 additions & 0 deletions js/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,38 @@ export async function POST(req: Request) {

See the [AI SDK docs](https://sdk.vercel.ai/docs) for more examples.

## Arbitrary SDKs

You can use the generic `wrapSDK` method to add tracing for arbitrary SDKs.

Do note that this will trace ALL methods in the SDK, not just chat completion endpoints.
If the SDK you are wrapping has other methods, we recommend using it for only LLM calls.

Here's an example using the Anthropic SDK:

```ts
import { wrapSDK } from "langsmith/wrappers";
import { Anthropic } from "@anthropic-ai/sdk";

const originalSDK = new Anthropic();
const sdkWithTracing = wrapSDK(originalSDK);

const response = await sdkWithTracing.messages.create({
messages: [
{
role: "user",
content: `What is 1 + 1? Respond only with "2" and nothing else.`,
},
],
model: "claude-3-sonnet-20240229",
max_tokens: 1024,
});
```

:::tip
[Click here](https://smith.langchain.com/public/0e7248af-bbed-47cf-be9f-5967fea1dec1/r) to see an example LangSmith trace of the above.
:::

#### Alternatives: **Log traces using a RunTree.**

A RunTree tracks your application. Each RunTree object is required to have a name and run_type. These and other important attributes are as follows:
Expand Down
2 changes: 1 addition & 1 deletion js/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -138,4 +138,4 @@
},
"./package.json": "./package.json"
}
}
}
65 changes: 65 additions & 0 deletions js/src/tests/wrapped_sdk.int.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import { jest } from "@jest/globals";
import { OpenAI } from "openai";
import { wrapSDK } from "../wrappers.js";
import { Client } from "../client.js";

test.concurrent("chat.completions", async () => {
const client = new Client();
const callSpy = jest
// eslint-disable-next-line @typescript-eslint/no-explicit-any
.spyOn((client as any).caller, "call")
.mockResolvedValue({ ok: true, text: () => "" });

const originalClient = new OpenAI();
const patchedClient = wrapSDK(new OpenAI(), { client });

// invoke
const original = await originalClient.chat.completions.create({
messages: [{ role: "user", content: `Say 'foo'` }],
temperature: 0,
seed: 42,
model: "gpt-3.5-turbo",
});

const patched = await patchedClient.chat.completions.create({
messages: [{ role: "user", content: `Say 'foo'` }],
temperature: 0,
seed: 42,
model: "gpt-3.5-turbo",
});

expect(patched.choices).toEqual(original.choices);

// stream
const originalStream = await originalClient.chat.completions.create({
messages: [{ role: "user", content: `Say 'foo'` }],
temperature: 0,
seed: 42,
model: "gpt-3.5-turbo",
stream: true,
});

const originalChoices = [];
for await (const chunk of originalStream) {
originalChoices.push(chunk.choices);
}

const patchedStream = await patchedClient.chat.completions.create({
messages: [{ role: "user", content: `Say 'foo'` }],
temperature: 0,
seed: 42,
model: "gpt-3.5-turbo",
stream: true,
});

const patchedChoices = [];
for await (const chunk of patchedStream) {
patchedChoices.push(chunk.choices);
}

expect(patchedChoices).toEqual(originalChoices);
for (const call of callSpy.mock.calls) {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
expect((call[2] as any)["method"]).toBe("POST");
}
});
82 changes: 76 additions & 6 deletions js/src/wrappers.ts
Original file line number Diff line number Diff line change
@@ -1,22 +1,92 @@
import type { OpenAI } from "openai";
import type { Client } from "./index.js";
import { traceable } from "./traceable.js";

export const wrapOpenAI = (
openai: OpenAI,
type OpenAIType = {
chat: {
completions: {
create: (...args: any[]) => any;
};
};
completions: {
create: (...args: any[]) => any;
};
};

/**
* Wraps an OpenAI client's completion methods, enabling automatic LangSmith
* tracing. Method signatures are unchanged.
* @param openai An OpenAI client instance.
* @param options LangSmith options.
* @returns
*/
export const wrapOpenAI = <T extends OpenAIType>(
openai: T,
options?: { client?: Client }
): OpenAI => {
// @ts-expect-error Promise<APIPromise<...>> != APIPromise<...>
): T => {
openai.chat.completions.create = traceable(
openai.chat.completions.create.bind(openai.chat.completions),
Object.assign({ name: "ChatOpenAI", run_type: "llm" }, options?.client)
);

// @ts-expect-error Promise<APIPromise<...>> != APIPromise<...>
openai.completions.create = traceable(
openai.completions.create.bind(openai.completions),
Object.assign({ name: "OpenAI", run_type: "llm" }, options?.client)
);

return openai;
};

const _wrapClient = <T extends object>(
sdk: T,
runName: string,
options?: { client?: Client }
): T => {
return new Proxy(sdk, {
get(target, propKey, receiver) {
const originalValue = target[propKey as keyof T];
if (typeof originalValue === "function") {
return traceable(
originalValue.bind(target),
Object.assign(
{ name: [runName, propKey.toString()].join("."), run_type: "llm" },
options?.client
)
);
} else if (
originalValue != null &&
!Array.isArray(originalValue) &&
// eslint-disable-next-line no-instanceof/no-instanceof
!(originalValue instanceof Date) &&
typeof originalValue === "object"
) {
return _wrapClient(
originalValue,
[runName, propKey.toString()].join("."),
options
);
} else {
return Reflect.get(target, propKey, receiver);
}
},
});
};

/**
* Wrap an arbitrary SDK, enabling automatic LangSmith tracing.
* Method signatures are unchanged.
*
* Note that this will wrap and trace ALL SDK methods, not just
* LLM completion methods. If the passed SDK contains other methods,
* we recommend using the wrapped instance for LLM calls only.
* @param sdk An arbitrary SDK instance.
* @param options LangSmith options.
* @returns
*/
export const wrapSDK = <T extends object>(
sdk: T,
options?: { client?: Client; runName?: string }
): T => {
return _wrapClient(sdk, options?.runName ?? sdk.constructor?.name, {
client: options?.client,
});
};

0 comments on commit 5d93731

Please sign in to comment.