Skip to content

Commit

Permalink
kfac for conv1d
Browse files Browse the repository at this point in the history
  • Loading branch information
tfjgeorge committed Nov 24, 2023
1 parent aca036f commit 68e31fb
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 3 deletions.
26 changes: 26 additions & 0 deletions nngeometry/generator/jacobian/grads.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,32 @@ def Jv(cls, buffer, mod, layer, x, gy, v, v_bias):
if layer.bias is not None:
buffer.add_(torch.mv(gy.sum(dim=2), v_bias))

@classmethod
def kfac_xx(cls, buffer, mod, layer, x, gy):
ks = (1, mod.weight.size(2))
# A_tilda in KFC
A_tilda = F.unfold(
x.unsqueeze(2),
kernel_size=ks,
stride=(1, mod.stride[0]),
padding=(0, mod.padding[0]),
dilation=(1, mod.dilation[0]),
)
# A_tilda is bs * #locations x #parameters
A_tilda = A_tilda.permute(0, 2, 1).contiguous().view(-1, A_tilda.size(1))
if layer.bias is not None:
A_tilda = torch.cat([A_tilda, torch.ones_like(A_tilda[:, :1])], dim=1)
# Omega_hat in KFC
buffer.add_(torch.mm(A_tilda.t(), A_tilda))

@classmethod
def kfac_gg(cls, buffer, mod, layer, x, gy):
spatial_locations = gy.size(2)
os = gy.size(1)
# DS_tilda in KFC
DS_tilda = gy.permute(0, 2, 1).contiguous().view(-1, os)
buffer.add_(torch.mm(DS_tilda.t(), DS_tilda) / spatial_locations)


FactoryMap = {
LinearLayer: LinearJacobianFactory,
Expand Down
3 changes: 2 additions & 1 deletion nngeometry/generator/jacobian/grads_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,8 @@ def conv1d_backward_using_unfold(mod, x, gy):
gy_s = gy.size()
bs = gy_s[0]
x_unfold = F.unfold(
x.unsqueeze(2), kernel_size=ks, stride=(1, mod.stride[0]), padding=(0, mod.padding[0]), dilation=(1, mod.dilation[0])
x.unsqueeze(2), kernel_size=ks, stride=(1, mod.stride[0]),
padding=(0, mod.padding[0]), dilation=(1, mod.dilation[0])
)
x_unfold_s = x_unfold.size()
return torch.bmm(
Expand Down
5 changes: 4 additions & 1 deletion nngeometry/generator/jacobian/jacobian.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,9 @@ def get_kfac_blocks(self, examples):
elif layer_class == "Conv2dLayer":
sG = layer.out_channels
sA = layer.in_channels * layer.kernel_size[0] * layer.kernel_size[1]
elif layer_class == "Conv1dLayer":
sG = layer.out_channels
sA = layer.in_channels * layer.kernel_size[0]
if layer.bias is not None:
sA += 1
self._blocks[layer_id] = (
Expand Down Expand Up @@ -761,7 +764,7 @@ def _hook_compute_kfac_blocks(self, mod, gy):
layer_id = self.m_to_l[mod]
layer = self.layer_collection[layer_id]
block = self._blocks[layer_id]
if mod_class in ["Linear", "Conv2d"]:
if mod_class in ["Linear", "Conv2d", "Conv1d"]:
FactoryMap[layer.__class__].kfac_gg(block[1], mod, layer, x, gy)
if self.i_output == 0:
# do this only once if n_output > 1
Expand Down
29 changes: 28 additions & 1 deletion tests/test_jacobian_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,32 @@ def output_fn(input, target):
return (train_loader, layer_collection, net.parameters(), net, output_fn, 4)


class Conv1dNet(nn.Module):
def __init__(self):
super(Conv1dNet, self).__init__()
self.conv1 = nn.Conv1d(3, 4, 3, 1)

def forward(self, x):
x = self.conv1(x)
return x.sum(axis=(2,))


def get_conv1dnet_kfc_task(bs=5):
train_set = torch.utils.data.TensorDataset(
torch.ones(size=(10, 3, 5)), torch.randint(0, 4, size=(10, 4))
)
train_loader = DataLoader(dataset=train_set, batch_size=bs, shuffle=False)
net = Conv1dNet()
to_device_model(net)
net.eval()

def output_fn(input, target):
return net(to_device(input))

layer_collection = LayerCollection.from_model(net)
return (train_loader, layer_collection, net.parameters(), net, output_fn, 4)


@pytest.fixture(autouse=True)
def make_test_deterministic():
torch.manual_seed(1234)
Expand All @@ -118,7 +144,8 @@ def test_jacobian_kfac_vs_pblockdiag():
where they are the same
"""
for get_task, mult in zip(
[get_convnet_kfc_task, get_fullyconnect_kfac_task], [15.0, 1.0]
[get_conv1dnet_kfc_task, get_convnet_kfc_task, get_fullyconnect_kfac_task],
[3., 15.0, 1.0]
):
loader, lc, parameters, model, function, n_output = get_task()

Expand Down

0 comments on commit 68e31fb

Please sign in to comment.