Skip to content

Commit

Permalink
Merge pull request #80 from tfjgeorge/layernorm
Browse files Browse the repository at this point in the history
layer norm
  • Loading branch information
tfjgeorge authored Feb 3, 2024
2 parents 5c61d0f + 1d61aed commit 0b97e1a
Show file tree
Hide file tree
Showing 4 changed files with 132 additions and 5 deletions.
44 changes: 41 additions & 3 deletions nngeometry/generator/jacobian/grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
WeightNorm1dLayer,
WeightNorm2dLayer,
Conv1dLayer,
LayerNormLayer,
)

from .grads_conv import conv2d_backward, convtranspose2d_backward, conv1d_backward
Expand Down Expand Up @@ -56,6 +57,7 @@ def Jv(cls, buffer, mod, layer, x, gy, v, v_bias):
cls.flat_grad(buffer_flat, mod, layer, x, gy)
v = v.view(-1)
if v_bias is not None:
v_bias = v_bias.view(-1)
v = torch.cat((v, v_bias))
buffer.add_(torch.mv(buffer_flat, v))

Expand Down Expand Up @@ -269,6 +271,18 @@ def flat_grad(cls, buffer, mod, layer, x, gy):
buffer[:, w_numel:].add_(gy.sum(dim=(2, 3)))


class LayerNormJacobianFactory(JacobianFactory):
@classmethod
def flat_grad(cls, buffer, mod, layer, x, gy):
w_numel = layer.weight.numel()
x_normalized = F.layer_norm(
x, normalized_shape=mod.normalized_shape, eps=mod.eps
)
buffer[:, :w_numel].add_((gy * x_normalized).reshape(x.size(0), -1))
if layer.bias is not None:
buffer[:, w_numel:].add_(gy.reshape(x.size(0), -1))


class GroupNormJacobianFactory(JacobianFactory):
@classmethod
def flat_grad(cls, buffer, mod, layer, x, gy):
Expand All @@ -282,16 +296,39 @@ class WeightNorm1dJacobianFactory(JacobianFactory):
@classmethod
def flat_grad(cls, buffer, mod, layer, x, gy):
bs = x.size(0)
gw_prime = (
torch.bmm(gy.unsqueeze(2), x.unsqueeze(1))
.view(bs, -1)
.view(bs, *mod.weight.size())
)
norm2 = (mod.weight**2).sum(dim=1, keepdim=True) + mod.eps
gw = torch.bmm(gy.unsqueeze(2) / torch.sqrt(norm2), x.unsqueeze(1))
wn2_out = F.linear(x, mod.weight / norm2**1.5)
gw -= (gy * wn2_out).unsqueeze(2) * mod.weight.unsqueeze(0)

gw = gw_prime / torch.sqrt(norm2).unsqueeze(0)

gw -= (gw_prime * mod.weight.unsqueeze(0)).sum(dim=2, keepdim=True) * (
mod.weight * norm2 ** (-1.5)
).unsqueeze(0)

buffer.add_(gw.view(bs, -1))


class WeightNorm2dJacobianFactory(JacobianFactory):
@classmethod
def flat_grad(cls, buffer, mod, layer, x, gy):
bs = x.size(0)
gw_prime = conv2d_backward(mod, x, gy).view(bs, *mod.weight.size())
norm2 = (mod.weight**2).sum(dim=(1, 2, 3), keepdim=True) + mod.eps

gw = gw_prime / torch.sqrt(norm2).unsqueeze(0)

gw -= (gw_prime * mod.weight.unsqueeze(0)).sum(dim=(2, 3, 4), keepdim=True) * (
mod.weight * norm2 ** (-1.5)
).unsqueeze(0)

buffer.add_(gw.view(bs, -1))

@classmethod
def flat_grad_(cls, buffer, mod, layer, x, gy):
bs = x.size(0)
out_dim = mod.weight.size(0)
norm2 = (mod.weight**2).sum(dim=(1, 2, 3)) + mod.eps
Expand Down Expand Up @@ -426,4 +463,5 @@ def kfe_diag(cls, buffer, mod, layer, x, gy, evecs_a, evecs_g):
WeightNorm2dLayer: WeightNorm2dJacobianFactory,
Cosine1dLayer: Cosine1dJacobianFactory,
Affine1dLayer: Affine1dJacobianFactory,
LayerNormLayer: LayerNormJacobianFactory,
}
26 changes: 26 additions & 0 deletions nngeometry/layercollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class LayerCollection:
"Affine1d",
"ConvTranspose2d",
"Conv1d",
"LayerNorm"
]

def __init__(self, layers=None):
Expand Down Expand Up @@ -146,6 +147,10 @@ def _module_to_layer(mod):
return Affine1dLayer(
num_features=mod.num_features, bias=(mod.bias is not None)
)
elif mod_class == "LayerNorm":
return LayerNormLayer(
normalized_shape=mod.normalized_shape, bias=(mod.bias is not None)
)

def numel(self):
"""
Expand Down Expand Up @@ -313,6 +318,24 @@ def __eq__(self, other):
return self.num_features == other.num_features


class LayerNormLayer(AbstractLayer):
def __init__(self, normalized_shape, bias=True):
self.weight = Parameter(*normalized_shape)
if bias:
self.bias = Parameter(*normalized_shape)
else:
self.bias = None

def numel(self):
if self.bias is not None:
return self.weight.numel() + self.bias.numel()
else:
return self.weight.numel()

def __eq__(self, other):
return self.weight == other.weight and self.bias == other.bias


class GroupNormLayer(AbstractLayer):
def __init__(self, num_groups, num_channels):
self.num_channels = num_channels
Expand Down Expand Up @@ -406,3 +429,6 @@ def __init__(self, *size):

def numel(self):
return reduce(operator.mul, self.size, 1)

def __eq__(self, other):
return self.size == other.size
58 changes: 58 additions & 0 deletions tests/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,3 +503,61 @@ def output_fn(input, target):

layer_collection = LayerCollection.from_model(net)
return (train_loader, layer_collection, net.parameters(), net, output_fn, 4)


class LayerNormNet(nn.Module):
def __init__(self, out_size):
super(LayerNormNet, self).__init__()

self.linear1 = nn.Linear(18 * 18, out_size)
self.layer_norm1 = nn.LayerNorm((out_size,))

self.net = nn.Sequential(self.linear1, self.layer_norm1)

def forward(self, x):
x = x[:, :, 5:-5, 5:-5].contiguous()
x = x.view(x.size(0), -1)
return self.net(x)


def get_layernorm_task():
train_set = get_mnist()
train_set = Subset(train_set, range(70))
train_loader = DataLoader(dataset=train_set, batch_size=30, shuffle=False)
net = LayerNormNet(out_size=3)
to_device_model(net)
net.eval()

def output_fn(input, target):
return net(to_device(input))

layer_collection = LayerCollection.from_model(net)
return (train_loader, layer_collection, net.parameters(), net, output_fn, 3)


class LayerNormConvNet(nn.Module):
def __init__(self):
super(LayerNormConvNet, self).__init__()
self.layer = nn.Conv2d(1, 3, (3, 2), 2)
self.layer_norm = nn.LayerNorm((3, 8, 9))

def forward(self, x):
x = x[:, :, 5:-5, 5:-5]
x = self.layer(x)
x = self.layer_norm(x)
return x.sum(dim=(2, 3))


def get_layernorm_conv_task():
train_set = get_mnist()
train_set = Subset(train_set, range(70))
train_loader = DataLoader(dataset=train_set, batch_size=30, shuffle=False)
net = LayerNormConvNet()
to_device_model(net)
net.eval()

def output_fn(input, target):
return net(to_device(input))

layer_collection = LayerCollection.from_model(net)
return (train_loader, layer_collection, net.parameters(), net, output_fn, 3)
9 changes: 7 additions & 2 deletions tests/test_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from tasks import (
get_batchnorm_conv_linear_task,
get_batchnorm_fc_linear_task,
get_conv1d_task,
get_conv_gn_task,
get_conv_skip_task,
get_conv_task,
Expand All @@ -11,11 +12,12 @@
get_fullyconnect_onlylast_task,
get_fullyconnect_task,
get_fullyconnect_wn_task,
get_layernorm_conv_task,
get_layernorm_task,
get_linear_conv_task,
get_linear_fc_task,
get_small_conv_transpose_task,
get_small_conv_wn_task,
get_conv1d_task,
)
from utils import check_ratio, check_tensors

Expand All @@ -41,6 +43,8 @@
]

nonlinear_tasks = [
get_layernorm_conv_task,
get_layernorm_task,
get_conv1d_task,
get_small_conv_transpose_task,
get_conv_task,
Expand Down Expand Up @@ -104,13 +108,14 @@ def test_jacobian_pushforward_dense_linear():

def test_jacobian_pushforward_dense_nonlinear():
for get_task in nonlinear_tasks:
print(get_task)
loader, lc, parameters, model, function, n_output = get_task()
generator = Jacobian(
layer_collection=lc, model=model, function=function, n_output=n_output
)
push_forward = PushForwardDense(generator=generator, examples=loader)
dw = random_pvector(lc, device=device)
dw = 1e-4 / dw.norm() * dw
dw = 1e-5 / dw.norm() * dw

doutput_lin = push_forward.mv(dw)

Expand Down

0 comments on commit 0b97e1a

Please sign in to comment.