From 273e8fd8b49f9a646c7b4af5ad9ca549953f4121 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Aur=C3=A9lien=20Sudre?= Date: Tue, 12 Nov 2024 23:01:38 +0100 Subject: [PATCH] Add models list --- app/components/chat/BaseChat.tsx | 100 +++++++++---- app/components/chat/Chat.client.tsx | 9 +- app/entry.server.tsx | 3 - app/lib/.server/llm/api-key.ts | 22 +-- app/lib/.server/llm/model.ts | 218 ++++++++++++++++++++++++++-- app/lib/.server/llm/stream-text.ts | 57 ++++---- app/routes/api.chat.ts | 16 +- app/routes/api.enhancer.ts | 4 +- app/routes/api.models.ts | 27 +++- app/utils/constants.ts | 133 ++++++++--------- worker-configuration.d.ts | 1 + 11 files changed, 426 insertions(+), 164 deletions(-) diff --git a/app/components/chat/BaseChat.tsx b/app/components/chat/BaseChat.tsx index 5fd4cd033..9a30c9d17 100644 --- a/app/components/chat/BaseChat.tsx +++ b/app/components/chat/BaseChat.tsx @@ -7,7 +7,7 @@ import { Menu } from '~/components/sidebar/Menu.client'; import { IconButton } from '~/components/ui/IconButton'; import { Workbench } from '~/components/workbench/Workbench.client'; import { classNames } from '~/utils/classNames'; -import { MODEL_LIST, DEFAULT_PROVIDER } from '~/utils/constants'; +import { DEFAULT_PROVIDER, staticProviders } from '~/utils/constants'; import { Messages } from './Messages.client'; import { SendButton } from './SendButton.client'; import { useState } from 'react'; @@ -24,17 +24,57 @@ const EXAMPLE_PROMPTS = [ { text: 'How do I center a div?' }, ]; -const providerList = [...new Set(MODEL_LIST.map((model) => model.provider))] +const providerList = staticProviders; + +const ModelSelector = ({ model, setModel, modelList, providerList, provider, setProvider, setModelList }) => { + const [apiKeys, setApiKeys] = useState>({}); + + const refreshModels = () => { + console.log('Refreshing models...'); + }; + + useEffect(() => { + const storedApiKeys = Cookies.get('apiKeys'); + if (storedApiKeys) { + setApiKeys(JSON.parse(storedApiKeys)); + } + }, []); + + useEffect(() => { + const firstModel = [...modelList].find((m) => m.provider == selectedProvider); + setModel(firstModel ? firstModel : null); + }, [provider]); -const ModelSelector = ({ model, setModel, modelList, providerList, provider, setProvider }) => { return (
- + +
+
); }; @@ -78,8 +120,8 @@ interface BaseChatProps { enhancingPrompt?: boolean; promptEnhanced?: boolean; input?: string; - model: string; - setModel: (model: string) => void; + model: ModelInfo; + setModel: (model: ModelInfo) => void; handleStop?: () => void; sendMessage?: (event: React.UIEvent, messageInput?: string) => void; handleInputChange?: (event: React.ChangeEvent) => void; @@ -111,6 +153,7 @@ export const BaseChat = React.forwardRef( const TEXTAREA_MAX_HEIGHT = chatStarted ? 400 : 200; const [provider, setProvider] = useState(DEFAULT_PROVIDER); const [apiKeys, setApiKeys] = useState>({}); + const [modelList, setModelList] = useState([]); // État pour la liste des modèles useEffect(() => { // Load API keys from cookies on component mount @@ -138,7 +181,7 @@ export const BaseChat = React.forwardRef( expires: 30, // 30 days secure: true, // Only send over HTTPS sameSite: 'strict', // Protect against CSRF - path: '/' // Accessible across the site + path: '/', // Accessible across the site }); } catch (error) { console.error('Error saving API keys to cookies:', error); @@ -192,10 +235,11 @@ export const BaseChat = React.forwardRef( (
{input.length > 3 ? (
- Use Shift + Return for a new line + Use Shift +{' '} + Return for + a new line
) : null} @@ -309,4 +355,4 @@ export const BaseChat = React.forwardRef( ); }, -); \ No newline at end of file +); diff --git a/app/components/chat/Chat.client.tsx b/app/components/chat/Chat.client.tsx index e1adca021..c200fc486 100644 --- a/app/components/chat/Chat.client.tsx +++ b/app/components/chat/Chat.client.tsx @@ -74,7 +74,7 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp const textareaRef = useRef(null); const [chatStarted, setChatStarted] = useState(initialMessages.length > 0); - const [model, setModel] = useState(DEFAULT_MODEL); + const [model, setModel] = useState(null); const { showChat } = useStore(chatStore); @@ -85,7 +85,8 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp const { messages, isLoading, input, handleInputChange, setInput, stop, append } = useChat({ api: '/api/chat', body: { - apiKeys + model, + apiKeys, }, onError: (error) => { logger.error('Request failed\n\n', error); @@ -188,7 +189,7 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp * manually reset the input and we'd have to manually pass in file attachments. However, those * aren't relevant here. */ - append({ role: 'user', content: `[Model: ${model}]\n\n${diff}\n\n${_input}` }); + append({ role: 'user', content: `${diff}\n\n${_input}` }); /** * After sending a new message we reset all modifications since the model @@ -196,7 +197,7 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp */ workbenchStore.resetAllFileModifications(); } else { - append({ role: 'user', content: `[Model: ${model}]\n\n${_input}` }); + append({ role: 'user', content: `${_input}` }); } setInput(''); diff --git a/app/entry.server.tsx b/app/entry.server.tsx index be2b42bf0..4baf07001 100644 --- a/app/entry.server.tsx +++ b/app/entry.server.tsx @@ -5,7 +5,6 @@ import { renderToReadableStream } from 'react-dom/server'; import { renderHeadToString } from 'remix-island'; import { Head } from './root'; import { themeStore } from '~/lib/stores/theme'; -import { initializeModelList } from '~/utils/constants'; export default async function handleRequest( request: Request, @@ -14,8 +13,6 @@ export default async function handleRequest( remixContext: EntryContext, _loadContext: AppLoadContext, ) { - await initializeModelList(); - const readable = await renderToReadableStream(, { signal: request.signal, onError(error: unknown) { diff --git a/app/lib/.server/llm/api-key.ts b/app/lib/.server/llm/api-key.ts index c4ab0eada..493f4781e 100644 --- a/app/lib/.server/llm/api-key.ts +++ b/app/lib/.server/llm/api-key.ts @@ -26,27 +26,27 @@ export function getAPIKey(cloudflareEnv: Env, provider: string, userApiKeys?: Re case 'OpenRouter': return env.OPEN_ROUTER_API_KEY || cloudflareEnv.OPEN_ROUTER_API_KEY; case 'Deepseek': - return env.DEEPSEEK_API_KEY || cloudflareEnv.DEEPSEEK_API_KEY + return env.DEEPSEEK_API_KEY || cloudflareEnv.DEEPSEEK_API_KEY; case 'Mistral': - return env.MISTRAL_API_KEY || cloudflareEnv.MISTRAL_API_KEY; - case "OpenAILike": + return env.MISTRAL_API_KEY || cloudflareEnv.MISTRAL_API_KEY; + case 'OpenAILike': return env.OPENAI_LIKE_API_KEY || cloudflareEnv.OPENAI_LIKE_API_KEY; default: - return ""; + return ''; } } export function getBaseURL(cloudflareEnv: Env, provider: string) { switch (provider) { case 'OpenAILike': - return env.OPENAI_LIKE_API_BASE_URL || cloudflareEnv.OPENAI_LIKE_API_BASE_URL; + return env.OPENAI_LIKE_API_BASE_URL || cloudflareEnv.OPENAI_LIKE_API_BASE_URL || 'http://localhost:4000'; case 'Ollama': - let baseUrl = env.OLLAMA_API_BASE_URL || cloudflareEnv.OLLAMA_API_BASE_URL || "http://localhost:11434"; - if (env.RUNNING_IN_DOCKER === 'true') { - baseUrl = baseUrl.replace("localhost", "host.docker.internal"); - } - return baseUrl; + let baseUrl = env.OLLAMA_API_BASE_URL || cloudflareEnv.OLLAMA_API_BASE_URL || 'http://localhost:11434'; + if (env.RUNNING_IN_DOCKER === 'true') { + baseUrl = baseUrl.replace('localhost', 'host.docker.internal'); + } + return baseUrl; default: - return ""; + return ''; } } diff --git a/app/lib/.server/llm/model.ts b/app/lib/.server/llm/model.ts index b56ad1026..232901dfc 100644 --- a/app/lib/.server/llm/model.ts +++ b/app/lib/.server/llm/model.ts @@ -5,9 +5,9 @@ import { createAnthropic } from '@ai-sdk/anthropic'; import { createOpenAI } from '@ai-sdk/openai'; import { createGoogleGenerativeAI } from '@ai-sdk/google'; import { ollama } from 'ollama-ai-provider'; -import { createOpenRouter } from "@openrouter/ai-sdk-provider"; -import { mistral } from '@ai-sdk/mistral'; +import { createOpenRouter } from '@openrouter/ai-sdk-provider'; import { createMistral } from '@ai-sdk/mistral'; +import type { ModelInfo, OllamaApiResponse, OllamaModel } from '~/utils/types'; export function getAnthropicModel(apiKey: string, model: string) { const anthropic = createAnthropic({ @@ -16,7 +16,8 @@ export function getAnthropicModel(apiKey: string, model: string) { return anthropic(model); } -export function getOpenAILikeModel(baseURL:string,apiKey: string, model: string) { + +export function getOpenAILikeModel(baseURL: string, apiKey: string, model: string) { const openai = createOpenAI({ baseURL, apiKey, @@ -24,6 +25,7 @@ export function getOpenAILikeModel(baseURL:string,apiKey: string, model: string) return openai(model); } + export function getOpenAIModel(apiKey: string, model: string) { const openai = createOpenAI({ apiKey, @@ -34,16 +36,14 @@ export function getOpenAIModel(apiKey: string, model: string) { export function getMistralModel(apiKey: string, model: string) { const mistral = createMistral({ - apiKey + apiKey, }); return mistral(model); } export function getGoogleModel(apiKey: string, model: string) { - const google = createGoogleGenerativeAI( - apiKey, - ); + const google = createGoogleGenerativeAI(apiKey); return google(model); } @@ -63,7 +63,7 @@ export function getOllamaModel(baseURL: string, model: string) { return Ollama; } -export function getDeepseekModel(apiKey: string, model: string){ +export function getDeepseekModel(apiKey: string, model: string) { const openai = createOpenAI({ baseURL: 'https://api.deepseek.com/beta', apiKey, @@ -74,13 +74,13 @@ export function getDeepseekModel(apiKey: string, model: string){ export function getOpenRouterModel(apiKey: string, model: string) { const openRouter = createOpenRouter({ - apiKey + apiKey, }); return openRouter.chat(model); } -export function getModel(provider: string, model: string, env: Env, apiKeys?: Record) { +export function getModel(provider: string, model: string, env: Env, apiKeys?: Record): ModelInfo { const apiKey = getAPIKey(env, provider, apiKeys); const baseURL = getBaseURL(env, provider); @@ -94,14 +94,204 @@ export function getModel(provider: string, model: string, env: Env, apiKeys?: Re case 'OpenRouter': return getOpenRouterModel(apiKey, model); case 'Google': - return getGoogleModel(apiKey, model) + return getGoogleModel(apiKey, model); case 'OpenAILike': - return getOpenAILikeModel(baseURL,apiKey, model); + return getOpenAILikeModel(baseURL, apiKey, model); case 'Deepseek': - return getDeepseekModel(apiKey, model) + return getDeepseekModel(apiKey, model); case 'Mistral': - return getMistralModel(apiKey, model); + return getMistralModel(apiKey, model); default: return getOllamaModel(baseURL, model); } } + +export async function getAnthropicModels(apiKey: string): Promise { + const anthropic = createAnthropic({ + apiKey, + }); + + return await anthropic.listModels(); +} + +export async function getOpenAIModels(apiKey: string): Promise { + const openai = createOpenAI({ + apiKey, + }); + + return await openai.listModels(); +} + +export async function getMistralModels(apiKey: string): Promise { + try { + const response = await fetch(`https://api.mistral.ai/v1/models`, { + method: 'GET', + headers: { + Authorization: `Bearer ${apiKey}`, + 'Content-Type': 'application/json', + Accept: 'application/json', + }, + }); + if (!response.ok) { + const body = await response.text(); + throw new Error(`Erreur HTTP ${response.status}: ${body}`); + } + + const data = await response.json(); + + return data.data.map((model: any) => ({ + id: model.id, + name: model.id, + provider: 'Mistral', + })); + } catch (error) { + console.error('Erreur lors de la communication avec Mistral:', error); + throw new Error('Impossible de récupérer la liste des modèles Mistral. Vérifiez votre configuration.'); + } +} + +export async function getGoogleModels(apiKey: string): Promise { + const google = createGoogleGenerativeAI(apiKey); + + return await google.listModels(); +} + +export async function getGroqModels(apiKey: string): Promise { + const openai = createOpenAI({ + baseURL: 'https://api.groq.com/openai/v1', + apiKey, + }); + + return await openai.listModels(); +} + +export async function getOllamaModels(baseURL: string): Promise { + try { + const response = await fetch(`${baseURL}/api/tags`, { + method: 'GET', + 'Content-Type': 'application/json', + Accept: 'application/json', + }); + if (!response.ok) { + const body = await response.text(); + throw new Error(`Erreur HTTP ${response.status}: ${body}`); + } + + const data = await response.json(); + + return data.models.map((model: any) => ({ + id: model.name, + name: model.name, + provider: 'Ollama', + })); + } catch (error) { + console.error('Erreur lors de la récupération des modèles Ollama:', error); + throw new Error('Impossible de récupérer la liste des modèles Ollama. Vérifiez votre configuration.'); + } +} + +export async function getOpenAILikeModels(baseURL: string, apiKey: string): Promise { + try { + const response = await fetch(`${baseURL}/v1/models`, { + method: 'GET', + headers: { + Authorization: `Bearer ${apiKey}`, + 'Content-Type': 'application/json', + Accept: 'application/json', + }, + }); + if (!response.ok) { + const body = await response.text(); + throw new Error(`Erreur HTTP ${response.status}: ${body}`); + } + + const data = await response.json(); + + return data.data.map((model: any) => ({ + id: model.id, + name: model.id, + provider: 'OpenAILike', + })); + } catch (error) { + console.error('Erreur lors de la récupération des modèles OpenAILike:', error); + throw new Error('Impossible de récupérer la liste des modèles OpenAILike. Vérifiez votre configuration.'); + } +} + +export async function getDeepseekModels(apiKey: string): Promise { + const openai = createOpenAI({ + baseURL: 'https://api.deepseek.com/beta', + apiKey, + }); + + return await openai.listModels(); +} + +export async function getOpenRouterModels(apiKey: string): Promise { + try { + const response = await fetch(`https://openrouter.ai/api/v1/models`, { + method: 'GET', + headers: { Authorization: `Bearer ${apiKey}` }, + }); + + if (!response.ok) { + const body = await response.text(); + throw new Error(`Erreur HTTP ${response.status}: ${body}`); + } + + const data = await response.json(); + + console.log('OpenRouter models:', data.data[0]); + + return data.data.map((model: any) => ({ + id: model.id, + name: model.id, + provider: 'OpenRouter', + })); + } catch (error) { + console.error('Erreur lors de la communication avec OpenRouter:', error); + throw new Error("Erreur lors de la communication avec l'IA. Vérifiez votre configuration OpenRouter."); + } +} + +export async function getModels(provider: string, env: Env, apiKeys?: Record): Promise { + const apiKey = getAPIKey(env, provider, apiKeys); + const baseURL = getBaseURL(env, provider); + + console.log(`Fetching models for provider: ${provider}, baseURL: ${baseURL}, key: ${apiKey}`); + + let models: ModelInfo[] = []; + + switch (provider) { + case 'Anthropic': + models = await getAnthropicModels(apiKey); + break; + case 'OpenAI': + models = await getOpenAIModels(apiKey); + break; + case 'Groq': + models = await getGroqModels(apiKey); + break; + case 'OpenRouter': + models = await getOpenRouterModels(apiKey); + break; + case 'Google': + models = await getGoogleModels(apiKey); + break; + case 'OpenAILike': + models = await getOpenAILikeModels(baseURL, apiKey); // Assuming OpenAILike uses the same API as OpenAI + break; + case 'Deepseek': + models = await getDeepseekModels(apiKey); + break; + case 'Mistral': + models = await getMistralModels(apiKey); + break; + default: + models = await getOllamaModels(baseURL); + } + + console.log(`Fetched models for provider ${provider}:`, models); + + return models; +} diff --git a/app/lib/.server/llm/stream-text.ts b/app/lib/.server/llm/stream-text.ts index ba951038a..7a364ffc6 100644 --- a/app/lib/.server/llm/stream-text.ts +++ b/app/lib/.server/llm/stream-text.ts @@ -5,6 +5,7 @@ import { getModel } from '~/lib/.server/llm/model'; import { MAX_TOKENS } from './constants'; import { getSystemPrompt } from './prompts'; import { MODEL_LIST, DEFAULT_MODEL, DEFAULT_PROVIDER } from '~/utils/constants'; +import type { ModelInfo } from '~/utils/types'; interface ToolResult { toolCallId: string; @@ -24,48 +25,48 @@ export type Messages = Message[]; export type StreamingOptions = Omit[0], 'model'>; -function extractModelFromMessage(message: Message): { model: string; content: string } { - const modelRegex = /^\[Model: (.*?)\]\n\n/; - const match = message.content.match(modelRegex); +// function extractModelFromMessage(message: Message): { model: string; content: string } { +// const modelRegex = /^\[Model: (.*?)\]\n\n/; +// const match = message.content.match(modelRegex); - if (match) { - const model = match[1]; - const content = message.content.replace(modelRegex, ''); - return { model, content }; - } +// if (match) { +// const model = match[1]; +// const content = message.content.replace(modelRegex, ''); +// return { model, content }; +// } - // Default model if not specified - return { model: DEFAULT_MODEL, content: message.content }; -} +// // Default model if not specified +// return { model: DEFAULT_MODEL, content: message.content }; +// } export function streamText( - messages: Messages, - env: Env, + model: ModelInfo, + messages: Messages, + env: Env, options?: StreamingOptions, - apiKeys?: Record + apiKeys?: Record, ) { - let currentModel = DEFAULT_MODEL; - const processedMessages = messages.map((message) => { - if (message.role === 'user') { - const { model, content } = extractModelFromMessage(message); - if (model && MODEL_LIST.find((m) => m.name === model)) { - currentModel = model; // Update the current model - } - return { ...message, content }; - } - return message; - }); + // const processedMessages = messages.map((message) => { + // if (message.role === 'user') { + // const { model, content } = extractModelFromMessage(message); + // // if (model && MODEL_LIST.find((m) => m.name === model)) { + // // currentModel = model; // Update the current model + // // } + // return { ...message, content }; + // } + // return message; + // }); - const provider = MODEL_LIST.find((model) => model.name === currentModel)?.provider || DEFAULT_PROVIDER; + // const provider = MODEL_LIST.find((model) => model.name === currentModel)?.provider || DEFAULT_PROVIDER; return _streamText({ - model: getModel(provider, currentModel, env, apiKeys), + model: getModel(model.provider, model.name, env, apiKeys), system: getSystemPrompt(), maxTokens: MAX_TOKENS, // headers: { // 'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15', // }, - messages: convertToCoreMessages(processedMessages), + messages: convertToCoreMessages(messages), ...options, }); } diff --git a/app/routes/api.chat.ts b/app/routes/api.chat.ts index 473f8c161..12dea8c68 100644 --- a/app/routes/api.chat.ts +++ b/app/routes/api.chat.ts @@ -5,15 +5,17 @@ import { MAX_RESPONSE_SEGMENTS, MAX_TOKENS } from '~/lib/.server/llm/constants'; import { CONTINUE_PROMPT } from '~/lib/.server/llm/prompts'; import { streamText, type Messages, type StreamingOptions } from '~/lib/.server/llm/stream-text'; import SwitchableStream from '~/lib/.server/llm/switchable-stream'; +import type { ModelInfo } from '~/utils/types'; export async function action(args: ActionFunctionArgs) { return chatAction(args); } async function chatAction({ context, request }: ActionFunctionArgs) { - const { messages, apiKeys } = await request.json<{ - messages: Messages, - apiKeys: Record + const { provider, model, messages, apiKeys } = await request.json<{ + messages: Messages; + model: ModelInfo; + apiKeys: Record; }>(); const stream = new SwitchableStream(); @@ -38,13 +40,13 @@ async function chatAction({ context, request }: ActionFunctionArgs) { messages.push({ role: 'assistant', content }); messages.push({ role: 'user', content: CONTINUE_PROMPT }); - const result = await streamText(messages, context.cloudflare.env, options); + const result = await streamText(model, messages, context.cloudflare.env, options, apiKeys); return stream.switchSource(result.toAIStream()); }, }; - const result = await streamText(messages, context.cloudflare.env, options, apiKeys); + const result = await streamText(model, messages, context.cloudflare.env, options, apiKeys); stream.switchSource(result.toAIStream()); @@ -56,11 +58,11 @@ async function chatAction({ context, request }: ActionFunctionArgs) { }); } catch (error) { console.log(error); - + if (error.message?.includes('API key')) { throw new Response('Invalid or missing API key', { status: 401, - statusText: 'Unauthorized' + statusText: 'Unauthorized', }); } diff --git a/app/routes/api.enhancer.ts b/app/routes/api.enhancer.ts index 5c8175ca3..895bccafc 100644 --- a/app/routes/api.enhancer.ts +++ b/app/routes/api.enhancer.ts @@ -2,6 +2,7 @@ import { type ActionFunctionArgs } from '@remix-run/cloudflare'; import { StreamingTextResponse, parseStreamPart } from 'ai'; import { streamText } from '~/lib/.server/llm/stream-text'; import { stripIndents } from '~/utils/stripIndent'; +import type { ModelInfo } from '~/utils/types'; const encoder = new TextEncoder(); const decoder = new TextDecoder(); @@ -11,10 +12,11 @@ export async function action(args: ActionFunctionArgs) { } async function enhancerAction({ context, request }: ActionFunctionArgs) { - const { message } = await request.json<{ message: string }>(); + const { model, message } = await request.json<{ message: string; model: ModelInfo }>(); try { const result = await streamText( + model, [ { role: 'user', diff --git a/app/routes/api.models.ts b/app/routes/api.models.ts index ace4ef009..85718868c 100644 --- a/app/routes/api.models.ts +++ b/app/routes/api.models.ts @@ -1,6 +1,25 @@ -import { json } from '@remix-run/cloudflare'; -import { MODEL_LIST } from '~/utils/constants'; +// Preventing TS checks with files presented in the video for a better presentation. +import { getModels } from '~/lib/.server/llm/model'; +import { type ActionFunctionArgs } from '@remix-run/cloudflare'; -export async function loader() { - return json(MODEL_LIST); +export async function action(args: ActionFunctionArgs) { + return modelsAction(args); +} + +async function modelsAction({ context, request }: ActionFunctionArgs) { + const { provider, apiKeys } = await request.json<{ + provider: string; + apiKeys: Record; + }>(); + + if (!provider || !context.cloudflare.env) { + throw new Response('Provider and environment are required', { status: 400 }); + } + + try { + const models = await getModels(provider, context.cloudflare.env, apiKeys); + return new Response(JSON.stringify(models), { status: 200 }); + } catch (error) { + return new Response(`Error fetching models: ${error}`, { status: 500 }); + } } diff --git a/app/utils/constants.ts b/app/utils/constants.ts index 35330575f..2e2b24330 100644 --- a/app/utils/constants.ts +++ b/app/utils/constants.ts @@ -7,45 +7,56 @@ export const MODEL_REGEX = /^\[Model: (.*?)\]\n\n/; export const DEFAULT_MODEL = 'google/gemini-flash-1.5-exp'; export const DEFAULT_PROVIDER = 'OpenRouter'; -const staticModels: ModelInfo[] = [ - { name: 'claude-3-5-sonnet-20240620', label: 'Claude 3.5 Sonnet', provider: 'Anthropic' }, - { name: 'gpt-4o', label: 'GPT-4o', provider: 'OpenAI' }, - { name: 'anthropic/claude-3.5-sonnet', label: 'Anthropic: Claude 3.5 Sonnet (OpenRouter)', provider: 'OpenRouter' }, - { name: 'anthropic/claude-3-haiku', label: 'Anthropic: Claude 3 Haiku (OpenRouter)', provider: 'OpenRouter' }, - { name: 'deepseek/deepseek-coder', label: 'Deepseek-Coder V2 236B (OpenRouter)', provider: 'OpenRouter' }, - { name: 'google/gemini-flash-1.5-exp', label: 'Google Gemini Flash 1.5 Exp (OpenRouter)', provider: 'OpenRouter' }, - { name: 'google/gemini-pro-1.5-exp', label: 'Google Gemini Pro 1.5 Exp (OpenRouter)', provider: 'OpenRouter' }, - { name: 'mistralai/mistral-nemo', label: 'OpenRouter Mistral Nemo (OpenRouter)', provider: 'OpenRouter' }, - { name: 'qwen/qwen-110b-chat', label: 'OpenRouter Qwen 110b Chat (OpenRouter)', provider: 'OpenRouter' }, - { name: 'cohere/command', label: 'Cohere Command (OpenRouter)', provider: 'OpenRouter' }, - { name: 'gemini-1.5-flash-latest', label: 'Gemini 1.5 Flash', provider: 'Google' }, - { name: 'gemini-1.5-pro-latest', label: 'Gemini 1.5 Pro', provider: 'Google'}, - { name: 'llama-3.1-70b-versatile', label: 'Llama 3.1 70b (Groq)', provider: 'Groq' }, - { name: 'llama-3.1-8b-instant', label: 'Llama 3.1 8b (Groq)', provider: 'Groq' }, - { name: 'llama-3.2-11b-vision-preview', label: 'Llama 3.2 11b (Groq)', provider: 'Groq' }, - { name: 'llama-3.2-3b-preview', label: 'Llama 3.2 3b (Groq)', provider: 'Groq' }, - { name: 'llama-3.2-1b-preview', label: 'Llama 3.2 1b (Groq)', provider: 'Groq' }, - { name: 'claude-3-opus-20240229', label: 'Claude 3 Opus', provider: 'Anthropic' }, - { name: 'claude-3-sonnet-20240229', label: 'Claude 3 Sonnet', provider: 'Anthropic' }, - { name: 'claude-3-haiku-20240307', label: 'Claude 3 Haiku', provider: 'Anthropic' }, - { name: 'gpt-4o-mini', label: 'GPT-4o Mini', provider: 'OpenAI' }, - { name: 'gpt-4-turbo', label: 'GPT-4 Turbo', provider: 'OpenAI' }, - { name: 'gpt-4', label: 'GPT-4', provider: 'OpenAI' }, - { name: 'gpt-3.5-turbo', label: 'GPT-3.5 Turbo', provider: 'OpenAI' }, - { name: 'deepseek-coder', label: 'Deepseek-Coder', provider: 'Deepseek'}, - { name: 'deepseek-chat', label: 'Deepseek-Chat', provider: 'Deepseek'}, - { name: 'open-mistral-7b', label: 'Mistral 7B', provider: 'Mistral' }, - { name: 'open-mixtral-8x7b', label: 'Mistral 8x7B', provider: 'Mistral' }, - { name: 'open-mixtral-8x22b', label: 'Mistral 8x22B', provider: 'Mistral' }, - { name: 'open-codestral-mamba', label: 'Codestral Mamba', provider: 'Mistral' }, - { name: 'open-mistral-nemo', label: 'Mistral Nemo', provider: 'Mistral' }, - { name: 'ministral-8b-latest', label: 'Mistral 8B', provider: 'Mistral' }, - { name: 'mistral-small-latest', label: 'Mistral Small', provider: 'Mistral' }, - { name: 'codestral-latest', label: 'Codestral', provider: 'Mistral' }, - { name: 'mistral-large-latest', label: 'Mistral Large Latest', provider: 'Mistral' }, +const staticProviders: string[] = [ + 'Ollama', + 'Anthropic', + 'OpenAI', + 'OpenRouter', + 'Mistral', + 'Groq', + 'Deepseek', + 'OpenAILike', ]; -export let MODEL_LIST: ModelInfo[] = [...staticModels]; +// const staticModels: ModelInfo[] = [ +// { name: 'claude-3-5-sonnet-20240620', label: 'Claude 3.5 Sonnet', provider: 'Anthropic' }, +// { name: 'gpt-4o', label: 'GPT-4o', provider: 'OpenAI' }, +// { name: 'anthropic/claude-3.5-sonnet', label: 'Anthropic: Claude 3.5 Sonnet (OpenRouter)', provider: 'OpenRouter' }, +// { name: 'anthropic/claude-3-haiku', label: 'Anthropic: Claude 3 Haiku (OpenRouter)', provider: 'OpenRouter' }, +// { name: 'deepseek/deepseek-coder', label: 'Deepseek-Coder V2 236B (OpenRouter)', provider: 'OpenRouter' }, +// { name: 'google/gemini-flash-1.5-exp', label: 'Google Gemini Flash 1.5 Exp (OpenRouter)', provider: 'OpenRouter' }, +// { name: 'google/gemini-pro-1.5-exp', label: 'Google Gemini Pro 1.5 Exp (OpenRouter)', provider: 'OpenRouter' }, +// { name: 'mistralai/mistral-nemo', label: 'OpenRouter Mistral Nemo (OpenRouter)', provider: 'OpenRouter' }, +// { name: 'qwen/qwen-110b-chat', label: 'OpenRouter Qwen 110b Chat (OpenRouter)', provider: 'OpenRouter' }, +// { name: 'cohere/command', label: 'Cohere Command (OpenRouter)', provider: 'OpenRouter' }, +// { name: 'gemini-1.5-flash-latest', label: 'Gemini 1.5 Flash', provider: 'Google' }, +// { name: 'gemini-1.5-pro-latest', label: 'Gemini 1.5 Pro', provider: 'Google' }, +// { name: 'llama-3.1-70b-versatile', label: 'Llama 3.1 70b (Groq)', provider: 'Groq' }, +// { name: 'llama-3.1-8b-instant', label: 'Llama 3.1 8b (Groq)', provider: 'Groq' }, +// { name: 'llama-3.2-11b-vision-preview', label: 'Llama 3.2 11b (Groq)', provider: 'Groq' }, +// { name: 'llama-3.2-3b-preview', label: 'Llama 3.2 3b (Groq)', provider: 'Groq' }, +// { name: 'llama-3.2-1b-preview', label: 'Llama 3.2 1b (Groq)', provider: 'Groq' }, +// { name: 'claude-3-opus-20240229', label: 'Claude 3 Opus', provider: 'Anthropic' }, +// { name: 'claude-3-sonnet-20240229', label: 'Claude 3 Sonnet', provider: 'Anthropic' }, +// { name: 'claude-3-haiku-20240307', label: 'Claude 3 Haiku', provider: 'Anthropic' }, +// { name: 'gpt-4o-mini', label: 'GPT-4o Mini', provider: 'OpenAI' }, +// { name: 'gpt-4-turbo', label: 'GPT-4 Turbo', provider: 'OpenAI' }, +// { name: 'gpt-4', label: 'GPT-4', provider: 'OpenAI' }, +// { name: 'gpt-3.5-turbo', label: 'GPT-3.5 Turbo', provider: 'OpenAI' }, +// { name: 'deepseek-coder', label: 'Deepseek-Coder', provider: 'Deepseek' }, +// { name: 'deepseek-chat', label: 'Deepseek-Chat', provider: 'Deepseek' }, +// { name: 'open-mistral-7b', label: 'Mistral 7B', provider: 'Mistral' }, +// { name: 'open-mixtral-8x7b', label: 'Mistral 8x7B', provider: 'Mistral' }, +// { name: 'open-mixtral-8x22b', label: 'Mistral 8x22B', provider: 'Mistral' }, +// { name: 'open-codestral-mamba', label: 'Codestral Mamba', provider: 'Mistral' }, +// { name: 'open-mistral-nemo', label: 'Mistral Nemo', provider: 'Mistral' }, +// { name: 'ministral-8b-latest', label: 'Mistral 8B', provider: 'Mistral' }, +// { name: 'mistral-small-latest', label: 'Mistral Small', provider: 'Mistral' }, +// { name: 'codestral-latest', label: 'Codestral', provider: 'Mistral' }, +// { name: 'mistral-large-latest', label: 'Mistral Large Latest', provider: 'Mistral' }, +// ]; + +// export let MODEL_LIST: ModelInfo[] = [...staticModels]; const getOllamaBaseUrl = () => { const defaultBaseUrl = import.meta.env.OLLAMA_API_BASE_URL || 'http://localhost:11434'; @@ -54,20 +65,18 @@ const getOllamaBaseUrl = () => { // Frontend always uses localhost return defaultBaseUrl; } - + // Backend: Check if we're running in Docker const isDocker = process.env.RUNNING_IN_DOCKER === 'true'; - - return isDocker - ? defaultBaseUrl.replace("localhost", "host.docker.internal") - : defaultBaseUrl; + + return isDocker ? defaultBaseUrl.replace('localhost', 'host.docker.internal') : defaultBaseUrl; }; async function getOllamaModels(): Promise { try { const base_url = getOllamaBaseUrl(); const response = await fetch(`${base_url}/api/tags`); - const data = await response.json() as OllamaApiResponse; + const data = (await response.json()) as OllamaApiResponse; return data.models.map((model: OllamaModel) => ({ name: model.name, @@ -80,32 +89,26 @@ async function getOllamaModels(): Promise { } async function getOpenAILikeModels(): Promise { - try { - const base_url =import.meta.env.OPENAI_LIKE_API_BASE_URL || ""; - if (!base_url) { + try { + const base_url = import.meta.env.OPENAI_LIKE_API_BASE_URL || 'http://localhost:4000'; + if (!base_url) { return []; - } - const api_key = import.meta.env.OPENAI_LIKE_API_KEY ?? ""; - const response = await fetch(`${base_url}/models`, { - headers: { - Authorization: `Bearer ${api_key}`, - } - }); - const res = await response.json() as any; + } + const api_key = import.meta.env.OPENAI_LIKE_API_KEY ?? ''; + const response = await fetch(`${base_url}/models`, { + headers: { + Authorization: `Bearer ${api_key}`, + }, + }); + const res = (await response.json()) as any; return res.data.map((model: any) => ({ name: model.id, label: model.id, provider: 'OpenAILike', })); - }catch (e) { - return [] - } - -} -async function initializeModelList(): Promise { - const ollamaModels = await getOllamaModels(); - const openAiLikeModels = await getOpenAILikeModels(); - MODEL_LIST = [...ollamaModels,...openAiLikeModels, ...staticModels]; + } catch (e) { + return []; + } } -initializeModelList().then(); -export { getOllamaModels, getOpenAILikeModels, initializeModelList }; + +export { getOllamaModels, getOpenAILikeModels, staticProviders }; diff --git a/worker-configuration.d.ts b/worker-configuration.d.ts index 82961ecd6..215987e78 100644 --- a/worker-configuration.d.ts +++ b/worker-configuration.d.ts @@ -1,6 +1,7 @@ interface Env { ANTHROPIC_API_KEY: string; OPENAI_API_KEY: string; + MISTRAL_API_KEY: string; GROQ_API_KEY: string; OPEN_ROUTER_API_KEY: string; OLLAMA_API_BASE_URL: string;