diff --git a/src/context.py b/src/context.py index d2657cd5..6dfa88e7 100644 --- a/src/context.py +++ b/src/context.py @@ -111,6 +111,7 @@ class WandB(DataClass): class Optimizer(DataClass): + ema_beta = 0.99 use_shampoo: bool = True block_size: int = 512 epsilon: float = 1e-5 diff --git a/src/model.py b/src/model.py index 989ec88c..adf88da3 100644 --- a/src/model.py +++ b/src/model.py @@ -83,14 +83,13 @@ def conv(ctx: Context, inp: jnp.ndarray, conv_kernel: int, scale: float, in_feat def bottleneck_block(ctx: Context, inp: jnp.ndarray) -> jnp.ndarray: ctx = ctx.add_to_prefix("bottleneck") inp = scale_norm_act(ctx, inp, ctx.dims.features, act=False, init_mean=None) - inp = conv(ctx, inp, ctx.dims.outer_bottleneck_kernel, 1 / ctx.dims.heads, - ctx.dims.features, ctx.dims.inner_bottleneck_features) + inp = conv(ctx, inp, ctx.dims.outer_bottleneck_kernel, 1 / ctx.dims.heads, ctx.dims.features, + ctx.dims.inner_bottleneck_features) inp = scale_norm_act(ctx, inp, ctx.dims.inner_bottleneck_features, psum=True) - inp = conv(ctx, inp, ctx.dims.inner_bottleneck_kernel, 1, - ctx.dims.inner_bottleneck_features, ctx.dims.inner_bottleneck_features) + inp = conv(ctx, inp, ctx.dims.inner_bottleneck_kernel, 1, ctx.dims.inner_bottleneck_features, + ctx.dims.inner_bottleneck_features) inp = scale_norm_act(ctx, inp, ctx.dims.inner_bottleneck_features) - return conv(ctx, inp, ctx.dims.outer_bottleneck_kernel, 1, - ctx.dims.inner_bottleneck_features, ctx.dims.features) + return conv(ctx, inp, ctx.dims.outer_bottleneck_kernel, 1, ctx.dims.inner_bottleneck_features, ctx.dims.features) def pointwise_block(ctx: Context, inp: jnp.ndarray) -> jnp.ndarray: @@ -284,8 +283,8 @@ def _grad(dy: REVERSIBLE_CTX) -> typing.Tuple[ return _fn(*src) -def cross_entropy_loss(ctx: Context, src_wgt: typing.Tuple[jnp.ndarray, jnp.ndarray], tgt: jnp.ndarray - ) -> typing.Tuple[jnp.ndarray, jnp.ndarray]: +def cross_entropy_loss(ctx: Context, src_wgt: typing.Tuple[jnp.ndarray, jnp.ndarray], tgt: jnp.ndarray) -> typing.Tuple[ + jnp.ndarray, jnp.ndarray]: # Forward: logsumexp(x) - x[target] # Backward: (logsumexp(x) - x[target] + logsumexp(x)^2 * z_loss).grad # -> softmax(x) - 1 + softmax(x) * logsumexp(x) * z_loss @@ -293,8 +292,7 @@ def cross_entropy_loss(ctx: Context, src_wgt: typing.Tuple[jnp.ndarray, jnp.ndar devices = ctx.dims.heads def _xent_slice(inp: typing.Tuple[ - jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray], - carry): + jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray], carry): inp, i, wgt, inner_tgt, index, d_wgt, loss, accuracy = inp inp_slice = inp[i] tmp = matmul(inp_slice, wgt).reshape(devices, -1, ctx.dims.vocab) @@ -324,14 +322,10 @@ def _fn(inp: jnp.ndarray, inner_tgt: jnp.ndarray, wgt: jnp.ndarray): inner_tgt = inner_tgt.reshape(ctx.data.vocab_size // ctx.dims.inner_bottleneck_features, -1) index = lax.psum_scatter(jnp.arange(ctx.dims.heads), ParallelAxes.model) // devices index = index.astype(jnp.int32) - (_, _, _, _, _, d_wgt, loss, accuracy), dx = lax.scan(_xent_slice, (inp, jnp.zeros((), dtype=jnp.int32), - wgt, - inner_tgt, index, - jnp.zeros(wgt.shape[::-1], - dtype=jnp.float32), - jnp.zeros((), dtype=jnp.float32), - jnp.zeros((), dtype=jnp.float32)), None, - inp.shape[0]) + (_, _, _, _, _, d_wgt, loss, accuracy), dx = lax.scan(_xent_slice, ( + inp, jnp.zeros((), dtype=jnp.int32), wgt, inner_tgt, index, + jnp.zeros(wgt.shape[::-1], dtype=jnp.float32), jnp.zeros((), dtype=jnp.float32), + jnp.zeros((), dtype=jnp.float32)), None, inp.shape[0]) dx = dx.transpose(1, 0, 2) / tgt.size # Shape[Features, inp.shape[0] // step, step // devices] dx = lax.all_gather(dx, ParallelAxes.model, axis=2).reshape(ctx.dims.features, -1).transpose(1, 0) dx = dx.reshape(original_shape) @@ -360,7 +354,7 @@ def _grad(dy) -> typing.Tuple[jnp.ndarray, jnp.ndarray, None, jnp.ndarray]: return _fn(*src) -def body_ctx(ctx: Context, src: jnp.ndarray) -> typing.Union[typing.Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]: +def _body_ctx(ctx: Context, src: jnp.ndarray) -> typing.Union[typing.Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]: src = input_embed(ctx, src) zero = jnp.zeros_like(src) src = (ctx.parameters, src, zero, src, zero) @@ -381,6 +375,33 @@ def body_ctx(ctx: Context, src: jnp.ndarray) -> typing.Union[typing.Tuple[jnp.nd return out, wgt +def _consistency_loss(out: jnp.ndarray, ema: jnp.ndarray) -> jnp.ndarray: + @jax.custom_gradient + def _fn(o: jnp.ndarray, e: jnp.ndarray): + # forward: (o - e) ** 2 + # backward: (o - e) * 2 + def _grad(dy: jnp.ndarray) -> typing.Tuple[jnp.ndarray, None]: + grad = (o - e) * 2 / out.size + return grad, None + + return jnp.zeros(()), _grad + + return _fn(out, ema) + + +def body_ctx(ctx: Context, src: jnp.ndarray) -> typing.Union[typing.Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]: + out = _body_ctx(ctx, src) + name_cache = ctx.name_cache + ema_ctx = ctx.add_to_prefix("ema_weight") + ema_ctx.name_cache = {} + ema_out = _body_ctx(ema_ctx, src) + ctx.name_cache = name_cache + + if not ctx.is_initializing: + return out + return out + _consistency_loss(out, ema_out), out[1] + + def compute(params: typing.Dict[str, jnp.ndarray], inp: jnp.ndarray) -> typing.Tuple[jnp.ndarray, jnp.ndarray]: ctx = Context() ctx.parameters = params diff --git a/src/optimizer.py b/src/optimizer.py index adb3be5e..ed305124 100644 --- a/src/optimizer.py +++ b/src/optimizer.py @@ -41,6 +41,10 @@ def small_parameter(param_name: str, grad: jnp.ndarray) -> bool: return "norm" in param_name.lower() or "rezero" in param_name.lower() or grad.ndim < 2 +def weighted_add(x0, x1, beta, heavyball: bool = False): + return x0 * beta + x1 * (1 if heavyball else (1 - beta)) + + def ema(ctx: Context, inp: jnp.ndarray, step: jnp.ndarray, beta: float, prefix: str, quantize: typing.Optional[bool] = None, init_val: typing.Optional[jnp.ndarray] = None, heavyball: bool = False) -> jnp.ndarray: @@ -49,7 +53,7 @@ def ema(ctx: Context, inp: jnp.ndarray, step: jnp.ndarray, beta: float, prefix: quantize = not small_parameter(ctx.global_prefix, inp) state = get_param(ctx, "momentum_buffer", inp.shape, dtype=jnp.bfloat16 if quantize else ctx.model.storage_dtype, init_val=jnp.zeros_like(inp) if init_val is None else init_val) - new_state = state.astype(jnp.float32) * beta + inp * (1 if heavyball else (1 - beta)) + new_state = weighted_add(state.astype(jnp.float32), inp, beta, heavyball) assign(ctx, "momentum_buffer", new_state) if heavyball: return new_state @@ -132,3 +136,6 @@ def update(ctx: Context, grads: typing.Dict[str, jnp.ndarray], step: jnp.ndarray grad = ema(inner_ctx, grad, step, 1 - ctx.optimizer.momentum_beta, "momentum", heavyball=True) ctx.parameters[param_name] = (1 + ctx.optimizer.weight_decay * parameter_lr) * ctx.parameters[param_name] ctx.parameters[param_name] = grad * parameter_lr + ctx.parameters[param_name] + ema_name = [k for k in ctx.parameters.keys() if param_name in k and 'ema_weight' in k][0] + ctx.parameters[ema_name] = weighted_add(ctx.parameters[ema_name], ctx.parameters[param_name], + ctx.optimizer.ema_beta)