From dc3f140396ac9e52943edbf6e255135f81503fc2 Mon Sep 17 00:00:00 2001 From: Thomas George Date: Wed, 22 Nov 2023 20:47:17 +0100 Subject: [PATCH] wip --- nngeometry/layercollection.py | 34 ++++++++++++++++++++++++++++++++++ setup.py | 2 +- tests/tasks.py | 34 ++++++++++++++++++++++++++++++++++ tests/test_jacobian.py | 3 ++- 4 files changed, 71 insertions(+), 2 deletions(-) diff --git a/nngeometry/layercollection.py b/nngeometry/layercollection.py index 23176fb..344174b 100644 --- a/nngeometry/layercollection.py +++ b/nngeometry/layercollection.py @@ -24,6 +24,7 @@ class LayerCollection: "Cosine1d", "Affine1d", "ConvTranspose2d", + "Conv1d", ] def __init__(self, layers=None): @@ -112,6 +113,13 @@ def _module_to_layer(mod): kernel_size=mod.kernel_size, bias=(mod.bias is not None), ) + elif mod_class == "Conv1d": + return Conv1dLayer( + in_channels=mod.in_channels, + out_channels=mod.out_channels, + kernel_size=mod.kernel_size, + bias=(mod.bias is not None), + ) elif mod_class == "BatchNorm1d": return BatchNorm1dLayer(num_features=mod.num_features) elif mod_class == "BatchNorm2d": @@ -230,6 +238,32 @@ def __eq__(self, other): and self.kernel_size == other.kernel_size ) +class Conv1dLayer(AbstractLayer): + def __init__(self, in_channels, out_channels, kernel_size, bias=True): + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.weight = Parameter( + out_channels, in_channels, kernel_size[0] + ) + if bias: + self.bias = Parameter(out_channels) + else: + self.bias = None + + def numel(self): + if self.bias is not None: + return self.weight.numel() + self.bias.numel() + else: + return self.weight.numel() + + def __eq__(self, other): + return ( + self.in_channels == other.in_channels + and self.out_channels == other.out_channels + and self.kernel_size == other.kernel_size + ) + class LinearLayer(AbstractLayer): def __init__(self, in_features, out_features, bias=True): diff --git a/setup.py b/setup.py index edb0c41..dc1b1ad 100644 --- a/setup.py +++ b/setup.py @@ -11,5 +11,5 @@ 'nngeometry.generator', 'nngeometry.generator.jacobian', 'nngeometry.object'], - install_requires=['torch>=2.0.0'], + install_requires=['torch>=2.0.0','torchvision>=0.9.1'], zip_safe=False) diff --git a/tests/tasks.py b/tests/tasks.py index 55bbd2f..cd1d0d3 100644 --- a/tests/tasks.py +++ b/tests/tasks.py @@ -482,3 +482,37 @@ def output_fn(input, target): layer_collection = LayerCollection.from_model(net) return (train_loader, layer_collection, net.parameters(), net, output_fn, 3) + + +class Conv1dNet(nn.Module): + def __init__(self, normalization="none"): + super(Conv1dNet, self).__init__() + if normalization != 'none': + raise NotImplementedError + self.normalization = normalization + self.conv1 = nn.Conv1d(1, 6, 3, 2) + self.conv2 = nn.Conv1d(6, 5, 4, 1) + self.fc1 = nn.Linear(7, 4) + + def forward(self, x): + x = x.reshape(x.size(0), x.size(1), -1) + x = tF.relu(self.conv1(x)) + x = tF.relu(self.conv2(x)) + x = x.view(x.size(0), -1) + x = self.fc1(x) + return x + + +def get_conv1d_task(normalization="none"): + train_set = get_mnist() + train_set = Subset(train_set, range(70)) + train_loader = DataLoader(dataset=train_set, batch_size=30, shuffle=False) + net = Conv1dNet(normalization=normalization) + to_device_model(net) + net.eval() + + def output_fn(input, target): + return net(to_device(input)) + + layer_collection = LayerCollection.from_model(net) + return (train_loader, layer_collection, net.parameters(), net, output_fn, 3) diff --git a/tests/test_jacobian.py b/tests/test_jacobian.py index e9f1348..db6b234 100644 --- a/tests/test_jacobian.py +++ b/tests/test_jacobian.py @@ -7,7 +7,7 @@ get_fullyconnect_onlylast_task, get_fullyconnect_task, get_fullyconnect_wn_task, get_linear_conv_task, get_linear_fc_task, get_small_conv_transpose_task, - get_small_conv_wn_task) + get_small_conv_wn_task, get_conv1d_task) from utils import check_ratio, check_tensors from nngeometry.generator import Jacobian @@ -27,6 +27,7 @@ ] nonlinear_tasks = [ + get_conv1d_task, get_small_conv_transpose_task, get_conv_task, get_fullyconnect_affine_task,