Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: refactor/standardise model providers code + add "get provider key" #251

Merged
merged 10 commits into from
Nov 14, 2024
18 changes: 15 additions & 3 deletions app/components/chat/APIKeyManager.tsx
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
import React, { useState } from 'react';
import { IconButton } from '~/components/ui/IconButton';
import type { ProviderInfo } from '~/types/model';

interface APIKeyManagerProps {
provider: string;
provider: ProviderInfo;
apiKey: string;
setApiKey: (key: string) => void;
getApiKeyLink?: string;
labelForGetApiKey?: string;
}

export const APIKeyManager: React.FC<APIKeyManagerProps> = ({ provider, apiKey, setApiKey }) => {
export const APIKeyManager: React.FC<APIKeyManagerProps> = ({
provider,
apiKey,
setApiKey,
}) => {
const [isEditing, setIsEditing] = useState(false);
const [tempKey, setTempKey] = useState(apiKey);

Expand All @@ -18,7 +25,7 @@ export const APIKeyManager: React.FC<APIKeyManagerProps> = ({ provider, apiKey,

return (
<div className="flex items-center gap-2 mt-2 mb-2">
<span className="text-sm text-bolt-elements-textSecondary">{provider} API Key:</span>
<span className="text-sm text-bolt-elements-textSecondary">{provider?.name} API Key:</span>
{isEditing ? (
<>
<input
Expand All @@ -42,6 +49,11 @@ export const APIKeyManager: React.FC<APIKeyManagerProps> = ({ provider, apiKey,
<IconButton onClick={() => setIsEditing(true)} title="Edit API Key">
<div className="i-ph:pencil-simple" />
</IconButton>

{provider?.getApiKeyLink && <IconButton onClick={() => window.open(provider?.getApiKeyLink)} title="Edit API Key">
<span className="mr-2">{provider?.labelForGetApiKey || 'Get API Key'}</span>
<div className={provider?.icon || "i-ph:key"} />
</IconButton>}
</>
)}
</div>
Expand Down
53 changes: 28 additions & 25 deletions app/components/chat/BaseChat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@ import { Menu } from '~/components/sidebar/Menu.client';
import { IconButton } from '~/components/ui/IconButton';
import { Workbench } from '~/components/workbench/Workbench.client';
import { classNames } from '~/utils/classNames';
import { MODEL_LIST, DEFAULT_PROVIDER } from '~/utils/constants';
import { MODEL_LIST, DEFAULT_PROVIDER, PROVIDER_LIST, initializeModelList } from '~/utils/constants';
import { Messages } from './Messages.client';
import { SendButton } from './SendButton.client';
import { useState } from 'react';
import { APIKeyManager } from './APIKeyManager';
import Cookies from 'js-cookie';

import styles from './BaseChat.module.scss';
import type { ProviderInfo } from '~/utils/types';

const EXAMPLE_PROMPTS = [
{ text: 'Build a todo app in React using Tailwind' },
Expand All @@ -24,42 +25,35 @@ const EXAMPLE_PROMPTS = [
{ text: 'How do I center a div?' },
];

const providerList = [...new Set(MODEL_LIST.map((model) => model.provider))];
const providerList = PROVIDER_LIST;

const ModelSelector = ({ model, setModel, provider, setProvider, modelList, providerList }) => {
return (
<div className="mb-2 flex gap-2">
<select
value={provider}
value={provider?.name}
onChange={(e) => {
setProvider(e.target.value);
setProvider(providerList.find(p => p.name === e.target.value));
const firstModel = [...modelList].find((m) => m.provider == e.target.value);
setModel(firstModel ? firstModel.name : '');
}}
className="flex-1 p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none focus:ring-2 focus:ring-bolt-elements-focus transition-all"
>
{providerList.map((provider) => (
<option key={provider} value={provider}>
{provider}
<option key={provider.name} value={provider.name}>
{provider.name}
</option>
))}
<option key="Ollama" value="Ollama">
Ollama
</option>
<option key="OpenAILike" value="OpenAILike">
OpenAILike
</option>
<option key="LMStudio" value="LMStudio">
LMStudio
</option>
</select>
<select
key={provider?.name}
value={model}
onChange={(e) => setModel(e.target.value)}
style={{maxWidth: "70%"}}
className="flex-1 p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none focus:ring-2 focus:ring-bolt-elements-focus transition-all"
>
{[...modelList]
.filter((e) => e.provider == provider && e.name)
.filter((e) => e.provider == provider?.name && e.name)
.map((modelOption) => (
<option key={modelOption.name} value={modelOption.name}>
{modelOption.label}
Expand All @@ -85,8 +79,8 @@ interface BaseChatProps {
input?: string;
model?: string;
setModel?: (model: string) => void;
provider?: string;
setProvider?: (provider: string) => void;
provider?: ProviderInfo;
setProvider?: (provider: ProviderInfo) => void;
handleStop?: () => void;
sendMessage?: (event: React.UIEvent, messageInput?: string) => void;
handleInputChange?: (event: React.ChangeEvent<HTMLTextAreaElement>) => void;
Expand Down Expand Up @@ -117,8 +111,11 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
},
ref,
) => {
console.log(provider);
const TEXTAREA_MAX_HEIGHT = chatStarted ? 400 : 200;
const [apiKeys, setApiKeys] = useState<Record<string, string>>({});
const [modelList, setModelList] = useState(MODEL_LIST);


useEffect(() => {
// Load API keys from cookies on component mount
Expand All @@ -135,6 +132,10 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
// Clear invalid cookie data
Cookies.remove('apiKeys');
}

initializeModelList().then(modelList => {
setModelList(modelList);
});
}, []);

const updateApiKey = (provider: string, key: string) => {
Expand Down Expand Up @@ -198,18 +199,20 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
})}
>
<ModelSelector
key={provider?.name + ':' + modelList.length}
model={model}
setModel={setModel}
modelList={MODEL_LIST}
modelList={modelList}
provider={provider}
setProvider={setProvider}
providerList={providerList}
/>
<APIKeyManager
provider={provider}
apiKey={apiKeys[provider] || ''}
setApiKey={(key) => updateApiKey(provider, key)}
providerList={PROVIDER_LIST}
/>
{provider &&
<APIKeyManager
provider={provider}
apiKey={apiKeys[provider.name] || ''}
setApiKey={(key) => updateApiKey(provider.name, key)}
/>}
<div
className={classNames(
'shadow-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background backdrop-filter backdrop-blur-[8px] rounded-lg overflow-hidden transition-all',
Expand Down
17 changes: 9 additions & 8 deletions app/components/chat/Chat.client.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ import { useChatHistory } from '~/lib/persistence';
import { chatStore } from '~/lib/stores/chat';
import { workbenchStore } from '~/lib/stores/workbench';
import { fileModificationsToHTML } from '~/utils/diff';
import { DEFAULT_MODEL, DEFAULT_PROVIDER } from '~/utils/constants';
import { DEFAULT_MODEL, DEFAULT_PROVIDER, PROVIDER_LIST } from '~/utils/constants';
import { cubicEasingFn } from '~/utils/easings';
import { createScopedLogger, renderLogger } from '~/utils/logger';
import { BaseChat } from './BaseChat';
import Cookies from 'js-cookie';
import type { ProviderInfo } from '~/utils/types';

const toastAnimation = cssTransition({
enter: 'animated fadeInRight',
Expand Down Expand Up @@ -80,7 +81,7 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
});
const [provider, setProvider] = useState(() => {
const savedProvider = Cookies.get('selectedProvider');
return savedProvider || DEFAULT_PROVIDER;
return PROVIDER_LIST.find(p => p.name === savedProvider) || DEFAULT_PROVIDER;
});

const { showChat } = useStore(chatStore);
Expand All @@ -96,7 +97,7 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
},
onError: (error) => {
logger.error('Request failed\n\n', error);
toast.error('There was an error processing your request');
toast.error('There was an error processing your request: ' + (error.message ? error.message : "No details were returned"));
},
onFinish: () => {
logger.debug('Finished streaming');
Expand Down Expand Up @@ -195,15 +196,15 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
* manually reset the input and we'd have to manually pass in file attachments. However, those
* aren't relevant here.
*/
append({ role: 'user', content: `[Model: ${model}]\n\n[Provider: ${provider}]\n\n${diff}\n\n${_input}` });
append({ role: 'user', content: `[Model: ${model}]\n\n[Provider: ${provider.name}]\n\n${diff}\n\n${_input}` });

/**
* After sending a new message we reset all modifications since the model
* should now be aware of all the changes.
*/
workbenchStore.resetAllFileModifications();
} else {
append({ role: 'user', content: `[Model: ${model}]\n\n[Provider: ${provider}]\n\n${_input}` });
append({ role: 'user', content: `[Model: ${model}]\n\n[Provider: ${provider.name}]\n\n${_input}` });
}

setInput('');
Expand All @@ -227,9 +228,9 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
Cookies.set('selectedModel', newModel, { expires: 30 });
};

const handleProviderChange = (newProvider: string) => {
const handleProviderChange = (newProvider: ProviderInfo) => {
setProvider(newProvider);
Cookies.set('selectedProvider', newProvider, { expires: 30 });
Cookies.set('selectedProvider', newProvider.name, { expires: 30 });
};

return (
Expand Down Expand Up @@ -263,7 +264,7 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
})}
enhancePrompt={() => {
enhancePrompt(
input,
input,
(input) => {
setInput(input);
scrollTextArea();
Expand Down
5 changes: 2 additions & 3 deletions app/lib/.server/llm/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import { createOpenAI } from '@ai-sdk/openai';
import { createGoogleGenerativeAI } from '@ai-sdk/google';
import { ollama } from 'ollama-ai-provider';
import { createOpenRouter } from "@openrouter/ai-sdk-provider";
import { mistral } from '@ai-sdk/mistral';
import { createMistral } from '@ai-sdk/mistral';

export function getAnthropicModel(apiKey: string, model: string) {
Expand Down Expand Up @@ -41,9 +40,9 @@ export function getMistralModel(apiKey: string, model: string) {
}

export function getGoogleModel(apiKey: string, model: string) {
const google = createGoogleGenerativeAI(
const google = createGoogleGenerativeAI({
apiKey,
);
});

return google(model);
}
Expand Down
31 changes: 27 additions & 4 deletions app/routes/api.chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,35 @@ export async function action(args: ActionFunctionArgs) {
return chatAction(args);
}

function parseCookies(cookieHeader) {
const cookies = {};

// Split the cookie string by semicolons and spaces
const items = cookieHeader.split(";").map(cookie => cookie.trim());

items.forEach(item => {
const [name, ...rest] = item.split("=");
if (name && rest) {
// Decode the name and value, and join value parts in case it contains '='
const decodedName = decodeURIComponent(name.trim());
const decodedValue = decodeURIComponent(rest.join("=").trim());
cookies[decodedName] = decodedValue;
}
});

return cookies;
}

async function chatAction({ context, request }: ActionFunctionArgs) {
const { messages, apiKeys } = await request.json<{
messages: Messages,
apiKeys: Record<string, string>
const { messages } = await request.json<{
messages: Messages
}>();

const cookieHeader = request.headers.get("Cookie");

// Parse the cookie's value (returns an object or null if no cookie exists)
const apiKeys = JSON.parse(parseCookies(cookieHeader).apiKeys || "{}");

const stream = new SwitchableStream();

try {
Expand Down Expand Up @@ -56,7 +79,7 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
});
} catch (error) {
console.log(error);

if (error.message?.includes('API key')) {
throw new Response('Invalid or missing API key', {
status: 401,
Expand Down
10 changes: 10 additions & 0 deletions app/types/model.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import type { ModelInfo } from '~/utils/types';

export type ProviderInfo = {
staticModels: ModelInfo[],
name: string,
getDynamicModels?: () => Promise<ModelInfo[]>,
getApiKeyLink?: string,
labelForGetApiKey?: string,
icon?:string,
};
Loading