Skip to content

Commit

Permalink
fix: added default ratelimiter to groq
Browse files Browse the repository at this point in the history
Added a new request rate controller to prevent hitting rate limits this
can be used with any model or prompt. For groq it's added by default
  • Loading branch information
dosco committed Jun 23, 2024
1 parent ceae901 commit 8d74f9e
Show file tree
Hide file tree
Showing 12 changed files with 127 additions and 22 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@ax-llm/ax",
"version": "9.0.9",
"version": "9.0.10",
"type": "module",
"description": "The best library to work with LLMs",
"typings": "build/module/src/index.d.ts",
Expand Down
23 changes: 20 additions & 3 deletions src/ai/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ import type {
AxEmbedResponse,
AxModelConfig,
AxModelInfo,
AxModelInfoWithProvider
AxModelInfoWithProvider,
AxTokenUsage
} from './types.js';

const colorLog = new ColorLog();
Expand Down Expand Up @@ -74,6 +75,9 @@ export class AxBaseAI<
private fetch?: AxAIServiceOptions['fetch'];
private tracer?: AxAIServiceOptions['tracer'];

private modelUsage?: AxTokenUsage;
private embedModelUsage?: AxTokenUsage;

protected apiURL: string;
protected name: string;
protected headers: Record<string, string>;
Expand Down Expand Up @@ -247,7 +251,9 @@ export class AxBaseAI<
logChatRequest(req);
}

const rv = this.rt ? await this.rt(fn) : await fn();
const rv = this.rt
? await this.rt(fn, { modelUsage: this.modelUsage })
: await fn();

if (stream) {
if (!this.generateChatStreamResp) {
Expand All @@ -260,6 +266,10 @@ export class AxBaseAI<
const res = respFn(resp, state);
res.sessionId = options?.sessionId;

if (res.modelUsage) {
this.modelUsage = res.modelUsage;
}

if (span?.isRecording()) {
setResponseAttr(res, span);
}
Expand Down Expand Up @@ -292,6 +302,10 @@ export class AxBaseAI<
const res = this.generateChatResp(rv as TChatResponse);
res.sessionId = options?.sessionId;

if (res.modelUsage) {
this.modelUsage = res.modelUsage;
}

if (span?.isRecording()) {
setResponseAttr(res, span);
}
Expand Down Expand Up @@ -358,13 +372,16 @@ export class AxBaseAI<
return res;
};

const resValue = this.rt ? await this.rt(async () => fn()) : await fn();
const resValue = this.rt
? await this.rt(fn, { embedModelUsage: this.embedModelUsage })
: await fn();
const res = this.generateEmbedResp!(resValue as TEmbedResponse);

res.sessionId = options?.sessionId;

if (span?.isRecording()) {
if (res.modelUsage) {
this.embedModelUsage = res.modelUsage;
span.setAttributes({
[axSpanAttributes.LLM_USAGE_COMPLETION_TOKENS]:
res.modelUsage.completionTokens ?? 0,
Expand Down
2 changes: 1 addition & 1 deletion src/ai/deepseek/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ export const axAIDeepSeekCodeConfig = (): DeepSeekConfig =>
export interface AxAIDeepSeekArgs {
name: 'deepseek';
apiKey: string;
config: Readonly<DeepSeekConfig>;
config?: Readonly<DeepSeekConfig>;
options?: Readonly<AxAIServiceOptions>;
}

Expand Down
27 changes: 23 additions & 4 deletions src/ai/groq/api.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { AxRateLimiterTokenUsage } from '../../util/ratelimit.js';
import { axBaseAIDefaultConfig } from '../base.js';
import { AxAIOpenAI } from '../openai/api.js';
import type { AxAIOpenAIConfig } from '../openai/types.js';
import type { AxAIServiceOptions } from '../types.js';
import type { AxAIServiceOptions, AxRateLimiterFunction } from '../types.js';

import { AxAIGroqModel } from './types.js';

Expand All @@ -16,8 +17,8 @@ const axAIGroqDefaultConfig = (): AxAIGroqAIConfig =>
export interface AxAIGroqArgs {
name: 'groq';
apiKey: string;
config: Readonly<AxAIGroqAIConfig>;
options?: Readonly<AxAIServiceOptions>;
config?: Readonly<AxAIGroqAIConfig>;
options?: Readonly<AxAIServiceOptions> & { tokensPerMinute?: number };
}

export class AxAIGroq extends AxAIOpenAI {
Expand All @@ -33,10 +34,28 @@ export class AxAIGroq extends AxAIOpenAI {
...axAIGroqDefaultConfig(),
...config
};

let rateLimiter = options?.rateLimiter;
if (!rateLimiter) {
const tokensPerMin = options?.tokensPerMinute ?? 5800;
const rt = new AxRateLimiterTokenUsage(tokensPerMin, tokensPerMin / 60);

rateLimiter = async (func, info) => {
const totalTokens = info.modelUsage?.totalTokens || 0;
await rt.acquire(totalTokens);
return func();
};
}

const _options = {
...options,
rateLimiter,
streamingUsage: false
};
super({
apiKey,
config: _config,
options: { ...options, streamingUsage: false },
options: _options,
apiURL: 'https://api.groq.com/openai/v1',
modelInfo: []
});
Expand Down
2 changes: 1 addition & 1 deletion src/ai/mistral/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ export const axAIMistralBestConfig = (): AxAIOpenAIConfig =>
export interface AxAIMistralArgs {
name: 'mistral';
apiKey: string;
config: Readonly<MistralConfig>;
config?: Readonly<MistralConfig>;
options?: Readonly<AxAIServiceOptions>;
}

Expand Down
5 changes: 3 additions & 2 deletions src/ai/together/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@ type TogetherAIConfig = AxAIOpenAIConfig;

export const axAITogetherDefaultConfig = (): TogetherAIConfig =>
structuredClone({
model: 'llama2-70b-4096',
// cspell:disable-next-line
model: 'mistralai/Mixtral-8x7B-Instruct-v0.1',
...axBaseAIDefaultConfig()
});

export interface AxAITogetherArgs {
name: 'together';
apiKey: string;
config: Readonly<TogetherAIConfig>;
config?: Readonly<TogetherAIConfig>;
options?: Readonly<AxAIServiceOptions>;
}

Expand Down
5 changes: 4 additions & 1 deletion src/ai/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ export type AxEmbedRequest = {
embedModel?: string;
};

export type AxRateLimiterFunction = <T>(func: unknown) => T;
export type AxRateLimiterFunction = <T = unknown>(
reqFunc: () => Promise<T | ReadableStream<T>>,
info: Readonly<{ modelUsage?: AxTokenUsage; embedModelUsage?: AxTokenUsage }>
) => Promise<T | ReadableStream<T>>;

export type AxAIPromptConfig = {
stream?: boolean;
Expand Down
11 changes: 8 additions & 3 deletions src/dsp/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,8 @@ export class AxGenerate<
ai,
modelConfig: mc,
stream,
model
model,
rateLimiter
}: Readonly<
Omit<AxProgramForwardOptions, 'ai'> & { ai: AxAIService; stream: boolean }
>) {
Expand Down Expand Up @@ -201,6 +202,7 @@ export class AxGenerate<
{
...(sessionId ? { sessionId } : {}),
...(traceId ? { traceId } : {}),
...(rateLimiter ? { rateLimiter } : {}),
stream
}
);
Expand All @@ -214,6 +216,7 @@ export class AxGenerate<
traceId,
ai,
modelConfig,
rateLimiter,
stream = false
}: Readonly<
Omit<AxProgramForwardOptions, 'ai' | 'mem'> & {
Expand All @@ -232,7 +235,8 @@ export class AxGenerate<
traceId,
ai,
stream,
modelConfig
modelConfig,
rateLimiter
});

if (res instanceof ReadableStream) {
Expand Down Expand Up @@ -396,7 +400,8 @@ export class AxGenerate<
traceId,
modelConfig,
stream,
maxSteps: options?.maxSteps
maxSteps: options?.maxSteps,
rateLimiter: options?.rateLimiter
});

const lastMemItem = mem.getLast(sessionId);
Expand Down
4 changes: 3 additions & 1 deletion src/dsp/program.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ import { readFileSync } from 'fs';
import type {
AxAIService,
AxChatResponse,
AxModelConfig
AxModelConfig,
AxRateLimiterFunction
} from '../ai/types.js';
import type { AxAIMemory } from '../mem/types.js';
import type { AxTracer } from '../trace/index.js';
Expand Down Expand Up @@ -45,6 +46,7 @@ export type AxProgramForwardOptions = {
sessionId?: string;
traceId?: string | undefined;
tracer?: AxTracer;
rateLimiter?: AxRateLimiterFunction;
stream?: boolean;
debug?: boolean;
};
Expand Down
15 changes: 10 additions & 5 deletions src/examples/food-search.ts
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,16 @@ const sig = new AxSignature(
`customerQuery:string -> restaurant:string, priceRange:string "use $ signs to indicate price range"`
);

const ai = new AxAI({
name: 'openai',
apiKey: process.env.OPENAI_APIKEY as string
});

// const ai = new AxAI({
// name: 'groq',
// apiKey: process.env.GROQ_APIKEY as string
// });

// const ai = new AxAI({
// name: 'cohere',
// apiKey: process.env.COHERE_APIKEY as string
Expand All @@ -157,11 +167,6 @@ const sig = new AxSignature(
// apiKey: process.env.GOOGLE_APIKEY as string
// });

const ai = new AxAI({
name: 'openai',
apiKey: process.env.OPENAI_APIKEY as string
});

// const ai = new AxAI({
// name: 'anthropic',
// apiKey: process.env.ANTHROPIC_APIKEY as string
Expand Down
3 changes: 3 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ export * from './dsp/index.js';
export * from './docs/index.js';
export * from './trace/index.js';

// cspell: disable-next-line
export { AxRateLimiterTokenUsage } from './util/ratelimit.js';

/*
Transformers learn,
Attention guides, wisdom turns—
Expand Down
50 changes: 50 additions & 0 deletions src/util/ratelimit.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
export class AxRateLimiterTokenUsage {
private maxTokens: number;
private refillRate: number;
private currentTokens: number;
private lastRefillTime: number;

constructor(maxTokens: number, refillRate: number) {
this.maxTokens = maxTokens;
this.refillRate = refillRate;
this.currentTokens = maxTokens;
this.lastRefillTime = Date.now();
}

private refillTokens() {
const now = Date.now();
const timeElapsed = (now - this.lastRefillTime) / 1000; // Convert ms to seconds
const tokensToAdd = timeElapsed * this.refillRate;
this.currentTokens = Math.min(
this.maxTokens,
this.currentTokens + tokensToAdd
);
this.lastRefillTime = now;
}

private async waitUntilTokensAvailable(tokens: number): Promise<void> {
this.refillTokens();
if (this.currentTokens >= tokens) {
this.currentTokens -= tokens;
return;
}

await new Promise((resolve) => setTimeout(resolve, 100)); // Wait for 100ms before checking again
return this.waitUntilTokensAvailable(tokens); // Recursive call
}

public async acquire(tokens: number): Promise<void> {
await this.waitUntilTokensAvailable(tokens);
}
}

/**
* Example usage of the rate limiter. Limits to 5800 tokens per minute.
const rateLimiter = new AxRateLimiterTokenUsage(5800, 5800 / 60);
const axRateLimiterFunction = async (func, info) => {
const totalTokens = info.modelUsage?.totalTokens || 0;
await rateLimiter.acquire(totalTokens);
return func();
};
**/

0 comments on commit 8d74f9e

Please sign in to comment.