Skip to content

Commit

Permalink
Merge pull request #6 from julien-c/add-gemma
Browse files Browse the repository at this point in the history
Add `gemma`
  • Loading branch information
biw authored Mar 11, 2024
2 parents 876a427 + c9cbedd commit 14fe924
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import type {
BaseGGUFMetadata,
BloomMetadata,
FalconMetadata,
GemmaMetadata,
GGUFMetadata,
GPT2Metadata,
GPTJMetadata,
Expand All @@ -16,6 +17,7 @@ import type {
import {
bloomMetadataSchema,
falconMetadataSchema,
gemmaMetadataSchema,
gPT2MetadataSchema,
gPTJMetadataSchema,
gPTNeoXMetadataSchema,
Expand Down Expand Up @@ -263,6 +265,7 @@ const isValidArchitecture = (
'gpt2',
'bloom',
'falcon',
'gemma',
'rwkv',
].includes(architecture)
}
Expand Down Expand Up @@ -317,6 +320,11 @@ const validateMetadata = (
if (res.success === false) return { error: res.error }
return { metadata: res.data }
}
case 'gemma': {
const res = gemmaMetadataSchema.safeParse(metadata)
if (res.success === false) return { error: res.error }
return { metadata: res.data }
}
case 'rwkv': {
const res = rWKVMetadataSchema.safeParse(metadata)
if (res.success === false) return { error: res.error }
Expand Down Expand Up @@ -604,6 +612,12 @@ export const isFalconMetadata = (
return metadata.general.architecture === 'falcon'
}

export const isGemmaMetadata = (
metadata: GGUFMetadata,
): metadata is GemmaMetadata => {
return metadata.general.architecture === 'gemma'
}

export const isRWKVMetadata = (
metadata: GGUFMetadata,
): metadata is RWKVMetadata => {
Expand Down
33 changes: 33 additions & 0 deletions src/metadataTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export type ArchitectureType =
| 'gpt2'
| 'bloom'
| 'falcon'
| 'gemma'
| 'rwkv'

export type BaseGGUFMetadata = {
Expand Down Expand Up @@ -369,6 +370,37 @@ export type RWKVMetadata = {
}
}

export type GemmaMetadata = {
gemma: {
attention: {
/** Also known as n_head. Number of attention heads. */
head_count: number
/** The number of heads per group used in Grouped-Query-Attention. If not
* present or if present and equal to [llm].attention.head_count, the model
* does not use GQA. */
head_count_kv?: number
/** Layer RMS normalization epsilon. */
layer_norm_rms_epsilon: number
}
block_count: number
/** Length of the context used during training or fine-tuning. RWKV is able
* to handle larger context than this limit, but the output quality
* may suffer. */
context_length: number
/** Also known as n_embd. Embedding layer size. */
embedding_length: number
/** Also known as n_ff. The length of the feedforward layer. */
feed_forward_length: number
}
general: BaseGGUFMetadata & {
/**
* describes what architecture this model implements. All lowercase ASCII,
* with only [a-z0-9]+ characters allowed.
**/
architecture: 'gemma'
}
}

export type WhisperMetadata = {
general: BaseGGUFMetadata & {
/**
Expand Down Expand Up @@ -416,5 +448,6 @@ export type GGUFMetadata =
| GPT2Metadata
| BloomMetadata
| FalconMetadata
| GemmaMetadata
| RWKVMetadata
| WhisperMetadata
21 changes: 21 additions & 0 deletions src/zodValidators.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export const architectureTypeSchema = z.union([
z.literal('gpt2'),
z.literal('bloom'),
z.literal('falcon'),
z.literal('gemma'),
z.literal('rwkv'),
])

Expand Down Expand Up @@ -218,6 +219,25 @@ export const rWKVMetadataSchema = z.object({
}),
})

export const gemmaMetadataSchema = z.object({
gemma: z.object({
attention: z.object({
head_count: z.number(),
head_count_kv: z.number().optional(),
layer_norm_rms_epsilon: z.number(),
}),
block_count: z.number(),
context_length: z.number(),
embedding_length: z.number(),
feed_forward_length: z.number(),
}),
general: baseGGUFMetadataSchema.and(
z.object({
architecture: z.literal('gemma'),
}),
),
})

export const whisperMetadataSchema = z.object({
general: baseGGUFMetadataSchema.and(
z.object({
Expand Down Expand Up @@ -253,6 +273,7 @@ export const gGUFMetadataSchema = z.union([
gPT2MetadataSchema,
bloomMetadataSchema,
falconMetadataSchema,
gemmaMetadataSchema,
rWKVMetadataSchema,
whisperMetadataSchema,
])

0 comments on commit 14fe924

Please sign in to comment.