Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
tfjgeorge committed Nov 22, 2023
1 parent 5658a74 commit dc3f140
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 2 deletions.
34 changes: 34 additions & 0 deletions nngeometry/layercollection.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class LayerCollection:
"Cosine1d",
"Affine1d",
"ConvTranspose2d",
"Conv1d",
]

def __init__(self, layers=None):
Expand Down Expand Up @@ -112,6 +113,13 @@ def _module_to_layer(mod):
kernel_size=mod.kernel_size,
bias=(mod.bias is not None),
)
elif mod_class == "Conv1d":
return Conv1dLayer(
in_channels=mod.in_channels,
out_channels=mod.out_channels,
kernel_size=mod.kernel_size,
bias=(mod.bias is not None),
)
elif mod_class == "BatchNorm1d":
return BatchNorm1dLayer(num_features=mod.num_features)
elif mod_class == "BatchNorm2d":
Expand Down Expand Up @@ -230,6 +238,32 @@ def __eq__(self, other):
and self.kernel_size == other.kernel_size
)

class Conv1dLayer(AbstractLayer):
def __init__(self, in_channels, out_channels, kernel_size, bias=True):
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.weight = Parameter(
out_channels, in_channels, kernel_size[0]
)
if bias:
self.bias = Parameter(out_channels)
else:
self.bias = None

def numel(self):
if self.bias is not None:
return self.weight.numel() + self.bias.numel()
else:
return self.weight.numel()

def __eq__(self, other):
return (
self.in_channels == other.in_channels
and self.out_channels == other.out_channels
and self.kernel_size == other.kernel_size
)


class LinearLayer(AbstractLayer):
def __init__(self, in_features, out_features, bias=True):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,5 @@
'nngeometry.generator',
'nngeometry.generator.jacobian',
'nngeometry.object'],
install_requires=['torch>=2.0.0'],
install_requires=['torch>=2.0.0','torchvision>=0.9.1'],
zip_safe=False)
34 changes: 34 additions & 0 deletions tests/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,3 +482,37 @@ def output_fn(input, target):

layer_collection = LayerCollection.from_model(net)
return (train_loader, layer_collection, net.parameters(), net, output_fn, 3)


class Conv1dNet(nn.Module):
def __init__(self, normalization="none"):
super(Conv1dNet, self).__init__()
if normalization != 'none':
raise NotImplementedError
self.normalization = normalization
self.conv1 = nn.Conv1d(1, 6, 3, 2)
self.conv2 = nn.Conv1d(6, 5, 4, 1)
self.fc1 = nn.Linear(7, 4)

def forward(self, x):
x = x.reshape(x.size(0), x.size(1), -1)
x = tF.relu(self.conv1(x))
x = tF.relu(self.conv2(x))
x = x.view(x.size(0), -1)
x = self.fc1(x)
return x


def get_conv1d_task(normalization="none"):
train_set = get_mnist()
train_set = Subset(train_set, range(70))
train_loader = DataLoader(dataset=train_set, batch_size=30, shuffle=False)
net = Conv1dNet(normalization=normalization)
to_device_model(net)
net.eval()

def output_fn(input, target):
return net(to_device(input))

layer_collection = LayerCollection.from_model(net)
return (train_loader, layer_collection, net.parameters(), net, output_fn, 3)
3 changes: 2 additions & 1 deletion tests/test_jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
get_fullyconnect_onlylast_task, get_fullyconnect_task,
get_fullyconnect_wn_task, get_linear_conv_task,
get_linear_fc_task, get_small_conv_transpose_task,
get_small_conv_wn_task)
get_small_conv_wn_task, get_conv1d_task)
from utils import check_ratio, check_tensors

from nngeometry.generator import Jacobian
Expand All @@ -27,6 +27,7 @@
]

nonlinear_tasks = [
get_conv1d_task,
get_small_conv_transpose_task,
get_conv_task,
get_fullyconnect_affine_task,
Expand Down

0 comments on commit dc3f140

Please sign in to comment.