diff --git a/nngeometry/generator/jacobian/grads.py b/nngeometry/generator/jacobian/grads.py index befa5bb..66e4a59 100644 --- a/nngeometry/generator/jacobian/grads.py +++ b/nngeometry/generator/jacobian/grads.py @@ -13,6 +13,7 @@ WeightNorm1dLayer, WeightNorm2dLayer, Conv1dLayer, + LayerNormLayer, ) from .grads_conv import conv2d_backward, convtranspose2d_backward, conv1d_backward @@ -56,6 +57,7 @@ def Jv(cls, buffer, mod, layer, x, gy, v, v_bias): cls.flat_grad(buffer_flat, mod, layer, x, gy) v = v.view(-1) if v_bias is not None: + v_bias = v_bias.view(-1) v = torch.cat((v, v_bias)) buffer.add_(torch.mv(buffer_flat, v)) @@ -269,6 +271,18 @@ def flat_grad(cls, buffer, mod, layer, x, gy): buffer[:, w_numel:].add_(gy.sum(dim=(2, 3))) +class LayerNormJacobianFactory(JacobianFactory): + @classmethod + def flat_grad(cls, buffer, mod, layer, x, gy): + w_numel = layer.weight.numel() + x_normalized = F.layer_norm( + x, normalized_shape=mod.normalized_shape, eps=mod.eps + ) + buffer[:, :w_numel].add_((gy * x_normalized).reshape(x.size(0), -1)) + if layer.bias is not None: + buffer[:, w_numel:].add_(gy.reshape(x.size(0), -1)) + + class GroupNormJacobianFactory(JacobianFactory): @classmethod def flat_grad(cls, buffer, mod, layer, x, gy): @@ -282,16 +296,39 @@ class WeightNorm1dJacobianFactory(JacobianFactory): @classmethod def flat_grad(cls, buffer, mod, layer, x, gy): bs = x.size(0) + gw_prime = ( + torch.bmm(gy.unsqueeze(2), x.unsqueeze(1)) + .view(bs, -1) + .view(bs, *mod.weight.size()) + ) norm2 = (mod.weight**2).sum(dim=1, keepdim=True) + mod.eps - gw = torch.bmm(gy.unsqueeze(2) / torch.sqrt(norm2), x.unsqueeze(1)) - wn2_out = F.linear(x, mod.weight / norm2**1.5) - gw -= (gy * wn2_out).unsqueeze(2) * mod.weight.unsqueeze(0) + + gw = gw_prime / torch.sqrt(norm2).unsqueeze(0) + + gw -= (gw_prime * mod.weight.unsqueeze(0)).sum(dim=2, keepdim=True) * ( + mod.weight * norm2 ** (-1.5) + ).unsqueeze(0) + buffer.add_(gw.view(bs, -1)) class WeightNorm2dJacobianFactory(JacobianFactory): @classmethod def flat_grad(cls, buffer, mod, layer, x, gy): + bs = x.size(0) + gw_prime = conv2d_backward(mod, x, gy).view(bs, *mod.weight.size()) + norm2 = (mod.weight**2).sum(dim=(1, 2, 3), keepdim=True) + mod.eps + + gw = gw_prime / torch.sqrt(norm2).unsqueeze(0) + + gw -= (gw_prime * mod.weight.unsqueeze(0)).sum(dim=(2, 3, 4), keepdim=True) * ( + mod.weight * norm2 ** (-1.5) + ).unsqueeze(0) + + buffer.add_(gw.view(bs, -1)) + + @classmethod + def flat_grad_(cls, buffer, mod, layer, x, gy): bs = x.size(0) out_dim = mod.weight.size(0) norm2 = (mod.weight**2).sum(dim=(1, 2, 3)) + mod.eps @@ -426,4 +463,5 @@ def kfe_diag(cls, buffer, mod, layer, x, gy, evecs_a, evecs_g): WeightNorm2dLayer: WeightNorm2dJacobianFactory, Cosine1dLayer: Cosine1dJacobianFactory, Affine1dLayer: Affine1dJacobianFactory, + LayerNormLayer: LayerNormJacobianFactory, } diff --git a/nngeometry/layercollection.py b/nngeometry/layercollection.py index e3924de..9fcd363 100644 --- a/nngeometry/layercollection.py +++ b/nngeometry/layercollection.py @@ -25,6 +25,7 @@ class LayerCollection: "Affine1d", "ConvTranspose2d", "Conv1d", + "LayerNorm" ] def __init__(self, layers=None): @@ -146,6 +147,10 @@ def _module_to_layer(mod): return Affine1dLayer( num_features=mod.num_features, bias=(mod.bias is not None) ) + elif mod_class == "LayerNorm": + return LayerNormLayer( + normalized_shape=mod.normalized_shape, bias=(mod.bias is not None) + ) def numel(self): """ @@ -313,6 +318,24 @@ def __eq__(self, other): return self.num_features == other.num_features +class LayerNormLayer(AbstractLayer): + def __init__(self, normalized_shape, bias=True): + self.weight = Parameter(*normalized_shape) + if bias: + self.bias = Parameter(*normalized_shape) + else: + self.bias = None + + def numel(self): + if self.bias is not None: + return self.weight.numel() + self.bias.numel() + else: + return self.weight.numel() + + def __eq__(self, other): + return self.weight == other.weight and self.bias == other.bias + + class GroupNormLayer(AbstractLayer): def __init__(self, num_groups, num_channels): self.num_channels = num_channels @@ -406,3 +429,6 @@ def __init__(self, *size): def numel(self): return reduce(operator.mul, self.size, 1) + + def __eq__(self, other): + return self.size == other.size diff --git a/tests/tasks.py b/tests/tasks.py index 6a6469b..2943806 100644 --- a/tests/tasks.py +++ b/tests/tasks.py @@ -503,3 +503,61 @@ def output_fn(input, target): layer_collection = LayerCollection.from_model(net) return (train_loader, layer_collection, net.parameters(), net, output_fn, 4) + + +class LayerNormNet(nn.Module): + def __init__(self, out_size): + super(LayerNormNet, self).__init__() + + self.linear1 = nn.Linear(18 * 18, out_size) + self.layer_norm1 = nn.LayerNorm((out_size,)) + + self.net = nn.Sequential(self.linear1, self.layer_norm1) + + def forward(self, x): + x = x[:, :, 5:-5, 5:-5].contiguous() + x = x.view(x.size(0), -1) + return self.net(x) + + +def get_layernorm_task(): + train_set = get_mnist() + train_set = Subset(train_set, range(70)) + train_loader = DataLoader(dataset=train_set, batch_size=30, shuffle=False) + net = LayerNormNet(out_size=3) + to_device_model(net) + net.eval() + + def output_fn(input, target): + return net(to_device(input)) + + layer_collection = LayerCollection.from_model(net) + return (train_loader, layer_collection, net.parameters(), net, output_fn, 3) + + +class LayerNormConvNet(nn.Module): + def __init__(self): + super(LayerNormConvNet, self).__init__() + self.layer = nn.Conv2d(1, 3, (3, 2), 2) + self.layer_norm = nn.LayerNorm((3, 8, 9)) + + def forward(self, x): + x = x[:, :, 5:-5, 5:-5] + x = self.layer(x) + x = self.layer_norm(x) + return x.sum(dim=(2, 3)) + + +def get_layernorm_conv_task(): + train_set = get_mnist() + train_set = Subset(train_set, range(70)) + train_loader = DataLoader(dataset=train_set, batch_size=30, shuffle=False) + net = LayerNormConvNet() + to_device_model(net) + net.eval() + + def output_fn(input, target): + return net(to_device(input)) + + layer_collection = LayerCollection.from_model(net) + return (train_loader, layer_collection, net.parameters(), net, output_fn, 3) diff --git a/tests/test_jacobian.py b/tests/test_jacobian.py index c188364..8eb0f83 100644 --- a/tests/test_jacobian.py +++ b/tests/test_jacobian.py @@ -3,6 +3,7 @@ from tasks import ( get_batchnorm_conv_linear_task, get_batchnorm_fc_linear_task, + get_conv1d_task, get_conv_gn_task, get_conv_skip_task, get_conv_task, @@ -11,11 +12,12 @@ get_fullyconnect_onlylast_task, get_fullyconnect_task, get_fullyconnect_wn_task, + get_layernorm_conv_task, + get_layernorm_task, get_linear_conv_task, get_linear_fc_task, get_small_conv_transpose_task, get_small_conv_wn_task, - get_conv1d_task, ) from utils import check_ratio, check_tensors @@ -41,6 +43,8 @@ ] nonlinear_tasks = [ + get_layernorm_conv_task, + get_layernorm_task, get_conv1d_task, get_small_conv_transpose_task, get_conv_task, @@ -104,13 +108,14 @@ def test_jacobian_pushforward_dense_linear(): def test_jacobian_pushforward_dense_nonlinear(): for get_task in nonlinear_tasks: + print(get_task) loader, lc, parameters, model, function, n_output = get_task() generator = Jacobian( layer_collection=lc, model=model, function=function, n_output=n_output ) push_forward = PushForwardDense(generator=generator, examples=loader) dw = random_pvector(lc, device=device) - dw = 1e-4 / dw.norm() * dw + dw = 1e-5 / dw.norm() * dw doutput_lin = push_forward.mv(dw)