Skip to content

Commit

Permalink
chore: Update constants.ts and fix model name bug
Browse files Browse the repository at this point in the history
  • Loading branch information
ZerxZ committed Oct 26, 2024
1 parent 1d8d6f1 commit 9b830ed
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 69 deletions.
49 changes: 29 additions & 20 deletions app/components/chat/BaseChat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import { Menu } from '~/components/sidebar/Menu.client';
import { IconButton } from '~/components/ui/IconButton';
import { Workbench } from '~/components/workbench/Workbench.client';
import { classNames } from '~/utils/classNames';
import { DEFAULT_PROVIDER, initializeModelList } from '~/utils/constants';
import { DEFAULT_PROVIDER, MODEL_LIST, initializeModelList, isInitialized } from '~/utils/constants';
import { Messages } from './Messages.client';
import { SendButton } from './SendButton.client';
import { useState } from 'react';
Expand All @@ -21,36 +21,47 @@ const EXAMPLE_PROMPTS = [
{ text: 'Make a space invaders game' },
{ text: 'How do I center a div?' },
];
const MODEL_LIST= await initializeModelList();
const providerList = [...new Set(MODEL_LIST.map((model) => model.provider))]

function ModelSelector (args) {
const {model, setModel, modelList,providerList} = args;


function ModelSelector({ model, setModel }) {
const modelList = MODEL_LIST;
const providerList =[...new Set([...modelList.map((m) => m.provider),"OpenAILike","Ollama"])];
const initialize = async () => {
if (!isInitialized) {
await initializeModelList();
}
};
initialize();

const [provider, setProvider] = useState(DEFAULT_PROVIDER);



const handleProviderChange = (e) => {
setProvider(e.target.value);
const firstModel = modelList.find((m) => m.provider === e.target.value);
setModel(firstModel ? firstModel.name : '');
};

return (
<div className="mb-2">
<select
value={provider}
onChange={(e) => {
setProvider(e.target.value);
const firstModel = [...modelList].find((m) => m.provider == e.target.value);
setModel(firstModel ? firstModel.name : '');
}}
onChange={handleProviderChange}
className="w-full p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none"
>
{providerList.map(provider => (
<option key={provider} value={provider}>
{provider}
</option>
))}
{providerList.map(providerName=>( <option key={providerName} value={providerName}>
{providerName}
</option>))}
</select>
<select
value={model}
onChange={(e) => setModel(e.target.value)}
className="w-full p-2 rounded-lg border border-bolt-elements-borderColor bg-bolt-elements-prompt-background text-bolt-elements-textPrimary focus:outline-none"
>
{[...modelList]
.filter((e) => e.provider == provider && e.name)
{modelList
.filter((e) => e.provider === provider && e.name)
.map((modelOption) => (
<option key={modelOption.name} value={modelOption.name}>
{modelOption.label}
Expand All @@ -59,7 +70,7 @@ function ModelSelector (args) {
</select>
</div>
);
};
}

const TEXTAREA_MIN_HEIGHT = 76;

Expand Down Expand Up @@ -153,8 +164,6 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
<ModelSelector
model={model}
setModel={setModel}
modelList={MODEL_LIST}
providerList={providerList}
/>
<div
className={classNames(
Expand Down
1 change: 0 additions & 1 deletion app/lib/.server/llm/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import { createOpenAI } from '@ai-sdk/openai';
import { createGoogleGenerativeAI } from '@ai-sdk/google';
import { ollama } from 'ollama-ai-provider';
import { createOpenRouter } from "@openrouter/ai-sdk-provider";
import { mistral } from '@ai-sdk/mistral';
import { createMistral } from '@ai-sdk/mistral';

export function getAnthropicModel(apiKey: string, model: string) {
Expand Down
7 changes: 4 additions & 3 deletions app/routes/api.models.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import { json } from '@remix-run/cloudflare';
import { json, type LoaderFunctionArgs } from '@remix-run/cloudflare';
import { initializeModelList } from '~/utils/tools';

export async function loader() {
const modelList = await initializeModelList();
export async function loader({context}: LoaderFunctionArgs) {
const { env } = context.cloudflare;
const modelList = await initializeModelList(env);
return json(modelList);
}
4 changes: 2 additions & 2 deletions app/utils/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,10 @@ export function setModelList(models: ModelInfo[]): void {
export function getStaticModels(): ModelInfo[] {
return [...staticModels];
}
let isInitialized = false;
export let isInitialized = false;

export async function initializeModelList(): Promise<ModelInfo[]> {
if (isInitialized ) {
if (isInitialized) {
return MODEL_LIST;
}
if (IS_SERVER){
Expand Down
51 changes: 10 additions & 41 deletions app/utils/tools.ts
Original file line number Diff line number Diff line change
@@ -1,47 +1,16 @@
import type { ModelInfo, OllamaApiResponse, OllamaModel } from '~/utils/types';
import { getStaticModels,setModelList } from '~/utils/constants';
import { env } from 'node:process';
import { getAPIKey, getBaseURL } from '~/lib/.server/llm/api-key';


export let MODEL_LIST: ModelInfo[] = [...getStaticModels()];

export function getAPIKey(provider: string) {
switch (provider) {
case 'Anthropic':
return env.ANTHROPIC_API_KEY;
case 'OpenAI':
return env.OPENAI_API_KEY;
case 'Google':
return env.GOOGLE_GENERATIVE_AI_API_KEY;
case 'Groq':
return env.GROQ_API_KEY;
case 'OpenRouter':
return env.OPEN_ROUTER_API_KEY;
case 'Deepseek':
return env.DEEPSEEK_API_KEY;
case 'Mistral':
return env.MISTRAL_API_KEY;
case "OpenAILike":
return import.meta.env.OPENAI_LIKE_API_KEY || env.OPENAI_LIKE_API_KEY;
default:
return "";
}
}
export function getBaseURL( provider: string){
switch (provider) {
case 'OpenAILike':
return import.meta.env.OPENAI_LIKE_API_BASE_URL || env.OPENAI_LIKE_API_BASE_URL || "";
case 'Ollama':
return import.meta.env.OLLAMA_API_BASE_URL || env.OLLAMA_API_BASE_URL || "http://localhost:11434";
default:
return "";
}
}


let isInitialized = false;
async function getOllamaModels(): Promise<ModelInfo[]> {
async function getOllamaModels(env: Env): Promise<ModelInfo[]> {
try {
const base_url = getBaseURL("Ollama") ;
const base_url = getBaseURL(env,"Ollama") ;
const response = await fetch(`${base_url}/api/tags`);
const data = await response.json() as OllamaApiResponse;
return data.models.map((model: OllamaModel) => ({
Expand All @@ -58,17 +27,17 @@ async function getOllamaModels(): Promise<ModelInfo[]> {
}
}

async function getOpenAILikeModels(): Promise<ModelInfo[]> {
async function getOpenAILikeModels(env: Env): Promise<ModelInfo[]> {
try {
const base_url = getBaseURL("OpenAILike") ;
const base_url = getBaseURL(env,"OpenAILike") ;
if (!base_url) {
return [{
name: "Empty",
label: "Empty",
provider: "OpenAILike"
}];
}
const api_key = getAPIKey("OpenAILike") ?? "";
const api_key = getAPIKey(env,"OpenAILike") ?? "";
const response = await fetch(`${base_url}/models`, {
headers: {
Authorization: `Bearer ${api_key}`,
Expand All @@ -91,13 +60,13 @@ async function getOpenAILikeModels(): Promise<ModelInfo[]> {
}


async function initializeModelList(): Promise<ModelInfo[]> {
async function initializeModelList(env: Env): Promise<ModelInfo[]> {
if (isInitialized) {
return MODEL_LIST;
}
isInitialized = true;
const ollamaModels = await getOllamaModels();
const openAiLikeModels = await getOpenAILikeModels();
const ollamaModels = await getOllamaModels(env);
const openAiLikeModels = await getOpenAILikeModels(env);
MODEL_LIST = [...getStaticModels(), ...ollamaModels, ...openAiLikeModels];
setModelList(MODEL_LIST);
return MODEL_LIST;
Expand Down
5 changes: 3 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
"sideEffects": false,
"type": "module",
"scripts": {
"deploy": "npm run build && wrangler pages deploy",
"deploy": "pnpm run build && wrangler pages deploy",
"dev:deploy": "pnpm run build && wrangler pages dev",
"build": "remix vite:build",
"dev": "remix vite:dev",
"test": "vitest --run",
"test:watch": "vitest",
"lint": "eslint --cache --cache-location ./node_modules/.cache/eslint .",
"lint:fix": "npm run lint -- --fix",
"lint:fix": "pnpm run lint -- --fix",
"start": "bindings=$(./bindings.sh) && wrangler pages dev ./build/client $bindings --ip 0.0.0.0 --port 3000",
"typecheck": "tsc",
"typegen": "wrangler types",
Expand Down

0 comments on commit 9b830ed

Please sign in to comment.