diff --git a/package.json b/package.json index 689ec45d..7e999933 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "@ax-llm/ax", - "version": "9.0.9", + "version": "9.0.10", "type": "module", "description": "The best library to work with LLMs", "typings": "build/module/src/index.d.ts", diff --git a/src/ai/base.ts b/src/ai/base.ts index ebe36da6..1d2bf86b 100644 --- a/src/ai/base.ts +++ b/src/ai/base.ts @@ -16,7 +16,8 @@ import type { AxEmbedResponse, AxModelConfig, AxModelInfo, - AxModelInfoWithProvider + AxModelInfoWithProvider, + AxTokenUsage } from './types.js'; const colorLog = new ColorLog(); @@ -74,6 +75,9 @@ export class AxBaseAI< private fetch?: AxAIServiceOptions['fetch']; private tracer?: AxAIServiceOptions['tracer']; + private modelUsage?: AxTokenUsage; + private embedModelUsage?: AxTokenUsage; + protected apiURL: string; protected name: string; protected headers: Record; @@ -247,7 +251,9 @@ export class AxBaseAI< logChatRequest(req); } - const rv = this.rt ? await this.rt(fn) : await fn(); + const rv = this.rt + ? await this.rt(fn, { modelUsage: this.modelUsage }) + : await fn(); if (stream) { if (!this.generateChatStreamResp) { @@ -260,6 +266,10 @@ export class AxBaseAI< const res = respFn(resp, state); res.sessionId = options?.sessionId; + if (res.modelUsage) { + this.modelUsage = res.modelUsage; + } + if (span?.isRecording()) { setResponseAttr(res, span); } @@ -292,6 +302,10 @@ export class AxBaseAI< const res = this.generateChatResp(rv as TChatResponse); res.sessionId = options?.sessionId; + if (res.modelUsage) { + this.modelUsage = res.modelUsage; + } + if (span?.isRecording()) { setResponseAttr(res, span); } @@ -358,13 +372,16 @@ export class AxBaseAI< return res; }; - const resValue = this.rt ? await this.rt(async () => fn()) : await fn(); + const resValue = this.rt + ? await this.rt(fn, { embedModelUsage: this.embedModelUsage }) + : await fn(); const res = this.generateEmbedResp!(resValue as TEmbedResponse); res.sessionId = options?.sessionId; if (span?.isRecording()) { if (res.modelUsage) { + this.embedModelUsage = res.modelUsage; span.setAttributes({ [axSpanAttributes.LLM_USAGE_COMPLETION_TOKENS]: res.modelUsage.completionTokens ?? 0, diff --git a/src/ai/deepseek/api.ts b/src/ai/deepseek/api.ts index f11f24b1..5b21ce8b 100644 --- a/src/ai/deepseek/api.ts +++ b/src/ai/deepseek/api.ts @@ -26,7 +26,7 @@ export const axAIDeepSeekCodeConfig = (): DeepSeekConfig => export interface AxAIDeepSeekArgs { name: 'deepseek'; apiKey: string; - config: Readonly; + config?: Readonly; options?: Readonly; } diff --git a/src/ai/groq/api.ts b/src/ai/groq/api.ts index 2a22a905..b3d49fa8 100644 --- a/src/ai/groq/api.ts +++ b/src/ai/groq/api.ts @@ -1,7 +1,8 @@ +import { AxRateLimiterTokenUsage } from '../../util/ratelimit.js'; import { axBaseAIDefaultConfig } from '../base.js'; import { AxAIOpenAI } from '../openai/api.js'; import type { AxAIOpenAIConfig } from '../openai/types.js'; -import type { AxAIServiceOptions } from '../types.js'; +import type { AxAIServiceOptions, AxRateLimiterFunction } from '../types.js'; import { AxAIGroqModel } from './types.js'; @@ -16,8 +17,8 @@ const axAIGroqDefaultConfig = (): AxAIGroqAIConfig => export interface AxAIGroqArgs { name: 'groq'; apiKey: string; - config: Readonly; - options?: Readonly; + config?: Readonly; + options?: Readonly & { tokensPerMinute?: number }; } export class AxAIGroq extends AxAIOpenAI { @@ -33,10 +34,28 @@ export class AxAIGroq extends AxAIOpenAI { ...axAIGroqDefaultConfig(), ...config }; + + let rateLimiter = options?.rateLimiter; + if (!rateLimiter) { + const tokensPerMin = options?.tokensPerMinute ?? 5800; + const rt = new AxRateLimiterTokenUsage(tokensPerMin, tokensPerMin / 60); + + rateLimiter = async (func, info) => { + const totalTokens = info.modelUsage?.totalTokens || 0; + await rt.acquire(totalTokens); + return func(); + }; + } + + const _options = { + ...options, + rateLimiter, + streamingUsage: false + }; super({ apiKey, config: _config, - options: { ...options, streamingUsage: false }, + options: _options, apiURL: 'https://api.groq.com/openai/v1', modelInfo: [] }); diff --git a/src/ai/mistral/api.ts b/src/ai/mistral/api.ts index d61bbabd..4014aa19 100644 --- a/src/ai/mistral/api.ts +++ b/src/ai/mistral/api.ts @@ -23,7 +23,7 @@ export const axAIMistralBestConfig = (): AxAIOpenAIConfig => export interface AxAIMistralArgs { name: 'mistral'; apiKey: string; - config: Readonly; + config?: Readonly; options?: Readonly; } diff --git a/src/ai/together/api.ts b/src/ai/together/api.ts index d09c235d..229dadde 100644 --- a/src/ai/together/api.ts +++ b/src/ai/together/api.ts @@ -9,14 +9,15 @@ type TogetherAIConfig = AxAIOpenAIConfig; export const axAITogetherDefaultConfig = (): TogetherAIConfig => structuredClone({ - model: 'llama2-70b-4096', + // cspell:disable-next-line + model: 'mistralai/Mixtral-8x7B-Instruct-v0.1', ...axBaseAIDefaultConfig() }); export interface AxAITogetherArgs { name: 'together'; apiKey: string; - config: Readonly; + config?: Readonly; options?: Readonly; } diff --git a/src/ai/types.ts b/src/ai/types.ts index 2e7bb320..8d9dd32a 100644 --- a/src/ai/types.ts +++ b/src/ai/types.ts @@ -147,7 +147,10 @@ export type AxEmbedRequest = { embedModel?: string; }; -export type AxRateLimiterFunction = (func: unknown) => T; +export type AxRateLimiterFunction = ( + reqFunc: () => Promise>, + info: Readonly<{ modelUsage?: AxTokenUsage; embedModelUsage?: AxTokenUsage }> +) => Promise>; export type AxAIPromptConfig = { stream?: boolean; diff --git a/src/dsp/generate.ts b/src/dsp/generate.ts index 75b8ea53..48a0776c 100644 --- a/src/dsp/generate.ts +++ b/src/dsp/generate.ts @@ -166,7 +166,8 @@ export class AxGenerate< ai, modelConfig: mc, stream, - model + model, + rateLimiter }: Readonly< Omit & { ai: AxAIService; stream: boolean } >) { @@ -201,6 +202,7 @@ export class AxGenerate< { ...(sessionId ? { sessionId } : {}), ...(traceId ? { traceId } : {}), + ...(rateLimiter ? { rateLimiter } : {}), stream } ); @@ -214,6 +216,7 @@ export class AxGenerate< traceId, ai, modelConfig, + rateLimiter, stream = false }: Readonly< Omit & { @@ -232,7 +235,8 @@ export class AxGenerate< traceId, ai, stream, - modelConfig + modelConfig, + rateLimiter }); if (res instanceof ReadableStream) { @@ -396,7 +400,8 @@ export class AxGenerate< traceId, modelConfig, stream, - maxSteps: options?.maxSteps + maxSteps: options?.maxSteps, + rateLimiter: options?.rateLimiter }); const lastMemItem = mem.getLast(sessionId); diff --git a/src/dsp/program.ts b/src/dsp/program.ts index b8d4d28e..11fe5c79 100644 --- a/src/dsp/program.ts +++ b/src/dsp/program.ts @@ -3,7 +3,8 @@ import { readFileSync } from 'fs'; import type { AxAIService, AxChatResponse, - AxModelConfig + AxModelConfig, + AxRateLimiterFunction } from '../ai/types.js'; import type { AxAIMemory } from '../mem/types.js'; import type { AxTracer } from '../trace/index.js'; @@ -45,6 +46,7 @@ export type AxProgramForwardOptions = { sessionId?: string; traceId?: string | undefined; tracer?: AxTracer; + rateLimiter?: AxRateLimiterFunction; stream?: boolean; debug?: boolean; }; diff --git a/src/examples/food-search.ts b/src/examples/food-search.ts index e96afca2..6fc42ffc 100644 --- a/src/examples/food-search.ts +++ b/src/examples/food-search.ts @@ -147,6 +147,16 @@ const sig = new AxSignature( `customerQuery:string -> restaurant:string, priceRange:string "use $ signs to indicate price range"` ); +const ai = new AxAI({ + name: 'openai', + apiKey: process.env.OPENAI_APIKEY as string +}); + +// const ai = new AxAI({ +// name: 'groq', +// apiKey: process.env.GROQ_APIKEY as string +// }); + // const ai = new AxAI({ // name: 'cohere', // apiKey: process.env.COHERE_APIKEY as string @@ -157,11 +167,6 @@ const sig = new AxSignature( // apiKey: process.env.GOOGLE_APIKEY as string // }); -const ai = new AxAI({ - name: 'openai', - apiKey: process.env.OPENAI_APIKEY as string -}); - // const ai = new AxAI({ // name: 'anthropic', // apiKey: process.env.ANTHROPIC_APIKEY as string diff --git a/src/index.ts b/src/index.ts index 627c8e05..c9223d8e 100644 --- a/src/index.ts +++ b/src/index.ts @@ -7,6 +7,9 @@ export * from './dsp/index.js'; export * from './docs/index.js'; export * from './trace/index.js'; +// cspell: disable-next-line +export { AxRateLimiterTokenUsage } from './util/ratelimit.js'; + /* Transformers learn, Attention guides, wisdom turns— diff --git a/src/util/ratelimit.ts b/src/util/ratelimit.ts new file mode 100644 index 00000000..18cf23ca --- /dev/null +++ b/src/util/ratelimit.ts @@ -0,0 +1,50 @@ +export class AxRateLimiterTokenUsage { + private maxTokens: number; + private refillRate: number; + private currentTokens: number; + private lastRefillTime: number; + + constructor(maxTokens: number, refillRate: number) { + this.maxTokens = maxTokens; + this.refillRate = refillRate; + this.currentTokens = maxTokens; + this.lastRefillTime = Date.now(); + } + + private refillTokens() { + const now = Date.now(); + const timeElapsed = (now - this.lastRefillTime) / 1000; // Convert ms to seconds + const tokensToAdd = timeElapsed * this.refillRate; + this.currentTokens = Math.min( + this.maxTokens, + this.currentTokens + tokensToAdd + ); + this.lastRefillTime = now; + } + + private async waitUntilTokensAvailable(tokens: number): Promise { + this.refillTokens(); + if (this.currentTokens >= tokens) { + this.currentTokens -= tokens; + return; + } + + await new Promise((resolve) => setTimeout(resolve, 100)); // Wait for 100ms before checking again + return this.waitUntilTokensAvailable(tokens); // Recursive call + } + + public async acquire(tokens: number): Promise { + await this.waitUntilTokensAvailable(tokens); + } +} + +/** + * Example usage of the rate limiter. Limits to 5800 tokens per minute. +const rateLimiter = new AxRateLimiterTokenUsage(5800, 5800 / 60); + +const axRateLimiterFunction = async (func, info) => { + const totalTokens = info.modelUsage?.totalTokens || 0; + await rateLimiter.acquire(totalTokens); + return func(); +}; +**/