diff --git a/core/src/api/index.ts b/core/src/api/index.ts index 23d5f6f874..3cf2693e7b 100644 --- a/core/src/api/index.ts +++ b/core/src/api/index.ts @@ -9,7 +9,8 @@ export enum AppRoute { openAppDirectory = 'openAppDirectory', openFileExplore = 'openFileExplorer', relaunch = 'relaunch', - joinPath = 'joinPath' + joinPath = 'joinPath', + baseName = 'baseName', } export enum AppEvent { diff --git a/core/src/core.ts b/core/src/core.ts index 0f20feb1e0..2cfd43a390 100644 --- a/core/src/core.ts +++ b/core/src/core.ts @@ -51,6 +51,27 @@ const openFileExplorer: (path: string) => Promise = (path) => */ const joinPath: (paths: string[]) => Promise = (paths) => global.core.api?.joinPath(paths) +/** + * Retrive the basename from an url. + * @param path - The path to retrieve. + * @returns {Promise} A promise that resolves with the basename. + */ +const baseName: (paths: string[]) => Promise = (path) => global.core.api?.baseName(path) + +/** + * Opens an external URL in the default web browser. + * + * @param {string} url - The URL to open. + * @returns {Promise} - A promise that resolves when the URL has been successfully opened. + */ +const openExternalUrl: (url: string) => Promise = (url) => + global.core.api?.openExternalUrl(url) + +/** + * Gets the resource path of the application. + * + * @returns {Promise} - A promise that resolves with the resource path. + */ const getResourcePath: () => Promise = () => global.core.api?.getResourcePath() /** @@ -74,4 +95,6 @@ export { openFileExplorer, getResourcePath, joinPath, + openExternalUrl, + baseName, } diff --git a/core/src/node/api/routes/common.ts b/core/src/node/api/routes/common.ts index 194429c66a..184ca131d7 100644 --- a/core/src/node/api/routes/common.ts +++ b/core/src/node/api/routes/common.ts @@ -1,6 +1,6 @@ import { AppRoute } from '../../../api' import { HttpServer } from '../HttpServer' -import { join } from 'path' +import { basename, join } from 'path' import { chatCompletions, deleteBuilder, @@ -36,7 +36,11 @@ export const commonRouter = async (app: HttpServer) => { // App Routes app.post(`/app/${AppRoute.joinPath}`, async (request: any, reply: any) => { const args = JSON.parse(request.body) as any[] - console.debug('joinPath: ', ...args[0]) reply.send(JSON.stringify(join(...args[0]))) }) + + app.post(`/app/${AppRoute.baseName}`, async (request: any, reply: any) => { + const args = JSON.parse(request.body) as any[] + reply.send(JSON.stringify(basename(args[0]))) + }) } diff --git a/electron/handlers/app.ts b/electron/handlers/app.ts index 2966ef888d..726ed612e4 100644 --- a/electron/handlers/app.ts +++ b/electron/handlers/app.ts @@ -1,13 +1,9 @@ import { app, ipcMain, shell, nativeTheme } from 'electron' -import { join } from 'path' +import { join, basename } from 'path' import { WindowManager } from './../managers/window' import { userSpacePath } from './../utils/path' import { AppRoute } from '@janhq/core' -import { getResourcePath } from './../utils/path' -import { - ExtensionManager, - ModuleManager, -} from '@janhq/core/node' +import { ExtensionManager, ModuleManager } from '@janhq/core/node' export function handleAppIPCs() { /** @@ -53,6 +49,13 @@ export function handleAppIPCs() { join(...paths) ) + /** + * Retrieve basename from given path, respect to the current OS. + */ + ipcMain.handle(AppRoute.baseName, async (_event, path: string) => + basename(path) + ) + /** * Relaunches the app in production - reload window in development. * @param _event - The IPC event object. diff --git a/electron/handlers/download.ts b/electron/handlers/download.ts index 9c49b44f34..621d850432 100644 --- a/electron/handlers/download.ts +++ b/electron/handlers/download.ts @@ -46,8 +46,11 @@ export function handleDownloaderIPCs() { */ ipcMain.handle(DownloadRoute.downloadFile, async (_event, url, fileName) => { const userDataPath = join(app.getPath('home'), 'jan') - if (typeof fileName === 'string' && fileName.includes('file:/')) { - fileName = fileName.replace('file:/', '') + if ( + typeof fileName === 'string' && + (fileName.includes('file:/') || fileName.includes('file:\\')) + ) { + fileName = fileName.replace('file:/', '').replace('file:\\', '') } const destination = resolve(userDataPath, fileName) const rq = request(url) diff --git a/electron/handlers/fileManager.ts b/electron/handlers/fileManager.ts index 5dc4483443..2a78deaf93 100644 --- a/electron/handlers/fileManager.ts +++ b/electron/handlers/fileManager.ts @@ -2,8 +2,8 @@ import { ipcMain } from 'electron' // @ts-ignore import reflect from '@alumna/reflect' -import { FileManagerRoute, getResourcePath } from '@janhq/core' -import { userSpacePath } from './../utils/path' +import { FileManagerRoute } from '@janhq/core' +import { userSpacePath, getResourcePath } from './../utils/path' /** * Handles file system extensions operations. diff --git a/electron/handlers/fs.ts b/electron/handlers/fs.ts index fdfaba6063..8f7e434cc9 100644 --- a/electron/handlers/fs.ts +++ b/electron/handlers/fs.ts @@ -13,8 +13,16 @@ export function handleFsIPCs() { return import(moduleName).then((mdl) => mdl[route]( ...args.map((arg) => - typeof arg === 'string' && arg.includes('file:/') - ? join(userSpacePath, arg.replace('file:/', '')) + typeof arg === 'string' && + (arg.includes(`file:/`) || arg.includes(`file:\\`)) + ? join( + userSpacePath, + arg + .replace(`file://`, '') + .replace(`file:/`, '') + .replace(`file:\\\\`, '') + .replace(`file:\\`, '') + ) : arg ) ) diff --git a/extensions/conversational-extension/src/index.ts b/extensions/conversational-extension/src/index.ts index 6531c489a0..ac31351dff 100644 --- a/extensions/conversational-extension/src/index.ts +++ b/extensions/conversational-extension/src/index.ts @@ -25,7 +25,7 @@ export default class JSONConversationalExtension */ async onLoad() { if (!(await fs.existsSync(JSONConversationalExtension._homeDir))) - fs.mkdirSync(JSONConversationalExtension._homeDir) + await fs.mkdirSync(JSONConversationalExtension._homeDir) console.debug('JSONConversationalExtension loaded') } diff --git a/extensions/inference-nitro-extension/src/index.ts b/extensions/inference-nitro-extension/src/index.ts index e1ccc9f2af..3abcfe766f 100644 --- a/extensions/inference-nitro-extension/src/index.ts +++ b/extensions/inference-nitro-extension/src/index.ts @@ -32,7 +32,8 @@ import { join } from "path"; * It also subscribes to events emitted by the @janhq/core package and handles new message requests. */ export default class JanInferenceNitroExtension implements InferenceExtension { - private static readonly _homeDir = "engines"; + private static readonly _homeDir = "file://engines"; + private static readonly _settingsDir = "file://settings"; private static readonly _engineMetadataFileName = "nitro.json"; private static _currentModel: Model; @@ -58,9 +59,13 @@ export default class JanInferenceNitroExtension implements InferenceExtension { /** * Subscribes to events emitted by the @janhq/core package. */ - async onLoad() { - if (!(await fs.existsSync(JanInferenceNitroExtension._homeDir))) - fs.mkdirSync(JanInferenceNitroExtension._homeDir); + async onLoad(): Promise { + if (!(await fs.existsSync(JanInferenceNitroExtension._homeDir))) { + await fs.mkdirSync(JanInferenceNitroExtension._homeDir).catch((err) => console.debug(err)); + } + + if (!(await fs.existsSync(JanInferenceNitroExtension._settingsDir))) + await fs.mkdirSync(JanInferenceNitroExtension._settingsDir); this.writeDefaultEngineSettings(); // Events subscription @@ -79,6 +84,24 @@ export default class JanInferenceNitroExtension implements InferenceExtension { events.on(EventName.OnInferenceStopped, () => { JanInferenceNitroExtension.handleInferenceStopped(this); }); + + // Attempt to fetch nvidia info + await executeOnMain(MODULE, "updateNvidiaInfo", {}); + + const gpuDriverConf = await fs.readFileSync( + join(JanInferenceNitroExtension._settingsDir, "settings.json") + ); + if (gpuDriverConf.notify && gpuDriverConf.run_mode === "cpu") { + // Driver is fully installed, but not in use + if (gpuDriverConf.nvidia_driver?.exist && gpuDriverConf.cuda?.exist) { + events.emit("OnGPUCompatiblePrompt", {}); + // Prompt user to switch + } else if (gpuDriverConf.nvidia_driver?.exist) { + // Prompt user to install cuda toolkit + events.emit("OnGPUDriverMissingPrompt", {}); + } + } + Promise.resolve() } /** diff --git a/extensions/inference-nitro-extension/src/module.ts b/extensions/inference-nitro-extension/src/module.ts index b8706a7dca..4537f801cd 100644 --- a/extensions/inference-nitro-extension/src/module.ts +++ b/extensions/inference-nitro-extension/src/module.ts @@ -1,9 +1,11 @@ const fs = require("fs"); +const fsPromises = fs.promises; const path = require("path"); -const { spawn } = require("child_process"); +const { exec, spawn } = require("child_process"); const tcpPortUsed = require("tcp-port-used"); const fetchRetry = require("fetch-retry")(global.fetch); const si = require("systeminformation"); +const { readFileSync, writeFileSync, existsSync } = require("fs"); // The PORT to use for the Nitro subprocess const PORT = 3928; @@ -14,6 +16,27 @@ const NITRO_HTTP_UNLOAD_MODEL_URL = `${NITRO_HTTP_SERVER_URL}/inferences/llamacp const NITRO_HTTP_VALIDATE_MODEL_URL = `${NITRO_HTTP_SERVER_URL}/inferences/llamacpp/modelstatus`; const NITRO_HTTP_KILL_URL = `${NITRO_HTTP_SERVER_URL}/processmanager/destroy`; const SUPPORTED_MODEL_FORMAT = ".gguf"; +const NVIDIA_INFO_FILE = path.join( + require("os").homedir(), + "jan", + "settings", + "settings.json" +); + +const DEFALT_SETTINGS = { + "notify": true, + "run_mode": "cpu", + "nvidia_driver": { + "exist": false, + "version": "" + }, + "cuda": { + "exist": false, + "version": "" + }, + "gpus": [], + "gpu_highest_vram": "" +} // The subprocess instance for Nitro let subprocess = undefined; @@ -29,6 +52,125 @@ function stopModel(): Promise { return killSubprocess(); } +/** + * Validate nvidia and cuda for linux and windows + */ +async function updateNvidiaDriverInfo(): Promise { + exec( + "nvidia-smi --query-gpu=driver_version --format=csv,noheader", + (error, stdout) => { + let data; + try { + data = JSON.parse(readFileSync(NVIDIA_INFO_FILE, "utf8")); + } catch (error) { + data = DEFALT_SETTINGS; + } + + if (!error) { + const firstLine = stdout.split("\n")[0].trim(); + data["nvidia_driver"].exist = true; + data["nvidia_driver"].version = firstLine; + } else { + data["nvidia_driver"].exist = false; + } + + writeFileSync(NVIDIA_INFO_FILE, JSON.stringify(data, null, 2)); + Promise.resolve(); + } + ); +} + +function checkFileExistenceInPaths(file: string, paths: string[]): boolean { + return paths.some((p) => existsSync(path.join(p, file))); +} + +function updateCudaExistence() { + let files: string[]; + let paths: string[]; + + if (process.platform === "win32") { + files = ["cublas64_12.dll", "cudart64_12.dll", "cublasLt64_12.dll"]; + paths = process.env.PATH ? process.env.PATH.split(path.delimiter) : []; + const nitro_cuda_path = path.join(__dirname, "bin", "win-cuda"); + paths.push(nitro_cuda_path); + } else { + files = ["libcudart.so.12", "libcublas.so.12", "libcublasLt.so.12"]; + paths = process.env.LD_LIBRARY_PATH + ? process.env.LD_LIBRARY_PATH.split(path.delimiter) + : []; + const nitro_cuda_path = path.join(__dirname, "bin", "linux-cuda"); + paths.push(nitro_cuda_path); + paths.push("/usr/lib/x86_64-linux-gnu/"); + } + + let cudaExists = files.every( + (file) => existsSync(file) || checkFileExistenceInPaths(file, paths) + ); + + let data; + try { + data = JSON.parse(readFileSync(NVIDIA_INFO_FILE, "utf8")); + } catch (error) { + data = DEFALT_SETTINGS; + } + + data["cuda"].exist = cudaExists; + if (cudaExists) { + data.run_mode = "gpu"; + } + writeFileSync(NVIDIA_INFO_FILE, JSON.stringify(data, null, 2)); +} + +async function updateGpuInfo(): Promise { + exec( + "nvidia-smi --query-gpu=index,memory.total --format=csv,noheader,nounits", + (error, stdout) => { + let data; + try { + data = JSON.parse(readFileSync(NVIDIA_INFO_FILE, "utf8")); + } catch (error) { + data = DEFALT_SETTINGS; + } + + if (!error) { + // Get GPU info and gpu has higher memory first + let highestVram = 0; + let highestVramId = "0"; + let gpus = stdout + .trim() + .split("\n") + .map((line) => { + let [id, vram] = line.split(", "); + vram = vram.replace(/\r/g, ""); + if (parseFloat(vram) > highestVram) { + highestVram = parseFloat(vram); + highestVramId = id; + } + return { id, vram }; + }); + + data["gpus"] = gpus; + data["gpu_highest_vram"] = highestVramId; + } else { + data["gpus"] = []; + } + + writeFileSync(NVIDIA_INFO_FILE, JSON.stringify(data, null, 2)); + Promise.resolve(); + } + ); +} + +async function updateNvidiaInfo() { + if (process.platform !== "darwin") { + await Promise.all([ + updateNvidiaDriverInfo(), + updateCudaExistence(), + updateGpuInfo(), + ]); + } +} + /** * Initializes a Nitro subprocess to load a machine learning model. * @param wrapper - The model wrapper. @@ -222,14 +364,26 @@ async function killSubprocess(): Promise { * Using child-process to spawn the process * Should run exactly platform specified Nitro binary version */ +/** + * Spawns a Nitro subprocess. + * @param nitroResourceProbe - The Nitro resource probe. + * @returns A promise that resolves when the Nitro subprocess is started. + */ function spawnNitroProcess(nitroResourceProbe: any): Promise { console.debug("Starting Nitro subprocess..."); return new Promise(async (resolve, reject) => { let binaryFolder = path.join(__dirname, "bin"); // Current directory by default + let cudaVisibleDevices = ""; let binaryName; - if (process.platform === "win32") { - binaryName = "win-start.bat"; + let nvida_info = JSON.parse(readFileSync(NVIDIA_INFO_FILE, "utf8")); + if (nvida_info["run_mode"] === "cpu") { + binaryFolder = path.join(binaryFolder, "win-cpu"); + } else { + binaryFolder = path.join(binaryFolder, "win-cuda"); + cudaVisibleDevices = nvida_info["gpu_highest_vram"]; + } + binaryName = "nitro.exe"; } else if (process.platform === "darwin") { if (process.arch === "arm64") { binaryFolder = path.join(binaryFolder, "mac-arm64"); @@ -238,13 +392,24 @@ function spawnNitroProcess(nitroResourceProbe: any): Promise { } binaryName = "nitro"; } else { - binaryName = "linux-start.sh"; + let nvida_info = JSON.parse(readFileSync(NVIDIA_INFO_FILE, "utf8")); + if (nvida_info["run_mode"] === "cpu") { + binaryFolder = path.join(binaryFolder, "win-cpu"); + } else { + binaryFolder = path.join(binaryFolder, "win-cuda"); + cudaVisibleDevices = nvida_info["gpu_highest_vram"]; + } + binaryName = "nitro"; } const binaryPath = path.join(binaryFolder, binaryName); // Execute the binary subprocess = spawn(binaryPath, [1, LOCAL_HOST, PORT], { cwd: binaryFolder, + env: { + ...process.env, + CUDA_VISIBLE_DEVICES: cudaVisibleDevices, + }, }); // Handle subprocess output @@ -296,4 +461,5 @@ module.exports = { stopModel, killSubprocess, dispose, + updateNvidiaInfo, }; diff --git a/extensions/inference-openai-extension/src/index.ts b/extensions/inference-openai-extension/src/index.ts index d612e474e9..27436749f6 100644 --- a/extensions/inference-openai-extension/src/index.ts +++ b/extensions/inference-openai-extension/src/index.ts @@ -53,9 +53,13 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension { /** * Subscribes to events emitted by the @janhq/core package. */ - async onLoad() { - if (!(await fs.existsSync(JanInferenceOpenAIExtension._homeDir))) - fs.mkdirSync(JanInferenceOpenAIExtension._homeDir); + async onLoad(): Promise { + if (!(await fs.existsSync(JanInferenceOpenAIExtension._homeDir))) { + await fs + .mkdirSync(JanInferenceOpenAIExtension._homeDir) + .catch((err) => console.debug(err)); + } + JanInferenceOpenAIExtension.writeDefaultEngineSettings(); // Events subscription @@ -73,6 +77,7 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension { events.on(EventName.OnInferenceStopped, () => { JanInferenceOpenAIExtension.handleInferenceStopped(this); }); + Promise.resolve(); } /** @@ -87,7 +92,7 @@ export default class JanInferenceOpenAIExtension implements InferenceExtension { JanInferenceOpenAIExtension._engineMetadataFileName ); if (await fs.existsSync(engineFile)) { - const engine = await fs.readFileSync(engineFile, 'utf-8'); + const engine = await fs.readFileSync(engineFile, "utf-8"); JanInferenceOpenAIExtension._engineSettings = typeof engine === "object" ? engine : JSON.parse(engine); } else { diff --git a/extensions/inference-triton-trtllm-extension/src/index.ts b/extensions/inference-triton-trtllm-extension/src/index.ts index 7ef5270e70..aed2f581ac 100644 --- a/extensions/inference-triton-trtllm-extension/src/index.ts +++ b/extensions/inference-triton-trtllm-extension/src/index.ts @@ -34,7 +34,7 @@ import { EngineSettings } from "./@types/global"; export default class JanInferenceTritonTrtLLMExtension implements InferenceExtension { - private static readonly _homeDir = "engines"; + private static readonly _homeDir = "file://engines"; private static readonly _engineMetadataFileName = "triton_trtllm.json"; static _currentModel: Model; diff --git a/extensions/model-extension/src/index.ts b/extensions/model-extension/src/index.ts index 38cf28ef02..8c8972dac3 100644 --- a/extensions/model-extension/src/index.ts +++ b/extensions/model-extension/src/index.ts @@ -8,8 +8,8 @@ import { InferenceEngine, joinPath, } from '@janhq/core' -import { basename } from 'path' import { ModelExtension, Model } from '@janhq/core' +import { baseName } from '@janhq/core/.' /** * A extension for models @@ -34,7 +34,7 @@ export default class JanModelExtension implements ModelExtension { * Called when the extension is loaded. * @override */ - onLoad(): void { + async onLoad() { this.copyModelsToHomeDir() } @@ -48,7 +48,7 @@ export default class JanModelExtension implements ModelExtension { try { // list all of the files under the home directory - if (fs.existsSync(JanModelExtension._homeDir)) { + if (await fs.existsSync(JanModelExtension._homeDir)) { // ignore if the model is already downloaded console.debug('Models already persisted.') return @@ -62,7 +62,7 @@ export default class JanModelExtension implements ModelExtension { const srcPath = await joinPath([resourePath, 'models']) const userSpace = await getUserSpace() - const destPath = await joinPath([userSpace, JanModelExtension._homeDir]) + const destPath = await joinPath([userSpace, 'models']) await fs.syncFile(srcPath, destPath) @@ -98,7 +98,7 @@ export default class JanModelExtension implements ModelExtension { // try to retrieve the download file name from the source url // if it fails, use the model ID as the file name - const extractedFileName = basename(model.source_url) + const extractedFileName = await model.source_url.split('/').pop() const fileName = extractedFileName .toLowerCase() .endsWith(JanModelExtension._supportedModelFormat) diff --git a/extensions/monitoring-extension/src/index.ts b/extensions/monitoring-extension/src/index.ts index 2e5e50ffa9..f1cbf6dad4 100644 --- a/extensions/monitoring-extension/src/index.ts +++ b/extensions/monitoring-extension/src/index.ts @@ -18,7 +18,7 @@ export default class JanMonitoringExtension implements MonitoringExtension { /** * Called when the extension is loaded. */ - onLoad(): void {} + async onLoad() {} /** * Called when the extension is unloaded. diff --git a/web/containers/GPUDriverPromptModal/index.tsx b/web/containers/GPUDriverPromptModal/index.tsx new file mode 100644 index 0000000000..68efa33d52 --- /dev/null +++ b/web/containers/GPUDriverPromptModal/index.tsx @@ -0,0 +1,84 @@ +import React from 'react' + +import { openExternalUrl } from '@janhq/core' + +import { + ModalClose, + ModalFooter, + ModalContent, + Modal, + ModalTitle, + ModalHeader, + Button, +} from '@janhq/uikit' + +import { useAtom } from 'jotai' + +import { isShowNotificationAtom, useSettings } from '@/hooks/useSettings' + +const GPUDriverPrompt: React.FC = () => { + const [showNotification, setShowNotification] = useAtom( + isShowNotificationAtom + ) + + const { saveSettings } = useSettings() + const onDoNotShowAgainChange = (e: React.ChangeEvent) => { + const isChecked = !e.target.checked + saveSettings({ notify: isChecked }) + } + + const openChanged = () => { + setShowNotification(false) + } + + return ( +
+ + + + Missing Nvidia Driver and Cuda Toolkit + +

+ It seems like you are missing Nvidia Driver or Cuda Toolkit or both. + Please follow the instructions on the{' '} + + openExternalUrl('https://developer.nvidia.com/cuda-toolkit') + } + > + NVidia Cuda Toolkit Installation Page + {' '} + and the{' '} + + openExternalUrl('https://www.nvidia.com/Download/index.aspx') + } + > + Nvidia Driver Installation Page + + . +

+
+ + Don't show again +
+ +
+ + + +
+
+
+
+
+ ) +} +export default GPUDriverPrompt diff --git a/web/containers/Providers/EventHandler.tsx b/web/containers/Providers/EventHandler.tsx index a828a02a11..e755804ff6 100644 --- a/web/containers/Providers/EventHandler.tsx +++ b/web/containers/Providers/EventHandler.tsx @@ -119,6 +119,8 @@ export default function EventHandler({ children }: { children: ReactNode }) { } } } + function handleGpuCompatiblePrompt() {} + function handleGpuDriverMissingPrompt() {} useEffect(() => { if (window.core?.events) { @@ -127,6 +129,8 @@ export default function EventHandler({ children }: { children: ReactNode }) { events.on(EventName.OnModelReady, handleModelReady) events.on(EventName.OnModelFail, handleModelFail) events.on(EventName.OnModelStopped, handleModelStopped) + events.on('OnGPUCompatiblePrompt', handleGpuCompatiblePrompt) + events.on('OnGPUDriverMissingPrompt', handleGpuDriverMissingPrompt) } // eslint-disable-next-line react-hooks/exhaustive-deps }, []) diff --git a/web/containers/Providers/EventListener.tsx b/web/containers/Providers/EventListener.tsx index 046f2ecd23..ff661aacc4 100644 --- a/web/containers/Providers/EventListener.tsx +++ b/web/containers/Providers/EventListener.tsx @@ -1,9 +1,8 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ -import { basename } from 'path' - import { PropsWithChildren, useEffect, useRef } from 'react' +import { baseName } from '@janhq/core' import { useAtomValue, useSetAtom } from 'jotai' import { useDownloadState } from '@/hooks/useDownloadState' @@ -37,10 +36,11 @@ export default function EventListenerWrapper({ children }: PropsWithChildren) { useEffect(() => { if (window && window.electronAPI) { window.electronAPI.onFileDownloadUpdate( - (_event: string, state: any | undefined) => { + async (_event: string, state: any | undefined) => { if (!state) return + const modelName = await baseName(state.fileName) const model = modelsRef.current.find( - (model) => modelBinFileName(model) === basename(state.fileName) + (model) => modelBinFileName(model) === modelName ) if (model) setDownloadState({ @@ -50,25 +50,31 @@ export default function EventListenerWrapper({ children }: PropsWithChildren) { } ) - window.electronAPI.onFileDownloadError((_event: string, state: any) => { - console.error('Download error', state) - const model = modelsRef.current.find( - (model) => modelBinFileName(model) === basename(state.fileName) - ) - if (model) setDownloadStateFailed(model.id) - }) - - window.electronAPI.onFileDownloadSuccess((_event: string, state: any) => { - if (state && state.fileName) { + window.electronAPI.onFileDownloadError( + async (_event: string, state: any) => { + console.error('Download error', state) + const modelName = await baseName(state.fileName) const model = modelsRef.current.find( - (model) => modelBinFileName(model) === basename(state.fileName) + (model) => modelBinFileName(model) === modelName ) - if (model) { - setDownloadStateSuccess(model.id) - setDownloadedModels([...downloadedModelRef.current, model]) + if (model) setDownloadStateFailed(model.id) + } + ) + + window.electronAPI.onFileDownloadSuccess( + async (_event: string, state: any) => { + if (state && state.fileName) { + const modelName = await baseName(state.fileName) + const model = modelsRef.current.find( + async (model) => modelBinFileName(model) === modelName + ) + if (model) { + setDownloadStateSuccess(model.id) + setDownloadedModels([...downloadedModelRef.current, model]) + } } } - }) + ) window.electronAPI.onAppUpdateDownloadUpdate( (_event: string, progress: any) => { diff --git a/web/containers/Providers/index.tsx b/web/containers/Providers/index.tsx index 3decde8091..c7e6e26a1c 100644 --- a/web/containers/Providers/index.tsx +++ b/web/containers/Providers/index.tsx @@ -8,6 +8,7 @@ import { TooltipProvider } from '@janhq/uikit' import { PostHogProvider } from 'posthog-js/react' +import GPUDriverPrompt from '@/containers/GPUDriverPromptModal' import EventListenerWrapper from '@/containers/Providers/EventListener' import JotaiWrapper from '@/containers/Providers/Jotai' import ThemeWrapper from '@/containers/Providers/Theme' @@ -25,11 +26,11 @@ import { instance } from '@/utils/posthog' import { extensionManager } from '@/extension' const Providers = (props: PropsWithChildren) => { + const { children } = props + const [setupCore, setSetupCore] = useState(false) const [activated, setActivated] = useState(false) - const { children } = props - async function setupExtensions() { // Register all active extensions await extensionManager.registerActive() @@ -74,6 +75,7 @@ const Providers = (props: PropsWithChildren) => { {children} + diff --git a/web/hooks/useEngineSettings.ts b/web/hooks/useEngineSettings.ts index 349275e96b..258a89aa48 100644 --- a/web/hooks/useEngineSettings.ts +++ b/web/hooks/useEngineSettings.ts @@ -2,7 +2,9 @@ import { fs, joinPath } from '@janhq/core' export const useEngineSettings = () => { const readOpenAISettings = async () => { - if (!fs.existsSync(await joinPath(['file://engines', 'openai.json']))) + if ( + !(await fs.existsSync(await joinPath(['file://engines', 'openai.json']))) + ) return {} const settings = await fs.readFileSync( await joinPath(['file://engines', 'openai.json']), diff --git a/web/hooks/useSettings.ts b/web/hooks/useSettings.ts new file mode 100644 index 0000000000..34d123359e --- /dev/null +++ b/web/hooks/useSettings.ts @@ -0,0 +1,67 @@ +import { useEffect, useState } from 'react' + +import { fs, joinPath } from '@janhq/core' +import { atom, useAtom } from 'jotai' + +export const isShowNotificationAtom = atom(false) + +export const useSettings = () => { + const [isGPUModeEnabled, setIsGPUModeEnabled] = useState(false) // New state for GPU mode + const [showNotification, setShowNotification] = useAtom( + isShowNotificationAtom + ) + + useEffect(() => { + setTimeout(() => validateSettings, 3000) + }, []) + + const validateSettings = async () => { + readSettings().then((settings) => { + if ( + settings && + settings.notify && + ((settings.nvidia_driver?.exist && !settings.cuda?.exist) || + !settings.nvidia_driver?.exist) + ) { + setShowNotification(true) + } + + // Check if run_mode is 'gpu' or 'cpu' and update state accordingly + setIsGPUModeEnabled(settings?.run_mode === 'gpu') + }) + } + + const readSettings = async () => { + if (!window?.core?.api) { + return + } + const settingsFile = await joinPath(['file://settings', 'settings.json']) + if (await fs.existsSync(settingsFile)) { + const settings = await fs.readFileSync(settingsFile, 'utf-8') + return typeof settings === 'object' ? settings : JSON.parse(settings) + } + return {} + } + const saveSettings = async ({ + runMode, + notify, + }: { + runMode?: string | undefined + notify?: boolean | undefined + }) => { + const settingsFile = await joinPath(['file://settings', 'settings.json']) + const settings = await readSettings() + if (runMode != null) settings.run_mode = runMode + if (notify != null) settings.notify = notify + await fs.writeFileSync(settingsFile, JSON.stringify(settings)) + } + + return { + showNotification, + isGPUModeEnabled, + readSettings, + saveSettings, + setShowNotification, + validateSettings, + } +} diff --git a/web/screens/Settings/Advanced/index.tsx b/web/screens/Settings/Advanced/index.tsx index 4bfc2ee4e8..122f3f8688 100644 --- a/web/screens/Settings/Advanced/index.tsx +++ b/web/screens/Settings/Advanced/index.tsx @@ -1,17 +1,57 @@ 'use client' -import { useContext } from 'react' +import { useContext, useEffect, useState } from 'react' import { Switch, Button } from '@janhq/uikit' import { FeatureToggleContext } from '@/context/FeatureToggle' +import { useSettings } from '@/hooks/useSettings' + const Advanced = () => { const { experimentalFeatureEnabed, setExperimentalFeatureEnabled } = useContext(FeatureToggleContext) + const [gpuEnabled, setGpuEnabled] = useState(false) + const { readSettings, saveSettings, validateSettings, setShowNotification } = + useSettings() + + useEffect(() => { + readSettings().then((settings) => { + setGpuEnabled(settings.run_mode === 'gpu') + }) + }, []) return (
+ {/* CPU / GPU switching */} + +
+
+
+
NVidia GPU
+
+

+ Enable GPU acceleration for NVidia GPUs. +

+
+ { + if (e === true) { + saveSettings({ runMode: 'gpu' }) + setGpuEnabled(true) + setShowNotification(false) + setTimeout(() => { + validateSettings() + }, 300) + } else { + saveSettings({ runMode: 'cpu' }) + setGpuEnabled(false) + } + }} + /> +
+ {/* Experimental */}
@@ -20,8 +60,7 @@ const Advanced = () => {

- Enable experimental features that may be unstable - tested. + Enable experimental features that may be unstable tested.

{

- Open the directory where your app data, like conversation history and model configurations, is located. + Open the directory where your app data, like conversation history + and model configurations, is located.