Skip to content

Commit

Permalink
Merge pull request #4133 from janhq/fix/4096-failed-to-get-huggingfac…
Browse files Browse the repository at this point in the history
…e-models

fix: 4096 - failed to get huggingface models
  • Loading branch information
louis-jan authored Nov 26, 2024
2 parents ad84845 + dc649bf commit 3a9a8da
Show file tree
Hide file tree
Showing 8 changed files with 21 additions and 72 deletions.
13 changes: 0 additions & 13 deletions core/src/browser/core.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ import { joinPath } from './core'
import { openFileExplorer } from './core'
import { getJanDataFolderPath } from './core'
import { abortDownload } from './core'
import { getFileSize } from './core'
import { executeOnMain } from './core'

describe('test core apis', () => {
Expand Down Expand Up @@ -66,18 +65,6 @@ describe('test core apis', () => {
expect(result).toBe('aborted')
})

it('should get file size', async () => {
const url = 'http://example.com/file'
globalThis.core = {
api: {
getFileSize: jest.fn().mockResolvedValue(1024),
},
}
const result = await getFileSize(url)
expect(globalThis.core.api.getFileSize).toHaveBeenCalledWith(url)
expect(result).toBe(1024)
})

it('should execute function on main process', async () => {
const extension = 'testExtension'
const method = 'testMethod'
Expand Down
10 changes: 0 additions & 10 deletions core/src/browser/core.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,6 @@ const downloadFile: (downloadRequest: DownloadRequest, network?: NetworkConfig)
network
) => globalThis.core?.api?.downloadFile(downloadRequest, network)

/**
* Get unit in bytes for a remote file.
*
* @param url - The url of the file.
* @returns {Promise<number>} - A promise that resolves with the file size.
*/
const getFileSize: (url: string) => Promise<number> = (url: string) =>
globalThis.core.api?.getFileSize(url)

/**
* Aborts the download of a specific file.
* @param {string} fileName - The name of the file whose download is to be aborted.
Expand Down Expand Up @@ -167,7 +158,6 @@ export {
getUserHomePath,
systemInformation,
showToast,
getFileSize,
dirName,
FileStat,
}
17 changes: 5 additions & 12 deletions core/src/node/api/processors/download.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@ jest.mock('fs', () => ({
createWriteStream: jest.fn(),
}))

const requestMock = jest.fn((options, callback) => {
callback(new Error('Test error'), null)
})
jest.mock('request', () => requestMock)

jest.mock('request-progress', () => {
return jest.fn().mockImplementation(() => {
return {
Expand Down Expand Up @@ -54,18 +59,6 @@ describe('Downloader', () => {
beforeEach(() => {
jest.resetAllMocks()
})
it('should handle getFileSize errors correctly', async () => {
const observer = jest.fn()
const url = 'http://example.com/file'

const downloader = new Downloader(observer)
const requestMock = jest.fn((options, callback) => {
callback(new Error('Test error'), null)
})
jest.mock('request', () => requestMock)

await expect(downloader.getFileSize(observer, url)).rejects.toThrow('Test error')
})

it('should pause download correctly', () => {
const observer = jest.fn()
Expand Down
21 changes: 0 additions & 21 deletions core/src/node/api/processors/download.ts
Original file line number Diff line number Diff line change
Expand Up @@ -135,25 +135,4 @@ export class Downloader implements Processor {
pauseDownload(_observer: any, fileName: any) {
DownloadManager.instance.networkRequests[fileName]?.pause()
}

async getFileSize(_observer: any, url: string): Promise<number> {
return new Promise((resolve, reject) => {
const request = require('request')
request(
{
url,
method: 'HEAD',
},
function (err: any, response: any) {
if (err) {
console.error('Getting file size failed:', err)
reject(err)
} else {
const size: number = response.headers['content-length'] ?? -1
resolve(size)
}
}
)
})
}
}
1 change: 0 additions & 1 deletion core/src/types/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ export enum DownloadRoute {
pauseDownload = 'pauseDownload',
resumeDownload = 'resumeDownload',
getDownloadProgress = 'getDownloadProgress',
getFileSize = 'getFileSize',
}

export enum DownloadEvent {
Expand Down
2 changes: 1 addition & 1 deletion web/containers/ModelSearch/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ const ModelSearch = ({ onSearchLocal }: Props) => {
errMessage = err.message
}
toaster({
title: 'Failed to get Hugging Face models',
title: 'Oops, you may be rate limited, give it a bit more time',
description: errMessage,
type: 'error',
})
Expand Down
13 changes: 8 additions & 5 deletions web/utils/huggingface.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,8 @@ import {
toHuggingFaceUrl,
InvalidHostError,
} from './huggingface'
import { getFileSize } from '@janhq/core'

// Mock the getFileSize function
jest.mock('@janhq/core', () => ({
getFileSize: jest.fn(),
AllQuantizations: ['q4_0', 'q4_1', 'q5_0', 'q5_1', 'q8_0'],
}))

Expand Down Expand Up @@ -38,9 +35,15 @@ describe('huggingface utils', () => {
}

;(global.fetch as jest.Mock).mockResolvedValue({
json: jest.fn().mockResolvedValue(mockResponse),
json: jest
.fn()
.mockResolvedValueOnce(mockResponse)
.mockResolvedValueOnce([{
path: 'model-q4_0.gguf', size: 1000000,
},{
path: 'model-q4_0.gguf', size: 2000
}]),
})
;(getFileSize as jest.Mock).mockResolvedValue(1000000)

const result = await fetchHuggingFaceRepoData('user/repo')

Expand Down
16 changes: 7 additions & 9 deletions web/utils/huggingface.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { AllQuantizations, getFileSize, HuggingFaceRepoData } from '@janhq/core'
import { AllQuantizations, HuggingFaceRepoData } from '@janhq/core'

/**
* Fetches data from a Hugging Face repository.
Expand Down Expand Up @@ -39,21 +39,19 @@ export const fetchHuggingFaceRepoData = async (
)
}

const promises: Promise<number>[] = []

// fetching file sizes
const url = new URL(sanitizedUrl)
const paths = url.pathname.split('/').filter((e) => e.trim().length > 0)

const repoTree: { path: string; size: number }[] = await fetch(
`https://huggingface.co/api/models/${paths[2]}/${paths[3]}/tree/main`
).then((res) => res.json())

for (const sibling of data.siblings) {
const downloadUrl = `https://huggingface.co/${paths[2]}/${paths[3]}/resolve/main/${sibling.rfilename}`
sibling.downloadUrl = downloadUrl
promises.push(getFileSize(downloadUrl))
}

const result = await Promise.all(promises)
for (let i = 0; i < data.siblings.length; i++) {
data.siblings[i].fileSize = result[i]
sibling.fileSize =
repoTree.find((file) => file.path === sibling.rfilename)?.size ?? 0
}

AllQuantizations.forEach((quantization) => {
Expand Down

0 comments on commit 3a9a8da

Please sign in to comment.