Skip to content

Commit

Permalink
fixup! discojs-core/models: add gpt
Browse files Browse the repository at this point in the history
  • Loading branch information
tharvik committed Mar 1, 2024
1 parent bcddd72 commit f1d0036
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 12 deletions.
5 changes: 3 additions & 2 deletions discojs/discojs-core/src/models/gpt/evaluate.ts
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import tf from '@tensorflow/tfjs'

import { data } from '../..'
import { GPTConfig } from '.'
import type { data } from '../..'

import type { GPTConfig } from './config'

export default async function evaluate(
model: any,
Expand Down
13 changes: 9 additions & 4 deletions discojs/discojs-core/src/models/gpt/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
export * from './train'
export * from './optimizers'
export * from './model'
export * from './config'
import { TFJS } from '../tfjs'

import { GPTLMHeadModel } from './model'

export class GPT extends TFJS {
constructor () {
super(new GPTLMHeadModel({}))
}
}
9 changes: 6 additions & 3 deletions discojs/discojs-core/src/models/gpt/model.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import tf from '@tensorflow/tfjs'

import { GPTConfig, getModelSizes, DEFAULT_CONFIG } from '.'
import { data, training } from '../..'
import type { data } from '../..'

import type { GPTConfig } from './config'
import { getModelSizes, DEFAULT_CONFIG } from './config'
import { train } from './train'
import type { TrainingCallbacks } from './types'

const Range = (config: any) => new Range_(config)
class Range_ extends tf.layers.Layer {
Expand Down Expand Up @@ -591,7 +594,7 @@ class GPTModel extends tf.LayersModel {
this,
dataset as data.Dataset,
config,
args.callbacks as training.TrainingCallbacks
args.callbacks as TrainingCallbacks
)
return {} as tf.History
}
Expand Down
9 changes: 6 additions & 3 deletions discojs/discojs-core/src/models/gpt/train.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import tf from '@tensorflow/tfjs'

import { data, training } from '../..'
import type { data } from '../..'

import { AdamW, clipByGlobalNormObj } from './optimizers'
import { GPTConfig, DEFAULT_CONFIG } from './config'
import type { GPTConfig } from './config'
import { DEFAULT_CONFIG } from './config'
import evaluate from './evaluate'
import { TrainingCallbacks } from './types'

export type GPTConfigWithWandb = Required<GPTConfig>

Expand Down Expand Up @@ -39,7 +42,7 @@ export async function train(
model: tf.LayersModel,
ds: data.Dataset,
config: GPTConfig,
callbacks: training.TrainingCallbacks,
callbacks: TrainingCallbacks,
evalDs?: data.Dataset
): Promise<void> {
const c = getConfig(config)
Expand Down
10 changes: 10 additions & 0 deletions discojs/discojs-core/src/models/gpt/types.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import type tf from '@tensorflow/tfjs'

export interface TrainingCallbacks {
onEpochBegin: (epoch: number, logs?: tf.Logs) => Promise<void>
onEpochEnd: (epoch: number, logs?: tf.Logs) => Promise<void>
onBatchBegin: (batch: number, logs?: tf.Logs) => Promise<void>
onBatchEnd: (batch: number, logs?: tf.Logs) => Promise<void>
onTrainBegin: (logs?: tf.Logs) => Promise<void>
onTrainEnd: (logs?: tf.Logs) => Promise<void>
}
1 change: 1 addition & 0 deletions discojs/discojs-core/src/models/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export interface Events extends Record<string, unknown> {
/** Trainable predictor */
// TODO make it typesafe: same shape of data/input/weights
export abstract class Model {
// TODO move to train generator input/output
abstract get weights (): WeightsContainer
abstract set weights (ws: WeightsContainer)

Expand Down

0 comments on commit f1d0036

Please sign in to comment.