Skip to content

Commit

Permalink
Refactor chat implementation and add API key support
Browse files Browse the repository at this point in the history
  • Loading branch information
ZerxZ committed Oct 28, 2024
1 parent 1bec743 commit 1c0c5c0
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 62 deletions.
32 changes: 23 additions & 9 deletions app/components/chat/Chat.client.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@ import { useChatHistory } from '~/lib/persistence';
import { chatStore } from '~/lib/stores/chat';
import { workbenchStore } from '~/lib/stores/workbench';
import { fileModificationsToHTML } from '~/utils/diff';
import { DEFAULT_MODEL, DEFAULT_PROVIDER, initializeModelList, isInitialized, MODEL_LIST } from '~/utils/constants';
import {
DEFAULT_MODEL,
DEFAULT_PROVIDER,
initializeModelList,
isInitialized,
MODEL_LIST,
PROVIDER_LIST
} from '~/utils/constants';
import { cubicEasingFn } from '~/utils/easings';
import { createScopedLogger, renderLogger } from '~/utils/logger';
import { BaseChat } from './BaseChat';
Expand Down Expand Up @@ -76,12 +83,12 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
const [model, setModel] = useState(DEFAULT_MODEL);
const [provider, setProvider] = useState(DEFAULT_PROVIDER);
const [modelList, setModelList] = useState(MODEL_LIST);
const [providerList, setProviderList] = useState([...new Set([...MODEL_LIST.map((m) => m.provider), 'Ollama', 'OpenAILike'])]);
const [providerList, setProviderList] = useState([...new Set([...MODEL_LIST.map((m) => m.provider), ...PROVIDER_LIST])]);
// TODO: Add API key
const [api_key, setApiKey] = useState("");
const initialize = async () => {
if (!isInitialized) {
const models= await initializeModelList();
const modelList = models;
const providerList = [...new Set([...models.map((m) => m.provider),"Ollama","OpenAILike"])];
const { modelList , providerList }= await initializeModelList();
setModelList(modelList);
setProviderList(providerList);
}
Expand Down Expand Up @@ -184,7 +191,12 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
chatStore.setKey('aborted', false);

runAnimation();

const message = { role: 'user', content: "" };
const body = {
model,
provider,
api_key,
}
if (fileModifications !== undefined) {
const diff = fileModificationsToHTML(fileModifications);

Expand All @@ -195,17 +207,19 @@ export const ChatImpl = memo(({ initialMessages, storeMessageHistory }: ChatProp
* manually reset the input and we'd have to manually pass in file attachments. However, those
* aren't relevant here.
*/
append({ role: 'user', content: `[Model: ${model}Provider: ${provider}]\n\n${diff}\n\n${_input}` });
message.content = `${diff}\n\n${_input}`;

/**
* After sending a new message we reset all modifications since the model
* should now be aware of all the changes.
*/
workbenchStore.resetAllFileModifications();
} else {
append({ role: 'user', content: `[Model: ${model}Provider: ${provider}]\n\n${_input}` });
message.content = _input;
}

append(message,{
body
})
setInput('');

resetEnhancer();
Expand Down
20 changes: 10 additions & 10 deletions app/lib/.server/llm/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,27 +79,27 @@ export function getOpenRouterModel(apiKey: string, model: string) {
return openRouter.chat(model);
}

export function getModel(provider: string, model: string, env: Env) {
const apiKey = getAPIKey(env, provider);
export function getModel(provider: string, model: string,apiKey:string, env: Env) {
const _apiKey = apiKey || getAPIKey(env, provider);
const baseURL = getBaseURL(env, provider);

switch (provider) {
case 'Anthropic':
return getAnthropicModel(apiKey, model);
return getAnthropicModel(_apiKey, model);
case 'OpenAI':
return getOpenAIModel(apiKey, model);
return getOpenAIModel(_apiKey, model);
case 'Groq':
return getGroqModel(apiKey, model);
return getGroqModel(_apiKey, model);
case 'OpenRouter':
return getOpenRouterModel(apiKey, model);
return getOpenRouterModel(_apiKey, model);
case 'Google':
return getGoogleModel(apiKey, model)
return getGoogleModel(_apiKey, model)
case 'OpenAILike':
return getOpenAILikeModel(baseURL,apiKey, model);
return getOpenAILikeModel(baseURL,_apiKey, model);
case 'Deepseek':
return getDeepseekModel(apiKey, model)
return getDeepseekModel(_apiKey, model)
case 'Mistral':
return getMistralModel(apiKey, model);
return getMistralModel(_apiKey, model);
default:
return getOllamaModel(baseURL, model);
}
Expand Down
56 changes: 22 additions & 34 deletions app/lib/.server/llm/stream-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import { getModel } from '~/lib/.server/llm/model';
import { MAX_TOKENS } from './constants';
import { getSystemPrompt } from './prompts';
import { DEFAULT_MODEL, DEFAULT_PROVIDER, hasModel } from '~/utils/constants';
import type { ChatRequest } from '~/routes/api.chat';

interface ToolResult<Name extends string, Args, Result> {
toolCallId: string;
Expand All @@ -24,40 +25,27 @@ export type Messages = Message[];

export type StreamingOptions = Omit<Parameters<typeof _streamText>[0], 'model'>;

function extractModelFromMessage(message: Message): { model: string; content: string } {
const modelRegex = /^\[Model: (.*?)Provider: (.*?)\]\n\n/;
const match = message.content.match(modelRegex);

if (!match) {
return { model: DEFAULT_MODEL, content: message.content,provider: DEFAULT_PROVIDER };
}
const [_,model,provider] = match;
const content = message.content.replace(modelRegex, '');
return { model, content ,provider};
// Default model if not specified

}

export function streamText(messages: Messages, env: Env, options?: StreamingOptions) {
let currentModel = DEFAULT_MODEL;
let currentProvider = DEFAULT_PROVIDER;
const lastMessage = messages.findLast((message) => message.role === 'user');
if (lastMessage) {
const { model, provider } = extractModelFromMessage(lastMessage);
if (hasModel(model, provider)) {
currentModel = model;
currentProvider = provider;
}
}
const processedMessages = messages.map((message) => {
if (message.role === 'user') {
const { content } = extractModelFromMessage(message);
return { ...message, content };
}
return message;
});

const coreMessages = convertToCoreMessages(processedMessages);
// function extractModelFromMessage(message: Message): { model: string; content: string } {
// const modelRegex = /^\[Model: (.*?)Provider: (.*?)\]\n\n/;
// const match = message.content.match(modelRegex);
//
// if (!match) {
// return { model: DEFAULT_MODEL, content: message.content,provider: DEFAULT_PROVIDER };
// }
// const [_,model,provider] = match;
// const content = message.content.replace(modelRegex, '');
// return { model, content ,provider};
// // Default model if not specified
//
// }

export function streamText(chatRequest: ChatRequest, env: Env, options?: StreamingOptions) {
const { messages,model,api_key,provider } = chatRequest;
const _hasModel = hasModel(model, provider);
let currentModel = _hasModel ? model : DEFAULT_MODEL;
let currentProvider = _hasModel ? provider:DEFAULT_PROVIDER;

const coreMessages = convertToCoreMessages(messages);
return _streamText({
model: getModel(currentProvider, currentModel, env),
system: getSystemPrompt(),
Expand Down
13 changes: 9 additions & 4 deletions app/routes/api.chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@ import SwitchableStream from '~/lib/.server/llm/switchable-stream';
export async function action(args: ActionFunctionArgs) {
return chatAction(args);
}
export type ChatRequest ={
messages: Messages;
model:string,
provider:string,
api_key:string
}

async function chatAction({ context, request }: ActionFunctionArgs) {
const { messages } = await request.json<{ messages: Messages }>();

const chatRequest = await request.json<ChatRequest>();
const stream = new SwitchableStream();

try {
Expand All @@ -34,13 +39,13 @@ async function chatAction({ context, request }: ActionFunctionArgs) {
messages.push({ role: 'assistant', content });
messages.push({ role: 'user', content: CONTINUE_PROMPT });

const result = await streamText(messages, context.cloudflare.env, options);
const result = await streamText(chatRequest, context.cloudflare.env, options);

return stream.switchSource(result.toAIStream());
},
};

const result = await streamText(messages, context.cloudflare.env, options);
const result = await streamText(chatRequest, context.cloudflare.env, options);

stream.switchSource(result.toAIStream());

Expand Down
26 changes: 21 additions & 5 deletions app/utils/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ const staticModels: ModelInfo[] = [
];

export let MODEL_LIST: ModelInfo[] = [...staticModels];
export const PROVIDER_LIST: string[] = ['Ollama', 'OpenAILike']

export function hasModel(modelName:string,provider:string): boolean {
for (const model of MODEL_LIST) {
if ( model.provider === provider && model.name === modelName) {
Expand All @@ -63,17 +65,31 @@ export function getStaticModels(): ModelInfo[] {
return [...staticModels];
}
export let isInitialized = false;

export async function initializeModelList(): Promise<ModelInfo[]> {
type ModelList={
modelList:ModelInfo[],
providerList:string[]
}
export async function initializeModelList(): Promise<ModelList> {
if (isInitialized) {
return MODEL_LIST;
return {
modelList:MODEL_LIST,
providerList:[...new Set([...MODEL_LIST.map((m) => m.provider),...PROVIDER_LIST ])]
}
}
if (IS_SERVER){
isInitialized = true;
return MODEL_LIST;
return {
modelList:MODEL_LIST,
providerList:[...new Set([...MODEL_LIST.map((m) => m.provider),...PROVIDER_LIST ])]
}
}
isInitialized = true;
const response = await fetch('/api/models');
MODEL_LIST = (await response.json()) as ModelInfo[];
return MODEL_LIST;

return {
modelList:MODEL_LIST,
providerList:[...new Set([...MODEL_LIST.map((m) => m.provider),...PROVIDER_LIST ])]

}
}

0 comments on commit 1c0c5c0

Please sign in to comment.