Skip to content

Commit

Permalink
feat(ts): evaluateComparative (#673)
Browse files Browse the repository at this point in the history
- **initial work on evaluate comparative**
- **Add create comparative experiment client method**
- **Add missing schema for create**
- **Wrap up a quick implementation**
  • Loading branch information
dqbd authored May 10, 2024
2 parents d12b958 + 970992f commit 0d07ad1
Show file tree
Hide file tree
Showing 8 changed files with 479 additions and 42 deletions.
64 changes: 64 additions & 0 deletions js/src/client.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import * as uuid from "uuid";

import { AsyncCaller, AsyncCallerParams } from "./utils/async_caller.js";
import {
ComparativeExperiment,
DataType,
Dataset,
DatasetDiffInfo,
Expand Down Expand Up @@ -187,6 +188,7 @@ interface FeedbackCreate {
feedback_source?: feedback_source | KVMap | null;
feedbackConfig?: FeedbackConfig;
session_id?: string;
comparative_experiment_id?: string;
}

interface FeedbackUpdate {
Expand Down Expand Up @@ -2233,6 +2235,7 @@ export class Client {
feedbackId,
feedbackConfig,
projectId,
comparativeExperimentId,
}: {
score?: ScoreType;
value?: ValueType;
Expand All @@ -2245,6 +2248,7 @@ export class Client {
feedbackId?: string;
eager?: boolean;
projectId?: string;
comparativeExperimentId?: string;
}
): Promise<Feedback> {
if (!runId && !projectId) {
Expand Down Expand Up @@ -2279,6 +2283,7 @@ export class Client {
correction,
comment,
feedback_source: feedback_source,
comparative_experiment_id: comparativeExperimentId,
feedbackConfig,
session_id: projectId,
};
Expand Down Expand Up @@ -2449,6 +2454,65 @@ export class Client {
return result as FeedbackIngestToken;
}

public async createComparativeExperiment({
name,
experimentIds,
referenceDatasetId,
createdAt,
description,
metadata,
id,
}: {
name: string;
experimentIds: Array<string>;
referenceDatasetId?: string;
createdAt?: Date;
description?: string;
metadata?: Record<string, unknown>;
id?: string;
}): Promise<ComparativeExperiment> {
if (experimentIds.length === 0) {
throw new Error("At least one experiment is required");
}

if (!referenceDatasetId) {
referenceDatasetId = (
await this.readProject({
projectId: experimentIds[0],
})
).reference_dataset_id;
}

if (!referenceDatasetId == null) {
throw new Error("A reference dataset is required");
}

const body = {
id,
name,
experiment_ids: experimentIds,
reference_dataset_id: referenceDatasetId,
description,
created_at: (createdAt ?? new Date())?.toISOString(),
extra: {} as Record<string, unknown>,
};

if (metadata) body.extra["metadata"] = metadata;

const response = await this.caller.call(
fetch,
`${this.apiUrl}/datasets/comparative`,
{
method: "POST",
headers: { ...this.headers, "Content-Type": "application/json" },
body: JSON.stringify(body),
signal: AbortSignal.timeout(this.timeout_ms),
...this.fetchOptions,
}
);
return await response.json();
}

/**
* Retrieves a list of presigned feedback tokens for a given run ID.
* @param runId The ID of the run.
Expand Down
59 changes: 40 additions & 19 deletions js/src/evaluation/_runner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,16 @@ import { v4 as uuidv4 } from "uuid";
type TargetT =
| ((input: KVMap, config?: KVMap) => Promise<KVMap>)
| ((input: KVMap, config?: KVMap) => KVMap)
| {
invoke: (input: KVMap, config?: KVMap) => KVMap;
}
| {
invoke: (input: KVMap, config?: KVMap) => Promise<KVMap>;
};
| { invoke: (input: KVMap, config?: KVMap) => KVMap }
| { invoke: (input: KVMap, config?: KVMap) => Promise<KVMap> };

type TargetNoInvoke =
| ((input: KVMap, config?: KVMap) => Promise<KVMap>)
| ((input: KVMap, config?: KVMap) => KVMap);

// Data format: dataset-name, dataset_id, or examples
type DataT = string | AsyncIterable<Example> | Example[];

// Summary evaluator runs over the whole dataset
// and reports aggregate metric(s)
type SummaryEvaluatorT =
Expand All @@ -41,6 +39,7 @@ type SummaryEvaluatorT =
runs: Array<Run>,
examples: Array<Example>
) => EvaluationResult | EvaluationResults);

// Row-level evaluator
type EvaluatorT =
| RunEvaluator
Expand Down Expand Up @@ -650,14 +649,39 @@ class _ExperimentManager {
const examples = await this.getExamples();
const modifiedAt = examples.map((ex) => ex.modified_at);

const maxModifiedAt =
modifiedAt.length > 0
? new Date(
Math.max(...modifiedAt.map((date) => new Date(date).getTime()))
)
: undefined;
// Python might return microseconds, which we need
// to account for when comparing dates.
const modifiedAtTime = modifiedAt.map((date) => {
function getMiliseconds(isoString: string) {
const time = isoString.split("T").at(1);
if (!time) return "";

const regex = /[0-9]{2}:[0-9]{2}:[0-9]{2}.([0-9]+)/;
const strMiliseconds = time.match(regex)?.[1];
return strMiliseconds ?? "";
}

const jsDate = new Date(date);

return maxModifiedAt?.toISOString();
let source = getMiliseconds(date);
let parsed = getMiliseconds(jsDate.toISOString());

const length = Math.max(source.length, parsed.length);
source = source.padEnd(length, "0");
parsed = parsed.padEnd(length, "0");

const microseconds =
(Number.parseInt(source, 10) - Number.parseInt(parsed, 10)) / 1000;

const time = jsDate.getTime() + microseconds;
return { date, time };
});

if (modifiedAtTime.length === 0) return undefined;
return modifiedAtTime.reduce(
(max, current) => (current.time > max.time ? current : max),
modifiedAtTime[0]
).date;
}

async _end(): Promise<void> {
Expand Down Expand Up @@ -735,9 +759,7 @@ function convertInvokeToTopLevel(fn: TargetT): TargetNoInvoke {

async function _evaluate(
target: TargetT | AsyncGenerator<Run>,
fields: EvaluateOptions & {
experiment?: TracerSession;
}
fields: EvaluateOptions & { experiment?: TracerSession }
): Promise<ExperimentResults> {
const client = fields.client ?? new Client();
const runs = _isCallable(target) ? null : (target as AsyncGenerator<Run>);
Expand All @@ -759,11 +781,10 @@ async function _evaluate(
if (_isCallable(target)) {
manager = await manager.withPredictions(
convertInvokeToTopLevel(target as TargetT),
{
maxConcurrency: fields.maxConcurrency,
}
{ maxConcurrency: fields.maxConcurrency }
);
}

if (fields.evaluators) {
manager = await manager.withEvaluators(fields.evaluators, {
maxConcurrency: fields.maxConcurrency,
Expand Down
Loading

0 comments on commit 0d07ad1

Please sign in to comment.