Skip to content

Commit

Permalink
add the splits to metadata in evaluate
Browse files Browse the repository at this point in the history
  • Loading branch information
samnoyes committed May 18, 2024
1 parent 27294a2 commit 97393b2
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 5 deletions.
16 changes: 16 additions & 0 deletions js/src/evaluation/_runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -694,13 +694,29 @@ class _ExperimentManager {
).date;
}

async _getDatasetSplits(): Promise<string[] | undefined> {
const examples = await this.getExamples();
const allSplits = examples.reduce((acc, ex) => {
if (ex.metadata && ex.metadata.dataset_split) {
if (Array.isArray(ex.metadata.dataset_split)) {
ex.metadata.dataset_split.forEach((split) => acc.add(split));
} else if (typeof ex.metadata.dataset_split === "string") {
acc.add(ex.metadata.dataset_split);
}
}
return acc;
}, new Set<string>());
return allSplits.size ? Array.from(allSplits) : undefined;
}

async _end(): Promise<void> {
const experiment = this._experiment;
if (!experiment) {
throw new Error("Experiment not yet started.");
}
const projectMetadata = await this._getExperimentMetadata();
projectMetadata["dataset_version"] = await this._getDatasetVersion();
projectMetadata["dataset_splits"] = await this._getDatasetSplits();
// Update revision_id if not already set
if (!projectMetadata["revision_id"]) {
projectMetadata["revision_id"] = await getDefaultRevisionId();
Expand Down
4 changes: 2 additions & 2 deletions js/src/tests/client.int.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ test.concurrent("Test LangSmith Client Dataset CRD", async () => {
// Says 'example updated' or something similar
const newExampleValue = await client.readExample(example.id);
expect(newExampleValue.inputs.col1).toBe("updatedExampleCol1");
expect(newExampleValue.metadata?.dataset_split).toBe(["my_split2"]);
expect(newExampleValue.metadata?.dataset_split).toStrictEqual(["my_split2"]);

await client.updateExample(example.id, {
inputs: { col1: "updatedExampleCol3" },
Expand All @@ -112,7 +112,7 @@ test.concurrent("Test LangSmith Client Dataset CRD", async () => {
// Says 'example updated' or something similar
const newExampleValue2 = await client.readExample(example.id);
expect(newExampleValue2.inputs.col1).toBe("updatedExampleCol3");
expect(newExampleValue2.metadata?.dataset_split).toBe(["my_split3"]);
expect(newExampleValue2.metadata?.dataset_split).toStrictEqual(["my_split3"]);
await client.deleteExample(example.id);
const examples2 = await toArray(
client.listExamples({ datasetId: newDataset.id })
Expand Down
85 changes: 84 additions & 1 deletion js/src/tests/evaluate.int.test.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { EvaluationResult } from "../evaluation/evaluator.js";
import { evaluate } from "../evaluation/_runner.js";
import { Example, Run } from "../schemas.js";
import { Example, Run, TracerSession } from "../schemas.js";
import { Client } from "../index.js";
import { afterAll, beforeAll } from "@jest/globals";
import { RunnableLambda } from "@langchain/core/runnables";
Expand Down Expand Up @@ -30,6 +30,13 @@ afterAll(async () => {
await client.deleteDataset({
datasetName: TESTING_DATASET_NAME,
});
try {
await client.deleteDataset({
datasetName: "my_splits_ds2",
});
} catch (_) {
//pass
}
});

test("evaluate can evaluate", async () => {
Expand Down Expand Up @@ -351,6 +358,82 @@ test("can pass multiple evaluators", async () => {
);
});

test("split info saved correctly", async () => {
const client = new Client();
// create a new dataset
await client.createDataset("my_splits_ds2", {
description:
"For testing purposed. Is created & deleted for each test run.",
});
// create examples
await client.createExamples({
inputs: [{ input: 1 }, { input: 2 }, { input: 3 }],
outputs: [{ output: 2 }, { output: 3 }, { output: 4 }],
splits: [["test"], ["train"], ["validation", "test"]],
datasetName: "my_splits_ds2",
});

const targetFunc = (input: Record<string, any>) => {
console.log("__input__", input);
return {
foo: input.input + 1,
};
};
await evaluate(targetFunc, {
data: client.listExamples({ datasetName: "my_splits_ds2" }),
description: "splits info saved correctly",
});

const exp = client.listProjects({ referenceDatasetName: "my_splits_ds2" });
let myExp: TracerSession | null = null;
for await (const session of exp) {
myExp = session;
}
expect(myExp?.extra?.metadata?.dataset_splits.sort()).toEqual(
["test", "train", "validation"].sort()
);

await evaluate(targetFunc, {
data: client.listExamples({
datasetName: "my_splits_ds2",
splits: ["test"],
}),
description: "splits info saved correctly",
});

const exp2 = client.listProjects({ referenceDatasetName: "my_splits_ds2" });
let myExp2: TracerSession | null = null;
for await (const session of exp2) {
if (myExp2 === null || session.start_time > myExp2.start_time) {
myExp2 = session;
}
}

expect(myExp2?.extra?.metadata?.dataset_splits.sort()).toEqual(
["test", "validation"].sort()
);

await evaluate(targetFunc, {
data: client.listExamples({
datasetName: "my_splits_ds2",
splits: ["train"],
}),
description: "splits info saved correctly",
});

const exp3 = client.listProjects({ referenceDatasetName: "my_splits_ds2" });
let myExp3: TracerSession | null = null;
for await (const session of exp3) {
if (myExp3 === null || session.start_time > myExp3.start_time) {
myExp3 = session;
}
}

expect(myExp3?.extra?.metadata?.dataset_splits.sort()).toEqual(
["train"].sort()
);
});

test("can pass multiple summary evaluators", async () => {
const targetFunc = (input: Record<string, any>) => {
console.log("__input__", input);
Expand Down
18 changes: 18 additions & 0 deletions python/langsmith/evaluation/_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,13 +1322,31 @@ def _get_dataset_version(self) -> Optional[str]:
max_modified_at = max(modified_at) if modified_at else None
return max_modified_at.isoformat() if max_modified_at else None

def _get_dataset_splits(self) -> Optional[list[str]]:
examples = list(self.examples)
splits = set()
for example in examples:
if (
example.metadata
and example.metadata.get("dataset_split")
and isinstance(example.metadata["dataset_split"], list)
):
for split in example.metadata["dataset_split"]:
if isinstance(split, str):
splits.add(split)
else:
splits.add("base")

return list(splits)

def _end(self) -> None:
experiment = self._experiment
if experiment is None:
raise ValueError("Experiment not started yet.")

project_metadata = self._get_experiment_metadata()
project_metadata["dataset_version"] = self._get_dataset_version()
project_metadata["dataset_splits"] = self._get_dataset_splits()
self.client.update_project(
experiment.id,
end_time=datetime.datetime.now(datetime.timezone.utc),
Expand Down
4 changes: 2 additions & 2 deletions python/tests/integration_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,15 @@ def test_list_examples(langchain_client: Client) -> None:
example.id
for example in example_list
if example.metadata is not None
and example.metadata.get("dataset_split") == "test"
and "test" in example.metadata.get("dataset_split", [])
][0],
split="train",
)

example_list = list(
langchain_client.list_examples(dataset_id=dataset.id, splits=["test"])
)
assert len(example_list) == 2
assert len(example_list) == 1

example_list = list(
langchain_client.list_examples(dataset_id=dataset.id, splits=["train"])
Expand Down

0 comments on commit 97393b2

Please sign in to comment.