Skip to content

Commit

Permalink
fix kfac with bias
Browse files Browse the repository at this point in the history
  • Loading branch information
tfjgeorge committed Nov 24, 2023
1 parent 23f3100 commit 7cb2897
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 8 deletions.
10 changes: 5 additions & 5 deletions nngeometry/object/pspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def get_dense_tensor(self, split_weight_bias=True):
a, g = self.data[layer_id]
start = self.generator.layer_collection.p_pos[layer_id]
sAG = a.size(0) * g.size(0)
if split_weight_bias:
if split_weight_bias and layer.bias:
reconstruct = torch.cat(
[
torch.cat(
Expand Down Expand Up @@ -517,7 +517,7 @@ def get_diag(self, split_weight_bias=True):
for layer_id, layer in self.generator.layer_collection.layers.items():
a, g = self.data[layer_id]
diag_of_block = torch.diag(g).view(-1, 1) * torch.diag(a).view(1, -1)
if split_weight_bias:
if split_weight_bias and layer.bias:
diags.append(diag_of_block[:, :-1].contiguous().view(-1))
diags.append(diag_of_block[:, -1:].view(-1))
else:
Expand All @@ -535,10 +535,10 @@ def mv(self, vs):
v = torch.cat([v, vs_dict[layer_id][1].unsqueeze(1)], dim=1)
a, g = self.data[layer_id]
mv = torch.mm(torch.mm(g, v), a)
if layer.bias is None:
mv_tuple = (mv.view(*sw),)
else:
if layer.bias:
mv_tuple = (mv[:, :-1].contiguous().view(*sw), mv[:, -1].contiguous())
else:
mv_tuple = (mv.view(*sw),)
out_dict[layer_id] = mv_tuple
return PVector(layer_collection=vs.layer_collection, dict_repr=out_dict)

Expand Down
2 changes: 1 addition & 1 deletion tests/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(self, normalization="none"):
self.bn1 = nn.BatchNorm2d(6)
elif self.normalization == "group_norm":
self.gn1 = nn.GroupNorm(2, 6)
self.conv2 = nn.Conv2d(6, 5, 4, 1)
self.conv2 = nn.Conv2d(6, 5, 4, 1, bias=False)
self.conv3 = nn.Conv2d(5, 7, 3, 1, 1)
if self.normalization == "weight_norm":
self.wn2 = WeightNorm1d(7, 4)
Expand Down
12 changes: 10 additions & 2 deletions tests/test_jacobian_kfac.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from tasks import get_conv_task, get_fullyconnect_task, get_mnist, to_device_model
from tasks import (
get_conv_task,
get_fullyconnect_task,
get_mnist,
to_device_model,
get_conv1d_task,
)
from torch.utils.data import DataLoader, Subset
from utils import angle, check_ratio, check_tensors

Expand Down Expand Up @@ -161,7 +167,7 @@ def test_jacobian_kfac_vs_pblockdiag():


def test_jacobian_kfac():
for get_task in [get_fullyconnect_task, get_conv_task]:
for get_task in [get_conv1d_task, get_fullyconnect_task, get_conv_task]:
loader, lc, parameters, model, function, n_output = get_task()

generator = Jacobian(
Expand Down Expand Up @@ -190,6 +196,8 @@ def test_jacobian_kfac():
# Test mv
mv_direct = torch.mv(G_kfac_split, random_v.get_flat_representation())
mv_kfac = M_kfac.mv(random_v)
print(mv_direct.size(), lc.layers)
pvec = PVector(layer_collection=lc, vector_repr=mv_direct)
check_tensors(mv_direct, mv_kfac.get_flat_representation())

# Test vTMv
Expand Down

0 comments on commit 7cb2897

Please sign in to comment.