Skip to content

Commit

Permalink
Merge branch 'main' of github.com:DeabLabs/cannoli
Browse files Browse the repository at this point in the history
  • Loading branch information
blindmansion committed Sep 8, 2024
2 parents 63a40e1 + de59766 commit 366897d
Showing 1 changed file with 36 additions and 33 deletions.
69 changes: 36 additions & 33 deletions packages/cannoli-core/src/providers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ export const GenericModelConfigSchema = z.object({
microstat: z.boolean().optional(),
microstat_eta: z.coerce.number().optional(),
microstat_tau: z.coerce.number().optional(),
max_tokens: z.coerce.number().optional(),
user: z.string().optional(),
num_ctx: z.coerce.number().optional(),
num_gqa: z.coerce.number().optional(),
num_gpu: z.coerce.number().optional(),
Expand Down Expand Up @@ -81,35 +83,6 @@ export type GenericCompletionResponse = {
function_call?: ChatCompletionAssistantMessageParam.FunctionCall;
};

// @deprecated
export const makeSampleConfig = (): GenericModelConfig => ({
apiKey: undefined,
baseURL: undefined,
model: "",
frequency_penalty: undefined,
presence_penalty: undefined,
stop: undefined,
function_call: undefined,
functions: undefined,
temperature: undefined,
top_p: undefined,
role: "user" || "assistant" || "system",
provider: undefined,
microstat: undefined,
microstat_eta: undefined,
microstat_tau: undefined,
num_ctx: undefined,
num_gqa: undefined,
num_gpu: undefined,
num_thread: undefined,
repeat_last_n: undefined,
repeat_penalty: undefined,
seed: undefined,
tfs_z: undefined,
num_predict: undefined,
top_k: undefined,
});

export type GetDefaultsByProvider = (provider: SupportedProviders) => GenericModelConfig | undefined;

export type LangchainMessages = ReturnType<typeof LLMProvider.convertMessages>;
Expand Down Expand Up @@ -164,10 +137,6 @@ export class LLMProvider {
return defaults;
}

getSampleConfig() {
return makeSampleConfig();
}

getMergedConfig = (args?: Partial<{
configOverrides: GenericModelConfig;
provider: SupportedProviders;
Expand Down Expand Up @@ -203,6 +172,15 @@ export class LLMProvider {
apiKey: config.apiKey,
model: config.model,
temperature: config.temperature,
topP: config.top_p,
frequencyPenalty: config.frequency_penalty,
presencePenalty: config.presence_penalty,
stop: config.stop?.split(","),
maxTokens: config.max_tokens,
user: config.user,
// beta openai feature
// @ts-expect-error
seed: config.seed,
maxRetries: 3,
configuration: {
baseURL: url,
Expand All @@ -219,6 +197,15 @@ export class LLMProvider {
azureOpenAIApiInstanceName: config.azureOpenAIApiInstanceName,
azureOpenAIApiVersion: config.azureOpenAIApiVersion,
azureOpenAIBasePath: url,
user: config.user,
maxTokens: config.max_tokens,
// beta openai feature
// @ts-expect-error
seed: config.seed,
topP: config.top_p,
frequencyPenalty: config.frequency_penalty,
presencePenalty: config.presence_penalty,
stop: config.stop?.split(","),
maxRetries: 3,
configuration: {
baseURL: url,
Expand All @@ -231,20 +218,31 @@ export class LLMProvider {
baseUrl: url,
model: config.model,
temperature: config.temperature,
topP: config.top_p,
frequencyPenalty: config.frequency_penalty,
presencePenalty: config.presence_penalty,
stop: config.stop?.split(","),
});
}

return new ChatOllama({
baseUrl: url,
model: config.model,
temperature: config.temperature,
topP: config.top_p,
frequencyPenalty: config.frequency_penalty,
presencePenalty: config.presence_penalty,
stop: config.stop?.split(","),
});
case "gemini":
return new ChatGoogleGenerativeAI({
maxRetries: 3,
model: config.model,
apiKey: config.apiKey,
temperature: config.temperature,
maxOutputTokens: config.max_tokens,
topP: config.top_p,
stopSequences: config.stop?.split(","),
});
case "anthropic":
return new ChatAnthropic({
Expand All @@ -253,6 +251,10 @@ export class LLMProvider {
temperature: config.temperature,
maxRetries: 0,
anthropicApiUrl: url,
maxTokens: config.max_tokens,
topP: config.top_p,
stopSequences: config.stop?.split(","),
topK: config.top_k,
clientOptions: {
defaultHeaders: {
"anthropic-dangerous-direct-browser-access": "true",
Expand All @@ -264,6 +266,7 @@ export class LLMProvider {
apiKey: config.apiKey,
model: config.model,
temperature: config.temperature,
stopSequences: config.stop?.split(","),
maxRetries: 3,
});
default:
Expand Down

0 comments on commit 366897d

Please sign in to comment.