Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MESA/SAM #59

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class WandB(DataClass):


class Optimizer(DataClass):
ema_beta = 0.99
use_shampoo: bool = True
block_size: int = 512
epsilon: float = 1e-5
Expand Down
59 changes: 40 additions & 19 deletions src/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,13 @@ def conv(ctx: Context, inp: jnp.ndarray, conv_kernel: int, scale: float, in_feat
def bottleneck_block(ctx: Context, inp: jnp.ndarray) -> jnp.ndarray:
ctx = ctx.add_to_prefix("bottleneck")
inp = scale_norm_act(ctx, inp, ctx.dims.features, act=False, init_mean=None)
inp = conv(ctx, inp, ctx.dims.outer_bottleneck_kernel, 1 / ctx.dims.heads,
ctx.dims.features, ctx.dims.inner_bottleneck_features)
inp = conv(ctx, inp, ctx.dims.outer_bottleneck_kernel, 1 / ctx.dims.heads, ctx.dims.features,
ctx.dims.inner_bottleneck_features)
inp = scale_norm_act(ctx, inp, ctx.dims.inner_bottleneck_features, psum=True)
inp = conv(ctx, inp, ctx.dims.inner_bottleneck_kernel, 1,
ctx.dims.inner_bottleneck_features, ctx.dims.inner_bottleneck_features)
inp = conv(ctx, inp, ctx.dims.inner_bottleneck_kernel, 1, ctx.dims.inner_bottleneck_features,
ctx.dims.inner_bottleneck_features)
inp = scale_norm_act(ctx, inp, ctx.dims.inner_bottleneck_features)
return conv(ctx, inp, ctx.dims.outer_bottleneck_kernel, 1,
ctx.dims.inner_bottleneck_features, ctx.dims.features)
return conv(ctx, inp, ctx.dims.outer_bottleneck_kernel, 1, ctx.dims.inner_bottleneck_features, ctx.dims.features)


def pointwise_block(ctx: Context, inp: jnp.ndarray) -> jnp.ndarray:
Expand Down Expand Up @@ -284,17 +283,16 @@ def _grad(dy: REVERSIBLE_CTX) -> typing.Tuple[
return _fn(*src)


def cross_entropy_loss(ctx: Context, src_wgt: typing.Tuple[jnp.ndarray, jnp.ndarray], tgt: jnp.ndarray
) -> typing.Tuple[jnp.ndarray, jnp.ndarray]:
def cross_entropy_loss(ctx: Context, src_wgt: typing.Tuple[jnp.ndarray, jnp.ndarray], tgt: jnp.ndarray) -> typing.Tuple[
jnp.ndarray, jnp.ndarray]:
# Forward: logsumexp(x) - x[target]
# Backward: (logsumexp(x) - x[target] + logsumexp(x)^2 * z_loss).grad
# -> softmax(x) - 1 + softmax(x) * logsumexp(x) * z_loss
src, param = src_wgt
devices = ctx.dims.heads

def _xent_slice(inp: typing.Tuple[
jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray],
carry):
jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray, jnp.ndarray], carry):
inp, i, wgt, inner_tgt, index, d_wgt, loss, accuracy = inp
inp_slice = inp[i]
tmp = matmul(inp_slice, wgt).reshape(devices, -1, ctx.dims.vocab)
Expand Down Expand Up @@ -324,14 +322,10 @@ def _fn(inp: jnp.ndarray, inner_tgt: jnp.ndarray, wgt: jnp.ndarray):
inner_tgt = inner_tgt.reshape(ctx.data.vocab_size // ctx.dims.inner_bottleneck_features, -1)
index = lax.psum_scatter(jnp.arange(ctx.dims.heads), ParallelAxes.model) // devices
index = index.astype(jnp.int32)
(_, _, _, _, _, d_wgt, loss, accuracy), dx = lax.scan(_xent_slice, (inp, jnp.zeros((), dtype=jnp.int32),
wgt,
inner_tgt, index,
jnp.zeros(wgt.shape[::-1],
dtype=jnp.float32),
jnp.zeros((), dtype=jnp.float32),
jnp.zeros((), dtype=jnp.float32)), None,
inp.shape[0])
(_, _, _, _, _, d_wgt, loss, accuracy), dx = lax.scan(_xent_slice, (
inp, jnp.zeros((), dtype=jnp.int32), wgt, inner_tgt, index,
jnp.zeros(wgt.shape[::-1], dtype=jnp.float32), jnp.zeros((), dtype=jnp.float32),
jnp.zeros((), dtype=jnp.float32)), None, inp.shape[0])
dx = dx.transpose(1, 0, 2) / tgt.size # Shape[Features, inp.shape[0] // step, step // devices]
dx = lax.all_gather(dx, ParallelAxes.model, axis=2).reshape(ctx.dims.features, -1).transpose(1, 0)
dx = dx.reshape(original_shape)
Expand Down Expand Up @@ -360,7 +354,7 @@ def _grad(dy) -> typing.Tuple[jnp.ndarray, jnp.ndarray, None, jnp.ndarray]:
return _fn(*src)


def body_ctx(ctx: Context, src: jnp.ndarray) -> typing.Union[typing.Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
def _body_ctx(ctx: Context, src: jnp.ndarray) -> typing.Union[typing.Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
src = input_embed(ctx, src)
zero = jnp.zeros_like(src)
src = (ctx.parameters, src, zero, src, zero)
Expand All @@ -381,6 +375,33 @@ def body_ctx(ctx: Context, src: jnp.ndarray) -> typing.Union[typing.Tuple[jnp.nd
return out, wgt


def _consistency_loss(out: jnp.ndarray, ema: jnp.ndarray) -> jnp.ndarray:
@jax.custom_gradient
def _fn(o: jnp.ndarray, e: jnp.ndarray):
# forward: (o - e) ** 2
# backward: (o - e) * 2
def _grad(dy: jnp.ndarray) -> typing.Tuple[jnp.ndarray, None]:
grad = (o - e) * 2 / out.size
return grad, None

return jnp.zeros(()), _grad

return _fn(out, ema)


def body_ctx(ctx: Context, src: jnp.ndarray) -> typing.Union[typing.Tuple[jnp.ndarray, jnp.ndarray], jnp.ndarray]:
out = _body_ctx(ctx, src)
name_cache = ctx.name_cache
ema_ctx = ctx.add_to_prefix("ema_weight")
ema_ctx.name_cache = {}
ema_out = _body_ctx(ema_ctx, src)
ctx.name_cache = name_cache

if not ctx.is_initializing:
return out
return out + _consistency_loss(out, ema_out), out[1]


def compute(params: typing.Dict[str, jnp.ndarray], inp: jnp.ndarray) -> typing.Tuple[jnp.ndarray, jnp.ndarray]:
ctx = Context()
ctx.parameters = params
Expand Down
9 changes: 8 additions & 1 deletion src/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def small_parameter(param_name: str, grad: jnp.ndarray) -> bool:
return "norm" in param_name.lower() or "rezero" in param_name.lower() or grad.ndim < 2


def weighted_add(x0, x1, beta, heavyball: bool = False):
return x0 * beta + x1 * (1 if heavyball else (1 - beta))


def ema(ctx: Context, inp: jnp.ndarray, step: jnp.ndarray, beta: float, prefix: str,
quantize: typing.Optional[bool] = None, init_val: typing.Optional[jnp.ndarray] = None,
heavyball: bool = False) -> jnp.ndarray:
Expand All @@ -49,7 +53,7 @@ def ema(ctx: Context, inp: jnp.ndarray, step: jnp.ndarray, beta: float, prefix:
quantize = not small_parameter(ctx.global_prefix, inp)
state = get_param(ctx, "momentum_buffer", inp.shape, dtype=jnp.bfloat16 if quantize else ctx.model.storage_dtype,
init_val=jnp.zeros_like(inp) if init_val is None else init_val)
new_state = state.astype(jnp.float32) * beta + inp * (1 if heavyball else (1 - beta))
new_state = weighted_add(state.astype(jnp.float32), inp, beta, heavyball)
assign(ctx, "momentum_buffer", new_state)
if heavyball:
return new_state
Expand Down Expand Up @@ -132,3 +136,6 @@ def update(ctx: Context, grads: typing.Dict[str, jnp.ndarray], step: jnp.ndarray
grad = ema(inner_ctx, grad, step, 1 - ctx.optimizer.momentum_beta, "momentum", heavyball=True)
ctx.parameters[param_name] = (1 + ctx.optimizer.weight_decay * parameter_lr) * ctx.parameters[param_name]
ctx.parameters[param_name] = grad * parameter_lr + ctx.parameters[param_name]
ema_name = [k for k in ctx.parameters.keys() if param_name in k and 'ema_weight' in k][0]
ctx.parameters[ema_name] = weighted_add(ctx.parameters[ema_name], ctx.parameters[param_name],
ctx.optimizer.ema_beta)