diff --git a/nngeometry/generator/jacobian/grads.py b/nngeometry/generator/jacobian/grads.py index 06524a1..fe2bf76 100644 --- a/nngeometry/generator/jacobian/grads.py +++ b/nngeometry/generator/jacobian/grads.py @@ -346,6 +346,32 @@ def Jv(cls, buffer, mod, layer, x, gy, v, v_bias): if layer.bias is not None: buffer.add_(torch.mv(gy.sum(dim=2), v_bias)) + @classmethod + def kfac_xx(cls, buffer, mod, layer, x, gy): + ks = (1, mod.weight.size(2)) + # A_tilda in KFC + A_tilda = F.unfold( + x.unsqueeze(2), + kernel_size=ks, + stride=(1, mod.stride[0]), + padding=(0, mod.padding[0]), + dilation=(1, mod.dilation[0]), + ) + # A_tilda is bs * #locations x #parameters + A_tilda = A_tilda.permute(0, 2, 1).contiguous().view(-1, A_tilda.size(1)) + if layer.bias is not None: + A_tilda = torch.cat([A_tilda, torch.ones_like(A_tilda[:, :1])], dim=1) + # Omega_hat in KFC + buffer.add_(torch.mm(A_tilda.t(), A_tilda)) + + @classmethod + def kfac_gg(cls, buffer, mod, layer, x, gy): + spatial_locations = gy.size(2) + os = gy.size(1) + # DS_tilda in KFC + DS_tilda = gy.permute(0, 2, 1).contiguous().view(-1, os) + buffer.add_(torch.mm(DS_tilda.t(), DS_tilda) / spatial_locations) + FactoryMap = { LinearLayer: LinearJacobianFactory, diff --git a/nngeometry/generator/jacobian/grads_conv.py b/nngeometry/generator/jacobian/grads_conv.py index f89b27b..c7b025c 100644 --- a/nngeometry/generator/jacobian/grads_conv.py +++ b/nngeometry/generator/jacobian/grads_conv.py @@ -128,7 +128,8 @@ def conv1d_backward_using_unfold(mod, x, gy): gy_s = gy.size() bs = gy_s[0] x_unfold = F.unfold( - x.unsqueeze(2), kernel_size=ks, stride=(1, mod.stride[0]), padding=(0, mod.padding[0]), dilation=(1, mod.dilation[0]) + x.unsqueeze(2), kernel_size=ks, stride=(1, mod.stride[0]), + padding=(0, mod.padding[0]), dilation=(1, mod.dilation[0]) ) x_unfold_s = x_unfold.size() return torch.bmm( diff --git a/nngeometry/generator/jacobian/jacobian.py b/nngeometry/generator/jacobian/jacobian.py index 96a5125..2c3d82c 100644 --- a/nngeometry/generator/jacobian/jacobian.py +++ b/nngeometry/generator/jacobian/jacobian.py @@ -250,6 +250,9 @@ def get_kfac_blocks(self, examples): elif layer_class == "Conv2dLayer": sG = layer.out_channels sA = layer.in_channels * layer.kernel_size[0] * layer.kernel_size[1] + elif layer_class == "Conv1dLayer": + sG = layer.out_channels + sA = layer.in_channels * layer.kernel_size[0] if layer.bias is not None: sA += 1 self._blocks[layer_id] = ( @@ -761,7 +764,7 @@ def _hook_compute_kfac_blocks(self, mod, gy): layer_id = self.m_to_l[mod] layer = self.layer_collection[layer_id] block = self._blocks[layer_id] - if mod_class in ["Linear", "Conv2d"]: + if mod_class in ["Linear", "Conv2d", "Conv1d"]: FactoryMap[layer.__class__].kfac_gg(block[1], mod, layer, x, gy) if self.i_output == 0: # do this only once if n_output > 1 diff --git a/tests/test_jacobian_kfac.py b/tests/test_jacobian_kfac.py index a3f97e6..aa528f9 100644 --- a/tests/test_jacobian_kfac.py +++ b/tests/test_jacobian_kfac.py @@ -106,6 +106,32 @@ def output_fn(input, target): return (train_loader, layer_collection, net.parameters(), net, output_fn, 4) +class Conv1dNet(nn.Module): + def __init__(self): + super(Conv1dNet, self).__init__() + self.conv1 = nn.Conv1d(3, 4, 3, 1) + + def forward(self, x): + x = self.conv1(x) + return x.sum(axis=(2,)) + + +def get_conv1dnet_kfc_task(bs=5): + train_set = torch.utils.data.TensorDataset( + torch.ones(size=(10, 3, 5)), torch.randint(0, 4, size=(10, 4)) + ) + train_loader = DataLoader(dataset=train_set, batch_size=bs, shuffle=False) + net = Conv1dNet() + to_device_model(net) + net.eval() + + def output_fn(input, target): + return net(to_device(input)) + + layer_collection = LayerCollection.from_model(net) + return (train_loader, layer_collection, net.parameters(), net, output_fn, 4) + + @pytest.fixture(autouse=True) def make_test_deterministic(): torch.manual_seed(1234) @@ -118,7 +144,8 @@ def test_jacobian_kfac_vs_pblockdiag(): where they are the same """ for get_task, mult in zip( - [get_convnet_kfc_task, get_fullyconnect_kfac_task], [15.0, 1.0] + [get_conv1dnet_kfc_task, get_convnet_kfc_task, get_fullyconnect_kfac_task], + [3., 15.0, 1.0] ): loader, lc, parameters, model, function, n_output = get_task()