Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fixed model not loading, even after baseUrl set in .env file #816

Merged
merged 8 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion app/components/chat/BaseChat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(

useEffect(() => {
// Load API keys from cookies on component mount

let parsedApiKeys: Record<string, string> | undefined = {};

try {
const storedApiKeys = Cookies.get('apiKeys');

Expand All @@ -127,6 +130,7 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(

if (typeof parsedKeys === 'object' && parsedKeys !== null) {
setApiKeys(parsedKeys);
parsedApiKeys = parsedKeys;
}
}
} catch (error) {
Expand Down Expand Up @@ -155,7 +159,7 @@ export const BaseChat = React.forwardRef<HTMLDivElement, BaseChatProps>(
Cookies.remove('providers');
}

initializeModelList(providerSettings).then((modelList) => {
initializeModelList({ apiKeys: parsedApiKeys, providerSettings }).then((modelList) => {
setModelList(modelList);
});

Expand Down
7 changes: 6 additions & 1 deletion app/components/settings/providers/ProvidersTab.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,12 @@ export default function ProvidersTab() {
type="text"
value={provider.settings.baseUrl || ''}
onChange={(e) => {
const newBaseUrl = e.target.value;
let newBaseUrl: string | undefined = e.target.value;

if (newBaseUrl && newBaseUrl.trim().length === 0) {
newBaseUrl = undefined;
}

updateProviderSettings(provider.name, { ...provider.settings, baseUrl: newBaseUrl });
logStore.logProvider(`Base URL updated for ${provider.name}`, {
provider: provider.name,
Expand Down
2 changes: 1 addition & 1 deletion app/entry.server.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ export default async function handleRequest(
remixContext: EntryContext,
_loadContext: AppLoadContext,
) {
await initializeModelList();
await initializeModelList({});

const readable = await renderToReadableStream(<RemixServer context={remixContext} url={request.url} />, {
signal: request.signal,
Expand Down
25 changes: 20 additions & 5 deletions app/lib/.server/llm/api-key.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
* Preventing TS checks with files presented in the video for a better presentation.
*/
import { env } from 'node:process';
import type { IProviderSetting } from '~/types/model';

export function getAPIKey(cloudflareEnv: Env, provider: string, userApiKeys?: Record<string, string>) {
/**
Expand Down Expand Up @@ -50,16 +51,30 @@ export function getAPIKey(cloudflareEnv: Env, provider: string, userApiKeys?: Re
}
}

export function getBaseURL(cloudflareEnv: Env, provider: string) {
export function getBaseURL(cloudflareEnv: Env, provider: string, providerSettings?: Record<string, IProviderSetting>) {
let settingBaseUrl = providerSettings?.[provider].baseUrl;

if (settingBaseUrl && settingBaseUrl.length == 0) {
settingBaseUrl = undefined;
}

switch (provider) {
case 'Together':
return env.TOGETHER_API_BASE_URL || cloudflareEnv.TOGETHER_API_BASE_URL || 'https://api.together.xyz/v1';
return (
settingBaseUrl ||
env.TOGETHER_API_BASE_URL ||
cloudflareEnv.TOGETHER_API_BASE_URL ||
'https://api.together.xyz/v1'
);
case 'OpenAILike':
return env.OPENAI_LIKE_API_BASE_URL || cloudflareEnv.OPENAI_LIKE_API_BASE_URL;
return settingBaseUrl || env.OPENAI_LIKE_API_BASE_URL || cloudflareEnv.OPENAI_LIKE_API_BASE_URL;
case 'LMStudio':
return env.LMSTUDIO_API_BASE_URL || cloudflareEnv.LMSTUDIO_API_BASE_URL || 'http://localhost:1234';
return (
settingBaseUrl || env.LMSTUDIO_API_BASE_URL || cloudflareEnv.LMSTUDIO_API_BASE_URL || 'http://localhost:1234'
);
case 'Ollama': {
let baseUrl = env.OLLAMA_API_BASE_URL || cloudflareEnv.OLLAMA_API_BASE_URL || 'http://localhost:11434';
let baseUrl =
settingBaseUrl || env.OLLAMA_API_BASE_URL || cloudflareEnv.OLLAMA_API_BASE_URL || 'http://localhost:11434';

if (env.RUNNING_IN_DOCKER === 'true') {
baseUrl = baseUrl.replace('localhost', 'host.docker.internal');
Expand Down
11 changes: 8 additions & 3 deletions app/lib/.server/llm/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ export function getHuggingFaceModel(apiKey: OptionalApiKey, model: string) {
}

export function getOllamaModel(baseURL: string, model: string) {
console.log({ baseURL, model });

const ollamaInstance = ollama(model, {
numCtx: DEFAULT_NUM_CTX,
}) as LanguageModelV1 & { config: any };
Expand Down Expand Up @@ -140,17 +142,20 @@ export function getPerplexityModel(apiKey: OptionalApiKey, model: string) {
export function getModel(
provider: string,
model: string,
env: Env,
serverEnv: Env,
apiKeys?: Record<string, string>,
providerSettings?: Record<string, IProviderSetting>,
) {
/*
* let apiKey; // Declare first
* let baseURL;
*/
// console.log({provider,model});

const apiKey = getAPIKey(serverEnv, provider, apiKeys); // Then assign
const baseURL = getBaseURL(serverEnv, provider, providerSettings);

const apiKey = getAPIKey(env, provider, apiKeys); // Then assign
const baseURL = providerSettings?.[provider].baseUrl || getBaseURL(env, provider);
// console.log({apiKey,baseURL});

switch (provider) {
case 'Anthropic':
Expand Down
9 changes: 6 additions & 3 deletions app/lib/.server/llm/stream-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,13 @@ export async function streamText(props: {
providerSettings?: Record<string, IProviderSetting>;
promptId?: string;
}) {
const { messages, env, options, apiKeys, files, providerSettings, promptId } = props;
const { messages, env: serverEnv, options, apiKeys, files, providerSettings, promptId } = props;

// console.log({serverEnv});

let currentModel = DEFAULT_MODEL;
let currentProvider = DEFAULT_PROVIDER.name;
const MODEL_LIST = await getModelList(apiKeys || {}, providerSettings);
const MODEL_LIST = await getModelList({ apiKeys, providerSettings, serverEnv: serverEnv as any });
const processedMessages = messages.map((message) => {
if (message.role === 'user') {
const { model, provider, content } = extractPropertiesFromMessage(message);
Expand Down Expand Up @@ -196,7 +199,7 @@ export async function streamText(props: {
}

return _streamText({
model: getModel(currentProvider, currentModel, env, apiKeys, providerSettings) as any,
model: getModel(currentProvider, currentModel, serverEnv, apiKeys, providerSettings) as any,
system: systemPrompt,
maxTokens: dynamicMaxTokens,
messages: convertToCoreMessages(processedMessages as any),
Expand Down
6 changes: 5 additions & 1 deletion app/types/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ import type { ModelInfo } from '~/utils/types';
export type ProviderInfo = {
staticModels: ModelInfo[];
name: string;
getDynamicModels?: (apiKeys?: Record<string, string>, providerSettings?: IProviderSetting) => Promise<ModelInfo[]>;
getDynamicModels?: (
apiKeys?: Record<string, string>,
providerSettings?: IProviderSetting,
serverEnv?: Record<string, string>,
) => Promise<ModelInfo[]>;
getApiKeyLink?: string;
labelForGetApiKey?: string;
icon?: string;
Expand Down
124 changes: 97 additions & 27 deletions app/utils/constants.ts
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,6 @@ const PROVIDER_LIST: ProviderInfo[] = [
],
getApiKeyLink: 'https://huggingface.co/settings/tokens',
},

{
name: 'OpenAI',
staticModels: [
Expand Down Expand Up @@ -325,26 +324,46 @@ const staticModels: ModelInfo[] = PROVIDER_LIST.map((p) => p.staticModels).flat(

export let MODEL_LIST: ModelInfo[] = [...staticModels];

export async function getModelList(
apiKeys: Record<string, string>,
providerSettings?: Record<string, IProviderSetting>,
) {
export async function getModelList(options: {
apiKeys?: Record<string, string>;
providerSettings?: Record<string, IProviderSetting>;
serverEnv?: Record<string, string>;
}) {
const { apiKeys, providerSettings, serverEnv } = options;

// console.log({ providerSettings, serverEnv,env:process.env });
MODEL_LIST = [
...(
await Promise.all(
PROVIDER_LIST.filter(
(p): p is ProviderInfo & { getDynamicModels: () => Promise<ModelInfo[]> } => !!p.getDynamicModels,
).map((p) => p.getDynamicModels(apiKeys, providerSettings?.[p.name])),
).map((p) => p.getDynamicModels(apiKeys, providerSettings?.[p.name], serverEnv)),
)
).flat(),
...staticModels,
];

return MODEL_LIST;
}

async function getTogetherModels(apiKeys?: Record<string, string>, settings?: IProviderSetting): Promise<ModelInfo[]> {
async function getTogetherModels(
apiKeys?: Record<string, string>,
settings?: IProviderSetting,
serverEnv: Record<string, string> = {},
): Promise<ModelInfo[]> {
try {
const baseUrl = settings?.baseUrl || import.meta.env.TOGETHER_API_BASE_URL || '';
let settingsBaseUrl = settings?.baseUrl;

if (settingsBaseUrl && settingsBaseUrl.length == 0) {
settingsBaseUrl = undefined;
}

const baseUrl =
settingsBaseUrl ||
serverEnv?.TOGETHER_API_BASE_URL ||
process.env.TOGETHER_API_BASE_URL ||
import.meta.env.TOGETHER_API_BASE_URL ||
'';
const provider = 'Together';

if (!baseUrl) {
Expand Down Expand Up @@ -383,8 +402,19 @@ async function getTogetherModels(apiKeys?: Record<string, string>, settings?: IP
}
}

const getOllamaBaseUrl = (settings?: IProviderSetting) => {
const defaultBaseUrl = settings?.baseUrl || import.meta.env.OLLAMA_API_BASE_URL || 'http://localhost:11434';
const getOllamaBaseUrl = (settings?: IProviderSetting, serverEnv: Record<string, string> = {}) => {
let settingsBaseUrl = settings?.baseUrl;

if (settingsBaseUrl && settingsBaseUrl.length == 0) {
settingsBaseUrl = undefined;
}

const defaultBaseUrl =
settings?.baseUrl ||
serverEnv?.OLLAMA_API_BASE_URL ||
process.env.OLLAMA_API_BASE_URL ||
import.meta.env.OLLAMA_API_BASE_URL ||
'http://localhost:11434';

// Check if we're in the browser
if (typeof window !== 'undefined') {
Expand All @@ -398,9 +428,13 @@ const getOllamaBaseUrl = (settings?: IProviderSetting) => {
return isDocker ? defaultBaseUrl.replace('localhost', 'host.docker.internal') : defaultBaseUrl;
};

async function getOllamaModels(apiKeys?: Record<string, string>, settings?: IProviderSetting): Promise<ModelInfo[]> {
async function getOllamaModels(
apiKeys?: Record<string, string>,
settings?: IProviderSetting,
serverEnv: Record<string, string> = {},
): Promise<ModelInfo[]> {
try {
const baseUrl = getOllamaBaseUrl(settings);
const baseUrl = getOllamaBaseUrl(settings, serverEnv);
const response = await fetch(`${baseUrl}/api/tags`);
const data = (await response.json()) as OllamaApiResponse;

Expand All @@ -421,9 +455,21 @@ async function getOllamaModels(apiKeys?: Record<string, string>, settings?: IPro
async function getOpenAILikeModels(
apiKeys?: Record<string, string>,
settings?: IProviderSetting,
serverEnv: Record<string, string> = {},
): Promise<ModelInfo[]> {
try {
const baseUrl = settings?.baseUrl || import.meta.env.OPENAI_LIKE_API_BASE_URL || '';
let settingsBaseUrl = settings?.baseUrl;

if (settingsBaseUrl && settingsBaseUrl.length == 0) {
settingsBaseUrl = undefined;
}

const baseUrl =
settingsBaseUrl ||
serverEnv.OPENAI_LIKE_API_BASE_URL ||
process.env.OPENAI_LIKE_API_BASE_URL ||
import.meta.env.OPENAI_LIKE_API_BASE_URL ||
'';

if (!baseUrl) {
return [];
Expand Down Expand Up @@ -486,9 +532,24 @@ async function getOpenRouterModels(): Promise<ModelInfo[]> {
}));
}

async function getLMStudioModels(_apiKeys?: Record<string, string>, settings?: IProviderSetting): Promise<ModelInfo[]> {
async function getLMStudioModels(
_apiKeys?: Record<string, string>,
settings?: IProviderSetting,
serverEnv: Record<string, string> = {},
): Promise<ModelInfo[]> {
try {
const baseUrl = settings?.baseUrl || import.meta.env.LMSTUDIO_API_BASE_URL || 'http://localhost:1234';
let settingsBaseUrl = settings?.baseUrl;

if (settingsBaseUrl && settingsBaseUrl.length == 0) {
settingsBaseUrl = undefined;
}

const baseUrl =
settingsBaseUrl ||
serverEnv.LMSTUDIO_API_BASE_URL ||
process.env.LMSTUDIO_API_BASE_URL ||
import.meta.env.LMSTUDIO_API_BASE_URL ||
'http://localhost:1234';
const response = await fetch(`${baseUrl}/v1/models`);
const data = (await response.json()) as any;

Expand All @@ -503,29 +564,37 @@ async function getLMStudioModels(_apiKeys?: Record<string, string>, settings?: I
}
}

async function initializeModelList(providerSettings?: Record<string, IProviderSetting>): Promise<ModelInfo[]> {
let apiKeys: Record<string, string> = {};
async function initializeModelList(options: {
env?: Record<string, string>;
providerSettings?: Record<string, IProviderSetting>;
apiKeys?: Record<string, string>;
}): Promise<ModelInfo[]> {
const { providerSettings, apiKeys: providedApiKeys, env } = options;
let apiKeys: Record<string, string> = providedApiKeys || {};

try {
const storedApiKeys = Cookies.get('apiKeys');
if (!providedApiKeys) {
try {
const storedApiKeys = Cookies.get('apiKeys');

if (storedApiKeys) {
const parsedKeys = JSON.parse(storedApiKeys);
if (storedApiKeys) {
const parsedKeys = JSON.parse(storedApiKeys);

if (typeof parsedKeys === 'object' && parsedKeys !== null) {
apiKeys = parsedKeys;
if (typeof parsedKeys === 'object' && parsedKeys !== null) {
apiKeys = parsedKeys;
}
}
} catch (error: any) {
logStore.logError('Failed to fetch API keys from cookies', error);
logger.warn(`Failed to fetch apikeys from cookies: ${error?.message}`);
}
} catch (error: any) {
logStore.logError('Failed to fetch API keys from cookies', error);
logger.warn(`Failed to fetch apikeys from cookies: ${error?.message}`);
}

MODEL_LIST = [
...(
await Promise.all(
PROVIDER_LIST.filter(
(p): p is ProviderInfo & { getDynamicModels: () => Promise<ModelInfo[]> } => !!p.getDynamicModels,
).map((p) => p.getDynamicModels(apiKeys, providerSettings?.[p.name])),
).map((p) => p.getDynamicModels(apiKeys, providerSettings?.[p.name], env)),
)
).flat(),
...staticModels,
Expand All @@ -534,6 +603,7 @@ async function initializeModelList(providerSettings?: Record<string, IProviderSe
return MODEL_LIST;
}

// initializeModelList({})
export {
getOllamaModels,
getOpenAILikeModels,
Expand Down
Loading