diff --git a/app/components/chat/BaseChat.tsx b/app/components/chat/BaseChat.tsx index fb9446a73..8ddfeeeef 100644 --- a/app/components/chat/BaseChat.tsx +++ b/app/components/chat/BaseChat.tsx @@ -7,12 +7,11 @@ import { Menu } from '~/components/sidebar/Menu.client'; import { IconButton } from '~/components/ui/IconButton'; import { Workbench } from '~/components/workbench/Workbench.client'; import { classNames } from '~/utils/classNames'; -import { DEFAULT_PROVIDER, MODEL_LIST, initializeModelList, isInitialized } from '~/utils/constants'; import { Messages } from './Messages.client'; import { SendButton } from './SendButton.client'; -import { useState } from 'react'; import styles from './BaseChat.module.scss'; +import type { ModelInfo } from '~/utils/types'; const EXAMPLE_PROMPTS = [ { text: 'Build a todo app in React using Tailwind' }, @@ -24,31 +23,20 @@ const EXAMPLE_PROMPTS = [ -function ModelSelector({ model, setModel }) { - const modelList = MODEL_LIST; - const providerList =[...new Set([...modelList.map((m) => m.provider),"OpenAILike","Ollama"])]; - const initialize = async () => { - if (!isInitialized) { - await initializeModelList(); - } - }; - initialize(); +function ModelSelector({ model, setModel ,provider,setProvider,modelList,providerList}) { - const [provider, setProvider] = useState(DEFAULT_PROVIDER); - const handleProviderChange = (e) => { - setProvider(e.target.value); - const firstModel = modelList.find((m) => m.provider === e.target.value); - setModel(firstModel ? firstModel.name : ''); - }; - return (
); @@ -85,8 +73,12 @@ interface BaseChatProps { enhancingPrompt?: boolean; promptEnhanced?: boolean; input?: string; - model: string; - setModel: (model: string) => void; + model?: string; + setModel?: (model: string) => void; + provider?: string; + setProvider?: (provider: string) => void; + modelList?: ModelInfo[]; + providerList?: string[]; handleStop?: () => void; sendMessage?: (event: React.UIEvent, messageInput?: string) => void; handleInputChange?: (event: React.ChangeEvent) => void; @@ -108,6 +100,10 @@ export const BaseChat = React.forwardRef( input = '', model, setModel, + provider, + setProvider, + modelList, + providerList, sendMessage, handleInputChange, enhancePrompt, @@ -116,7 +112,6 @@ export const BaseChat = React.forwardRef( ref, ) => { const TEXTAREA_MAX_HEIGHT = chatStarted ? 400 : 200; - return (
(
( } event.preventDefault(); - + console.log('Enter pressed'); + console.log("event", event); sendMessage?.(event); } }} diff --git a/app/components/chat/Chat.client.tsx b/app/components/chat/Chat.client.tsx index 458bd8364..ed4897b4c 100644 --- a/app/components/chat/Chat.client.tsx +++ b/app/components/chat/Chat.client.tsx @@ -11,7 +11,7 @@ import { useChatHistory } from '~/lib/persistence'; import { chatStore } from '~/lib/stores/chat'; import { workbenchStore } from '~/lib/stores/workbench'; import { fileModificationsToHTML } from '~/utils/diff'; -import { DEFAULT_MODEL } from '~/utils/constants'; +import { DEFAULT_MODEL, DEFAULT_PROVIDER, initializeModelList, isInitialized, MODEL_LIST } from '~/utils/constants'; import { cubicEasingFn } from '~/utils/easings'; import { createScopedLogger, renderLogger } from '~/utils/logger'; import { BaseChat } from './BaseChat'; @@ -74,6 +74,19 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp const [chatStarted, setChatStarted] = useState(initialMessages.length > 0); const [model, setModel] = useState(DEFAULT_MODEL); + const [provider, setProvider] = useState(DEFAULT_PROVIDER); + const [modelList, setModelList] = useState(MODEL_LIST); + const [providerList, setProviderList] = useState([...new Set([...MODEL_LIST.map((m) => m.provider), 'Ollama', 'OpenAILike'])]); + const initialize = async () => { + if (!isInitialized) { + const models= await initializeModelList(); + const modelList = models; + const providerList = [...new Set([...models.map((m) => m.provider),"Ollama","OpenAILike"])]; + setModelList(modelList); + setProviderList(providerList); + } + }; + initialize(); const { showChat } = useStore(chatStore); @@ -182,7 +195,7 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp * manually reset the input and we'd have to manually pass in file attachments. However, those * aren't relevant here. */ - append({ role: 'user', content: `[Model: ${model}]\n\n${diff}\n\n${_input}` }); + append({ role: 'user', content: `[Model: ${model}Provider: ${provider}]\n\n${diff}\n\n${_input}` }); /** * After sending a new message we reset all modifications since the model @@ -190,7 +203,7 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp */ workbenchStore.resetAllFileModifications(); } else { - append({ role: 'user', content: `[Model: ${model}]\n\n${_input}` }); + append({ role: 'user', content: `[Model: ${model}Provider: ${provider}]\n\n${_input}` }); } setInput(''); @@ -215,6 +228,10 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp sendMessage={sendMessage} model={model} setModel={setModel} + provider={provider} + setProvider={setProvider} + modelList={modelList} + providerList={providerList} messageRef={messageRef} scrollRef={scrollRef} handleInputChange={handleInputChange} diff --git a/app/lib/.server/llm/stream-text.ts b/app/lib/.server/llm/stream-text.ts index de3d5bfa8..e3f233bcf 100644 --- a/app/lib/.server/llm/stream-text.ts +++ b/app/lib/.server/llm/stream-text.ts @@ -4,7 +4,7 @@ import { streamText as _streamText, convertToCoreMessages } from 'ai'; import { getModel } from '~/lib/.server/llm/model'; import { MAX_TOKENS } from './constants'; import { getSystemPrompt } from './prompts'; -import { MODEL_LIST, DEFAULT_MODEL, DEFAULT_PROVIDER } from '~/utils/constants'; +import { DEFAULT_MODEL, DEFAULT_PROVIDER, hasModel } from '~/utils/constants'; interface ToolResult { toolCallId: string; @@ -25,42 +25,47 @@ export type Messages = Message[]; export type StreamingOptions = Omit[0], 'model'>; function extractModelFromMessage(message: Message): { model: string; content: string } { - const modelRegex = /^\[Model: (.*?)\]\n\n/; + const modelRegex = /^\[Model: (.*?)Provider: (.*?)\]\n\n/; const match = message.content.match(modelRegex); - if (match) { - const model = match[1]; - const content = message.content.replace(modelRegex, ''); - return { model, content }; + if (!match) { + return { model: DEFAULT_MODEL, content: message.content,provider: DEFAULT_PROVIDER }; } - + const [_,model,provider] = match; + const content = message.content.replace(modelRegex, ''); + return { model, content ,provider}; // Default model if not specified - return { model: DEFAULT_MODEL, content: message.content }; + } export function streamText(messages: Messages, env: Env, options?: StreamingOptions) { let currentModel = DEFAULT_MODEL; + let currentProvider = DEFAULT_PROVIDER; + const lastMessage = messages.findLast((message) => message.role === 'user'); + if (lastMessage) { + const { model, provider } = extractModelFromMessage(lastMessage); + if (hasModel(model, provider)) { + currentModel = model; + currentProvider = provider; + } + } const processedMessages = messages.map((message) => { if (message.role === 'user') { - const { model, content } = extractModelFromMessage(message); - if (model && MODEL_LIST.find((m) => m.name === model)) { - currentModel = model; // Update the current model - } + const { content } = extractModelFromMessage(message); return { ...message, content }; } return message; }); - const provider = MODEL_LIST.find((model) => model.name === currentModel)?.provider || DEFAULT_PROVIDER; - + const coreMessages = convertToCoreMessages(processedMessages); return _streamText({ - model: getModel(provider, currentModel, env), + model: getModel(currentProvider, currentModel, env), system: getSystemPrompt(), maxTokens: MAX_TOKENS, // headers: { // 'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15', // }, - messages: convertToCoreMessages(processedMessages), + messages: coreMessages, ...options, }); } diff --git a/app/routes/_index.tsx b/app/routes/_index.tsx index 86d73409c..7f197908a 100644 --- a/app/routes/_index.tsx +++ b/app/routes/_index.tsx @@ -3,6 +3,8 @@ import { ClientOnly } from 'remix-utils/client-only'; import { BaseChat } from '~/components/chat/BaseChat'; import { Chat } from '~/components/chat/Chat.client'; import { Header } from '~/components/header/Header'; +import { useState } from 'react'; +import { DEFAULT_MODEL, DEFAULT_PROVIDER, MODEL_LIST } from '~/utils/constants'; export const meta: MetaFunction = () => { return [{ title: 'Bolt' }, { name: 'description', content: 'Talk with Bolt, an AI assistant from StackBlitz' }]; @@ -11,10 +13,14 @@ export const meta: MetaFunction = () => { export const loader = () => json({}); export default function Index() { + const [model, setModel] = useState(DEFAULT_MODEL); + const [provider, setProvider] = useState(DEFAULT_PROVIDER); + const [modelList, setModelList] = useState(MODEL_LIST); + const [providerList, setProviderList] = useState([...new Set([...MODEL_LIST.map((m) => m.provider), 'Ollama', 'OpenAILike'])]); return (
- }>{() => } + }>{() => }
); } diff --git a/app/utils/constants.ts b/app/utils/constants.ts index c3b398892..8ba57905b 100644 --- a/app/utils/constants.ts +++ b/app/utils/constants.ts @@ -45,7 +45,14 @@ const staticModels: ModelInfo[] = [ ]; export let MODEL_LIST: ModelInfo[] = [...staticModels]; - +export function hasModel(modelName:string,provider:string): boolean { + for (const model of MODEL_LIST) { + if ( model.provider === provider && model.name === modelName) { + return true; + } + } + return false +} export const IS_SERVER = typeof window === 'undefined'; export function setModelList(models: ModelInfo[]): void {