Skip to content

Commit

Permalink
Add multi project query support (OR) (#439)
Browse files Browse the repository at this point in the history
  • Loading branch information
hinthornw authored Feb 14, 2024
1 parent 878ba07 commit dc8e7db
Show file tree
Hide file tree
Showing 9 changed files with 237 additions and 298 deletions.
4 changes: 2 additions & 2 deletions js/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
"devDependencies": {
"@babel/preset-env": "^7.22.4",
"@jest/globals": "^29.5.0",
"@langchain/core": "^0.1.28",
"@tsconfig/recommended": "^1.0.2",
"@types/jest": "^29.5.1",
"@typescript-eslint/eslint-plugin": "^5.59.8",
Expand All @@ -75,7 +76,6 @@
"eslint-plugin-no-instanceof": "^1.0.1",
"eslint-plugin-prettier": "^4.2.1",
"jest": "^29.5.0",
"langchain": "^0.0.147",
"prettier": "^2.8.8",
"ts-jest": "^29.1.0",
"ts-node": "^10.9.1",
Expand Down Expand Up @@ -122,4 +122,4 @@
},
"./package.json": "./package.json"
}
}
}
71 changes: 63 additions & 8 deletions js/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ interface ClientConfig {
}

interface ListRunsParams {
projectId?: string;
projectName?: string;
projectId?: string | string[];
projectName?: string | string[];
executionOrder?: number;
parentRunId?: string;
referenceExampleId?: string;
Expand Down Expand Up @@ -818,15 +818,23 @@ export class Client {
filter,
limit,
}: ListRunsParams): AsyncIterable<Run> {
let projectId_ = projectId;
let projectIds: string[] = [];
if (projectId) {
projectIds = Array.isArray(projectId) ? projectId : [projectId];
}
if (projectName) {
if (projectId) {
throw new Error("Only one of projectId or projectName may be given");
}
projectId_ = (await this.readProject({ projectName })).id;
const projectNames = Array.isArray(projectName)
? projectName
: [projectName];
const projectIds_ = await Promise.all(
projectNames.map((name) =>
this.readProject({ projectName: name }).then((project) => project.id)
)
);
projectIds.push(...projectIds_);
}
const body = {
session: projectId_ ? [projectId_] : null,
session: projectIds.length ? projectIds : null,
run_type: runType,
reference_example: referenceExampleId,
query,
Expand Down Expand Up @@ -1110,6 +1118,53 @@ export class Client {
return result as TracerSession;
}

public async hasProject({
projectId,
projectName,
}: {
projectId?: string;
projectName?: string;
}): Promise<boolean> {
// TODO: Add a head request
let path = "/sessions";
const params = new URLSearchParams();
if (projectId !== undefined && projectName !== undefined) {
throw new Error("Must provide either projectName or projectId, not both");
} else if (projectId !== undefined) {
assertUuid(projectId);
path += `/${projectId}`;
} else if (projectName !== undefined) {
params.append("name", projectName);
} else {
throw new Error("Must provide projectName or projectId");
}
const response = await this.caller.call(
fetch,
`${this.apiUrl}${path}?${params}`,
{
method: "GET",
headers: this.headers,
signal: AbortSignal.timeout(this.timeout_ms),
}
);
// consume the response body to release the connection
// https://undici.nodejs.org/#/?id=garbage-collection
try {
const result = await response.json();
if (!response.ok) {
return false;
}
// If it's OK and we're querying by name, need to check the list is not empty
if (Array.isArray(result)) {
return result.length > 0;
}
// projectId querying
return true;
} catch (e) {
return false;
}
}

public async readProject({
projectId,
projectName,
Expand Down
4 changes: 4 additions & 0 deletions js/src/tests/batch_client.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ test.concurrent(
const langchainClient = new Client({
autoBatchTracing: true,
callerOptions: { maxRetries: 0 },
timeout_ms: 30_000,
});
const projectName = "__test_persist_update_run_batch_1";
await deleteProject(langchainClient, projectName);
Expand Down Expand Up @@ -97,6 +98,7 @@ test.concurrent(
autoBatchTracing: true,
callerOptions: { maxRetries: 0 },
pendingAutoBatchedRunLimit: 2,
timeout_ms: 30_000,
});
const projectName = "__test_persist_update_run_batch_above_bs_limit";
await deleteProject(langchainClient, projectName);
Expand Down Expand Up @@ -141,6 +143,7 @@ test.concurrent(
const langchainClient = new Client({
autoBatchTracing: true,
callerOptions: { maxRetries: 0 },
timeout_ms: 30_000,
});
const projectName = "__test_persist_update_run_batch_with_delay";
await deleteProject(langchainClient, projectName);
Expand Down Expand Up @@ -181,6 +184,7 @@ test.concurrent(
const langchainClient = new Client({
autoBatchTracing: true,
callerOptions: { maxRetries: 0 },
timeout_ms: 30_000,
});
const projectName = "__test_persist_update_run_tree";
await deleteProject(langchainClient, projectName);
Expand Down
2 changes: 1 addition & 1 deletion js/src/tests/client.int.test.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { Dataset, Run } from "../schemas.js";
import { FunctionMessage, HumanMessage } from "langchain/schema";
import { FunctionMessage, HumanMessage } from "@langchain/core/messages";

import { Client } from "../client.js";
import { v4 as uuidv4 } from "uuid";
Expand Down
89 changes: 89 additions & 0 deletions js/src/tests/run_trees.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,25 @@ async function waitUntil(
throw new Error("Timeout");
}

async function pollRunsUntilCount(
client: Client,
projectName: string,
count: number
): Promise<void> {
await waitUntil(
async () => {
try {
const runs = await toArray(client.listRuns({ projectName }));
return runs.length === count;
} catch (e) {
return false;
}
},
120_000, // Wait up to 120 seconds
3000 // every 3 second
);
}

test.concurrent(
"Test post and patch run",
async () => {
Expand Down Expand Up @@ -130,3 +149,73 @@ test.concurrent(
},
120_000
);

test.concurrent(
"Test list runs multi project",
async () => {
const projectNames = [
"__My JS Tracer Project - test_list_runs_multi_project",
"__My JS Tracer Project - test_list_runs_multi_project2",
];

try {
const langchainClient = new Client({ timeout_ms: 30000 });

for (const project of projectNames) {
if (await langchainClient.hasProject({ projectName: project })) {
await langchainClient.deleteProject({ projectName: project });
}
}

const parentRunConfig: RunTreeConfig = {
name: "parent_run",
inputs: { text: "hello world" },
project_name: projectNames[0],
client: langchainClient,
};

const parent_run = new RunTree(parentRunConfig);
await parent_run.postRun();
await parent_run.end({ output: "Completed: foo" });
await parent_run.patchRun();

const parentRunConfig2: RunTreeConfig = {
name: "parent_run",
inputs: { text: "hello world" },
project_name: projectNames[1],
client: langchainClient,
};

const parent_run2 = new RunTree(parentRunConfig2);
await parent_run2.postRun();
await parent_run2.end({ output: "Completed: foo" });
await parent_run2.patchRun();
await pollRunsUntilCount(langchainClient, projectNames[0], 1);
await pollRunsUntilCount(langchainClient, projectNames[1], 1);

const runsIter = langchainClient.listRuns({
projectName: projectNames,
});
const runs = await toArray(runsIter);

expect(runs.length).toBe(2);
expect(
runs.every((run) => run?.outputs?.["output"] === "Completed: foo")
).toBe(true);
expect(runs[0].session_id).not.toBe(runs[1].session_id);
} finally {
const langchainClient = new Client();

for (const project of projectNames) {
if (await langchainClient.hasProject({ projectName: project })) {
try {
await langchainClient.deleteProject({ projectName: project });
} catch (e) {
console.debug(e);
}
}
}
}
},
120_000
);
4 changes: 2 additions & 2 deletions js/src/utils/async_caller.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,10 @@ export class AsyncCaller {
}
}
},
retries: this.maxRetries,
randomize: true,
// If needed we can change some of the defaults here,
// but they're quite sensible.
retries: this.maxRetries,
randomize: true,
}
),
{ throwOnTimeout: true }
Expand Down
Loading

0 comments on commit dc8e7db

Please sign in to comment.