From 8f85cd01698a2d4bb6c1f0e4946210fd1299f652 Mon Sep 17 00:00:00 2001 From: Thomas George Date: Mon, 29 Jan 2024 22:38:42 +0100 Subject: [PATCH 1/6] layer norm work in progress --- nngeometry/generator/jacobian/grads.py | 38 ++++++++++++++++++++++++-- nngeometry/layercollection.py | 26 ++++++++++++++++++ tests/test_jacobian.py | 10 ++++++- tests/test_tasks/datasets.py | 10 +++++++ tests/test_tasks/device.py | 21 ++++++++++++++ tests/test_tasks/layernorm.py | 33 ++++++++++++++++++++++ 6 files changed, 134 insertions(+), 4 deletions(-) create mode 100644 tests/test_tasks/datasets.py create mode 100644 tests/test_tasks/device.py create mode 100644 tests/test_tasks/layernorm.py diff --git a/nngeometry/generator/jacobian/grads.py b/nngeometry/generator/jacobian/grads.py index befa5bb..22004b9 100644 --- a/nngeometry/generator/jacobian/grads.py +++ b/nngeometry/generator/jacobian/grads.py @@ -13,6 +13,7 @@ WeightNorm1dLayer, WeightNorm2dLayer, Conv1dLayer, + LayerNormLayer ) from .grads_conv import conv2d_backward, convtranspose2d_backward, conv1d_backward @@ -269,6 +270,18 @@ def flat_grad(cls, buffer, mod, layer, x, gy): buffer[:, w_numel:].add_(gy.sum(dim=(2, 3))) +class LayerNormJacobianFactory(JacobianFactory): + @classmethod + def flat_grad(cls, buffer, mod, layer, x, gy): + w_numel = layer.weight.numel() + x_normalized = F.layer_norm( + x, normalized_shape=mod.normalized_shape, eps=mod.eps + ) + buffer[:, :w_numel].add_(gy * x_normalized) + if layer.bias is not None: + buffer[:, w_numel:].add_(gy) + + class GroupNormJacobianFactory(JacobianFactory): @classmethod def flat_grad(cls, buffer, mod, layer, x, gy): @@ -279,19 +292,37 @@ def flat_grad(cls, buffer, mod, layer, x, gy): class WeightNorm1dJacobianFactory(JacobianFactory): + @classmethod def flat_grad(cls, buffer, mod, layer, x, gy): bs = x.size(0) + gw_prime = torch.bmm(gy.unsqueeze(2), x.unsqueeze(1)).view(bs, -1).view(bs, *mod.weight.size()) norm2 = (mod.weight**2).sum(dim=1, keepdim=True) + mod.eps - gw = torch.bmm(gy.unsqueeze(2) / torch.sqrt(norm2), x.unsqueeze(1)) - wn2_out = F.linear(x, mod.weight / norm2**1.5) - gw -= (gy * wn2_out).unsqueeze(2) * mod.weight.unsqueeze(0) + + gw = gw_prime / torch.sqrt(norm2).unsqueeze(0) + + gw-= (gw_prime * mod.weight.unsqueeze(0)).sum(dim=2, keepdim=True) * (mod.weight * norm2**(-1.5)).unsqueeze(0) + buffer.add_(gw.view(bs, -1)) class WeightNorm2dJacobianFactory(JacobianFactory): @classmethod def flat_grad(cls, buffer, mod, layer, x, gy): + bs = x.size(0) + gw_prime = conv2d_backward(mod, x, gy).view(bs, *mod.weight.size()) + norm2 = (mod.weight**2).sum(dim=(1,2,3), keepdim=True) + mod.eps + + gw = gw_prime / torch.sqrt(norm2).unsqueeze(0) + # print((gw_prime * mod.weight.unsqueeze(0)).sum(dim=(2,3,4), keepdim=True).size()) + # print((mod.weight * norm2**(-1.5)).unsqueeze(0).size()) + + gw-= (gw_prime * mod.weight.unsqueeze(0)).sum(dim=(2,3,4), keepdim=True) * (mod.weight * norm2**(-1.5)).unsqueeze(0) + + buffer.add_(gw.view(bs, -1)) + + @classmethod + def flat_grad_(cls, buffer, mod, layer, x, gy): bs = x.size(0) out_dim = mod.weight.size(0) norm2 = (mod.weight**2).sum(dim=(1, 2, 3)) + mod.eps @@ -426,4 +457,5 @@ def kfe_diag(cls, buffer, mod, layer, x, gy, evecs_a, evecs_g): WeightNorm2dLayer: WeightNorm2dJacobianFactory, Cosine1dLayer: Cosine1dJacobianFactory, Affine1dLayer: Affine1dJacobianFactory, + LayerNormLayer: LayerNormJacobianFactory, } diff --git a/nngeometry/layercollection.py b/nngeometry/layercollection.py index e3924de..9fcd363 100644 --- a/nngeometry/layercollection.py +++ b/nngeometry/layercollection.py @@ -25,6 +25,7 @@ class LayerCollection: "Affine1d", "ConvTranspose2d", "Conv1d", + "LayerNorm" ] def __init__(self, layers=None): @@ -146,6 +147,10 @@ def _module_to_layer(mod): return Affine1dLayer( num_features=mod.num_features, bias=(mod.bias is not None) ) + elif mod_class == "LayerNorm": + return LayerNormLayer( + normalized_shape=mod.normalized_shape, bias=(mod.bias is not None) + ) def numel(self): """ @@ -313,6 +318,24 @@ def __eq__(self, other): return self.num_features == other.num_features +class LayerNormLayer(AbstractLayer): + def __init__(self, normalized_shape, bias=True): + self.weight = Parameter(*normalized_shape) + if bias: + self.bias = Parameter(*normalized_shape) + else: + self.bias = None + + def numel(self): + if self.bias is not None: + return self.weight.numel() + self.bias.numel() + else: + return self.weight.numel() + + def __eq__(self, other): + return self.weight == other.weight and self.bias == other.bias + + class GroupNormLayer(AbstractLayer): def __init__(self, num_groups, num_channels): self.num_channels = num_channels @@ -406,3 +429,6 @@ def __init__(self, *size): def numel(self): return reduce(operator.mul, self.size, 1) + + def __eq__(self, other): + return self.size == other.size diff --git a/tests/test_jacobian.py b/tests/test_jacobian.py index c188364..8ec8c40 100644 --- a/tests/test_jacobian.py +++ b/tests/test_jacobian.py @@ -18,6 +18,7 @@ get_conv1d_task, ) from utils import check_ratio, check_tensors +from test_tasks.layernorm import get_layernorm_task from nngeometry.generator import Jacobian from nngeometry.object.fspace import FMatDense @@ -41,6 +42,7 @@ ] nonlinear_tasks = [ + get_layernorm_task, get_conv1d_task, get_small_conv_transpose_task, get_conv_task, @@ -104,6 +106,7 @@ def test_jacobian_pushforward_dense_linear(): def test_jacobian_pushforward_dense_nonlinear(): for get_task in nonlinear_tasks: + print(get_task) loader, lc, parameters, model, function, n_output = get_task() generator = Jacobian( layer_collection=lc, model=model, function=function, n_output=n_output @@ -123,8 +126,13 @@ def test_jacobian_pushforward_dense_nonlinear(): check_tensors( output_after - output_before, doutput_lin.get_flat_representation().t(), - eps=5e-3, + eps=5e-3, only_print_diff=True, ) + # check_tensors( + # output_after - output_before, + # doutput_lin.get_flat_representation().t(), + # eps=5e-3, + # ) def test_jacobian_pushforward_implicit(): diff --git a/tests/test_tasks/datasets.py b/tests/test_tasks/datasets.py new file mode 100644 index 0000000..78d6cfc --- /dev/null +++ b/tests/test_tasks/datasets.py @@ -0,0 +1,10 @@ +from torchvision import datasets, transforms +default_datapath = "tmp" + +def get_mnist(): + return datasets.MNIST( + root=default_datapath, + train=True, + download=True, + transform=transforms.ToTensor(), + ) diff --git a/tests/test_tasks/device.py b/tests/test_tasks/device.py new file mode 100644 index 0000000..2545b18 --- /dev/null +++ b/tests/test_tasks/device.py @@ -0,0 +1,21 @@ +import torch + +if torch.cuda.is_available(): + device = "cuda" + + def to_device(tensor): + return tensor.to(device) + + def to_device_model(model): + model.to("cuda") + +else: + device = "cpu" + + # on cpu we need to use double as otherwise ill-conditioning in sums + # causes numerical instability + def to_device(tensor): + return tensor.double() + + def to_device_model(model): + model.double() diff --git a/tests/test_tasks/layernorm.py b/tests/test_tasks/layernorm.py new file mode 100644 index 0000000..8407b1b --- /dev/null +++ b/tests/test_tasks/layernorm.py @@ -0,0 +1,33 @@ +import torch.nn as nn +from .datasets import get_mnist +from .device import to_device_model,to_device +from torch.utils.data import DataLoader, Subset +from nngeometry.layercollection import LayerCollection + +class LayerNormNet(nn.Module): + def __init__(self, out_size): + super(LayerNormNet, self).__init__() + + self.linear1 = nn.Linear(18*18, out_size) + self.layer_norm1 = nn.LayerNorm((out_size,)) + + self.net = nn.Sequential(self.linear1, self.layer_norm1) + + def forward(self, x): + x = x[:, :, 5:-5, 5:-5].contiguous() + x = x.view(x.size(0), -1) + return self.net(x) + +def get_layernorm_task(normalization="none"): + train_set = get_mnist() + train_set = Subset(train_set, range(70)) + train_loader = DataLoader(dataset=train_set, batch_size=30, shuffle=False) + net = LayerNormNet(out_size=3) + to_device_model(net) + net.eval() + + def output_fn(input, target): + return net(to_device(input)) + + layer_collection = LayerCollection.from_model(net) + return (train_loader, layer_collection, net.parameters(), net, output_fn, 3) From 719f1e7786928fdae9b363facb90c39be480921f Mon Sep 17 00:00:00 2001 From: Thomas George Date: Wed, 31 Jan 2024 10:04:19 +0100 Subject: [PATCH 2/6] should fix layernorm, other tests still not passing --- nngeometry/generator/jacobian/grads.py | 25 +++++++----- tests/test_jacobian.py | 54 ++++++++++---------------- tests/test_tasks/datasets.py | 2 + tests/test_tasks/layernorm.py | 40 +++++++++++++++++-- 4 files changed, 75 insertions(+), 46 deletions(-) diff --git a/nngeometry/generator/jacobian/grads.py b/nngeometry/generator/jacobian/grads.py index 22004b9..b7372c0 100644 --- a/nngeometry/generator/jacobian/grads.py +++ b/nngeometry/generator/jacobian/grads.py @@ -13,7 +13,7 @@ WeightNorm1dLayer, WeightNorm2dLayer, Conv1dLayer, - LayerNormLayer + LayerNormLayer, ) from .grads_conv import conv2d_backward, convtranspose2d_backward, conv1d_backward @@ -277,9 +277,9 @@ def flat_grad(cls, buffer, mod, layer, x, gy): x_normalized = F.layer_norm( x, normalized_shape=mod.normalized_shape, eps=mod.eps ) - buffer[:, :w_numel].add_(gy * x_normalized) + buffer[:, :w_numel].add_((gy * x_normalized).reshape(x.size(0), -1)) if layer.bias is not None: - buffer[:, w_numel:].add_(gy) + buffer[:, w_numel:].add_(gy.reshape(x.size(0), -1)) class GroupNormJacobianFactory(JacobianFactory): @@ -292,16 +292,21 @@ def flat_grad(cls, buffer, mod, layer, x, gy): class WeightNorm1dJacobianFactory(JacobianFactory): - @classmethod def flat_grad(cls, buffer, mod, layer, x, gy): bs = x.size(0) - gw_prime = torch.bmm(gy.unsqueeze(2), x.unsqueeze(1)).view(bs, -1).view(bs, *mod.weight.size()) + gw_prime = ( + torch.bmm(gy.unsqueeze(2), x.unsqueeze(1)) + .view(bs, -1) + .view(bs, *mod.weight.size()) + ) norm2 = (mod.weight**2).sum(dim=1, keepdim=True) + mod.eps gw = gw_prime / torch.sqrt(norm2).unsqueeze(0) - - gw-= (gw_prime * mod.weight.unsqueeze(0)).sum(dim=2, keepdim=True) * (mod.weight * norm2**(-1.5)).unsqueeze(0) + + gw -= (gw_prime * mod.weight.unsqueeze(0)).sum(dim=2, keepdim=True) * ( + mod.weight * norm2 ** (-1.5) + ).unsqueeze(0) buffer.add_(gw.view(bs, -1)) @@ -311,13 +316,15 @@ class WeightNorm2dJacobianFactory(JacobianFactory): def flat_grad(cls, buffer, mod, layer, x, gy): bs = x.size(0) gw_prime = conv2d_backward(mod, x, gy).view(bs, *mod.weight.size()) - norm2 = (mod.weight**2).sum(dim=(1,2,3), keepdim=True) + mod.eps + norm2 = (mod.weight**2).sum(dim=(1, 2, 3), keepdim=True) + mod.eps gw = gw_prime / torch.sqrt(norm2).unsqueeze(0) # print((gw_prime * mod.weight.unsqueeze(0)).sum(dim=(2,3,4), keepdim=True).size()) # print((mod.weight * norm2**(-1.5)).unsqueeze(0).size()) - gw-= (gw_prime * mod.weight.unsqueeze(0)).sum(dim=(2,3,4), keepdim=True) * (mod.weight * norm2**(-1.5)).unsqueeze(0) + gw -= (gw_prime * mod.weight.unsqueeze(0)).sum(dim=(2, 3, 4), keepdim=True) * ( + mod.weight * norm2 ** (-1.5) + ).unsqueeze(0) buffer.add_(gw.view(bs, -1)) diff --git a/tests/test_jacobian.py b/tests/test_jacobian.py index 8ec8c40..7ee6d80 100644 --- a/tests/test_jacobian.py +++ b/tests/test_jacobian.py @@ -1,36 +1,22 @@ import pytest import torch -from tasks import ( - get_batchnorm_conv_linear_task, - get_batchnorm_fc_linear_task, - get_conv_gn_task, - get_conv_skip_task, - get_conv_task, - get_fullyconnect_affine_task, - get_fullyconnect_cosine_task, - get_fullyconnect_onlylast_task, - get_fullyconnect_task, - get_fullyconnect_wn_task, - get_linear_conv_task, - get_linear_fc_task, - get_small_conv_transpose_task, - get_small_conv_wn_task, - get_conv1d_task, -) +from tasks import (get_batchnorm_conv_linear_task, + get_batchnorm_fc_linear_task, get_conv1d_task, + get_conv_gn_task, get_conv_skip_task, get_conv_task, + get_fullyconnect_affine_task, get_fullyconnect_cosine_task, + get_fullyconnect_onlylast_task, get_fullyconnect_task, + get_fullyconnect_wn_task, get_linear_conv_task, + get_linear_fc_task, get_small_conv_transpose_task, + get_small_conv_wn_task) +from test_tasks.layernorm import get_layernorm_conv_task, get_layernorm_task from utils import check_ratio, check_tensors -from test_tasks.layernorm import get_layernorm_task from nngeometry.generator import Jacobian from nngeometry.object.fspace import FMatDense -from nngeometry.object.map import PullBackDense, PushForwardDense, PushForwardImplicit -from nngeometry.object.pspace import ( - PMatBlockDiag, - PMatDense, - PMatDiag, - PMatImplicit, - PMatLowRank, - PMatQuasiDiag, -) +from nngeometry.object.map import (PullBackDense, PushForwardDense, + PushForwardImplicit) +from nngeometry.object.pspace import (PMatBlockDiag, PMatDense, PMatDiag, + PMatImplicit, PMatLowRank, PMatQuasiDiag) from nngeometry.object.vector import PVector, random_fvector, random_pvector linear_tasks = [ @@ -42,6 +28,7 @@ ] nonlinear_tasks = [ + get_layernorm_conv_task, get_layernorm_task, get_conv1d_task, get_small_conv_transpose_task, @@ -126,13 +113,14 @@ def test_jacobian_pushforward_dense_nonlinear(): check_tensors( output_after - output_before, doutput_lin.get_flat_representation().t(), - eps=5e-3, only_print_diff=True, + eps=5e-3, + only_print_diff=True, + ) + check_tensors( + output_after - output_before, + doutput_lin.get_flat_representation().t(), + eps=5e-3, ) - # check_tensors( - # output_after - output_before, - # doutput_lin.get_flat_representation().t(), - # eps=5e-3, - # ) def test_jacobian_pushforward_implicit(): diff --git a/tests/test_tasks/datasets.py b/tests/test_tasks/datasets.py index 78d6cfc..590c21e 100644 --- a/tests/test_tasks/datasets.py +++ b/tests/test_tasks/datasets.py @@ -1,6 +1,8 @@ from torchvision import datasets, transforms + default_datapath = "tmp" + def get_mnist(): return datasets.MNIST( root=default_datapath, diff --git a/tests/test_tasks/layernorm.py b/tests/test_tasks/layernorm.py index 8407b1b..7b868dd 100644 --- a/tests/test_tasks/layernorm.py +++ b/tests/test_tasks/layernorm.py @@ -1,14 +1,17 @@ import torch.nn as nn -from .datasets import get_mnist -from .device import to_device_model,to_device from torch.utils.data import DataLoader, Subset + from nngeometry.layercollection import LayerCollection +from .datasets import get_mnist +from .device import to_device, to_device_model + + class LayerNormNet(nn.Module): def __init__(self, out_size): super(LayerNormNet, self).__init__() - self.linear1 = nn.Linear(18*18, out_size) + self.linear1 = nn.Linear(18 * 18, out_size) self.layer_norm1 = nn.LayerNorm((out_size,)) self.net = nn.Sequential(self.linear1, self.layer_norm1) @@ -18,7 +21,8 @@ def forward(self, x): x = x.view(x.size(0), -1) return self.net(x) -def get_layernorm_task(normalization="none"): + +def get_layernorm_task(): train_set = get_mnist() train_set = Subset(train_set, range(70)) train_loader = DataLoader(dataset=train_set, batch_size=30, shuffle=False) @@ -31,3 +35,31 @@ def output_fn(input, target): layer_collection = LayerCollection.from_model(net) return (train_loader, layer_collection, net.parameters(), net, output_fn, 3) + + +class LayerNormConvNet(nn.Module): + def __init__(self): + super(LayerNormConvNet, self).__init__() + self.layer = nn.Conv2d(1, 3, (3, 2), 2) + self.layer_norm = nn.LayerNorm((3,8,9)) + + def forward(self, x): + x = x[:, :, 5:-5, 5:-5] + x = self.layer(x) + x = self.layer_norm(x) + return x.sum(dim=(2, 3)) + + +def get_layernorm_conv_task(): + train_set = get_mnist() + train_set = Subset(train_set, range(70)) + train_loader = DataLoader(dataset=train_set, batch_size=30, shuffle=False) + net = LayerNormConvNet() + to_device_model(net) + net.eval() + + def output_fn(input, target): + return net(to_device(input)) + + layer_collection = LayerCollection.from_model(net) + return (train_loader, layer_collection, net.parameters(), net, output_fn, 3) From 0e8164ed8e9707541f0ee85ba9277f7ecb666d9b Mon Sep 17 00:00:00 2001 From: Thomas George Date: Sat, 3 Feb 2024 09:44:14 +0100 Subject: [PATCH 3/6] step for finite diff --- tests/test_jacobian.py | 46 +++++++++++++++++++++++++----------------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/tests/test_jacobian.py b/tests/test_jacobian.py index 7ee6d80..ebe7288 100644 --- a/tests/test_jacobian.py +++ b/tests/test_jacobian.py @@ -1,22 +1,36 @@ import pytest import torch -from tasks import (get_batchnorm_conv_linear_task, - get_batchnorm_fc_linear_task, get_conv1d_task, - get_conv_gn_task, get_conv_skip_task, get_conv_task, - get_fullyconnect_affine_task, get_fullyconnect_cosine_task, - get_fullyconnect_onlylast_task, get_fullyconnect_task, - get_fullyconnect_wn_task, get_linear_conv_task, - get_linear_fc_task, get_small_conv_transpose_task, - get_small_conv_wn_task) +from tasks import ( + get_batchnorm_conv_linear_task, + get_batchnorm_fc_linear_task, + get_conv1d_task, + get_conv_gn_task, + get_conv_skip_task, + get_conv_task, + get_fullyconnect_affine_task, + get_fullyconnect_cosine_task, + get_fullyconnect_onlylast_task, + get_fullyconnect_task, + get_fullyconnect_wn_task, + get_linear_conv_task, + get_linear_fc_task, + get_small_conv_transpose_task, + get_small_conv_wn_task, +) from test_tasks.layernorm import get_layernorm_conv_task, get_layernorm_task from utils import check_ratio, check_tensors from nngeometry.generator import Jacobian from nngeometry.object.fspace import FMatDense -from nngeometry.object.map import (PullBackDense, PushForwardDense, - PushForwardImplicit) -from nngeometry.object.pspace import (PMatBlockDiag, PMatDense, PMatDiag, - PMatImplicit, PMatLowRank, PMatQuasiDiag) +from nngeometry.object.map import PullBackDense, PushForwardDense, PushForwardImplicit +from nngeometry.object.pspace import ( + PMatBlockDiag, + PMatDense, + PMatDiag, + PMatImplicit, + PMatLowRank, + PMatQuasiDiag, +) from nngeometry.object.vector import PVector, random_fvector, random_pvector linear_tasks = [ @@ -100,7 +114,7 @@ def test_jacobian_pushforward_dense_nonlinear(): ) push_forward = PushForwardDense(generator=generator, examples=loader) dw = random_pvector(lc, device=device) - dw = 1e-4 / dw.norm() * dw + dw = 1e-5 / dw.norm() * dw doutput_lin = push_forward.mv(dw) @@ -110,12 +124,6 @@ def test_jacobian_pushforward_dense_nonlinear(): # This is non linear, so we don't expect the finite difference # estimate to be very accurate. We use a larger eps value - check_tensors( - output_after - output_before, - doutput_lin.get_flat_representation().t(), - eps=5e-3, - only_print_diff=True, - ) check_tensors( output_after - output_before, doutput_lin.get_flat_representation().t(), From 8a143d0d87e8793e2937fe9d72413bea6268a4e2 Mon Sep 17 00:00:00 2001 From: Thomas George Date: Sat, 3 Feb 2024 09:52:05 +0100 Subject: [PATCH 4/6] removes test task dir - NB it would be a good idea to add it in the future --- tests/tasks.py | 58 +++++++++++++++++++++++++++++++ tests/test_jacobian.py | 3 +- tests/test_tasks/datasets.py | 12 ------- tests/test_tasks/device.py | 21 ----------- tests/test_tasks/layernorm.py | 65 ----------------------------------- 5 files changed, 60 insertions(+), 99 deletions(-) delete mode 100644 tests/test_tasks/datasets.py delete mode 100644 tests/test_tasks/device.py delete mode 100644 tests/test_tasks/layernorm.py diff --git a/tests/tasks.py b/tests/tasks.py index 6a6469b..2943806 100644 --- a/tests/tasks.py +++ b/tests/tasks.py @@ -503,3 +503,61 @@ def output_fn(input, target): layer_collection = LayerCollection.from_model(net) return (train_loader, layer_collection, net.parameters(), net, output_fn, 4) + + +class LayerNormNet(nn.Module): + def __init__(self, out_size): + super(LayerNormNet, self).__init__() + + self.linear1 = nn.Linear(18 * 18, out_size) + self.layer_norm1 = nn.LayerNorm((out_size,)) + + self.net = nn.Sequential(self.linear1, self.layer_norm1) + + def forward(self, x): + x = x[:, :, 5:-5, 5:-5].contiguous() + x = x.view(x.size(0), -1) + return self.net(x) + + +def get_layernorm_task(): + train_set = get_mnist() + train_set = Subset(train_set, range(70)) + train_loader = DataLoader(dataset=train_set, batch_size=30, shuffle=False) + net = LayerNormNet(out_size=3) + to_device_model(net) + net.eval() + + def output_fn(input, target): + return net(to_device(input)) + + layer_collection = LayerCollection.from_model(net) + return (train_loader, layer_collection, net.parameters(), net, output_fn, 3) + + +class LayerNormConvNet(nn.Module): + def __init__(self): + super(LayerNormConvNet, self).__init__() + self.layer = nn.Conv2d(1, 3, (3, 2), 2) + self.layer_norm = nn.LayerNorm((3, 8, 9)) + + def forward(self, x): + x = x[:, :, 5:-5, 5:-5] + x = self.layer(x) + x = self.layer_norm(x) + return x.sum(dim=(2, 3)) + + +def get_layernorm_conv_task(): + train_set = get_mnist() + train_set = Subset(train_set, range(70)) + train_loader = DataLoader(dataset=train_set, batch_size=30, shuffle=False) + net = LayerNormConvNet() + to_device_model(net) + net.eval() + + def output_fn(input, target): + return net(to_device(input)) + + layer_collection = LayerCollection.from_model(net) + return (train_loader, layer_collection, net.parameters(), net, output_fn, 3) diff --git a/tests/test_jacobian.py b/tests/test_jacobian.py index ebe7288..8eb0f83 100644 --- a/tests/test_jacobian.py +++ b/tests/test_jacobian.py @@ -12,12 +12,13 @@ get_fullyconnect_onlylast_task, get_fullyconnect_task, get_fullyconnect_wn_task, + get_layernorm_conv_task, + get_layernorm_task, get_linear_conv_task, get_linear_fc_task, get_small_conv_transpose_task, get_small_conv_wn_task, ) -from test_tasks.layernorm import get_layernorm_conv_task, get_layernorm_task from utils import check_ratio, check_tensors from nngeometry.generator import Jacobian diff --git a/tests/test_tasks/datasets.py b/tests/test_tasks/datasets.py deleted file mode 100644 index 590c21e..0000000 --- a/tests/test_tasks/datasets.py +++ /dev/null @@ -1,12 +0,0 @@ -from torchvision import datasets, transforms - -default_datapath = "tmp" - - -def get_mnist(): - return datasets.MNIST( - root=default_datapath, - train=True, - download=True, - transform=transforms.ToTensor(), - ) diff --git a/tests/test_tasks/device.py b/tests/test_tasks/device.py deleted file mode 100644 index 2545b18..0000000 --- a/tests/test_tasks/device.py +++ /dev/null @@ -1,21 +0,0 @@ -import torch - -if torch.cuda.is_available(): - device = "cuda" - - def to_device(tensor): - return tensor.to(device) - - def to_device_model(model): - model.to("cuda") - -else: - device = "cpu" - - # on cpu we need to use double as otherwise ill-conditioning in sums - # causes numerical instability - def to_device(tensor): - return tensor.double() - - def to_device_model(model): - model.double() diff --git a/tests/test_tasks/layernorm.py b/tests/test_tasks/layernorm.py deleted file mode 100644 index 7b868dd..0000000 --- a/tests/test_tasks/layernorm.py +++ /dev/null @@ -1,65 +0,0 @@ -import torch.nn as nn -from torch.utils.data import DataLoader, Subset - -from nngeometry.layercollection import LayerCollection - -from .datasets import get_mnist -from .device import to_device, to_device_model - - -class LayerNormNet(nn.Module): - def __init__(self, out_size): - super(LayerNormNet, self).__init__() - - self.linear1 = nn.Linear(18 * 18, out_size) - self.layer_norm1 = nn.LayerNorm((out_size,)) - - self.net = nn.Sequential(self.linear1, self.layer_norm1) - - def forward(self, x): - x = x[:, :, 5:-5, 5:-5].contiguous() - x = x.view(x.size(0), -1) - return self.net(x) - - -def get_layernorm_task(): - train_set = get_mnist() - train_set = Subset(train_set, range(70)) - train_loader = DataLoader(dataset=train_set, batch_size=30, shuffle=False) - net = LayerNormNet(out_size=3) - to_device_model(net) - net.eval() - - def output_fn(input, target): - return net(to_device(input)) - - layer_collection = LayerCollection.from_model(net) - return (train_loader, layer_collection, net.parameters(), net, output_fn, 3) - - -class LayerNormConvNet(nn.Module): - def __init__(self): - super(LayerNormConvNet, self).__init__() - self.layer = nn.Conv2d(1, 3, (3, 2), 2) - self.layer_norm = nn.LayerNorm((3,8,9)) - - def forward(self, x): - x = x[:, :, 5:-5, 5:-5] - x = self.layer(x) - x = self.layer_norm(x) - return x.sum(dim=(2, 3)) - - -def get_layernorm_conv_task(): - train_set = get_mnist() - train_set = Subset(train_set, range(70)) - train_loader = DataLoader(dataset=train_set, batch_size=30, shuffle=False) - net = LayerNormConvNet() - to_device_model(net) - net.eval() - - def output_fn(input, target): - return net(to_device(input)) - - layer_collection = LayerCollection.from_model(net) - return (train_loader, layer_collection, net.parameters(), net, output_fn, 3) From 4ce6305f771e5d0d744e7090b0c2a8025aef4e45 Mon Sep 17 00:00:00 2001 From: Thomas George Date: Sat, 3 Feb 2024 09:55:24 +0100 Subject: [PATCH 5/6] removes useless print in test --- nngeometry/generator/jacobian/grads.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/nngeometry/generator/jacobian/grads.py b/nngeometry/generator/jacobian/grads.py index b7372c0..5d73ebf 100644 --- a/nngeometry/generator/jacobian/grads.py +++ b/nngeometry/generator/jacobian/grads.py @@ -319,8 +319,6 @@ def flat_grad(cls, buffer, mod, layer, x, gy): norm2 = (mod.weight**2).sum(dim=(1, 2, 3), keepdim=True) + mod.eps gw = gw_prime / torch.sqrt(norm2).unsqueeze(0) - # print((gw_prime * mod.weight.unsqueeze(0)).sum(dim=(2,3,4), keepdim=True).size()) - # print((mod.weight * norm2**(-1.5)).unsqueeze(0).size()) gw -= (gw_prime * mod.weight.unsqueeze(0)).sum(dim=(2, 3, 4), keepdim=True) * ( mod.weight * norm2 ** (-1.5) From 1d61aed81d4655f31a4e438d512066ae2321605d Mon Sep 17 00:00:00 2001 From: Thomas George Date: Sat, 3 Feb 2024 10:02:37 +0100 Subject: [PATCH 6/6] fixes bias shape --- nngeometry/generator/jacobian/grads.py | 1 + 1 file changed, 1 insertion(+) diff --git a/nngeometry/generator/jacobian/grads.py b/nngeometry/generator/jacobian/grads.py index 5d73ebf..66e4a59 100644 --- a/nngeometry/generator/jacobian/grads.py +++ b/nngeometry/generator/jacobian/grads.py @@ -57,6 +57,7 @@ def Jv(cls, buffer, mod, layer, x, gy, v, v_bias): cls.flat_grad(buffer_flat, mod, layer, x, gy) v = v.view(-1) if v_bias is not None: + v_bias = v_bias.view(-1) v = torch.cat((v, v_bias)) buffer.add_(torch.mv(buffer_flat, v))