From de597665ec54082f7270eec9b7b9ae14e501a078 Mon Sep 17 00:00:00 2001 From: Anthony Powell Date: Sun, 8 Sep 2024 14:06:10 -0400 Subject: [PATCH] Pass more config options through to LLM providers --- packages/cannoli-core/src/providers.ts | 69 ++++++++++++++------------ 1 file changed, 36 insertions(+), 33 deletions(-) diff --git a/packages/cannoli-core/src/providers.ts b/packages/cannoli-core/src/providers.ts index 9adc9c8..f455804 100644 --- a/packages/cannoli-core/src/providers.ts +++ b/packages/cannoli-core/src/providers.ts @@ -46,6 +46,8 @@ export const GenericModelConfigSchema = z.object({ microstat: z.boolean().optional(), microstat_eta: z.coerce.number().optional(), microstat_tau: z.coerce.number().optional(), + max_tokens: z.coerce.number().optional(), + user: z.string().optional(), num_ctx: z.coerce.number().optional(), num_gqa: z.coerce.number().optional(), num_gpu: z.coerce.number().optional(), @@ -80,35 +82,6 @@ export type GenericCompletionResponse = { function_call?: ChatCompletionAssistantMessageParam.FunctionCall; }; -// @deprecated -export const makeSampleConfig = (): GenericModelConfig => ({ - apiKey: undefined, - baseURL: undefined, - model: "", - frequency_penalty: undefined, - presence_penalty: undefined, - stop: undefined, - function_call: undefined, - functions: undefined, - temperature: undefined, - top_p: undefined, - role: "user" || "assistant" || "system", - provider: undefined, - microstat: undefined, - microstat_eta: undefined, - microstat_tau: undefined, - num_ctx: undefined, - num_gqa: undefined, - num_gpu: undefined, - num_thread: undefined, - repeat_last_n: undefined, - repeat_penalty: undefined, - seed: undefined, - tfs_z: undefined, - num_predict: undefined, - top_k: undefined, -}); - export type GetDefaultsByProvider = (provider: SupportedProviders) => GenericModelConfig | undefined; export type LangchainMessages = ReturnType; @@ -163,10 +136,6 @@ export class LLMProvider { return defaults; } - getSampleConfig() { - return makeSampleConfig(); - } - getMergedConfig = (args?: Partial<{ configOverrides: GenericModelConfig; provider: SupportedProviders; @@ -202,6 +171,15 @@ export class LLMProvider { apiKey: config.apiKey, model: config.model, temperature: config.temperature, + topP: config.top_p, + frequencyPenalty: config.frequency_penalty, + presencePenalty: config.presence_penalty, + stop: config.stop?.split(","), + maxTokens: config.max_tokens, + user: config.user, + // beta openai feature + // @ts-expect-error + seed: config.seed, maxRetries: 3, configuration: { baseURL: url, @@ -218,6 +196,15 @@ export class LLMProvider { azureOpenAIApiInstanceName: config.azureOpenAIApiInstanceName, azureOpenAIApiVersion: config.azureOpenAIApiVersion, azureOpenAIBasePath: url, + user: config.user, + maxTokens: config.max_tokens, + // beta openai feature + // @ts-expect-error + seed: config.seed, + topP: config.top_p, + frequencyPenalty: config.frequency_penalty, + presencePenalty: config.presence_penalty, + stop: config.stop?.split(","), maxRetries: 3, configuration: { baseURL: url, @@ -230,6 +217,10 @@ export class LLMProvider { baseUrl: url, model: config.model, temperature: config.temperature, + topP: config.top_p, + frequencyPenalty: config.frequency_penalty, + presencePenalty: config.presence_penalty, + stop: config.stop?.split(","), }); } @@ -237,6 +228,10 @@ export class LLMProvider { baseUrl: url, model: config.model, temperature: config.temperature, + topP: config.top_p, + frequencyPenalty: config.frequency_penalty, + presencePenalty: config.presence_penalty, + stop: config.stop?.split(","), }); case "gemini": return new ChatGoogleGenerativeAI({ @@ -244,6 +239,9 @@ export class LLMProvider { model: config.model, apiKey: config.apiKey, temperature: config.temperature, + maxOutputTokens: config.max_tokens, + topP: config.top_p, + stopSequences: config.stop?.split(","), }); case "anthropic": return new ChatAnthropic({ @@ -252,6 +250,10 @@ export class LLMProvider { temperature: config.temperature, maxRetries: 0, anthropicApiUrl: url, + maxTokens: config.max_tokens, + topP: config.top_p, + stopSequences: config.stop?.split(","), + topK: config.top_k, clientOptions: { defaultHeaders: { "anthropic-dangerous-direct-browser-access": "true", @@ -263,6 +265,7 @@ export class LLMProvider { apiKey: config.apiKey, model: config.model, temperature: config.temperature, + stopSequences: config.stop?.split(","), maxRetries: 3, }); default: