From c7bfa239047e5558e4c67361bf4914f297c643f2 Mon Sep 17 00:00:00 2001 From: Douglas Kogut Date: Mon, 24 Jun 2024 16:55:41 -0400 Subject: [PATCH] refactor --- src/index.ts | 140 +++++++++++++++++++++++---------------------------- 1 file changed, 63 insertions(+), 77 deletions(-) diff --git a/src/index.ts b/src/index.ts index 2be582f..4763aa3 100644 --- a/src/index.ts +++ b/src/index.ts @@ -41,14 +41,11 @@ const ggufMagicNumber = Buffer.from([0x47, 0x47, 0x55, 0x46]).readInt32LE() const fileChunkSize = 10 * 1024 * 1024 -type GGUFFile = { data: Buffer; offset: number } +type GGUFFile = { data: Buffer; fd: number; offset: number } -const readFileChunk = async ( - fd: number, - file: GGUFFile, -): Promise => { +const readFileChunk = async (file: GGUFFile): Promise => { const buffer = Buffer.alloc(fileChunkSize) - const { bytesRead } = await fs.read(fd, buffer, 0, fileChunkSize, null) + const { bytesRead } = await fs.read(file.fd, buffer, 0, fileChunkSize, null) if (bytesRead !== fileChunkSize) { return new Error('unexpected bytes read') } @@ -57,13 +54,12 @@ const readFileChunk = async ( } const readNBytes = async ( - fd: number, numBytes: number, file: GGUFFile, ): Promise<{ error: Error } | { bytes: Buffer; error?: undefined }> => { const end = file.offset + numBytes if (end > file.data.length) { - const err = await readFileChunk(fd, file) + const err = await readFileChunk(file) if (err) return { error: err } } const buffer = file.data.subarray(file.offset, end) @@ -71,68 +67,62 @@ const readNBytes = async ( return { bytes: buffer } } -const readUint8 = async (fd: number, file: GGUFFile): Promise => { - const bytes = await readNBytes(fd, 1, file) +const readUint8 = async (file: GGUFFile): Promise => { + const bytes = await readNBytes(1, file) if (bytes.error) return bytes return { error: null, value: bytes.bytes.readUInt8() } } -const readUint16 = async (fd: number, file: GGUFFile): Promise => { - const bytes = await readNBytes(fd, 2, file) +const readUint16 = async (file: GGUFFile): Promise => { + const bytes = await readNBytes(2, file) if (bytes.error) return bytes return { error: null, value: bytes.bytes.readUInt16LE() } } -const readUint32 = async (fd: number, file: GGUFFile): Promise => { - const bytes = await readNBytes(fd, 4, file) +const readUint32 = async (file: GGUFFile): Promise => { + const bytes = await readNBytes(4, file) if (bytes.error) return bytes return { error: null, value: bytes.bytes.readUInt32LE() } } -const readUint64 = async (fd: number, file: GGUFFile): Promise => { - const bytes = await readNBytes(fd, 8, file) +const readUint64 = async (file: GGUFFile): Promise => { + const bytes = await readNBytes(8, file) if (bytes.error) return bytes return { error: null, value: bytes.bytes.readBigUInt64LE() } } -const readInt8 = async (fd: number, file: GGUFFile): Promise => { - const bytes = await readNBytes(fd, 1, file) +const readInt8 = async (file: GGUFFile): Promise => { + const bytes = await readNBytes(1, file) if (bytes.error) return bytes return { error: null, value: bytes.bytes.readInt8() } } -const readInt16 = async (fd: number, file: GGUFFile): Promise => { - const bytes = await readNBytes(fd, 2, file) +const readInt16 = async (file: GGUFFile): Promise => { + const bytes = await readNBytes(2, file) if (bytes.error) return bytes return { error: null, value: bytes.bytes.readInt16LE() } } -const readInt32 = async (fd: number, file: GGUFFile): Promise => { - const bytes = await readNBytes(fd, 4, file) +const readInt32 = async (file: GGUFFile): Promise => { + const bytes = await readNBytes(4, file) if (bytes.error) return bytes return { error: null, value: bytes.bytes.readInt32LE() } } -const readInt64 = async (fd: number, file: GGUFFile): Promise => { - const bytes = await readNBytes(fd, 8, file) +const readInt64 = async (file: GGUFFile): Promise => { + const bytes = await readNBytes(8, file) if (bytes.error) return bytes return { error: null, value: bytes.bytes.readBigInt64LE() } } -const readFloat32 = async ( - fd: number, - file: GGUFFile, -): Promise => { - const bytes = await readNBytes(fd, 4, file) +const readFloat32 = async (file: GGUFFile): Promise => { + const bytes = await readNBytes(4, file) if (bytes.error) return bytes return { error: null, value: bytes.bytes.readFloatLE() } } -const readFloat64 = async ( - fd: number, - file: GGUFFile, -): Promise => { - const bytes = await readNBytes(fd, 8, file) +const readFloat64 = async (file: GGUFFile): Promise => { + const bytes = await readNBytes(8, file) if (bytes.error) return bytes const arrayBuffer = new ArrayBuffer(8) const view = new DataView(arrayBuffer) @@ -143,30 +133,28 @@ const readFloat64 = async ( } const readBool = async ( - fd: number, file: GGUFFile, ): Promise<{ error: Error } | { error: null; value: boolean }> => { - const bytes = await readNBytes(fd, 1, file) + const bytes = await readNBytes(1, file) if (bytes.error) return bytes return { error: null, value: !!bytes.bytes.readUint8() } } const readVersionedSize = async ( - fd: number, version: Version, file: GGUFFile, ): Promise => { let value: bigint switch (version) { case 1: { - const n = await readUint32(fd, file) + const n = await readUint32(file) if (n.error) return n value = BigInt(n.value) break } case 3: case 2: { - const n = await readUint64(fd, file) + const n = await readUint64(file) if (n.error) return n value = n.value break @@ -176,13 +164,12 @@ const readVersionedSize = async ( } const readString = async ( - fd: number, version: Version, file: GGUFFile, ): Promise<{ error: Error } | { error: null; value: string }> => { - const nBytes = await readVersionedSize(fd, version, file) + const nBytes = await readVersionedSize(version, file) if (nBytes.error) return nBytes - const strBuffer = await readNBytes(fd, Number(nBytes.value), file) // TODO: fix cast + const strBuffer = await readNBytes(Number(nBytes.value), file) // TODO: fix cast if (strBuffer.error) return strBuffer return { error: null, @@ -192,85 +179,84 @@ const readString = async ( } const readArray = async ( - fd: number, version: Version, file: GGUFFile, ): Promise<{ error: Error } | { error: null; value: MetadataArray }> => { - const arrType = await readUint32(fd, file) + const arrType = await readUint32(file) if (arrType.error) return arrType - const numElts = await readVersionedSize(fd, version, file) + const numElts = await readVersionedSize(version, file) if (numElts.error) return numElts const ret: MetadataArray = [] for (let i = 0; i < numElts.value; ++i) { switch (arrType.value) { case 0: { - const value = await readUint8(fd, file) + const value = await readUint8(file) if (value.error) return value ret.push(value.value) break } case 1: { - const value = await readInt8(fd, file) + const value = await readInt8(file) if (value.error) return value ret.push(value.value) break } case 2: { - const value = await readUint16(fd, file) + const value = await readUint16(file) if (value.error) return value ret.push(value.value) break } case 3: { - const value = await readInt16(fd, file) + const value = await readInt16(file) if (value.error) return value ret.push(value.value) break } case 4: { - const value = await readUint32(fd, file) + const value = await readUint32(file) if (value.error) return value ret.push(value.value) break } case 5: { - const value = await readInt32(fd, file) + const value = await readInt32(file) if (value.error) return value ret.push(value.value) break } case 6: { - const value = await readFloat32(fd, file) + const value = await readFloat32(file) if (value.error) return value ret.push(value.value) break } case 7: { - const value = await readBool(fd, file) + const value = await readBool(file) if (value.error) return value ret.push(value.value) break } case 8: { - const value = await readString(fd, version, file) + const value = await readString(version, file) if (value.error) return value ret.push(value.value) break } case 10: { - const value = await readUint64(fd, file) + const value = await readUint64(file) if (value.error) return value ret.push(value.value) break } case 11: { - const value = await readInt64(fd, file) + const value = await readInt64(file) if (value.error) return value ret.push(value.value) break } case 12: { - const value = await readFloat64(fd, file) + const value = await readFloat64(file) if (value.error) return value ret.push(value.value) break @@ -433,15 +419,15 @@ export const parseRawMetadata = async ( filePath: string, ): Promise => { const fd = await fs.open(filePath, 'r') - const file: GGUFFile = { data: Buffer.from([]), offset: 0 } + const file: GGUFFile = { data: Buffer.from([]), fd, offset: 0 } - const magic = await readUint32(fd, file) + const magic = await readUint32(file) if (magic.error) return magic if (magic.value !== ggufMagicNumber) { return { error: new Error('invalid gguf magic number') } } - const version = await readUint32(fd, file) + const version = await readUint32(file) if (version.error) return version if (!isVersion(version.value)) { return { @@ -449,10 +435,10 @@ export const parseRawMetadata = async ( } } - const tensorCount = await readVersionedSize(fd, version.value, file) + const tensorCount = await readVersionedSize(version.value, file) if (tensorCount.error) return tensorCount - const numKv = await readVersionedSize(fd, version.value, file) + const numKv = await readVersionedSize(version.value, file) if (numKv.error) return numKv const metadata: Record = {} @@ -494,86 +480,86 @@ export const parseRawMetadata = async ( } for (let i = 0; i < numKv.value; ++i) { - const key = await readString(fd, version.value, file) + const key = await readString(version.value, file) if (key.error) return key - const keyType = await readUint32(fd, file) + const keyType = await readUint32(file) if (keyType.error) return keyType switch (keyType.value) { case 0: { - const value = await readUint8(fd, file) + const value = await readUint8(file) if (value.error) return value setKey(key.value, value.value) break } case 1: { - const value = await readInt8(fd, file) + const value = await readInt8(file) if (value.error) return value setKey(key.value, value.value) break } case 2: { - const value = await readUint16(fd, file) + const value = await readUint16(file) if (value.error) return value setKey(key.value, value.value) break } case 3: { - const value = await readInt16(fd, file) + const value = await readInt16(file) if (value.error) return value setKey(key.value, value.value) break } case 4: { - const value = await readUint32(fd, file) + const value = await readUint32(file) if (value.error) return value setKey(key.value, value.value) break } case 5: { - const value = await readInt32(fd, file) + const value = await readInt32(file) if (value.error) return value setKey(key.value, value.value) break } case 6: { - const value = await readFloat32(fd, file) + const value = await readFloat32(file) if (value.error) return value setKey(key.value, value.value) break } case 7: { - const value = await readBool(fd, file) + const value = await readBool(file) if (value.error) return value setKey(key.value, value.value) break } case 8: { - const value = await readString(fd, version.value, file) + const value = await readString(version.value, file) if (value.error) return value setKey(key.value, value.value) break } case 9: { - const value = await readArray(fd, version.value, file) + const value = await readArray(version.value, file) if (value.error) return value setKey(key.value, value.value) break } case 10: { - const value = await readUint64(fd, file) + const value = await readUint64(file) if (value.error) return value setKey(key.value, value.value) break } case 11: { - const value = await readInt64(fd, file) + const value = await readInt64(file) if (value.error) return value setKey(key.value, value.value) break } case 12: { - const value = await readFloat64(fd, file) + const value = await readFloat64(file) if (value.error) return value setKey(key.value, value.value) break