Skip to content

Commit

Permalink
feat: 🎉 add stats aggregation for generation results
Browse files Browse the repository at this point in the history
  • Loading branch information
pelikhan committed Oct 12, 2024
1 parent 1fd31b5 commit 156e3ab
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 12 deletions.
47 changes: 46 additions & 1 deletion packages/cli/src/run.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { capitalize } from "inflection"
import { resolve, join, relative, dirname } from "node:path"
import { isQuiet, wrapColor } from "./log"
import { emptyDir, ensureDir, appendFileSync } from "fs-extra"
import { emptyDir, ensureDir, appendFileSync, exists } from "fs-extra"
import { convertDiagnosticsToSARIF } from "./sarif"
import { buildProject } from "./build"
import { diagnosticsToCSV } from "../../core/src/ast"
Expand Down Expand Up @@ -34,6 +34,7 @@ import {
TRACE_DETAILS,
CLI_ENV_VAR_RX,
AGENT_MEMORY_CACHE_NAME,
STATS_DIR_NAME,
} from "../../core/src/constants"
import { isCancelError, errorMessage } from "../../core/src/error"
import { Fragment, GenerationResult } from "../../core/src/generation"
Expand Down Expand Up @@ -74,6 +75,8 @@ import { prettifyMarkdown } from "../../core/src/markdown"
import { delay } from "es-toolkit"
import { GenerationStats } from "../../core/src/usage"
import { traceAgentMemory } from "../../core/src/agent"
import { JSONLineCache } from "../../core/src/cache"
import { appendFile, stat } from "node:fs/promises"

function parseVars(
vars: string[],
Expand Down Expand Up @@ -340,6 +343,7 @@ export async function runScript(
}
if (!isQuiet) logVerbose("") // force new line

await aggregateResults(scriptId, outTrace, result)
await traceAgentMemory(trace)
if (outAnnotations && result.annotations?.length) {
if (isJSONLFilename(outAnnotations))
Expand Down Expand Up @@ -544,3 +548,44 @@ export async function runScript(

return { exitCode: 0, result }
}

async function aggregateResults(
scriptId: string,
outTrace: string,
result: GenerationResult
) {
const statsDir = dotGenaiscriptPath(".")
await ensureDir(statsDir)
const statsFile = path.join(statsDir, "stats.csv")
if (!(await exists(statsFile)))
await writeFile(
statsFile,
[
"script",
"status",
"cost",
"total_tokens",
"prompt_tokens",
"completion_tokens",
"trace",
"version",
].join(",") + "\n",
{ encoding: "utf-8" }
)
await appendFile(
statsFile,
[
scriptId,
result.status,
result.stats.cost,
result.stats.total_tokens,
result.stats.prompt_tokens,
result.stats.completion_tokens,
path.basename(outTrace),
result.version,
]
.map((s) => String(s))
.join(",") + "\n",
{ encoding: "utf-8" }
)
}
1 change: 1 addition & 0 deletions packages/core/src/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ export const PROMPTFOO_CONFIG_DIR = ".genaiscript/config/tests"
export const PROMPTFOO_REMOTE_API_PORT = 15500

export const RUNS_DIR_NAME = "runs"
export const STATS_DIR_NAME = "stats"

export const EMOJI_SUCCESS = "✅"
export const EMOJI_FAIL = "❌"
Expand Down
13 changes: 12 additions & 1 deletion packages/core/src/generation.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
// Import necessary modules and interfaces
import { CancellationToken } from "./cancellation"
import { LanguageModel } from "./chat"
import { ChatCompletionMessageParam, ChatCompletionsOptions } from "./chattypes"
import {
ChatCompletionMessageParam,
ChatCompletionsOptions,
ChatCompletionUsage,
} from "./chattypes"
import { MarkdownTrace } from "./trace"
import { GenerationStats } from "./usage"

Expand Down Expand Up @@ -66,6 +70,13 @@ export interface GenerationResult extends GenerationOutput {
* Version of the GenAIScript used
*/
version: string

/**
* Statistics of the generation
*/
stats: {
cost: number
} & ChatCompletionUsage
}

// Type representing possible statuses of generation
Expand Down
4 changes: 4 additions & 0 deletions packages/core/src/promptrunner.ts
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,10 @@ export async function runTemplate(
genVars,
schemas,
json,
stats: {
cost: options.stats.cost(),
...options.stats.usage,
},
}

// If there's an error, provide status text
Expand Down
20 changes: 10 additions & 10 deletions packages/core/src/usage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import {

/**
* Estimates the cost of a chat completion based on the model and usage.
*
*
* @param modelId - The identifier of the model used for chat completion.
* @param usage - The usage statistics for the chat completion.
* @returns The estimated cost or undefined if estimation is not possible.
Expand Down Expand Up @@ -57,7 +57,7 @@ export function estimateCost(modelId: string, usage: ChatCompletionUsage) {

/**
* Renders the cost as a string for display purposes.
*
*
* @param value - The cost to be rendered.
* @returns A string representation of the cost.
*/
Expand Down Expand Up @@ -89,7 +89,7 @@ export class GenerationStats {

/**
* Constructs a GenerationStats instance.
*
*
* @param model - The model used for chat completions.
* @param label - Optional label for the statistics.
*/
Expand All @@ -113,7 +113,7 @@ export class GenerationStats {

/**
* Calculates the total cost based on the usage statistics.
*
*
* @returns The total cost.
*/
cost(): number {
Expand All @@ -125,7 +125,7 @@ export class GenerationStats {

/**
* Accumulates the usage statistics from this instance and its children.
*
*
* @returns The accumulated usage statistics.
*/
accumulatedUsage(): ChatCompletionUsage {
Expand All @@ -149,7 +149,7 @@ export class GenerationStats {

/**
* Creates a new child GenerationStats instance.
*
*
* @param model - The model used for the child chat completions.
* @param label - Optional label for the child's statistics.
* @returns The created child GenerationStats instance.
Expand All @@ -162,7 +162,7 @@ export class GenerationStats {

/**
* Traces the generation statistics using a MarkdownTrace instance.
*
*
* @param trace - The MarkdownTrace instance used for tracing.
*/
trace(trace: MarkdownTrace) {
Expand All @@ -176,7 +176,7 @@ export class GenerationStats {

/**
* Helper method to trace individual statistics.
*
*
* @param trace - The MarkdownTrace instance used for tracing.
*/
private traceStats(trace: MarkdownTrace) {
Expand Down Expand Up @@ -218,7 +218,7 @@ export class GenerationStats {

/**
* Helper method to log tokens with indentation.
*
*
* @param indent - The indentation used for logging.
*/
private logTokens(indent: string) {
Expand All @@ -240,7 +240,7 @@ export class GenerationStats {

/**
* Adds usage statistics to the current instance.
*
*
* @param req - The request containing details about the chat completion.
* @param usage - The usage statistics to be added.
*/
Expand Down

0 comments on commit 156e3ab

Please sign in to comment.