diff --git a/discojs/discojs-core/src/informant/training_informant/base.ts b/discojs/discojs-core/src/informant/training_informant/base.ts index fa0cc9e25..d9e63e59e 100644 --- a/discojs/discojs-core/src/informant/training_informant/base.ts +++ b/discojs/discojs-core/src/informant/training_informant/base.ts @@ -11,6 +11,8 @@ export abstract class Base { protected readonly trainingGraphInformant = new GraphInformant() protected readonly validationGraphInformant = new GraphInformant() + private _losses = List() + // statistics protected currentRound = 0 protected currentNumberOfParticipants = 0 @@ -71,6 +73,17 @@ export abstract class Base { return this.validationGraphInformant.accuracy() } + /** add loss for current round */ + // eslint-disable-next-line accessor-pairs + set loss (loss: number) { + this._losses = this._losses.push(loss) + } + + /** return loss of each round */ + get losses (): List { + return this._losses + } + trainingAccuracyData (): List { return this.trainingGraphInformant.data() } diff --git a/discojs/discojs-core/src/models/gpt/config.ts b/discojs/discojs-core/src/models/gpt/config.ts new file mode 100644 index 000000000..b8e879dfb --- /dev/null +++ b/discojs/discojs-core/src/models/gpt/config.ts @@ -0,0 +1,77 @@ +type ModelType = + | 'gpt2' + | 'gpt2-medium' + | 'gpt2-large' + | 'gpt2-xl' + | 'gpt-mini' + | 'gpt-micro' + | 'gpt-nano' + +interface ModelSize { + nLayer?: number + nHead?: number + nEmbd?: number +} + +export interface GPTConfig { + lr: number + batchSize: number + blockSize: number + vocabSize: number + evaluate?: boolean + maxEvalBatches?: number + evaluateEvery?: number + epochs?: number + maxIter?: number + weightDecay?: number + verbose?: 0 | 1 + bias?: boolean + debug?: boolean + dropout?: number + residDrop?: number + embdDrop?: number + tokEmb?: boolean + lmHead?: boolean + modelType: ModelType +} + +export const DEFAULT_CONFIG: Required = { + lr: 0.001, + weightDecay: 0, + batchSize: 2, + epochs: 9999, + maxIter: 10_000, + verbose: 0, + modelType: 'gpt-nano', + evaluate: true, + maxEvalBatches: 12, + evaluateEvery: 100, + blockSize: 128, + vocabSize: 50258, + bias: true, + debug: false, + dropout: 0.2, + residDrop: 0.2, + embdDrop: 0.2, + tokEmb: true, + lmHead: true +} + +export function getModelSizes (modelType: ModelType): Required { + switch (modelType) { + case 'gpt2': + return { nLayer: 12, nHead: 12, nEmbd: 768 } + case 'gpt2-medium': + return { nLayer: 24, nHead: 16, nEmbd: 1024 } + case 'gpt2-large': + return { nLayer: 36, nHead: 20, nEmbd: 1280 } + case 'gpt2-xl': + return { nLayer: 48, nHead: 25, nEmbd: 1600 } + case 'gpt-mini': + return { nLayer: 6, nHead: 6, nEmbd: 192 } + case 'gpt-micro': + return { nLayer: 4, nHead: 4, nEmbd: 128 } + case 'gpt-nano': + return { nLayer: 3, nHead: 3, nEmbd: 48 } + } +} diff --git a/discojs/discojs-core/src/models/gpt/evaluate.ts b/discojs/discojs-core/src/models/gpt/evaluate.ts new file mode 100644 index 000000000..cd07653b7 --- /dev/null +++ b/discojs/discojs-core/src/models/gpt/evaluate.ts @@ -0,0 +1,54 @@ +import tf from '@tensorflow/tfjs' + +export default async function evaluate ( + model: tf.LayersModel, + dataset: tf.data.Dataset<{ xs: tf.Tensor, ys: tf.Tensor }> +): Promise> { + let datasetSize = 0 + let totalLoss = 0 + const acc: [number, number] = [0, 0] + + await dataset.map(({ xs, ys }) => { + const logits = model.apply(xs) + if (Array.isArray(logits)) { + throw new Error('model outputed many tensor') + } + if (logits instanceof tf.SymbolicTensor) { + throw new Error('model outputed symbolic tensor') + } + xs.dispose() + + return { logits, ys } + }).mapAsync(async ({ logits, ys }) => { + const loss = (await tf.losses.softmaxCrossEntropy(ys, logits).array()) + if (typeof loss !== 'number') { + throw new Error('got multiple loss') + } + + const accTensor = tf.metrics.categoricalAccuracy(ys, logits) + const accSize = accTensor.shape.reduce((l, r) => l * r, 1) + const accSum = accTensor.sum() + const accSummed = await accSum.array() + if (typeof accSummed !== 'number') { + throw new Error('got multiple accuracy sum') + } + + tf.dispose([ys, logits, accTensor, accSum]) + + return { loss, accSummed, accSize } + }).forEachAsync(({ loss, accSummed, accSize }) => { + datasetSize += 1 + totalLoss += loss + acc[0] += accSummed + acc[1] += accSize + }) + + const loss = totalLoss / datasetSize + + return { + val_loss: loss, + val_perplexity: Math.exp(loss), + acc: acc[0] / acc[1], + val_acc: acc[0] / acc[1] + } +} diff --git a/discojs/discojs-core/src/models/gpt/index.ts b/discojs/discojs-core/src/models/gpt/index.ts new file mode 100644 index 000000000..6c3fff9ad --- /dev/null +++ b/discojs/discojs-core/src/models/gpt/index.ts @@ -0,0 +1,139 @@ +import tf from '@tensorflow/tfjs' + +import { WeightsContainer } from '../..' +import type { Dataset } from '../../dataset' +import { Sink } from '../../utils/event_emitter' + +import type { EpochLogs, Prediction, Sample } from '../model' +import { Model } from '../model' + +import { GPTLMHeadModel } from './model' + +// TODO too big config +interface Config { + modelType: 'gpt-nano' + epochs: number // TODO mv to Task + maxIter: number + batchSize: number + blockSize: number + lr: number + vocabSize: number + maxEvalBatches: number +} + +export class GPT extends Model { + private readonly model: GPTLMHeadModel + + private static readonly batchSize = 4 + private static readonly blockSize = 128 + private static readonly vocabSize = 50258 + + constructor () { + super() + + // TODO sensible defaults? + const config: Config = { + modelType: 'gpt-nano', + epochs: 1, + maxIter: 2, + batchSize: GPT.batchSize, + blockSize: GPT.blockSize, + lr: 0.001, + vocabSize: GPT.vocabSize, + maxEvalBatches: 1 + } + + this.model = new GPTLMHeadModel(config) + } + + override get weights (): WeightsContainer { + return new WeightsContainer(this.model.weights.map((w) => w.read())) + } + + override set weights (ws: WeightsContainer) { + this.model.setWeights(ws.weights) + } + + private convertCharDataset (dataset: Dataset): Dataset { + const batchSize = 4 + const sampleSize = GPT.blockSize + 1 + const chunkSize = sampleSize * batchSize * 2 + + function toUInt16 (low: number, high: number): number { + low &= 0xff + high &= 0xff + return (high << 8) | low + } + + // TODO add support for small last batch + return dataset.batch(chunkSize, false).mapAsync(async (chunk) => { + if (!(chunk instanceof tf.Tensor)) { + throw new Error('chunk is not a Tensor') + } + if (chunk.shape.length !== 2 || chunk.shape[1] !== 1) { + throw new Error('dataset is not a only char') + } + + const buffer = await chunk.buffer() + + const xs = tf.buffer([batchSize, GPT.blockSize], 'int32') + const ys = tf.buffer([batchSize, GPT.blockSize, GPT.vocabSize], 'int32') + + for (let i = 0; i < batchSize; i++) { + for (let j = 0; j < sampleSize; j++) { + const idx = (i * sampleSize + j) * 2 + const low = buffer.get(idx) + const high = buffer.get(idx + 1) + const token = toUInt16(low, high) + if (j < sampleSize - 1) xs.set(token, i, j) + if (j > 0) ys.set(1, i, j - 1, token) + } + } + + return { xs: xs.toTensor(), ys: ys.toTensor() } + }) + } + + override async * train ( + trainingData: Dataset, + validationData?: Dataset, + epochs = 1, + tracker = new Sink() + ): AsyncGenerator { + for (let i = 0; i < epochs; i++) { + let logs: tf.Logs | undefined + + await this.model.fitDataset( + this.convertCharDataset(trainingData), { + epochs: 1, + validationData: validationData !== undefined ? this.convertCharDataset(validationData) : validationData, + callbacks: { + onEpochEnd: (_, cur) => { logs = cur }, + onBatchBegin: () => { tracker.emit('batchBegin', undefined) }, + onBatchEnd: () => { tracker.emit('batchEnd', undefined) } + } + }) + + yield logs + } + } + + override async predict (input: Sample): Promise { + const ret = this.model.predict(input) + if (Array.isArray(ret)) { + throw new Error('prediction yield many Tensors but should have only returned one') + } + + return ret + } + + static deserialize (weights: WeightsContainer): Model { + const model = new GPT() + model.weights = weights + return model + } + + serialize (): WeightsContainer { + return this.weights + } +} diff --git a/discojs/discojs-core/src/models/gpt/model.ts b/discojs/discojs-core/src/models/gpt/model.ts new file mode 100644 index 000000000..f881c50c6 --- /dev/null +++ b/discojs/discojs-core/src/models/gpt/model.ts @@ -0,0 +1,542 @@ +import tf, { LayersModel, layers, serialization } from '@tensorflow/tfjs' + +import type { GPTConfig } from './config' +import { getModelSizes, DEFAULT_CONFIG } from './config' +import { train } from './train' +import type { TrainingCallbacks } from './types' + +class Range extends layers.Layer { + static readonly className = 'Range' + + computeOutputShape (inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] { + return inputShape + } + + call (input: tf.Tensor | tf.Tensor[], kwargs: Record): tf.Tensor | tf.Tensor[] { + return tf.tidy(() => { + if (Array.isArray(input)) { + // TODO support multitensor + input = input[0] + } + this.invokeCallHook(input, kwargs) + const T = input.shape[1] + if (T === undefined) throw new Error('unexpected shape') + return tf.reshape(tf.range(0, T, 1, 'int32'), [1, T]) + }) + } +} +serialization.registerClass(Range) + +class LogLayer extends layers.Layer { + static readonly className = 'LogLayer' + + computeOutputShape (inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] { + return inputShape + } + + call (input: tf.Tensor | tf.Tensor[], kwargs: Record): tf.Tensor | tf.Tensor[] { + return tf.tidy(() => { + if (Array.isArray(input)) { + input = input[0] + } + this.invokeCallHook(input, kwargs) + return input + }) + } +} +serialization.registerClass(LogLayer) + +class CausalSelfAttentionBase extends layers.Layer { + static readonly className = 'CausalSelfAttentionBase' + + private readonly blockSize: number + private readonly nHead: number + private readonly nEmbd: number + private readonly dropout: number + private readonly mask: tf.Tensor + + constructor ( + private readonly config: ConstructorParameters[0] & Record<'blockSize' | 'nHead' | 'nEmbd' | 'dropout', number> + ) { + super(config) + + this.blockSize = config.blockSize + this.nEmbd = config.nEmbd + this.nHead = config.nHead + this.dropout = config.dropout + + this.mask = tf.linalg.bandPart(tf.ones([this.blockSize, this.blockSize]), -1, 0) + } + + computeOutputShape (): tf.Shape | tf.Shape[] { + // TODO doesn't take input shape in account + return [null, this.blockSize, this.nEmbd] + } + + getConfig (): serialization.ConfigDict { + const config = super.getConfig() + return Object.assign({}, config, this.config) + } + + call (input: tf.Tensor | tf.Tensor[], kwargs: Record): tf.Tensor | tf.Tensor[] { + return tf.tidy(() => { + if (Array.isArray(input)) { + input = input[0] + } + this.invokeCallHook(input, kwargs) + + let [q, k, v] = input.split(3, -1) as [tf.Tensor, tf.Tensor, tf.Tensor] + const [B, T, C] = k.shape + const splitHeads = (x: tf.Tensor): tf.Tensor4D => + x.reshape([B, T, this.nHead, C / this.nHead]).transpose([0, 2, 1, 3]) + q = splitHeads(q) + k = splitHeads(k) + v = splitHeads(v) + + let att = tf.mul( + tf.matMul(q, k, false, true), + tf.div( + 1, + tf.sqrt(tf.cast(k.shape[k.shape.length - 1], 'float32')) + ) + ) + att = tf.add(att, tf.mul(tf.sub(1, this.mask), -1e9)) + att = tf.softmax(att, -1) + att = kwargs.training === true ? tf.dropout(att, this.dropout) : att + + let y = tf.matMul(att, v) + y = tf.transpose(y, [0, 2, 1, 3]) + y = tf.reshape(y, [B, T, C]) + + return y + }) + } +} +serialization.registerClass(CausalSelfAttentionBase) + +type CausalSelfAttentionConfig = + ConstructorParameters[0] + & Record<'blockSize' | 'nHead' | 'nEmbd' | 'dropout', number> + & { bias: boolean } + +class CausalSelfAttention extends layers.Layer { + static readonly className = 'CausalSelfAttention' + + private readonly nHead: number + private readonly nEmbd: number + private readonly dropout: number + private readonly bias: boolean + private readonly mask: tf.Tensor2D + + cAttnKernel?: tf.LayerVariable + cAttnBias?: tf.LayerVariable + cProjKernel?: tf.LayerVariable + cProjBias?: tf.LayerVariable + + constructor (private readonly config: CausalSelfAttentionConfig) { + super(config) + + this.nEmbd = config.nEmbd + this.nHead = config.nHead + this.dropout = config.dropout + this.bias = config.bias + + this.mask = tf.linalg.bandPart(tf.ones([config.blockSize, config.blockSize]), -1, 0) + } + + build (): void { + this.cAttnKernel = this.addWeight( + 'c_attn/kernel', + [this.nEmbd, 3 * this.nEmbd], + 'float32', + tf.initializers.glorotNormal({}) + ) + this.cAttnBias = this.addWeight( + 'c_attn/bias', + [3 * this.nEmbd], + 'float32', + tf.initializers.zeros() + ) + this.cProjKernel = this.addWeight( + 'c_proj/kernel', + [this.nEmbd, this.nEmbd], + 'float32', + tf.initializers.glorotNormal({}) + ) + this.cProjBias = this.addWeight( + 'c_proj/bias', + [this.nEmbd], + 'float32', + tf.initializers.zeros() + ) + } + + computeOutputShape (inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] { + return inputShape + } + + getConfig (): serialization.ConfigDict { + const config = super.getConfig() + return Object.assign({}, config, this.config) + } + + call (input: tf.Tensor | tf.Tensor[], kwargs: Record): tf.Tensor | tf.Tensor[] { + return tf.tidy(() => { + if (this.cAttnKernel === undefined || + this.cAttnBias === undefined || + this.cProjKernel === undefined || + this.cProjBias === undefined + ) { throw new Error('not built') } + + if (Array.isArray(input)) { + input = input[0] + } + this.invokeCallHook(input, kwargs) + + const dense = (x: tf.Tensor, kernel: tf.LayerVariable, bias: tf.LayerVariable): tf.Tensor => { + const k = kernel.read().expandDims(0).tile([x.shape[0], 1, 1]) + const m = x.matMul(k) + if (this.bias) { + return tf.add(m, bias.read()) + } else { + return m + } + } + + const cAttn = dense(input, this.cAttnKernel, this.cAttnBias) + + let [q, k, v] = tf.split(cAttn, 3, -1) as [tf.Tensor, tf.Tensor, tf.Tensor] + const [B, T, C] = k.shape + + const splitHeads = (x: tf.Tensor): tf.Tensor => + tf.transpose( + tf.reshape(x, [B, T, this.nHead, C / this.nHead]), + [0, 2, 1, 3] + ) + + q = splitHeads(q) + k = splitHeads(k) + v = splitHeads(v) + + let att = tf.mul( + tf.matMul(q, k, false, true), + tf.div( + 1, + tf.sqrt(tf.cast(k.shape[k.shape.length - 1], 'float32')) + ) + ) + + const mask = this.mask.slice([0, 0], [T, T]) + att = tf.add(att, tf.mul(tf.sub(1, mask), -1e9)) + att = tf.softmax(att, -1) + att = kwargs.training === true ? tf.dropout(att, this.dropout) : att + + let y = tf.matMul(att, v) + y = tf.transpose(y, [0, 2, 1, 3]) + y = tf.reshape(y, [B, T, C]) + y = dense(y, this.cProjKernel, this.cProjBias) + y = kwargs.training === true ? tf.dropout(y, this.dropout) : y + + return y + }) + } +} +serialization.registerClass(CausalSelfAttention) + +class GELU extends layers.Layer { + static readonly className = 'GELU' + + constructor () { + super({}) + } + + computeOutputShape (inputShape: tf.Shape | tf.Shape[]): tf.Shape | tf.Shape[] { + return inputShape + } + + call (input: tf.Tensor | tf.Tensor[], kwargs: Record): tf.Tensor | tf.Tensor[] { + return tf.tidy(() => { + if (Array.isArray(input)) { + // TODO support multitensor + input = input[0] + } + this.invokeCallHook(input, kwargs) + const cdf = tf.mul( + 0.5, + tf.add( + 1, + tf.tanh( + tf.mul( + tf.sqrt(tf.div(2, Math.PI)), + tf.add(input, tf.mul(0.044715, tf.pow(input, 3))) + ) + ) + ) + ) + return tf.mul(input, cdf) + }) + } +} +serialization.registerClass(GELU) + +function MLP (conf: any): LayersModel { + const config = Object.assign({ name: 'mlp' }, conf) + const inputs = tf.input({ shape: [config.blockSize, config.nEmbd] }) + let x + x = tf.layers + .dense({ + name: config.name + '/c_fc', + units: 4 * config.nEmbd, + inputDim: config.nEmbd, + inputShape: [config.blockSize, config.nEmbd] + }) + .apply(inputs) + x = new GELU().apply(x) + x = tf.layers + .dense({ + name: config.name + '/c_proj', + units: config.nEmbd, + inputDim: 4 * config.nEmbd, + inputShape: [config.blockSize, 4 * config.nEmbd] + }) + .apply(x) + x = tf.layers + .dropout({ + name: config.name + '/drop', + rate: config.residDrop + }) + .apply(x) + return tf.model({ inputs, outputs: x as any }) +} + +function Block (conf: CausalSelfAttentionConfig & { debug: boolean }): LayersModel { + const config = Object.assign({ name: 'h' }, conf) + const inputs = tf.input({ shape: [config.blockSize, config.nEmbd] }) + let x1, x2 + x1 = tf.layers + .layerNormalization({ name: config.name + '/ln_1', epsilon: 1e-5 }) + .apply(inputs) + if (config.debug) { + x1 = new LogLayer({ name: config.name + '/ln_1_log' }).apply(x1) + } + x1 = new CausalSelfAttention( + Object.assign({}, config, { name: config.name + '/attn' }) + ).apply(x1) + x1 = tf.layers.add().apply([inputs, x1 as any]) + x2 = tf.layers + .layerNormalization({ name: config.name + '/ln_2', epsilon: 1e-5 }) + .apply(x1) + x2 = MLP(Object.assign({}, config, { name: config.name + '/mlp' })).apply( + x2 + ) + x2 = tf.layers.add().apply([x1 as any, x2 as any]) + return tf.model({ name: config.name, inputs, outputs: x2 as any }) +} + +function GPT (conf: GPTConfig): LayersModel { + const configDefaults = { + name: 'transformer', + ...DEFAULT_CONFIG + } + + const modelSizes = getModelSizes(conf.modelType) + const config = Object.assign({}, configDefaults, conf, modelSizes) + + console.log('IN MODEL CONFIG', config) + + const inputs = tf.input({ shape: [null] }) + + const tokEmb = config.tokEmb + ? tf.layers + .embedding({ + name: config.name + '/wte', + inputDim: config.vocabSize, + outputDim: config.nEmbd, + embeddingsInitializer: 'zeros', + embeddingsRegularizer: undefined, + activityRegularizer: undefined + }) + .apply(inputs) + : inputs + + const range = new Range({}).apply(inputs) + let posEmb = tf.layers + .embedding({ + name: config.name + '/wpe', + inputDim: config.blockSize, + outputDim: config.nEmbd, + embeddingsInitializer: 'zeros' + }) + .apply(range) + if (config.debug) { + posEmb = new LogLayer({ name: 'posEmb' }).apply(posEmb) + } + + let x + x = tf.layers.add().apply([tokEmb as any, posEmb as any]) + x = tf.layers + .dropout({ + name: 'drop', + rate: config.embdDrop + }) + .apply(x) + if (config.debug) { + x = new LogLayer({ name: 'dropadd' }).apply(x) + } + + for (let i = 0; i < config.nLayer; i++) { + x = Block( + Object.assign({}, config, { name: config.name + '/h/' + i }) + ).apply(x) + } + x = tf.layers + .layerNormalization({ name: config.name + '/ln_f', epsilon: 1e-5 }) + .apply(x) + if (config.debug) { + x = new LogLayer({ name: 'fin/ln' }).apply(x) + } + + if (config.lmHead) { + x = tf.layers + .dense({ + name: 'lm_head', + units: config.vocabSize, + inputDim: config.nEmbd, + inputShape: [config.blockSize, config.nEmbd], + useBias: false + }) + .apply(x) + } + return tf.model({ inputs, outputs: x as any }) +} + +interface GenerateConfig { + maxNewTokens: number + temperature: number + doSample: boolean +} + +const defaultGenerateConfig: GenerateConfig = { + maxNewTokens: 20, + temperature: 1.0, + doSample: false +} + +function prepareIdx (idx: tf.TensorLike): tf.Tensor2D { + return tf.tidy(() => { + let ret: tf.Tensor + if (idx instanceof tf.Tensor) { + ret = idx.clone() + } else { + ret = tf.tensor(idx) + } + if (ret.dtype !== 'int32') { + ret = ret.toInt() + } + switch (ret.shape.length) { + case 1: + return ret.expandDims(0) + case 2: + return ret as tf.Tensor2D + default: + throw new Error('unexpected shape') + } + }) +} + +/** + * tfjs does not export LazyIterator and Dataset... + */ +declare abstract class LazyIterator { + abstract next (): Promise> +} + +declare abstract class Dataset { + abstract iterator (): Promise> + size: number +} + +class GPTModel extends LayersModel { + constructor (protected readonly config: GPTConfig) { + const gpt = GPT(config) + const { inputs, outputs, name } = gpt + super({ inputs, outputs, name }) + Object.assign(this, gpt) + } + + async fitDataset ( + dataset: Dataset, + args: tf.ModelFitDatasetArgs + ): Promise { + console.log('=== GPTModel custom train function ===') + const config = { ...this.config, ...args } + + await train( + this, + dataset as tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }>, + config, + args.callbacks as TrainingCallbacks, + args.validationData as tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }> + ) + + return new tf.History() + } +} + +interface GenerateOutput { idxNext: tf.Tensor2D, timePerToken: number } + +class GPTLMHeadModel extends GPTModel { + async generate (idxRaw: tf.TensorLike, conf: GenerateConfig, act?: (_: GenerateOutput) => Promise): Promise { + const config = Object.assign({}, defaultGenerateConfig, conf) + let idx = prepareIdx(idxRaw) + for (let step = 0; step < config.maxNewTokens; step++) { + const { idxNext, timePerToken } = this.generateOnce(this, idx, config) + const idxNew = idx.concat(idxNext, 1) + tf.dispose(idx) + idx = idxNew + const idxNextArr = await (idxNext as any).array() + tf.dispose(idxNext) + if (act !== undefined) { + await act({ idxNext: idxNextArr, timePerToken }) + } + } + const idxArr = await idx.array() + tf.dispose(idx) + return idxArr + } + + private generateOnce (model: tf.LayersModel, idx: tf.Tensor2D, config: GenerateConfig): GenerateOutput { + let timePerToken = performance.now() + + const idxNext = tf.tidy(() => { + const blockSize: any = model.inputs[0].shape[1] + const idxCond = + idx.shape[1] <= blockSize + ? idx + : idx.slice([0, -blockSize], [-1, -1]) + const outputed = model.predict(idxCond) + if (Array.isArray(outputed)) throw new Error('model outputed multiple values') + if (outputed.shape.length !== 3) throw new Error('model outputed weird shape') + const logits = outputed as tf.Tensor3D + + timePerToken = performance.now() - timePerToken + const logitsScaled = logits + .slice([0, idx.shape[1] - 1, 0]) + .reshape([logits.shape[0], logits.shape[2]]) + .div(tf.scalar(config.temperature)) + const probs = logitsScaled.softmax(-1) + if (config.doSample) { + return tf.multinomial(probs, 1) as tf.Tensor2D + } else { + return probs.argMax(-1).expandDims(1) + } + }) + + return { + idxNext, + timePerToken + } + } +} + +export { GPT, GPTModel, GPTLMHeadModel } diff --git a/discojs/discojs-core/src/models/gpt/optimizers.ts b/discojs/discojs-core/src/models/gpt/optimizers.ts new file mode 100644 index 000000000..7c8ee03ea --- /dev/null +++ b/discojs/discojs-core/src/models/gpt/optimizers.ts @@ -0,0 +1,118 @@ +import tf, { AdamOptimizer } from '@tensorflow/tfjs' + +function l2Loss (tensor: tf.Tensor): tf.Tensor { + return tf.div(tf.sum(tf.square(tensor)), 2) +} + +function globalNorm (tensors: tf.Tensor[]): tf.Tensor { + const halfSquaredNorms: tf.Tensor[] = [] + tensors.forEach((tensor: tf.Tensor) => { + halfSquaredNorms.push(l2Loss(tensor)) + }) + const halfSquaredNorm: tf.Tensor = tf.sum(tf.stack(halfSquaredNorms)) + const norm: tf.Tensor = tf.sqrt( + tf.mul(halfSquaredNorm, tf.scalar(2.0, halfSquaredNorm.dtype)) + ) + return norm +} + +function clipByGlobalNorm ( + tensors: tf.Tensor[], + clipNorm: number, + useNorm?: tf.Tensor +): tf.Tensor[] { + return tf.tidy(() => { + useNorm = useNorm ?? globalNorm(tensors) + const scale: tf.Tensor = tf.mul( + clipNorm, + tf.minimum( + tf.div(tf.scalar(1.0), useNorm), + tf.div(tf.scalar(1.0, useNorm.dtype), clipNorm) + ) + ) + const tensorsClipped: tf.Tensor[] = [] + tensors.forEach((tensor: tf.Tensor) => { + tensorsClipped.push(tf.clone(tf.mul(tensor, scale))) + }) + return tensorsClipped + }) +} + +function clipByGlobalNormObj ( + tensorsObj: Record, + clipNorm: number, + useNorm?: tf.Tensor +): Record { + const varNames: string[] = Object.keys(tensorsObj) + const tensorsArr: tf.Tensor[] = varNames.map((n: string) => tensorsObj[n]) + const tensorsArrClipped: tf.Tensor[] = clipByGlobalNorm( + tensorsArr, + clipNorm, + useNorm + ) + const tensorsObjClipped: Record = {} + tensorsArrClipped.forEach((t: tf.Tensor, ti: number) => { + tensorsObjClipped[varNames[ti]] = t + }) + return tensorsObjClipped +} + +class AdamW extends AdamOptimizer { + weightDecayRate: number + includeInWeightDecay: string[] + excludeFromWeightDecay: string[] + gradientClipNorm: number + + constructor (params: { + learningRate?: number + beta1?: number + beta2?: number + epsilon?: number + weightDecayRate?: number + includeInWeightDecay?: string[] + excludeFromWeightDecay?: string[] + gradientClipNorm?: number + }) { + console.log('Using custom AdamW optimizer') + const defaultParams = { + learningRate: 0.1, + beta1: 0.9, + beta2: 0.999, + epsilon: 1e-7, + weightDecayRate: 0, + includeInWeightDecay: [], + excludeFromWeightDecay: [], + gradientClipNorm: 1.0 + } + const p = Object.assign({}, defaultParams, params) + super(p.learningRate, p.beta1, p.beta2, p.epsilon) + this.weightDecayRate = p.weightDecayRate + this.includeInWeightDecay = p.includeInWeightDecay + this.excludeFromWeightDecay = p.excludeFromWeightDecay + this.gradientClipNorm = p.gradientClipNorm + } + + applyGradients (variableGradients: Record | Array<{ name: string, tensor: tf.Tensor }>): void { + const varNames: string[] = Array.isArray(variableGradients) + ? variableGradients.map((v) => v.name) + : Object.keys(variableGradients) + + varNames.forEach((name: string) => { + if (this.includeInWeightDecay.includes(name)) { + const value = tf.engine().registeredVariables[name] + const newValue: tf.Tensor = tf.sub( + value, + tf.mul( + this.learningRate, + tf.mul(value, this.weightDecayRate) + ) + ) + value.assign(newValue) + } + }) + + super.applyGradients(variableGradients) + } +} + +export { AdamW, clipByGlobalNorm, clipByGlobalNormObj } diff --git a/discojs/discojs-core/src/models/gpt/train.ts b/discojs/discojs-core/src/models/gpt/train.ts new file mode 100644 index 000000000..3d1c92649 --- /dev/null +++ b/discojs/discojs-core/src/models/gpt/train.ts @@ -0,0 +1,115 @@ +import tf from '@tensorflow/tfjs' + +import { AdamW, clipByGlobalNormObj } from './optimizers' +import type { GPTConfig } from './config' +import { DEFAULT_CONFIG } from './config' +import evaluate from './evaluate' +import type { TrainingCallbacks } from './types' + +function resolveConfig (config: GPTConfig): Required { + return { + ...DEFAULT_CONFIG, + ...config + } +} + +function getCustomAdam (model: tf.LayersModel, c: Required): tf.Optimizer { + const includeInWeightDecay: string[] = [] + const excludeFromWeightDecay: string[] = [] + + // TODO unsafe cast + const namedWeights = (model as unknown as any).getNamedWeights() as Array<{ name: string, tensor: tf.Tensor }> + + namedWeights.forEach((v) => { + if ( + v.name.includes('bias') || + v.name.includes('normalization') || + v.name.includes('emb') + ) { + excludeFromWeightDecay.push(v.name) + } else { + includeInWeightDecay.push(v.name) + } + }) + + return new AdamW({ + learningRate: c.lr, + weightDecayRate: c.weightDecay, + includeInWeightDecay, + excludeFromWeightDecay + }) +} + +export async function train ( + model: tf.LayersModel, + ds: tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }>, + config: GPTConfig, + callbacks: TrainingCallbacks, + evalDs?: tf.data.Dataset<{ xs: tf.Tensor2D, ys: tf.Tensor3D }> +): Promise { + const c = resolveConfig(config) + + const opt = c.weightDecay !== 0 ? getCustomAdam(model, c) : tf.train.adam(c.lr) + + await callbacks.onTrainBegin?.() + + console.warn('=== Starting training ===') + + for (let epoch = 1; epoch <= c.epochs; epoch++) { + await callbacks.onEpochBegin?.(epoch) + + await tf.data.zip<[number, { xs: tf.Tensor2D, ys: tf.Tensor3D }]>([ + tf.data.generator(function * () { + for (let i = 1; i <= c.maxIter; i++) { yield i } + }), + ds + ]).mapAsync(async ([iteration, { xs, ys }]) => { + await callbacks.onBatchBegin?.(iteration) + return { iteration, xs, ys } + }).map(({ iteration, xs, ys }) => tf.tidy(() => { + const { grads, value: loss } = opt.computeGradients(() => { + const logits = model.apply(xs) + if (Array.isArray(logits)) { + throw new Error('model outputed many tensor') + } + if (logits instanceof tf.SymbolicTensor) { + throw new Error('model outputed symbolic tensor') + } + + const loss = tf.losses.softmaxCrossEntropy(ys, logits) + return loss as tf.Scalar + }) + + tf.dispose([xs, ys]) + + const gradsClipped = clipByGlobalNormObj(grads, 1) + opt.applyGradients(gradsClipped) + + return { iteration, loss } + })).mapAsync(async ({ iteration, loss }) => { + const raw = await loss.array() + tf.dispose(loss) + return [iteration, raw] + }).mapAsync(async ([iteration, loss]) => { + await callbacks.onBatchEnd?.(iteration) + return [iteration, loss] + }).forEachAsync(([iteration, loss]) => { + console.log( + `Epoch: ${epoch}`, + `\tStep: ${iteration} / ${c.maxIter}`, + `\tLoss: ${loss.toFixed(3)}`, + `\tMemory: ${(tf.memory().numBytes / 1024 / 1024).toFixed(2)} MB` + ) + }) + + let logs: tf.Logs | undefined + if (evalDs !== undefined) { + logs = await evaluate(model, evalDs) + } + await callbacks.onEpochEnd?.(epoch, logs) + + await new Promise((resolve) => setTimeout(resolve, 1)) + } + + await callbacks.onTrainEnd?.() +} diff --git a/discojs/discojs-core/src/models/gpt/types.ts b/discojs/discojs-core/src/models/gpt/types.ts new file mode 100644 index 000000000..ed40e168d --- /dev/null +++ b/discojs/discojs-core/src/models/gpt/types.ts @@ -0,0 +1,10 @@ +import type tf from '@tensorflow/tfjs' + +export interface TrainingCallbacks { + onEpochBegin?: (epoch: number, logs?: tf.Logs) => Promise + onEpochEnd?: (epoch: number, logs?: tf.Logs) => Promise + onBatchBegin?: (batch: number, logs?: tf.Logs) => Promise + onBatchEnd?: (batch: number, logs?: tf.Logs) => Promise + onTrainBegin?: (logs?: tf.Logs) => Promise + onTrainEnd?: (logs?: tf.Logs) => Promise +} diff --git a/discojs/discojs-core/src/models/index.ts b/discojs/discojs-core/src/models/index.ts index 25b868724..e6bf727ff 100644 --- a/discojs/discojs-core/src/models/index.ts +++ b/discojs/discojs-core/src/models/index.ts @@ -1,2 +1,4 @@ export { Model } from './model' + +export { GPT } from './gpt' export { TFJS } from './tfjs' diff --git a/discojs/discojs-core/src/serialization/model.ts b/discojs/discojs-core/src/serialization/model.ts index fa4dcdd59..d41585490 100644 --- a/discojs/discojs-core/src/serialization/model.ts +++ b/discojs/discojs-core/src/serialization/model.ts @@ -2,7 +2,9 @@ import msgpack from 'msgpack-lite' import type tf from '@tensorflow/tfjs' import type { Model } from '..' -import { models } from '..' +import { models, serialization } from '..' + +const enum Type { TFJS, GPT } export type Encoded = Uint8Array @@ -13,7 +15,12 @@ export function isEncoded (raw: unknown): raw is Encoded { export async function encode (model: Model): Promise { if (model instanceof models.TFJS) { const serialized = await model.serialize() - return msgpack.encode(serialized) + return msgpack.encode([Type.TFJS, serialized]) + } + + if (model instanceof models.GPT) { + const serialized = await serialization.weights.encode(model.serialize()) + return msgpack.encode([Type.GPT, serialized]) } throw new Error('unknown model type') @@ -23,9 +30,32 @@ export async function decode (encoded: unknown): Promise { if (!isEncoded(encoded)) { throw new Error('invalid encoding') } - const raw = msgpack.decode(encoded) + const raw: unknown = msgpack.decode(encoded) + + if (!Array.isArray(raw) || raw.length !== 2) { + throw new Error('invalid encoding') + } + const [type, rawModel] = raw as [unknown, unknown] - // TODO how to select model type? prepend with model id - // TODO totally unsafe casting - return await models.TFJS.deserialize(raw as tf.io.ModelArtifacts) + if (typeof type !== 'number' || (type !== Type.TFJS && type !== Type.GPT)) { + throw new Error('invalid encoding') + } + switch (type) { + case Type.TFJS: + // TODO totally unsafe casting + return await models.TFJS.deserialize(rawModel as tf.io.ModelArtifacts) + case Type.GPT: { + if (!Array.isArray(rawModel)) { + throw new Error('invalid encoding') + } + const arr: unknown[] = rawModel + if (arr.some((r) => typeof r !== 'number')) { + throw new Error('invalid encoding') + } + const nums = arr as number[] + + const serialized = serialization.weights.decode(nums) + return models.GPT.deserialize(serialized) + } + } } diff --git a/discojs/discojs-core/src/training/trainer/trainer.ts b/discojs/discojs-core/src/training/trainer/trainer.ts index eea1d866f..c9b20d784 100644 --- a/discojs/discojs-core/src/training/trainer/trainer.ts +++ b/discojs/discojs-core/src/training/trainer/trainer.ts @@ -80,6 +80,9 @@ export abstract class Trainer { if (logs !== undefined && !isNaN(logs.acc) && !isNaN(logs.val_acc)) { this.trainingInformant.updateTrainingGraph(this.roundDecimals(logs.acc)) this.trainingInformant.updateValidationGraph(this.roundDecimals(logs.val_acc)) + if (logs.val_loss !== undefined) { + this.trainingInformant.loss = logs.val_loss + } } else { this.trainerLogger.error('onEpochEnd: NaN value') }