Skip to content

Commit

Permalink
Add wrapAISDKModel method for Vercel's AI SDK (#896)
Browse files Browse the repository at this point in the history
Improve support for tracing Vercel AI SDK calls with a new
`wrapAISDKModel` method

Because AI SDK calls nest the text/object stream within the response,
this PR adds an optional parameter to `traceable` that specifies a key
in the response object to wrap. It also adds support for specially
tapping `ReadableStream` so that `traceable` does not change them to
plain async iterators.

Here's an example:

```ts
import { anthropic } from "@ai-sdk/anthropic";
import { streamObject } from "ai";
import { wrapAISDKModel } from "langsmith/wrappers/vercel";

const modelWithTracing = wrapAISDKModel(anthropic("claude-3-haiku-20240307"));
const { partialObjectStream } = await streamObject({
  model: modelWithTracing,
  prompt: "Write a vegetarian lasagna recipe for 4 people.",
  schema: z.object({
    ingredients: z.array(z.string()),
  }),
});
for await (const chunk of partialObjectStream) {
  console.log(chunk);
}
```

TODO: improve support in the UI for properly rendering token counts and
aggregating chunks. The data is returned but is not in one of our
currently supported formats.

Example traces:

Generate text:
https://smith.langchain.com/public/fbd12847-9485-43cf-a0a3-82c0b3318594/r
Generate object:
https://smith.langchain.com/public/c25944c7-2428-44d0-991b-ef3b5e47cab5/r
Stream text:
https://smith.langchain.com/public/2fb34c85-fec5-4361-a487-ffcf67718750/r
Stream object:
https://smith.langchain.com/public/27f5c9ba-006d-4a2f-b569-2e8da09dae23/r

CC @dqbd
  • Loading branch information
jacoblee93 authored Jul 30, 2024
2 parents 546a36f + 31cb73b commit d020931
Show file tree
Hide file tree
Showing 11 changed files with 568 additions and 73 deletions.
4 changes: 4 additions & 0 deletions js/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ Chinook_Sqlite.sql
/wrappers/openai.js
/wrappers/openai.d.ts
/wrappers/openai.d.cts
/wrappers/vercel.cjs
/wrappers/vercel.js
/wrappers/vercel.d.ts
/wrappers/vercel.d.cts
/singletons/traceable.cjs
/singletons/traceable.js
/singletons/traceable.d.ts
Expand Down
20 changes: 18 additions & 2 deletions js/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@
"wrappers/openai.js",
"wrappers/openai.d.ts",
"wrappers/openai.d.cts",
"wrappers/vercel.cjs",
"wrappers/vercel.js",
"wrappers/vercel.d.ts",
"wrappers/vercel.d.cts",
"singletons/traceable.cjs",
"singletons/traceable.js",
"singletons/traceable.d.ts",
Expand Down Expand Up @@ -101,17 +105,18 @@
"uuid": "^9.0.0"
},
"devDependencies": {
"@ai-sdk/openai": "^0.0.40",
"@babel/preset-env": "^7.22.4",
"@faker-js/faker": "^8.4.1",
"@jest/globals": "^29.5.0",
"langchain": "^0.2.10",
"@langchain/core": "^0.2.17",
"@langchain/langgraph": "^0.0.29",
"@langchain/openai": "^0.2.5",
"@tsconfig/recommended": "^1.0.2",
"@types/jest": "^29.5.1",
"@typescript-eslint/eslint-plugin": "^5.59.8",
"@typescript-eslint/parser": "^5.59.8",
"ai": "^3.2.37",
"babel-jest": "^29.5.0",
"cross-env": "^7.0.3",
"dotenv": "^16.1.3",
Expand All @@ -121,11 +126,13 @@
"eslint-plugin-no-instanceof": "^1.0.1",
"eslint-plugin-prettier": "^4.2.1",
"jest": "^29.5.0",
"langchain": "^0.2.10",
"openai": "^4.38.5",
"prettier": "^2.8.8",
"ts-jest": "^29.1.0",
"ts-node": "^10.9.1",
"typescript": "^5.4.5"
"typescript": "^5.4.5",
"zod": "^3.23.8"
},
"peerDependencies": {
"@langchain/core": "*",
Expand Down Expand Up @@ -249,6 +256,15 @@
"import": "./wrappers/openai.js",
"require": "./wrappers/openai.cjs"
},
"./wrappers/vercel": {
"types": {
"import": "./wrappers/vercel.d.ts",
"require": "./wrappers/vercel.d.cts",
"default": "./wrappers/vercel.d.ts"
},
"import": "./wrappers/vercel.js",
"require": "./wrappers/vercel.cjs"
},
"./singletons/traceable": {
"types": {
"import": "./singletons/traceable.d.ts",
Expand Down
1 change: 1 addition & 0 deletions js/scripts/create-entrypoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ const entrypoints = {
wrappers: "wrappers/index",
anonymizer: "anonymizer/index",
"wrappers/openai": "wrappers/openai",
"wrappers/vercel": "wrappers/vercel",
"singletons/traceable": "singletons/traceable",
};

Expand Down
77 changes: 77 additions & 0 deletions js/src/tests/wrapped_ai_sdk.int.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import { openai } from "@ai-sdk/openai";
import {
generateObject,
generateText,
streamObject,
streamText,
tool,
} from "ai";
import { z } from "zod";
import { wrapAISDKModel } from "../wrappers/vercel.js";

test("AI SDK generateText", async () => {
const modelWithTracing = wrapAISDKModel(openai("gpt-4o-mini"));
const { text } = await generateText({
model: modelWithTracing,
prompt: "Write a vegetarian lasagna recipe for 4 people.",
});
console.log(text);
});

test("AI SDK generateText with a tool", async () => {
const modelWithTracing = wrapAISDKModel(openai("gpt-4o-mini"));
const { text } = await generateText({
model: modelWithTracing,
prompt:
"Write a vegetarian lasagna recipe for 4 people. Get ingredients first.",
tools: {
getIngredients: tool({
description: "get a list of ingredients",
parameters: z.object({
ingredients: z.array(z.string()),
}),
execute: async () =>
JSON.stringify(["pasta", "tomato", "cheese", "onions"]),
}),
},
maxToolRoundtrips: 2,
});
console.log(text);
});

test("AI SDK generateObject", async () => {
const modelWithTracing = wrapAISDKModel(openai("gpt-4o-mini"));
const { object } = await generateObject({
model: modelWithTracing,
prompt: "Write a vegetarian lasagna recipe for 4 people.",
schema: z.object({
ingredients: z.array(z.string()),
}),
});
console.log(object);
});

test("AI SDK streamText", async () => {
const modelWithTracing = wrapAISDKModel(openai("gpt-4o-mini"));
const { textStream } = await streamText({
model: modelWithTracing,
prompt: "Write a vegetarian lasagna recipe for 4 people.",
});
for await (const chunk of textStream) {
console.log(chunk);
}
});

test("AI SDK streamObject", async () => {
const modelWithTracing = wrapAISDKModel(openai("gpt-4o-mini"));
const { partialObjectStream } = await streamObject({
model: modelWithTracing,
prompt: "Write a vegetarian lasagna recipe for 4 people.",
schema: z.object({
ingredients: z.array(z.string()),
}),
});
for await (const chunk of partialObjectStream) {
console.log(chunk);
}
});
93 changes: 92 additions & 1 deletion js/src/traceable.ts
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ export function traceable<Func extends (...args: any[]) => any>(
// eslint-disable-next-line @typescript-eslint/no-explicit-any
aggregator?: (args: any[]) => any;
argsConfigPath?: [number] | [number, string];
__finalTracedIteratorKey?: string;

/**
* Extract invocation parameters from the arguments of the traced function.
Expand All @@ -294,7 +295,12 @@ export function traceable<Func extends (...args: any[]) => any>(
}
) {
type Inputs = Parameters<Func>;
const { aggregator, argsConfigPath, ...runTreeConfig } = config ?? {};
const {
aggregator,
argsConfigPath,
__finalTracedIteratorKey,
...runTreeConfig
} = config ?? {};

const traceableFunc = (
...args: Inputs | [RunTree, ...Inputs] | [RunnableConfigLike, ...Inputs]
Expand Down Expand Up @@ -434,6 +440,47 @@ export function traceable<Func extends (...args: any[]) => any>(
return chunks;
}

function tapReadableStreamForTracing(
stream: ReadableStream<unknown>,
snapshot: ReturnType<typeof AsyncLocalStorage.snapshot> | undefined
) {
const reader = stream.getReader();
let finished = false;
const chunks: unknown[] = [];

const tappedStream = new ReadableStream({
async start(controller) {
// eslint-disable-next-line no-constant-condition
while (true) {
const result = await (snapshot
? snapshot(() => reader.read())
: reader.read());
if (result.done) {
finished = true;
await currentRunTree?.end(
handleRunOutputs(await handleChunks(chunks))
);
await handleEnd();
controller.close();
break;
}
chunks.push(result.value);
controller.enqueue(result.value);
}
},
async cancel(reason) {
if (!finished) await currentRunTree?.end(undefined, "Cancelled");
await currentRunTree?.end(
handleRunOutputs(await handleChunks(chunks))
);
await handleEnd();
return reader.cancel(reason);
},
});

return tappedStream;
}

async function* wrapAsyncIteratorForTracing(
iterator: AsyncIterator<unknown, unknown, undefined>,
snapshot: ReturnType<typeof AsyncLocalStorage.snapshot> | undefined
Expand Down Expand Up @@ -463,10 +510,14 @@ export function traceable<Func extends (...args: any[]) => any>(
await handleEnd();
}
}

function wrapAsyncGeneratorForTracing(
iterable: AsyncIterable<unknown>,
snapshot: ReturnType<typeof AsyncLocalStorage.snapshot> | undefined
) {
if (isReadableStream(iterable)) {
return tapReadableStreamForTracing(iterable, snapshot);
}
const iterator = iterable[Symbol.asyncIterator]();
const wrappedIterator = wrapAsyncIteratorForTracing(iterator, snapshot);
iterable[Symbol.asyncIterator] = () => wrappedIterator;
Expand Down Expand Up @@ -512,6 +563,25 @@ export function traceable<Func extends (...args: any[]) => any>(
return wrapAsyncGeneratorForTracing(returnValue, snapshot);
}

if (
!Array.isArray(returnValue) &&
typeof returnValue === "object" &&
returnValue != null &&
__finalTracedIteratorKey !== undefined &&
isAsyncIterable(
(returnValue as Record<string, any>)[__finalTracedIteratorKey]
)
) {
const snapshot = AsyncLocalStorage.snapshot();
return {
...returnValue,
[__finalTracedIteratorKey]: wrapAsyncGeneratorForTracing(
(returnValue as Record<string, any>)[__finalTracedIteratorKey],
snapshot
),
};
}

const tracedPromise = new Promise<unknown>((resolve, reject) => {
Promise.resolve(returnValue)
.then(
Expand All @@ -523,6 +593,27 @@ export function traceable<Func extends (...args: any[]) => any>(
);
}

if (
!Array.isArray(rawOutput) &&
typeof rawOutput === "object" &&
rawOutput != null &&
__finalTracedIteratorKey !== undefined &&
isAsyncIterable(
(rawOutput as Record<string, any>)[__finalTracedIteratorKey]
)
) {
const snapshot = AsyncLocalStorage.snapshot();
return {
...rawOutput,
[__finalTracedIteratorKey]: wrapAsyncGeneratorForTracing(
(rawOutput as Record<string, any>)[
__finalTracedIteratorKey
],
snapshot
),
};
}

if (isGenerator(wrappedFunc) && isIteratorLike(rawOutput)) {
const chunks = gatherAll(rawOutput);

Expand Down
72 changes: 72 additions & 0 deletions js/src/wrappers/generic.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import type { RunTreeConfig } from "../index.js";
import { traceable } from "../traceable.js";

export const _wrapClient = <T extends object>(
sdk: T,
runName: string,
options?: Omit<RunTreeConfig, "name">
): T => {
return new Proxy(sdk, {
get(target, propKey, receiver) {
const originalValue = target[propKey as keyof T];
if (typeof originalValue === "function") {
return traceable(originalValue.bind(target), {
run_type: "llm",
...options,
name: [runName, propKey.toString()].join("."),
});
} else if (
originalValue != null &&
!Array.isArray(originalValue) &&
// eslint-disable-next-line no-instanceof/no-instanceof
!(originalValue instanceof Date) &&
typeof originalValue === "object"
) {
return _wrapClient(
originalValue,
[runName, propKey.toString()].join("."),
options
);
} else {
return Reflect.get(target, propKey, receiver);
}
},
});
};

type WrapSDKOptions = Partial<
RunTreeConfig & {
/**
* @deprecated Use `name` instead.
*/
runName: string;
}
>;

/**
* Wrap an arbitrary SDK, enabling automatic LangSmith tracing.
* Method signatures are unchanged.
*
* Note that this will wrap and trace ALL SDK methods, not just
* LLM completion methods. If the passed SDK contains other methods,
* we recommend using the wrapped instance for LLM calls only.
* @param sdk An arbitrary SDK instance.
* @param options LangSmith options.
* @returns
*/
export const wrapSDK = <T extends object>(
sdk: T,
options?: WrapSDKOptions
): T => {
const traceableOptions = options ? { ...options } : undefined;
if (traceableOptions != null) {
delete traceableOptions.runName;
delete traceableOptions.name;
}

return _wrapClient(
sdk,
options?.name ?? options?.runName ?? sdk.constructor?.name,
traceableOptions
);
};
1 change: 1 addition & 0 deletions js/src/wrappers/index.ts
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
export * from "./openai.js";
export { wrapSDK } from "./generic.js";
Loading

0 comments on commit d020931

Please sign in to comment.