Skip to content

Commit

Permalink
discojs-core/models: add gpt
Browse files Browse the repository at this point in the history
Closes: #641
Closes: #619
Closes: #600
  • Loading branch information
peacefulotter authored and tharvik committed Mar 13, 2024
1 parent dbdb764 commit 7f95b86
Show file tree
Hide file tree
Showing 11 changed files with 1,109 additions and 6 deletions.
13 changes: 13 additions & 0 deletions discojs/discojs-core/src/informant/training_informant/base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ export abstract class Base {
protected readonly trainingGraphInformant = new GraphInformant()
protected readonly validationGraphInformant = new GraphInformant()

private _losses = List<number>()

// statistics
protected currentRound = 0
protected currentNumberOfParticipants = 0
Expand Down Expand Up @@ -71,6 +73,17 @@ export abstract class Base {
return this.validationGraphInformant.accuracy()
}

/** add loss for current round */
// eslint-disable-next-line accessor-pairs
set loss (loss: number) {
this._losses = this._losses.push(loss)
}

/** return loss of each round */
get losses (): List<number> {
return this._losses
}

trainingAccuracyData (): List<number> {
return this.trainingGraphInformant.data()
}
Expand Down
77 changes: 77 additions & 0 deletions discojs/discojs-core/src/models/gpt/config.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
type ModelType =
| 'gpt2'
| 'gpt2-medium'
| 'gpt2-large'
| 'gpt2-xl'
| 'gpt-mini'
| 'gpt-micro'
| 'gpt-nano'

interface ModelSize {
nLayer?: number
nHead?: number
nEmbd?: number
}

export interface GPTConfig {
lr: number
batchSize: number
blockSize: number
vocabSize: number
evaluate?: boolean
maxEvalBatches?: number
evaluateEvery?: number
epochs?: number
maxIter?: number
weightDecay?: number
verbose?: 0 | 1
bias?: boolean
debug?: boolean
dropout?: number
residDrop?: number
embdDrop?: number
tokEmb?: boolean
lmHead?: boolean
modelType: ModelType
}

export const DEFAULT_CONFIG: Required<GPTConfig> = {
lr: 0.001,
weightDecay: 0,
batchSize: 2,
epochs: 9999,
maxIter: 10_000,
verbose: 0,
modelType: 'gpt-nano',
evaluate: true,
maxEvalBatches: 12,
evaluateEvery: 100,
blockSize: 128,
vocabSize: 50258,
bias: true,
debug: false,
dropout: 0.2,
residDrop: 0.2,
embdDrop: 0.2,
tokEmb: true,
lmHead: true
}

export function getModelSizes (modelType: ModelType): Required<ModelSize> {
switch (modelType) {
case 'gpt2':
return { nLayer: 12, nHead: 12, nEmbd: 768 }
case 'gpt2-medium':
return { nLayer: 24, nHead: 16, nEmbd: 1024 }
case 'gpt2-large':
return { nLayer: 36, nHead: 20, nEmbd: 1280 }
case 'gpt2-xl':
return { nLayer: 48, nHead: 25, nEmbd: 1600 }
case 'gpt-mini':
return { nLayer: 6, nHead: 6, nEmbd: 192 }
case 'gpt-micro':
return { nLayer: 4, nHead: 4, nEmbd: 128 }
case 'gpt-nano':
return { nLayer: 3, nHead: 3, nEmbd: 48 }
}
}
54 changes: 54 additions & 0 deletions discojs/discojs-core/src/models/gpt/evaluate.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import tf from '@tensorflow/tfjs'

export default async function evaluate (
model: tf.LayersModel,
dataset: tf.data.Dataset<{ xs: tf.Tensor, ys: tf.Tensor }>
): Promise<Record<'acc' | 'val_acc' | 'val_loss' | 'val_perplexity', number>> {
let datasetSize = 0
let totalLoss = 0
const acc: [number, number] = [0, 0]

await dataset.map(({ xs, ys }) => {
const logits = model.apply(xs)
if (Array.isArray(logits)) {
throw new Error('model outputed many tensor')
}
if (logits instanceof tf.SymbolicTensor) {
throw new Error('model outputed symbolic tensor')
}
xs.dispose()

return { logits, ys }
}).mapAsync(async ({ logits, ys }) => {
const loss = (await tf.losses.softmaxCrossEntropy(ys, logits).array())
if (typeof loss !== 'number') {
throw new Error('got multiple loss')
}

const accTensor = tf.metrics.categoricalAccuracy(ys, logits)
const accSize = accTensor.shape.reduce((l, r) => l * r, 1)
const accSum = accTensor.sum()
const accSummed = await accSum.array()
if (typeof accSummed !== 'number') {
throw new Error('got multiple accuracy sum')
}

tf.dispose([ys, logits, accTensor, accSum])

return { loss, accSummed, accSize }
}).forEachAsync(({ loss, accSummed, accSize }) => {
datasetSize += 1
totalLoss += loss
acc[0] += accSummed
acc[1] += accSize
})

const loss = totalLoss / datasetSize

return {
val_loss: loss,
val_perplexity: Math.exp(loss),
acc: acc[0] / acc[1],
val_acc: acc[0] / acc[1]
}
}
139 changes: 139 additions & 0 deletions discojs/discojs-core/src/models/gpt/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import tf from '@tensorflow/tfjs'

import { WeightsContainer } from '../..'
import type { Dataset } from '../../dataset'
import { Sink } from '../../utils/event_emitter'

import type { EpochLogs, Prediction, Sample } from '../model'
import { Model } from '../model'

import { GPTLMHeadModel } from './model'

// TODO too big config
interface Config {
modelType: 'gpt-nano'
epochs: number // TODO mv to Task
maxIter: number
batchSize: number
blockSize: number
lr: number
vocabSize: number
maxEvalBatches: number
}

export class GPT extends Model {
private readonly model: GPTLMHeadModel

private static readonly batchSize = 4
private static readonly blockSize = 128
private static readonly vocabSize = 50258

constructor () {
super()

// TODO sensible defaults?
const config: Config = {
modelType: 'gpt-nano',
epochs: 1,
maxIter: 2,
batchSize: GPT.batchSize,
blockSize: GPT.blockSize,
lr: 0.001,
vocabSize: GPT.vocabSize,
maxEvalBatches: 1
}

this.model = new GPTLMHeadModel(config)
}

override get weights (): WeightsContainer {
return new WeightsContainer(this.model.weights.map((w) => w.read()))
}

override set weights (ws: WeightsContainer) {
this.model.setWeights(ws.weights)
}

private convertCharDataset (dataset: Dataset): Dataset {
const batchSize = 4
const sampleSize = GPT.blockSize + 1
const chunkSize = sampleSize * batchSize * 2

function toUInt16 (low: number, high: number): number {
low &= 0xff
high &= 0xff
return (high << 8) | low
}

// TODO add support for small last batch
return dataset.batch(chunkSize, false).mapAsync(async (chunk) => {
if (!(chunk instanceof tf.Tensor)) {
throw new Error('chunk is not a Tensor')
}
if (chunk.shape.length !== 2 || chunk.shape[1] !== 1) {
throw new Error('dataset is not a only char')
}

const buffer = await chunk.buffer()

const xs = tf.buffer([batchSize, GPT.blockSize], 'int32')
const ys = tf.buffer([batchSize, GPT.blockSize, GPT.vocabSize], 'int32')

for (let i = 0; i < batchSize; i++) {
for (let j = 0; j < sampleSize; j++) {
const idx = (i * sampleSize + j) * 2
const low = buffer.get(idx)
const high = buffer.get(idx + 1)
const token = toUInt16(low, high)
if (j < sampleSize - 1) xs.set(token, i, j)
if (j > 0) ys.set(1, i, j - 1, token)
}
}

return { xs: xs.toTensor(), ys: ys.toTensor() }
})
}

override async * train (
trainingData: Dataset,
validationData?: Dataset,
epochs = 1,
tracker = new Sink()
): AsyncGenerator<EpochLogs, void> {
for (let i = 0; i < epochs; i++) {
let logs: tf.Logs | undefined

await this.model.fitDataset(
this.convertCharDataset(trainingData), {
epochs: 1,
validationData: validationData !== undefined ? this.convertCharDataset(validationData) : validationData,
callbacks: {
onEpochEnd: (_, cur) => { logs = cur },
onBatchBegin: () => { tracker.emit('batchBegin', undefined) },
onBatchEnd: () => { tracker.emit('batchEnd', undefined) }
}
})

yield logs
}
}

override async predict (input: Sample): Promise<Prediction> {
const ret = this.model.predict(input)
if (Array.isArray(ret)) {
throw new Error('prediction yield many Tensors but should have only returned one')
}

return ret
}

static deserialize (weights: WeightsContainer): Model {
const model = new GPT()
model.weights = weights
return model
}

serialize (): WeightsContainer {
return this.weights
}
}
Loading

0 comments on commit 7f95b86

Please sign in to comment.