Skip to content

Commit

Permalink
chore: Update constants.ts and fix model name bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ZerxZ committed Oct 26, 2024
1 parent 9b830ed commit de22dcd
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 52 deletions.
62 changes: 31 additions & 31 deletions app/components/chat/BaseChat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@ import { Menu } from '~/components/sidebar/Menu.client';
import { IconButton } from '~/components/ui/IconButton';
import { Workbench } from '~/components/workbench/Workbench.client';
import { classNames } from '~/utils/classNames';
import { DEFAULT_PROVIDER, MODEL_LIST, initializeModelList, isInitialized } from '~/utils/constants';
import { Messages } from './Messages.client';
import { SendButton } from './SendButton.client';
import { useState } from 'react';

import styles from './BaseChat.module.scss';
import type { ModelInfo } from '~/utils/types';

const EXAMPLE_PROMPTS = [
{ text: 'Build a todo app in React using Tailwind' },
Expand All @@ -24,31 +23,20 @@ const EXAMPLE_PROMPTS = [



function ModelSelector({ model, setModel }) {
const modelList = MODEL_LIST;
const providerList =[...new Set([...modelList.map((m) => m.provider),"OpenAILike","Ollama"])];
const initialize = async () => {
if (!isInitialized) {
await initializeModelList();
}
};
initialize();
function ModelSelector({ model, setModel ,provider,setProvider,modelList,providerList}) {

const [provider, setProvider] = useState(DEFAULT_PROVIDER);



const handleProviderChange = (e) => {
setProvider(e.target.value);
const firstModel = modelList.find((m) => m.provider === e.target.value);
setModel(firstModel ? firstModel.name : '');
};

return (
<div className="mb-2">
<select
value={provider}
onChange={handleProviderChange}
onChange={(e) => {
setProvider(e.target.value);
const firstModel = modelList.find((m) => m.provider === e.target.value);
setModel(firstModel ? firstModel.name : '');
}}
className="w-full p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none"
>
{providerList.map(providerName=>( <option key={providerName} value={providerName}>
Expand All @@ -57,16 +45,16 @@ function ModelSelector({ model, setModel }) {
</select>
<select
value={model}
onChange={(e) => setModel(e.target.value)}
onChange={(e) => {
setModel(e.target.value)
}}
className="w-full p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none"
>
{modelList
.filter((e) => e.provider === provider && e.name)
.map((modelOption) => (
<option key={modelOption.name} value={modelOption.name}>
{modelOption.label}
</option>
))}
{[...modelList].filter(e => e.provider == provider && e.name).map((modelOption) => (
<option key={modelOption.name} value={modelOption.name}>
{modelOption.label}
</option>
))}
</select>
</div>
);
Expand All @@ -85,8 +73,12 @@ interface BaseChatProps {
enhancingPrompt?: boolean;
promptEnhanced?: boolean;
input?: string;
model: string;
setModel: (model: string) => void;
model?: string;
setModel?: (model: string) => void;
provider?: string;
setProvider?: (provider: string) => void;
modelList?: ModelInfo[];
providerList?: string[];
handleStop?: () => void;
sendMessage?: (event: React.UIEvent, messageInput?: string) => void;
handleInputChange?: (event: React.ChangeEvent<HTMLTextAreaElement>) => void;
Expand All @@ -108,6 +100,10 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
input = '',
model,
setModel,
provider,
setProvider,
modelList,
providerList,
sendMessage,
handleInputChange,
enhancePrompt,
Expand All @@ -116,7 +112,6 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
ref,
) => {
const TEXTAREA_MAX_HEIGHT = chatStarted ? 400 : 200;

return (
<div
ref={ref}
Expand Down Expand Up @@ -164,6 +159,10 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
<ModelSelector
model={model}
setModel={setModel}
provider={provider}
setProvider={setProvider}
modelList={modelList}
providerList={providerList}
/>
<div
className={classNames(
Expand All @@ -180,7 +179,8 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
}

event.preventDefault();

console.log('Enter pressed');
console.log("event", event);
sendMessage?.(event);
}
}}
Expand Down
23 changes: 20 additions & 3 deletions app/components/chat/Chat.client.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { useChatHistory } from '~/lib/persistence';
import { chatStore } from '~/lib/stores/chat';
import { workbenchStore } from '~/lib/stores/workbench';
import { fileModificationsToHTML } from '~/utils/diff';
import { DEFAULT_MODEL } from '~/utils/constants';
import { DEFAULT_MODEL, DEFAULT_PROVIDER, initializeModelList, isInitialized, MODEL_LIST } from '~/utils/constants';
import { cubicEasingFn } from '~/utils/easings';
import { createScopedLogger, renderLogger } from '~/utils/logger';
import { BaseChat } from './BaseChat';
Expand Down Expand Up @@ -74,6 +74,19 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp

const [chatStarted, setChatStarted] = useState(initialMessages.length > 0);
const [model, setModel] = useState(DEFAULT_MODEL);
const [provider, setProvider] = useState(DEFAULT_PROVIDER);
const [modelList, setModelList] = useState(MODEL_LIST);
const [providerList, setProviderList] = useState([...new Set([...MODEL_LIST.map((m) => m.provider), 'Ollama', 'OpenAILike'])]);
const initialize = async () => {
if (!isInitialized) {
const models= await initializeModelList();
const modelList = models;
const providerList = [...new Set([...models.map((m) => m.provider),"Ollama","OpenAILike"])];
setModelList(modelList);
setProviderList(providerList);
}
};
initialize();

const { showChat } = useStore(chatStore);

Expand Down Expand Up @@ -182,15 +195,15 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
* manually reset the input and we'd have to manually pass in file attachments. However, those
* aren't relevant here.
*/
append({ role: 'user', content: `[Model: ${model}]\n\n${diff}\n\n${_input}` });
append({ role: 'user', content: `[Model: ${model}Provider: ${provider}]\n\n${diff}\n\n${_input}` });

/**
* After sending a new message we reset all modifications since the model
* should now be aware of all the changes.
*/
workbenchStore.resetAllFileModifications();
} else {
append({ role: 'user', content: `[Model: ${model}]\n\n${_input}` });
append({ role: 'user', content: `[Model: ${model}Provider: ${provider}]\n\n${_input}` });
}

setInput('');
Expand All @@ -215,6 +228,10 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
sendMessage={sendMessage}
model={model}
setModel={setModel}
provider={provider}
setProvider={setProvider}
modelList={modelList}
providerList={providerList}
messageRef={messageRef}
scrollRef={scrollRef}
handleInputChange={handleInputChange}
Expand Down
37 changes: 21 additions & 16 deletions app/lib/.server/llm/stream-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { streamText as _streamText, convertToCoreMessages } from 'ai';
import { getModel } from '~/lib/.server/llm/model';
import { MAX_TOKENS } from './constants';
import { getSystemPrompt } from './prompts';
import { MODEL_LIST, DEFAULT_MODEL, DEFAULT_PROVIDER } from '~/utils/constants';
import { DEFAULT_MODEL, DEFAULT_PROVIDER, hasModel } from '~/utils/constants';

interface ToolResult<Name extends string, Args, Result> {
toolCallId: string;
Expand All @@ -25,42 +25,47 @@ export type Messages = Message[];
export type StreamingOptions = Omit<Parameters<typeof _streamText>[0], 'model'>;

function extractModelFromMessage(message: Message): { model: string; content: string } {
const modelRegex = /^\[Model: (.*?)\]\n\n/;
const modelRegex = /^\[Model: (.*?)Provider: (.*?)\]\n\n/;
const match = message.content.match(modelRegex);

if (match) {
const model = match[1];
const content = message.content.replace(modelRegex, '');
return { model, content };
if (!match) {
return { model: DEFAULT_MODEL, content: message.content,provider: DEFAULT_PROVIDER };
}

const [_,model,provider] = match;
const content = message.content.replace(modelRegex, '');
return { model, content ,provider};
// Default model if not specified
return { model: DEFAULT_MODEL, content: message.content };

}

export function streamText(messages: Messages, env: Env, options?: StreamingOptions) {
let currentModel = DEFAULT_MODEL;
let currentProvider = DEFAULT_PROVIDER;
const lastMessage = messages.findLast((message) => message.role === 'user');
if (lastMessage) {
const { model, provider } = extractModelFromMessage(lastMessage);
if (hasModel(model, provider)) {
currentModel = model;
currentProvider = provider;
}
}
const processedMessages = messages.map((message) => {
if (message.role === 'user') {
const { model, content } = extractModelFromMessage(message);
if (model && MODEL_LIST.find((m) => m.name === model)) {
currentModel = model; // Update the current model
}
const { content } = extractModelFromMessage(message);
return { ...message, content };
}
return message;
});

const provider = MODEL_LIST.find((model) => model.name === currentModel)?.provider || DEFAULT_PROVIDER;

const coreMessages = convertToCoreMessages(processedMessages);
return _streamText({
model: getModel(provider, currentModel, env),
model: getModel(currentProvider, currentModel, env),
system: getSystemPrompt(),
maxTokens: MAX_TOKENS,
// headers: {
// 'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15',
// },
messages: convertToCoreMessages(processedMessages),
messages: coreMessages,
...options,
});
}
8 changes: 7 additions & 1 deletion app/routes/_index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import { ClientOnly } from 'remix-utils/client-only';
import { BaseChat } from '~/components/chat/BaseChat';
import { Chat } from '~/components/chat/Chat.client';
import { Header } from '~/components/header/Header';
import { useState } from 'react';
import { DEFAULT_MODEL, DEFAULT_PROVIDER, MODEL_LIST } from '~/utils/constants';

export const meta: MetaFunction = () => {
return [{ title: 'Bolt' }, { name: 'description', content: 'Talk with Bolt, an AI assistant from StackBlitz' }];
Expand All @@ -11,10 +13,14 @@ export const meta: MetaFunction = () => {
export const loader = () => json({});

export default function Index() {
const [model, setModel] = useState(DEFAULT_MODEL);
const [provider, setProvider] = useState(DEFAULT_PROVIDER);
const [modelList, setModelList] = useState(MODEL_LIST);
const [providerList, setProviderList] = useState([...new Set([...MODEL_LIST.map((m) => m.provider), 'Ollama', 'OpenAILike'])]);
return (
<div className="flex flex-col h-full w-full">
<Header />
<ClientOnly fallback={<BaseChat />}>{() => <Chat />}</ClientOnly>
<ClientOnly fallback={<BaseChat model={model} modelList={modelList} provider={provider} providerList={providerList} setModel={setModel} setProvider={setProvider}/>}>{() => <Chat />}</ClientOnly>
</div>
);
}
9 changes: 8 additions & 1 deletion app/utils/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,14 @@ const staticModels: ModelInfo[] = [
];

export let MODEL_LIST: ModelInfo[] = [...staticModels];

export function hasModel(modelName:string,provider:string): boolean {
for (const model of MODEL_LIST) {
if ( model.provider === provider && model.name === modelName) {
return true;
}
}
return false
}
export const IS_SERVER = typeof window === 'undefined';

export function setModelList(models: ModelInfo[]): void {
Expand Down

0 comments on commit de22dcd

Please sign in to comment.