Skip to content

Commit

Permalink
Pass more config options through to LLM providers
Browse files Browse the repository at this point in the history
  • Loading branch information
cephalization committed Sep 8, 2024
1 parent 2de6b85 commit de59766
Showing 1 changed file with 36 additions and 33 deletions.
69 changes: 36 additions & 33 deletions packages/cannoli-core/src/providers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ export const GenericModelConfigSchema = z.object({
microstat: z.boolean().optional(),
microstat_eta: z.coerce.number().optional(),
microstat_tau: z.coerce.number().optional(),
max_tokens: z.coerce.number().optional(),
user: z.string().optional(),
num_ctx: z.coerce.number().optional(),
num_gqa: z.coerce.number().optional(),
num_gpu: z.coerce.number().optional(),
Expand Down Expand Up @@ -80,35 +82,6 @@ export type GenericCompletionResponse = {
function_call?: ChatCompletionAssistantMessageParam.FunctionCall;
};

// @deprecated
export const makeSampleConfig = (): GenericModelConfig => ({
apiKey: undefined,
baseURL: undefined,
model: "",
frequency_penalty: undefined,
presence_penalty: undefined,
stop: undefined,
function_call: undefined,
functions: undefined,
temperature: undefined,
top_p: undefined,
role: "user" || "assistant" || "system",
provider: undefined,
microstat: undefined,
microstat_eta: undefined,
microstat_tau: undefined,
num_ctx: undefined,
num_gqa: undefined,
num_gpu: undefined,
num_thread: undefined,
repeat_last_n: undefined,
repeat_penalty: undefined,
seed: undefined,
tfs_z: undefined,
num_predict: undefined,
top_k: undefined,
});

export type GetDefaultsByProvider = (provider: SupportedProviders) => GenericModelConfig | undefined;

export type LangchainMessages = ReturnType<typeof LLMProvider.convertMessages>;
Expand Down Expand Up @@ -163,10 +136,6 @@ export class LLMProvider {
return defaults;
}

getSampleConfig() {
return makeSampleConfig();
}

getMergedConfig = (args?: Partial<{
configOverrides: GenericModelConfig;
provider: SupportedProviders;
Expand Down Expand Up @@ -202,6 +171,15 @@ export class LLMProvider {
apiKey: config.apiKey,
model: config.model,
temperature: config.temperature,
topP: config.top_p,
frequencyPenalty: config.frequency_penalty,
presencePenalty: config.presence_penalty,
stop: config.stop?.split(","),
maxTokens: config.max_tokens,
user: config.user,
// beta openai feature
// @ts-expect-error
seed: config.seed,
maxRetries: 3,
configuration: {
baseURL: url,
Expand All @@ -218,6 +196,15 @@ export class LLMProvider {
azureOpenAIApiInstanceName: config.azureOpenAIApiInstanceName,
azureOpenAIApiVersion: config.azureOpenAIApiVersion,
azureOpenAIBasePath: url,
user: config.user,
maxTokens: config.max_tokens,
// beta openai feature
// @ts-expect-error
seed: config.seed,
topP: config.top_p,
frequencyPenalty: config.frequency_penalty,
presencePenalty: config.presence_penalty,
stop: config.stop?.split(","),
maxRetries: 3,
configuration: {
baseURL: url,
Expand All @@ -230,20 +217,31 @@ export class LLMProvider {
baseUrl: url,
model: config.model,
temperature: config.temperature,
topP: config.top_p,
frequencyPenalty: config.frequency_penalty,
presencePenalty: config.presence_penalty,
stop: config.stop?.split(","),
});
}

return new ChatOllama({
baseUrl: url,
model: config.model,
temperature: config.temperature,
topP: config.top_p,
frequencyPenalty: config.frequency_penalty,
presencePenalty: config.presence_penalty,
stop: config.stop?.split(","),
});
case "gemini":
return new ChatGoogleGenerativeAI({
maxRetries: 3,
model: config.model,
apiKey: config.apiKey,
temperature: config.temperature,
maxOutputTokens: config.max_tokens,
topP: config.top_p,
stopSequences: config.stop?.split(","),
});
case "anthropic":
return new ChatAnthropic({
Expand All @@ -252,6 +250,10 @@ export class LLMProvider {
temperature: config.temperature,
maxRetries: 0,
anthropicApiUrl: url,
maxTokens: config.max_tokens,
topP: config.top_p,
stopSequences: config.stop?.split(","),
topK: config.top_k,
clientOptions: {
defaultHeaders: {
"anthropic-dangerous-direct-browser-access": "true",
Expand All @@ -263,6 +265,7 @@ export class LLMProvider {
apiKey: config.apiKey,
model: config.model,
temperature: config.temperature,
stopSequences: config.stop?.split(","),
maxRetries: 3,
});
default:
Expand Down

0 comments on commit de59766

Please sign in to comment.