From dc3f140396ac9e52943edbf6e255135f81503fc2 Mon Sep 17 00:00:00 2001 From: Thomas George Date: Wed, 22 Nov 2023 20:47:17 +0100 Subject: [PATCH 1/9] wip --- nngeometry/layercollection.py | 34 ++++++++++++++++++++++++++++++++++ setup.py | 2 +- tests/tasks.py | 34 ++++++++++++++++++++++++++++++++++ tests/test_jacobian.py | 3 ++- 4 files changed, 71 insertions(+), 2 deletions(-) diff --git a/nngeometry/layercollection.py b/nngeometry/layercollection.py index 23176fb..344174b 100644 --- a/nngeometry/layercollection.py +++ b/nngeometry/layercollection.py @@ -24,6 +24,7 @@ class LayerCollection: "Cosine1d", "Affine1d", "ConvTranspose2d", + "Conv1d", ] def __init__(self, layers=None): @@ -112,6 +113,13 @@ def _module_to_layer(mod): kernel_size=mod.kernel_size, bias=(mod.bias is not None), ) + elif mod_class == "Conv1d": + return Conv1dLayer( + in_channels=mod.in_channels, + out_channels=mod.out_channels, + kernel_size=mod.kernel_size, + bias=(mod.bias is not None), + ) elif mod_class == "BatchNorm1d": return BatchNorm1dLayer(num_features=mod.num_features) elif mod_class == "BatchNorm2d": @@ -230,6 +238,32 @@ def __eq__(self, other): and self.kernel_size == other.kernel_size ) +class Conv1dLayer(AbstractLayer): + def __init__(self, in_channels, out_channels, kernel_size, bias=True): + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.weight = Parameter( + out_channels, in_channels, kernel_size[0] + ) + if bias: + self.bias = Parameter(out_channels) + else: + self.bias = None + + def numel(self): + if self.bias is not None: + return self.weight.numel() + self.bias.numel() + else: + return self.weight.numel() + + def __eq__(self, other): + return ( + self.in_channels == other.in_channels + and self.out_channels == other.out_channels + and self.kernel_size == other.kernel_size + ) + class LinearLayer(AbstractLayer): def __init__(self, in_features, out_features, bias=True): diff --git a/setup.py b/setup.py index edb0c41..dc1b1ad 100644 --- a/setup.py +++ b/setup.py @@ -11,5 +11,5 @@ 'nngeometry.generator', 'nngeometry.generator.jacobian', 'nngeometry.object'], - install_requires=['torch>=2.0.0'], + install_requires=['torch>=2.0.0','torchvision>=0.9.1'], zip_safe=False) diff --git a/tests/tasks.py b/tests/tasks.py index 55bbd2f..cd1d0d3 100644 --- a/tests/tasks.py +++ b/tests/tasks.py @@ -482,3 +482,37 @@ def output_fn(input, target): layer_collection = LayerCollection.from_model(net) return (train_loader, layer_collection, net.parameters(), net, output_fn, 3) + + +class Conv1dNet(nn.Module): + def __init__(self, normalization="none"): + super(Conv1dNet, self).__init__() + if normalization != 'none': + raise NotImplementedError + self.normalization = normalization + self.conv1 = nn.Conv1d(1, 6, 3, 2) + self.conv2 = nn.Conv1d(6, 5, 4, 1) + self.fc1 = nn.Linear(7, 4) + + def forward(self, x): + x = x.reshape(x.size(0), x.size(1), -1) + x = tF.relu(self.conv1(x)) + x = tF.relu(self.conv2(x)) + x = x.view(x.size(0), -1) + x = self.fc1(x) + return x + + +def get_conv1d_task(normalization="none"): + train_set = get_mnist() + train_set = Subset(train_set, range(70)) + train_loader = DataLoader(dataset=train_set, batch_size=30, shuffle=False) + net = Conv1dNet(normalization=normalization) + to_device_model(net) + net.eval() + + def output_fn(input, target): + return net(to_device(input)) + + layer_collection = LayerCollection.from_model(net) + return (train_loader, layer_collection, net.parameters(), net, output_fn, 3) diff --git a/tests/test_jacobian.py b/tests/test_jacobian.py index e9f1348..db6b234 100644 --- a/tests/test_jacobian.py +++ b/tests/test_jacobian.py @@ -7,7 +7,7 @@ get_fullyconnect_onlylast_task, get_fullyconnect_task, get_fullyconnect_wn_task, get_linear_conv_task, get_linear_fc_task, get_small_conv_transpose_task, - get_small_conv_wn_task) + get_small_conv_wn_task, get_conv1d_task) from utils import check_ratio, check_tensors from nngeometry.generator import Jacobian @@ -27,6 +27,7 @@ ] nonlinear_tasks = [ + get_conv1d_task, get_small_conv_transpose_task, get_conv_task, get_fullyconnect_affine_task, From aca036ffa4d4b33ec1b674441db472eac91b4ded Mon Sep 17 00:00:00 2001 From: Thomas George Date: Fri, 24 Nov 2023 22:35:11 +0100 Subject: [PATCH 2/9] conv1d flat grad --- nngeometry/generator/jacobian/grads.py | 27 ++++++++++++++++++-- nngeometry/generator/jacobian/grads_conv.py | 28 +++++++++++++++++++-- tests/tasks.py | 10 +++++--- 3 files changed, 57 insertions(+), 8 deletions(-) diff --git a/nngeometry/generator/jacobian/grads.py b/nngeometry/generator/jacobian/grads.py index ea43c13..06524a1 100644 --- a/nngeometry/generator/jacobian/grads.py +++ b/nngeometry/generator/jacobian/grads.py @@ -5,9 +5,10 @@ BatchNorm2dLayer, Conv2dLayer, ConvTranspose2dLayer, Cosine1dLayer, GroupNormLayer, LinearLayer, - WeightNorm1dLayer, WeightNorm2dLayer) + WeightNorm1dLayer, WeightNorm2dLayer, + Conv1dLayer) -from .grads_conv import conv2d_backward, convtranspose2d_backward +from .grads_conv import conv2d_backward, convtranspose2d_backward, conv1d_backward class JacobianFactory: @@ -325,8 +326,30 @@ def flat_grad(cls, buffer, mod, layer, x, gy): buffer[:, w_numel:].add_(gy) +class Conv1dJacobianFactory(JacobianFactory): + @classmethod + def flat_grad(cls, buffer, mod, layer, x, gy): + bs = x.size(0) + w_numel = layer.weight.numel() + indiv_gw = conv1d_backward(mod, x, gy) + buffer[:, :w_numel].add_(indiv_gw.view(bs, -1)) + if layer.bias is not None: + buffer[:, w_numel:].add_(gy.sum(dim=2)) + + @classmethod + def Jv(cls, buffer, mod, layer, x, gy, v, v_bias): + bs = x.size(0) + gy2 = F.conv1d( + x, v, stride=mod.stride, padding=mod.padding, dilation=mod.dilation + ) + buffer.add_((gy * gy2).view(bs, -1).sum(dim=1)) + if layer.bias is not None: + buffer.add_(torch.mv(gy.sum(dim=2), v_bias)) + + FactoryMap = { LinearLayer: LinearJacobianFactory, + Conv1dLayer: Conv1dJacobianFactory, Conv2dLayer: Conv2dJacobianFactory, ConvTranspose2dLayer: ConvTranspose2dJacobianFactory, BatchNorm1dLayer: BatchNorm1dJacobianFactory, diff --git a/nngeometry/generator/jacobian/grads_conv.py b/nngeometry/generator/jacobian/grads_conv.py index 268c568..f89b27b 100644 --- a/nngeometry/generator/jacobian/grads_conv.py +++ b/nngeometry/generator/jacobian/grads_conv.py @@ -119,7 +119,25 @@ def conv2d_backward_using_unfold(mod, x, gy): def conv2d_backward(*args, **kwargs): - return _conv_grad_impl.get_impl()(*args, **kwargs) + return _conv_grad_impl.get_impl2d()(*args, **kwargs) + + +def conv1d_backward_using_unfold(mod, x, gy): + """Computes per-example gradients for nn.Conv1d layers.""" + ks = (1, mod.weight.size(2)) + gy_s = gy.size() + bs = gy_s[0] + x_unfold = F.unfold( + x.unsqueeze(2), kernel_size=ks, stride=(1, mod.stride[0]), padding=(0, mod.padding[0]), dilation=(1, mod.dilation[0]) + ) + x_unfold_s = x_unfold.size() + return torch.bmm( + gy.view(bs, gy_s[1], -1), x_unfold.view(bs, x_unfold_s[1], -1).permute(0, 2, 1) + ) + + +def conv1d_backward(*args, **kwargs): + return _conv_grad_impl.get_impl1d()(*args, **kwargs) class ConvGradImplManager: @@ -129,12 +147,18 @@ def __init__(self): def use_unfold(self, choice=True): self._use_unfold = choice - def get_impl(self): + def get_impl2d(self): if self._use_unfold: return conv2d_backward_using_unfold else: return conv2d_backward_using_conv + def get_impl1d(self): + if self._use_unfold: + return conv1d_backward_using_unfold + else: + raise NotImplementedError() + _conv_grad_impl = ConvGradImplManager() diff --git a/tests/tasks.py b/tests/tasks.py index cd1d0d3..bfa98d4 100644 --- a/tests/tasks.py +++ b/tests/tasks.py @@ -490,14 +490,16 @@ def __init__(self, normalization="none"): if normalization != 'none': raise NotImplementedError self.normalization = normalization - self.conv1 = nn.Conv1d(1, 6, 3, 2) - self.conv2 = nn.Conv1d(6, 5, 4, 1) - self.fc1 = nn.Linear(7, 4) + self.conv1 = nn.Conv1d(1, 6, 3, 3) + self.conv2 = nn.Conv1d(6, 5, 4, 8, bias=False) + self.conv3 = nn.Conv1d(5, 2, 4, 4) + self.fc1 = nn.Linear(16, 4) def forward(self, x): x = x.reshape(x.size(0), x.size(1), -1) x = tF.relu(self.conv1(x)) x = tF.relu(self.conv2(x)) + x = tF.relu(self.conv3(x)) x = x.view(x.size(0), -1) x = self.fc1(x) return x @@ -515,4 +517,4 @@ def output_fn(input, target): return net(to_device(input)) layer_collection = LayerCollection.from_model(net) - return (train_loader, layer_collection, net.parameters(), net, output_fn, 3) + return (train_loader, layer_collection, net.parameters(), net, output_fn, 4) From 68e31fb5a4102ef9edda4aa2c458b4c42840e7f8 Mon Sep 17 00:00:00 2001 From: Thomas George Date: Fri, 24 Nov 2023 22:50:41 +0100 Subject: [PATCH 3/9] kfac for conv1d --- nngeometry/generator/jacobian/grads.py | 26 ++++++++++++++++++ nngeometry/generator/jacobian/grads_conv.py | 3 ++- nngeometry/generator/jacobian/jacobian.py | 5 +++- tests/test_jacobian_kfac.py | 29 ++++++++++++++++++++- 4 files changed, 60 insertions(+), 3 deletions(-) diff --git a/nngeometry/generator/jacobian/grads.py b/nngeometry/generator/jacobian/grads.py index 06524a1..fe2bf76 100644 --- a/nngeometry/generator/jacobian/grads.py +++ b/nngeometry/generator/jacobian/grads.py @@ -346,6 +346,32 @@ def Jv(cls, buffer, mod, layer, x, gy, v, v_bias): if layer.bias is not None: buffer.add_(torch.mv(gy.sum(dim=2), v_bias)) + @classmethod + def kfac_xx(cls, buffer, mod, layer, x, gy): + ks = (1, mod.weight.size(2)) + # A_tilda in KFC + A_tilda = F.unfold( + x.unsqueeze(2), + kernel_size=ks, + stride=(1, mod.stride[0]), + padding=(0, mod.padding[0]), + dilation=(1, mod.dilation[0]), + ) + # A_tilda is bs * #locations x #parameters + A_tilda = A_tilda.permute(0, 2, 1).contiguous().view(-1, A_tilda.size(1)) + if layer.bias is not None: + A_tilda = torch.cat([A_tilda, torch.ones_like(A_tilda[:, :1])], dim=1) + # Omega_hat in KFC + buffer.add_(torch.mm(A_tilda.t(), A_tilda)) + + @classmethod + def kfac_gg(cls, buffer, mod, layer, x, gy): + spatial_locations = gy.size(2) + os = gy.size(1) + # DS_tilda in KFC + DS_tilda = gy.permute(0, 2, 1).contiguous().view(-1, os) + buffer.add_(torch.mm(DS_tilda.t(), DS_tilda) / spatial_locations) + FactoryMap = { LinearLayer: LinearJacobianFactory, diff --git a/nngeometry/generator/jacobian/grads_conv.py b/nngeometry/generator/jacobian/grads_conv.py index f89b27b..c7b025c 100644 --- a/nngeometry/generator/jacobian/grads_conv.py +++ b/nngeometry/generator/jacobian/grads_conv.py @@ -128,7 +128,8 @@ def conv1d_backward_using_unfold(mod, x, gy): gy_s = gy.size() bs = gy_s[0] x_unfold = F.unfold( - x.unsqueeze(2), kernel_size=ks, stride=(1, mod.stride[0]), padding=(0, mod.padding[0]), dilation=(1, mod.dilation[0]) + x.unsqueeze(2), kernel_size=ks, stride=(1, mod.stride[0]), + padding=(0, mod.padding[0]), dilation=(1, mod.dilation[0]) ) x_unfold_s = x_unfold.size() return torch.bmm( diff --git a/nngeometry/generator/jacobian/jacobian.py b/nngeometry/generator/jacobian/jacobian.py index 96a5125..2c3d82c 100644 --- a/nngeometry/generator/jacobian/jacobian.py +++ b/nngeometry/generator/jacobian/jacobian.py @@ -250,6 +250,9 @@ def get_kfac_blocks(self, examples): elif layer_class == "Conv2dLayer": sG = layer.out_channels sA = layer.in_channels * layer.kernel_size[0] * layer.kernel_size[1] + elif layer_class == "Conv1dLayer": + sG = layer.out_channels + sA = layer.in_channels * layer.kernel_size[0] if layer.bias is not None: sA += 1 self._blocks[layer_id] = ( @@ -761,7 +764,7 @@ def _hook_compute_kfac_blocks(self, mod, gy): layer_id = self.m_to_l[mod] layer = self.layer_collection[layer_id] block = self._blocks[layer_id] - if mod_class in ["Linear", "Conv2d"]: + if mod_class in ["Linear", "Conv2d", "Conv1d"]: FactoryMap[layer.__class__].kfac_gg(block[1], mod, layer, x, gy) if self.i_output == 0: # do this only once if n_output > 1 diff --git a/tests/test_jacobian_kfac.py b/tests/test_jacobian_kfac.py index a3f97e6..aa528f9 100644 --- a/tests/test_jacobian_kfac.py +++ b/tests/test_jacobian_kfac.py @@ -106,6 +106,32 @@ def output_fn(input, target): return (train_loader, layer_collection, net.parameters(), net, output_fn, 4) +class Conv1dNet(nn.Module): + def __init__(self): + super(Conv1dNet, self).__init__() + self.conv1 = nn.Conv1d(3, 4, 3, 1) + + def forward(self, x): + x = self.conv1(x) + return x.sum(axis=(2,)) + + +def get_conv1dnet_kfc_task(bs=5): + train_set = torch.utils.data.TensorDataset( + torch.ones(size=(10, 3, 5)), torch.randint(0, 4, size=(10, 4)) + ) + train_loader = DataLoader(dataset=train_set, batch_size=bs, shuffle=False) + net = Conv1dNet() + to_device_model(net) + net.eval() + + def output_fn(input, target): + return net(to_device(input)) + + layer_collection = LayerCollection.from_model(net) + return (train_loader, layer_collection, net.parameters(), net, output_fn, 4) + + @pytest.fixture(autouse=True) def make_test_deterministic(): torch.manual_seed(1234) @@ -118,7 +144,8 @@ def test_jacobian_kfac_vs_pblockdiag(): where they are the same """ for get_task, mult in zip( - [get_convnet_kfc_task, get_fullyconnect_kfac_task], [15.0, 1.0] + [get_conv1dnet_kfc_task, get_convnet_kfc_task, get_fullyconnect_kfac_task], + [3., 15.0, 1.0] ): loader, lc, parameters, model, function, n_output = get_task() From 23f3100df9adf8df872432fd92a5acf10f44aff8 Mon Sep 17 00:00:00 2001 From: Thomas George Date: Fri, 24 Nov 2023 23:03:33 +0100 Subject: [PATCH 4/9] flake --- nngeometry/generator/jacobian/grads.py | 19 +++++++---- nngeometry/generator/jacobian/grads_conv.py | 7 ++-- nngeometry/layercollection.py | 5 ++- tests/tasks.py | 2 +- tests/test_jacobian.py | 38 ++++++++++++++------- tests/test_jacobian_kfac.py | 2 +- 6 files changed, 48 insertions(+), 25 deletions(-) diff --git a/nngeometry/generator/jacobian/grads.py b/nngeometry/generator/jacobian/grads.py index fe2bf76..77414b8 100644 --- a/nngeometry/generator/jacobian/grads.py +++ b/nngeometry/generator/jacobian/grads.py @@ -1,12 +1,19 @@ import torch import torch.nn.functional as F -from nngeometry.layercollection import (Affine1dLayer, BatchNorm1dLayer, - BatchNorm2dLayer, Conv2dLayer, - ConvTranspose2dLayer, Cosine1dLayer, - GroupNormLayer, LinearLayer, - WeightNorm1dLayer, WeightNorm2dLayer, - Conv1dLayer) +from nngeometry.layercollection import ( + Affine1dLayer, + BatchNorm1dLayer, + BatchNorm2dLayer, + Conv2dLayer, + ConvTranspose2dLayer, + Cosine1dLayer, + GroupNormLayer, + LinearLayer, + WeightNorm1dLayer, + WeightNorm2dLayer, + Conv1dLayer, +) from .grads_conv import conv2d_backward, convtranspose2d_backward, conv1d_backward diff --git a/nngeometry/generator/jacobian/grads_conv.py b/nngeometry/generator/jacobian/grads_conv.py index c7b025c..565ea5b 100644 --- a/nngeometry/generator/jacobian/grads_conv.py +++ b/nngeometry/generator/jacobian/grads_conv.py @@ -128,8 +128,11 @@ def conv1d_backward_using_unfold(mod, x, gy): gy_s = gy.size() bs = gy_s[0] x_unfold = F.unfold( - x.unsqueeze(2), kernel_size=ks, stride=(1, mod.stride[0]), - padding=(0, mod.padding[0]), dilation=(1, mod.dilation[0]) + x.unsqueeze(2), + kernel_size=ks, + stride=(1, mod.stride[0]), + padding=(0, mod.padding[0]), + dilation=(1, mod.dilation[0]), ) x_unfold_s = x_unfold.size() return torch.bmm( diff --git a/nngeometry/layercollection.py b/nngeometry/layercollection.py index 344174b..e3924de 100644 --- a/nngeometry/layercollection.py +++ b/nngeometry/layercollection.py @@ -238,14 +238,13 @@ def __eq__(self, other): and self.kernel_size == other.kernel_size ) + class Conv1dLayer(AbstractLayer): def __init__(self, in_channels, out_channels, kernel_size, bias=True): self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size - self.weight = Parameter( - out_channels, in_channels, kernel_size[0] - ) + self.weight = Parameter(out_channels, in_channels, kernel_size[0]) if bias: self.bias = Parameter(out_channels) else: diff --git a/tests/tasks.py b/tests/tasks.py index bfa98d4..32ea77c 100644 --- a/tests/tasks.py +++ b/tests/tasks.py @@ -487,7 +487,7 @@ def output_fn(input, target): class Conv1dNet(nn.Module): def __init__(self, normalization="none"): super(Conv1dNet, self).__init__() - if normalization != 'none': + if normalization != "none": raise NotImplementedError self.normalization = normalization self.conv1 = nn.Conv1d(1, 6, 3, 3) diff --git a/tests/test_jacobian.py b/tests/test_jacobian.py index db6b234..c188364 100644 --- a/tests/test_jacobian.py +++ b/tests/test_jacobian.py @@ -1,21 +1,35 @@ import pytest import torch -from tasks import (get_batchnorm_conv_linear_task, - get_batchnorm_fc_linear_task, get_conv_gn_task, - get_conv_skip_task, get_conv_task, - get_fullyconnect_affine_task, get_fullyconnect_cosine_task, - get_fullyconnect_onlylast_task, get_fullyconnect_task, - get_fullyconnect_wn_task, get_linear_conv_task, - get_linear_fc_task, get_small_conv_transpose_task, - get_small_conv_wn_task, get_conv1d_task) +from tasks import ( + get_batchnorm_conv_linear_task, + get_batchnorm_fc_linear_task, + get_conv_gn_task, + get_conv_skip_task, + get_conv_task, + get_fullyconnect_affine_task, + get_fullyconnect_cosine_task, + get_fullyconnect_onlylast_task, + get_fullyconnect_task, + get_fullyconnect_wn_task, + get_linear_conv_task, + get_linear_fc_task, + get_small_conv_transpose_task, + get_small_conv_wn_task, + get_conv1d_task, +) from utils import check_ratio, check_tensors from nngeometry.generator import Jacobian from nngeometry.object.fspace import FMatDense -from nngeometry.object.map import (PullBackDense, PushForwardDense, - PushForwardImplicit) -from nngeometry.object.pspace import (PMatBlockDiag, PMatDense, PMatDiag, - PMatImplicit, PMatLowRank, PMatQuasiDiag) +from nngeometry.object.map import PullBackDense, PushForwardDense, PushForwardImplicit +from nngeometry.object.pspace import ( + PMatBlockDiag, + PMatDense, + PMatDiag, + PMatImplicit, + PMatLowRank, + PMatQuasiDiag, +) from nngeometry.object.vector import PVector, random_fvector, random_pvector linear_tasks = [ diff --git a/tests/test_jacobian_kfac.py b/tests/test_jacobian_kfac.py index aa528f9..2188197 100644 --- a/tests/test_jacobian_kfac.py +++ b/tests/test_jacobian_kfac.py @@ -145,7 +145,7 @@ def test_jacobian_kfac_vs_pblockdiag(): """ for get_task, mult in zip( [get_conv1dnet_kfc_task, get_convnet_kfc_task, get_fullyconnect_kfac_task], - [3., 15.0, 1.0] + [3.0, 15.0, 1.0], ): loader, lc, parameters, model, function, n_output = get_task() From 7cb28970d482b47f34632920caa6751bb60a21fe Mon Sep 17 00:00:00 2001 From: Thomas George Date: Fri, 24 Nov 2023 23:50:22 +0100 Subject: [PATCH 5/9] fix kfac with bias --- nngeometry/object/pspace.py | 10 +++++----- tests/tasks.py | 2 +- tests/test_jacobian_kfac.py | 12 ++++++++++-- 3 files changed, 16 insertions(+), 8 deletions(-) diff --git a/nngeometry/object/pspace.py b/nngeometry/object/pspace.py index 33c9fb0..ddd6231 100644 --- a/nngeometry/object/pspace.py +++ b/nngeometry/object/pspace.py @@ -487,7 +487,7 @@ def get_dense_tensor(self, split_weight_bias=True): a, g = self.data[layer_id] start = self.generator.layer_collection.p_pos[layer_id] sAG = a.size(0) * g.size(0) - if split_weight_bias: + if split_weight_bias and layer.bias: reconstruct = torch.cat( [ torch.cat( @@ -517,7 +517,7 @@ def get_diag(self, split_weight_bias=True): for layer_id, layer in self.generator.layer_collection.layers.items(): a, g = self.data[layer_id] diag_of_block = torch.diag(g).view(-1, 1) * torch.diag(a).view(1, -1) - if split_weight_bias: + if split_weight_bias and layer.bias: diags.append(diag_of_block[:, :-1].contiguous().view(-1)) diags.append(diag_of_block[:, -1:].view(-1)) else: @@ -535,10 +535,10 @@ def mv(self, vs): v = torch.cat([v, vs_dict[layer_id][1].unsqueeze(1)], dim=1) a, g = self.data[layer_id] mv = torch.mm(torch.mm(g, v), a) - if layer.bias is None: - mv_tuple = (mv.view(*sw),) - else: + if layer.bias: mv_tuple = (mv[:, :-1].contiguous().view(*sw), mv[:, -1].contiguous()) + else: + mv_tuple = (mv.view(*sw),) out_dict[layer_id] = mv_tuple return PVector(layer_collection=vs.layer_collection, dict_repr=out_dict) diff --git a/tests/tasks.py b/tests/tasks.py index 32ea77c..ae67633 100644 --- a/tests/tasks.py +++ b/tests/tasks.py @@ -101,7 +101,7 @@ def __init__(self, normalization="none"): self.bn1 = nn.BatchNorm2d(6) elif self.normalization == "group_norm": self.gn1 = nn.GroupNorm(2, 6) - self.conv2 = nn.Conv2d(6, 5, 4, 1) + self.conv2 = nn.Conv2d(6, 5, 4, 1, bias=False) self.conv3 = nn.Conv2d(5, 7, 3, 1, 1) if self.normalization == "weight_norm": self.wn2 = WeightNorm1d(7, 4) diff --git a/tests/test_jacobian_kfac.py b/tests/test_jacobian_kfac.py index 2188197..8d41ce3 100644 --- a/tests/test_jacobian_kfac.py +++ b/tests/test_jacobian_kfac.py @@ -4,7 +4,13 @@ import torch import torch.nn as nn import torch.nn.functional as F -from tasks import get_conv_task, get_fullyconnect_task, get_mnist, to_device_model +from tasks import ( + get_conv_task, + get_fullyconnect_task, + get_mnist, + to_device_model, + get_conv1d_task, +) from torch.utils.data import DataLoader, Subset from utils import angle, check_ratio, check_tensors @@ -161,7 +167,7 @@ def test_jacobian_kfac_vs_pblockdiag(): def test_jacobian_kfac(): - for get_task in [get_fullyconnect_task, get_conv_task]: + for get_task in [get_conv1d_task, get_fullyconnect_task, get_conv_task]: loader, lc, parameters, model, function, n_output = get_task() generator = Jacobian( @@ -190,6 +196,8 @@ def test_jacobian_kfac(): # Test mv mv_direct = torch.mv(G_kfac_split, random_v.get_flat_representation()) mv_kfac = M_kfac.mv(random_v) + print(mv_direct.size(), lc.layers) + pvec = PVector(layer_collection=lc, vector_repr=mv_direct) check_tensors(mv_direct, mv_kfac.get_flat_representation()) # Test vTMv From 9fb878004be02c731623765d671f1792daebb163 Mon Sep 17 00:00:00 2001 From: Thomas George Date: Sat, 25 Nov 2023 00:21:26 +0100 Subject: [PATCH 6/9] ekfac for conv1d --- nngeometry/generator/jacobian/grads.py | 34 +++++++++++++++++++++++ nngeometry/generator/jacobian/jacobian.py | 5 +++- nngeometry/object/pspace.py | 6 ++-- tests/test_jacobian_ekfac.py | 6 ++-- 4 files changed, 44 insertions(+), 7 deletions(-) diff --git a/nngeometry/generator/jacobian/grads.py b/nngeometry/generator/jacobian/grads.py index 77414b8..befa5bb 100644 --- a/nngeometry/generator/jacobian/grads.py +++ b/nngeometry/generator/jacobian/grads.py @@ -379,6 +379,40 @@ def kfac_gg(cls, buffer, mod, layer, x, gy): DS_tilda = gy.permute(0, 2, 1).contiguous().view(-1, os) buffer.add_(torch.mm(DS_tilda.t(), DS_tilda) / spatial_locations) + @classmethod + def kfe_diag(cls, buffer, mod, layer, x, gy, evecs_a, evecs_g): + ks = (1, mod.weight.size(2)) + gy_s = gy.size() + bs = gy_s[0] + # project x to kfe + x_unfold = F.unfold( + x.unsqueeze(2), + kernel_size=ks, + stride=(1, mod.stride[0]), + padding=(0, mod.padding[0]), + dilation=(1, mod.dilation[0]), + ) + x_unfold_s = x_unfold.size() + x_unfold = ( + x_unfold.view(bs, x_unfold_s[1], -1) + .permute(0, 2, 1) + .contiguous() + .view(-1, x_unfold_s[1]) + ) + if mod.bias is not None: + x_unfold = torch.cat([x_unfold, torch.ones_like(x_unfold[:, :1])], dim=1) + x_kfe = torch.mm(x_unfold, evecs_a) + + # project gy to kfe + gy = gy.view(bs, gy_s[1], -1).permute(0, 2, 1).contiguous() + gy_kfe = torch.mm(gy.view(-1, gy_s[1]), evecs_g) + gy_kfe = gy_kfe.view(bs, -1, gy_s[1]).permute(0, 2, 1).contiguous() + + indiv_gw = torch.bmm( + gy_kfe.view(bs, gy_s[1], -1), x_kfe.view(bs, -1, x_kfe.size(1)) + ) + buffer.add_((indiv_gw**2).sum(dim=0).view(-1)) + FactoryMap = { LinearLayer: LinearJacobianFactory, diff --git a/nngeometry/generator/jacobian/jacobian.py b/nngeometry/generator/jacobian/jacobian.py index 2c3d82c..ddedff9 100644 --- a/nngeometry/generator/jacobian/jacobian.py +++ b/nngeometry/generator/jacobian/jacobian.py @@ -462,6 +462,9 @@ def get_kfe_diag(self, kfe, examples): elif layer_class == "Conv2dLayer": sG = layer.out_channels sA = layer.in_channels * layer.kernel_size[0] * layer.kernel_size[1] + elif layer_class == "Conv1dLayer": + sG = layer.out_channels + sA = layer.in_channels * layer.kernel_size[0] if layer.bias is not None: sA += 1 self._diags[layer_id] = torch.zeros((sG * sA), device=device, dtype=dtype) @@ -778,7 +781,7 @@ def _hook_compute_kfe_diag(self, mod, gy): layer = self.layer_collection[layer_id] x = self.xs[mod] evecs_a, evecs_g = self._kfe[layer_id] - if mod_class in ["Linear", "Conv2d"]: + if mod_class in ["Linear", "Conv2d", "Conv1d"]: FactoryMap[layer.__class__].kfe_diag( self._diags[layer_id], mod, layer, x, gy, evecs_a, evecs_g ) diff --git a/nngeometry/object/pspace.py b/nngeometry/object/pspace.py index ddd6231..b374684 100644 --- a/nngeometry/object/pspace.py +++ b/nngeometry/object/pspace.py @@ -602,7 +602,7 @@ class PMatEKFAC(PMatAbstract): """ EKFAC representation from *George, Laurent et al., Fast Approximate Natural Gradient Descent - in a Kronecker-factored Eigenbasis, NIPS 2018* + in a Kronecker-factored Eigenbasis, NeurIPS 2018* """ @@ -659,9 +659,9 @@ def get_KFE(self, split_weight_bias=True): """ evecs, _ = self.data KFE = dict() - for layer_id, _ in self.generator.layer_collection.layers.items(): + for layer_id, layer in self.generator.layer_collection.layers.items(): evecs_a, evecs_g = evecs[layer_id] - if split_weight_bias: + if split_weight_bias and layer.bias: kronecker(evecs_g, evecs_a[:-1, :]) kronecker(evecs_g, evecs_a[-1:, :].contiguous()) KFE[layer_id] = torch.cat( diff --git a/tests/test_jacobian_ekfac.py b/tests/test_jacobian_ekfac.py index 32b65da..79be84d 100644 --- a/tests/test_jacobian_ekfac.py +++ b/tests/test_jacobian_ekfac.py @@ -1,6 +1,6 @@ import pytest import torch -from tasks import device, get_conv_task, get_fullyconnect_task +from tasks import device, get_conv_task, get_conv1d_task, get_fullyconnect_task from utils import check_ratio, check_tensors from nngeometry.generator import Jacobian @@ -20,7 +20,7 @@ def test_pspace_ekfac_vs_kfac(): sense of the Frobenius norm """ eps = 1e-4 - for get_task in [get_fullyconnect_task, get_conv_task]: + for get_task in [get_conv1d_task, get_fullyconnect_task, get_conv_task]: loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian( @@ -50,7 +50,7 @@ def test_pspace_ekfac_vs_direct(): Check EKFAC basis operations against direct computation using get_dense_tensor """ - for get_task in [get_fullyconnect_task, get_conv_task]: + for get_task in [get_conv1d_task, get_fullyconnect_task, get_conv_task]: loader, lc, parameters, model, function, n_output = get_task() model.train() From 985b4439d43689e96c34f4074942e89973f71343 Mon Sep 17 00:00:00 2001 From: Thomas George Date: Sat, 25 Nov 2023 00:26:49 +0100 Subject: [PATCH 7/9] fixes pvector when bias is none --- nngeometry/object/vector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/nngeometry/object/vector.py b/nngeometry/object/vector.py index 3b2bbb5..4744a67 100644 --- a/nngeometry/object/vector.py +++ b/nngeometry/object/vector.py @@ -194,7 +194,7 @@ def _dict_to_flat(self): parts = [] for layer_id, layer in self.layer_collection.layers.items(): parts.append(self.dict_repr[layer_id][0].view(-1)) - if len(self.dict_repr[layer_id]) > 1: + if layer.bias: parts.append(self.dict_repr[layer_id][1].view(-1)) return torch.cat(parts) From bce2cc6167072cd6bcdab32e0d0beb8ca5068876 Mon Sep 17 00:00:00 2001 From: Thomas George Date: Sat, 25 Nov 2023 13:52:09 +0100 Subject: [PATCH 8/9] flake8 --- nngeometry/generator/jacobian/grads_conv.py | 5 ----- nngeometry/generator/jacobian/jacobian.py | 5 ++++- nngeometry/object/__init__.py | 12 ++++++++++-- tests/tasks.py | 15 --------------- tests/test_jacobian.py | 1 + tests/test_jacobian_kfac.py | 3 --- tests/test_metrics.py | 11 ++++++++--- tests/test_pickle.py | 9 +++++++-- tests/test_vector.py | 3 +-- 9 files changed, 31 insertions(+), 33 deletions(-) diff --git a/nngeometry/generator/jacobian/grads_conv.py b/nngeometry/generator/jacobian/grads_conv.py index 565ea5b..1ae27d7 100644 --- a/nngeometry/generator/jacobian/grads_conv.py +++ b/nngeometry/generator/jacobian/grads_conv.py @@ -83,11 +83,6 @@ def conv_backward( return weight_bgrad -def conv1d_backward(*args, **kwargs): - """Computes per-example gradients for nn.Conv1d layers.""" - return conv_backward(*args, nd=1, **kwargs) - - def conv2d_backward_using_conv(mod, x, gy): """Computes per-example gradients for nn.Conv2d layers.""" return conv_backward( diff --git a/nngeometry/generator/jacobian/jacobian.py b/nngeometry/generator/jacobian/jacobian.py index ddedff9..93b611a 100644 --- a/nngeometry/generator/jacobian/jacobian.py +++ b/nngeometry/generator/jacobian/jacobian.py @@ -41,7 +41,10 @@ def __init__( self.centering = centering if function is None: - function = lambda *x: model(x[0]) + + def function(*x): + return model(x[0]) + self.function = function if layer_collection is None: diff --git a/nngeometry/object/__init__.py b/nngeometry/object/__init__.py index de504b3..d8065db 100644 --- a/nngeometry/object/__init__.py +++ b/nngeometry/object/__init__.py @@ -1,7 +1,15 @@ from .fspace import FMatDense from .map import PullBackDense, PushForwardDense, PushForwardImplicit -from .pspace import (PMatBlockDiag, PMatDense, PMatDiag, PMatEKFAC, - PMatImplicit, PMatKFAC, PMatLowRank, PMatQuasiDiag) +from .pspace import ( + PMatBlockDiag, + PMatDense, + PMatDiag, + PMatEKFAC, + PMatImplicit, + PMatKFAC, + PMatLowRank, + PMatQuasiDiag, +) from .vector import FVector, PVector __all__ = [ diff --git a/tests/tasks.py b/tests/tasks.py index ae67633..6a6469b 100644 --- a/tests/tasks.py +++ b/tests/tasks.py @@ -370,21 +370,6 @@ def get_fullyconnect_affine_task(): return get_fullyconnect_task(normalization="affine") -def get_conv_task(normalization="none"): - train_set = get_mnist() - train_set = Subset(train_set, range(70)) - train_loader = DataLoader(dataset=train_set, batch_size=30, shuffle=False) - net = ConvNet(normalization=normalization) - to_device_model(net) - net.eval() - - def output_fn(input, target): - return net(to_device(input)) - - layer_collection = LayerCollection.from_model(net) - return (train_loader, layer_collection, net.parameters(), net, output_fn, 3) - - def get_conv_bn_task(): return get_conv_task(normalization="batch_norm") diff --git a/tests/test_jacobian.py b/tests/test_jacobian.py index c188364..05ca2de 100644 --- a/tests/test_jacobian.py +++ b/tests/test_jacobian.py @@ -833,6 +833,7 @@ def test_bn_eval_mode(): model.train() with pytest.raises(RuntimeError): FMat_dense = FMatDense(generator=generator, examples=loader) + FMat_dense.get_dense_tensor() def test_example_passing(): diff --git a/tests/test_jacobian_kfac.py b/tests/test_jacobian_kfac.py index 8d41ce3..f27178c 100644 --- a/tests/test_jacobian_kfac.py +++ b/tests/test_jacobian_kfac.py @@ -3,7 +3,6 @@ import pytest import torch import torch.nn as nn -import torch.nn.functional as F from tasks import ( get_conv_task, get_fullyconnect_task, @@ -196,8 +195,6 @@ def test_jacobian_kfac(): # Test mv mv_direct = torch.mv(G_kfac_split, random_v.get_flat_representation()) mv_kfac = M_kfac.mv(random_v) - print(mv_direct.size(), lc.layers) - pvec = PVector(layer_collection=lc, vector_repr=mv_direct) check_tensors(mv_direct, mv_kfac.get_flat_representation()) # Test vTMv diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 4bd375e..a67c6e9 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1,9 +1,14 @@ import pytest import torch import torch.nn.functional as tF -from tasks import (device, get_conv_gn_task, get_conv_task, - get_fullyconnect_segm_task, get_fullyconnect_task, - to_device) +from tasks import ( + device, + get_conv_gn_task, + get_conv_task, + get_fullyconnect_segm_task, + get_fullyconnect_task, + to_device, +) from test_jacobian import get_output_vector, update_model from nngeometry.metrics import FIM, FIM_MonteCarlo diff --git a/tests/test_pickle.py b/tests/test_pickle.py index 5645083..0768d4e 100644 --- a/tests/test_pickle.py +++ b/tests/test_pickle.py @@ -4,8 +4,13 @@ from utils import check_tensors from nngeometry.generator import Jacobian -from nngeometry.object.pspace import (PMatBlockDiag, PMatDense, PMatDiag, - PMatLowRank, PMatQuasiDiag) +from nngeometry.object.pspace import ( + PMatBlockDiag, + PMatDense, + PMatDiag, + PMatLowRank, + PMatQuasiDiag, +) from nngeometry.object.vector import PVector diff --git a/tests/test_vector.py b/tests/test_vector.py index 3fef91e..1e34eba 100644 --- a/tests/test_vector.py +++ b/tests/test_vector.py @@ -5,8 +5,7 @@ from utils import check_ratio, check_tensors from nngeometry.layercollection import LayerCollection -from nngeometry.object.vector import (PVector, random_pvector, - random_pvector_dict) +from nngeometry.object.vector import PVector, random_pvector, random_pvector_dict class ConvNet(nn.Module): From 418beb97e628e8ce286aababb55c7b8ace0e2b83 Mon Sep 17 00:00:00 2001 From: Thomas George Date: Sat, 25 Nov 2023 14:14:12 +0100 Subject: [PATCH 9/9] removes useless instruction in test --- tests/test_jacobian.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_jacobian.py b/tests/test_jacobian.py index 05ca2de..c188364 100644 --- a/tests/test_jacobian.py +++ b/tests/test_jacobian.py @@ -833,7 +833,6 @@ def test_bn_eval_mode(): model.train() with pytest.raises(RuntimeError): FMat_dense = FMatDense(generator=generator, examples=loader) - FMat_dense.get_dense_tensor() def test_example_passing():