From dbaf84225e61c8164abf429a7e1c215c2fd4cee6 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Thu, 7 Mar 2024 16:03:03 +0100 Subject: [PATCH 1/2] Add basic support for `gemma` --- src/index.ts | 14 ++++++++++++++ src/metadataTypes.ts | 23 +++++++++++++++++++++++ src/zodValidators.ts | 16 ++++++++++++++++ 3 files changed, 53 insertions(+) diff --git a/src/index.ts b/src/index.ts index 03a8b0e..d3fc2c8 100644 --- a/src/index.ts +++ b/src/index.ts @@ -5,6 +5,7 @@ import type { BaseGGUFMetadata, BloomMetadata, FalconMetadata, + GemmaMetadata, GGUFMetadata, GPT2Metadata, GPTJMetadata, @@ -16,6 +17,7 @@ import type { import { bloomMetadataSchema, falconMetadataSchema, + gemmaMetadataSchema, gPT2MetadataSchema, gPTJMetadataSchema, gPTNeoXMetadataSchema, @@ -263,6 +265,7 @@ const isValidArchitecture = ( 'gpt2', 'bloom', 'falcon', + 'gemma', 'rwkv', ].includes(architecture) } @@ -317,6 +320,11 @@ const validateMetadata = ( if (res.success === false) return { error: res.error } return { metadata: res.data } } + case 'gemma': { + const res = gemmaMetadataSchema.safeParse(metadata) + if (res.success === false) return { error: res.error } + return { metadata: res.data } + } case 'rwkv': { const res = rWKVMetadataSchema.safeParse(metadata) if (res.success === false) return { error: res.error } @@ -604,6 +612,12 @@ export const isFalconMetadata = ( return metadata.general.architecture === 'falcon' } +export const isGemmaMetadata = ( + metadata: GGUFMetadata, +): metadata is GemmaMetadata => { + return metadata.general.architecture === 'gemma' +} + export const isRWKVMetadata = ( metadata: GGUFMetadata, ): metadata is RWKVMetadata => { diff --git a/src/metadataTypes.ts b/src/metadataTypes.ts index 3707c53..37e5151 100644 --- a/src/metadataTypes.ts +++ b/src/metadataTypes.ts @@ -10,6 +10,7 @@ export type ArchitectureType = | 'gpt2' | 'bloom' | 'falcon' + | 'gemma' | 'rwkv' export type BaseGGUFMetadata = { @@ -369,6 +370,27 @@ export type RWKVMetadata = { } } +export type GemmaMetadata = { + gemma: { + block_count: number + /** Length of the context used during training or fine-tuning. RWKV is able + * to handle larger context than this limit, but the output quality + * may suffer. */ + context_length: number + /** Also known as n_embd. Embedding layer size. */ + embedding_length: number + /** Also known as n_ff. The length of the feedforward layer. */ + feed_forward_length: number + } + general: BaseGGUFMetadata & { + /** + * describes what architecture this model implements. All lowercase ASCII, + * with only [a-z0-9]+ characters allowed. + **/ + architecture: 'gemma' + } +} + export type WhisperMetadata = { general: BaseGGUFMetadata & { /** @@ -416,5 +438,6 @@ export type GGUFMetadata = | GPT2Metadata | BloomMetadata | FalconMetadata + | GemmaMetadata | RWKVMetadata | WhisperMetadata diff --git a/src/zodValidators.ts b/src/zodValidators.ts index 6f7fb15..dea3735 100644 --- a/src/zodValidators.ts +++ b/src/zodValidators.ts @@ -9,6 +9,7 @@ export const architectureTypeSchema = z.union([ z.literal('gpt2'), z.literal('bloom'), z.literal('falcon'), + z.literal('gemma'), z.literal('rwkv'), ]) @@ -218,6 +219,20 @@ export const rWKVMetadataSchema = z.object({ }), }) +export const gemmaMetadataSchema = z.object({ + gemma: z.object({ + block_count: z.number(), + context_length: z.number(), + embedding_length: z.number(), + feed_forward_length: z.number(), + }), + general: baseGGUFMetadataSchema.and( + z.object({ + architecture: z.literal('gemma'), + }), + ), +}) + export const whisperMetadataSchema = z.object({ general: baseGGUFMetadataSchema.and( z.object({ @@ -253,6 +268,7 @@ export const gGUFMetadataSchema = z.union([ gPT2MetadataSchema, bloomMetadataSchema, falconMetadataSchema, + gemmaMetadataSchema, rWKVMetadataSchema, whisperMetadataSchema, ]) From c9cbedd3df08981f58954f87ca4d36b022854f52 Mon Sep 17 00:00:00 2001 From: Julien Chaumond Date: Thu, 7 Mar 2024 16:09:30 +0100 Subject: [PATCH 2/2] Also validate attention --- src/metadataTypes.ts | 10 ++++++++++ src/zodValidators.ts | 5 +++++ 2 files changed, 15 insertions(+) diff --git a/src/metadataTypes.ts b/src/metadataTypes.ts index 37e5151..990dc32 100644 --- a/src/metadataTypes.ts +++ b/src/metadataTypes.ts @@ -372,6 +372,16 @@ export type RWKVMetadata = { export type GemmaMetadata = { gemma: { + attention: { + /** Also known as n_head. Number of attention heads. */ + head_count: number + /** The number of heads per group used in Grouped-Query-Attention. If not + * present or if present and equal to [llm].attention.head_count, the model + * does not use GQA. */ + head_count_kv?: number + /** Layer RMS normalization epsilon. */ + layer_norm_rms_epsilon: number + } block_count: number /** Length of the context used during training or fine-tuning. RWKV is able * to handle larger context than this limit, but the output quality diff --git a/src/zodValidators.ts b/src/zodValidators.ts index dea3735..8f34a1e 100644 --- a/src/zodValidators.ts +++ b/src/zodValidators.ts @@ -221,6 +221,11 @@ export const rWKVMetadataSchema = z.object({ export const gemmaMetadataSchema = z.object({ gemma: z.object({ + attention: z.object({ + head_count: z.number(), + head_count_kv: z.number().optional(), + layer_norm_rms_epsilon: z.number(), + }), block_count: z.number(), context_length: z.number(), embedding_length: z.number(),