Skip to content

Commit

Permalink
Moved provider and setProvider variables to the higher level componen…
Browse files Browse the repository at this point in the history
…t so that it can be accessed in sendMessage.

Added provider to message queue in sendMessage.
Changed streamText to extract both model and provider.
  • Loading branch information
TommyHolmberg committed Nov 6, 2024
1 parent a6d81b1 commit 074e2f3
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 24 deletions.
9 changes: 7 additions & 2 deletions app/components/chat/BaseChat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ const EXAMPLE_PROMPTS = [

const providerList = [...new Set(MODEL_LIST.map((model) => model.provider))]

const ModelSelector = ({ model, setModel, modelList, providerList }) => {
const [provider, setProvider] = useState(DEFAULT_PROVIDER);
const ModelSelector = ({ model, setModel, provider, setProvider, modelList, providerList }) => {
return (
<div className="mb-2">
<select
Expand Down Expand Up @@ -79,6 +78,8 @@ interface BaseChatProps {
input?: string;
model: string;
setModel: (model: string) => void;
provider: string;
setProvider: (provider: string) => void;
handleStop?: () => void;
sendMessage?: (event: React.UIEvent, messageInput?: string) => void;
handleInputChange?: (event: React.ChangeEvent<HTMLTextAreaElement>) => void;
Expand All @@ -100,6 +101,8 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
input = '',
model,
setModel,
provider,
setProvider,
sendMessage,
handleInputChange,
enhancePrompt,
Expand Down Expand Up @@ -157,6 +160,8 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
model={model}
setModel={setModel}
modelList={MODEL_LIST}
provider={provider}
setProvider={setProvider}
providerList={providerList}
/>
<div
Expand Down
9 changes: 6 additions & 3 deletions app/components/chat/Chat.client.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import { useChatHistory } from '~/lib/persistence';
import { chatStore } from '~/lib/stores/chat';
import { workbenchStore } from '~/lib/stores/workbench';
import { fileModificationsToHTML } from '~/utils/diff';
import { DEFAULT_MODEL } from '~/utils/constants';
import { DEFAULT_MODEL, DEFAULT_PROVIDER } from '~/utils/constants';
import { cubicEasingFn } from '~/utils/easings';
import { createScopedLogger, renderLogger } from '~/utils/logger';
import { BaseChat } from './BaseChat';
Expand Down Expand Up @@ -74,6 +74,7 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp

const [chatStarted, setChatStarted] = useState(initialMessages.length > 0);
const [model, setModel] = useState(DEFAULT_MODEL);
const [provider, setProvider] = useState(DEFAULT_PROVIDER);

const { showChat } = useStore(chatStore);

Expand Down Expand Up @@ -182,15 +183,15 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
* manually reset the input and we'd have to manually pass in file attachments. However, those
* aren't relevant here.
*/
append({ role: 'user', content: `[Model: ${model}]\n\n${diff}\n\n${_input}` });
append({ role: 'user', content: `[Model: ${model}]\n\n[Provider: ${provider}]\n\n${diff}\n\n${_input}` });

/**
* After sending a new message we reset all modifications since the model
* should now be aware of all the changes.
*/
workbenchStore.resetAllFileModifications();
} else {
append({ role: 'user', content: `[Model: ${model}]\n\n${_input}` });
append({ role: 'user', content: `[Model: ${model}]\n\n[Provider: ${provider}]\n\n${_input}` });
}

setInput('');
Expand All @@ -215,6 +216,8 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
sendMessage={sendMessage}
model={model}
setModel={setModel}
provider={provider}
setProvider={setProvider}
messageRef={messageRef}
scrollRef={scrollRef}
handleInputChange={handleInputChange}
Expand Down
47 changes: 28 additions & 19 deletions app/lib/.server/llm/stream-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,42 +24,51 @@ export type Messages = Message[];

export type StreamingOptions = Omit<Parameters<typeof _streamText>[0], 'model'>;

function extractModelFromMessage(message: Message): { model: string; content: string } {
function extractPropertiesFromMessage(message: Message): { model: string; provider: string; content: string } {
const modelRegex = /^\[Model: (.*?)\]\n\n/;
const match = message.content.match(modelRegex);
const providerRegex = /\[Provider: (.*?)\]\n\n/;

if (match) {
const model = match[1];
const content = message.content.replace(modelRegex, '');
return { model, content };
}
// Extract model
const modelMatch = message.content.match(modelRegex);
const model = modelMatch ? modelMatch[1] : DEFAULT_MODEL;

// Default model if not specified
return { model: DEFAULT_MODEL, content: message.content };
// Extract provider
const providerMatch = message.content.match(providerRegex);
const provider = providerMatch ? providerMatch[1] : DEFAULT_PROVIDER;

// Remove model and provider lines from content
const cleanedContent = message.content
.replace(modelRegex, '')
.replace(providerRegex, '')
.trim();

return { model, provider, content: cleanedContent };
}

export function streamText(messages: Messages, env: Env, options?: StreamingOptions) {
let currentModel = DEFAULT_MODEL;
let currentProvider = DEFAULT_PROVIDER;

const processedMessages = messages.map((message) => {
if (message.role === 'user') {
const { model, content } = extractModelFromMessage(message);
if (model && MODEL_LIST.find((m) => m.name === model)) {
currentModel = model; // Update the current model
const { model, provider, content } = extractPropertiesFromMessage(message);

if (MODEL_LIST.find((m) => m.name === model)) {
currentModel = model;
}

currentProvider = provider;

return { ...message, content };
}
return message;
});

const provider = MODEL_LIST.find((model) => model.name === currentModel)?.provider || DEFAULT_PROVIDER;
return message; // No changes for non-user messages
});

return _streamText({
model: getModel(provider, currentModel, env),
model: getModel(currentProvider, currentModel, env),
system: getSystemPrompt(),
maxTokens: MAX_TOKENS,
// headers: {
// 'anthropic-beta': 'max-tokens-3-5-sonnet-2024-07-15',
// },
messages: convertToCoreMessages(processedMessages),
...options,
});
Expand Down

0 comments on commit 074e2f3

Please sign in to comment.