diff --git a/nngeometry/generator/jacobian/grads.py b/nngeometry/generator/jacobian/grads.py index ea43c13..befa5bb 100644 --- a/nngeometry/generator/jacobian/grads.py +++ b/nngeometry/generator/jacobian/grads.py @@ -1,13 +1,21 @@ import torch import torch.nn.functional as F -from nngeometry.layercollection import (Affine1dLayer, BatchNorm1dLayer, - BatchNorm2dLayer, Conv2dLayer, - ConvTranspose2dLayer, Cosine1dLayer, - GroupNormLayer, LinearLayer, - WeightNorm1dLayer, WeightNorm2dLayer) - -from .grads_conv import conv2d_backward, convtranspose2d_backward +from nngeometry.layercollection import ( + Affine1dLayer, + BatchNorm1dLayer, + BatchNorm2dLayer, + Conv2dLayer, + ConvTranspose2dLayer, + Cosine1dLayer, + GroupNormLayer, + LinearLayer, + WeightNorm1dLayer, + WeightNorm2dLayer, + Conv1dLayer, +) + +from .grads_conv import conv2d_backward, convtranspose2d_backward, conv1d_backward class JacobianFactory: @@ -325,8 +333,90 @@ def flat_grad(cls, buffer, mod, layer, x, gy): buffer[:, w_numel:].add_(gy) +class Conv1dJacobianFactory(JacobianFactory): + @classmethod + def flat_grad(cls, buffer, mod, layer, x, gy): + bs = x.size(0) + w_numel = layer.weight.numel() + indiv_gw = conv1d_backward(mod, x, gy) + buffer[:, :w_numel].add_(indiv_gw.view(bs, -1)) + if layer.bias is not None: + buffer[:, w_numel:].add_(gy.sum(dim=2)) + + @classmethod + def Jv(cls, buffer, mod, layer, x, gy, v, v_bias): + bs = x.size(0) + gy2 = F.conv1d( + x, v, stride=mod.stride, padding=mod.padding, dilation=mod.dilation + ) + buffer.add_((gy * gy2).view(bs, -1).sum(dim=1)) + if layer.bias is not None: + buffer.add_(torch.mv(gy.sum(dim=2), v_bias)) + + @classmethod + def kfac_xx(cls, buffer, mod, layer, x, gy): + ks = (1, mod.weight.size(2)) + # A_tilda in KFC + A_tilda = F.unfold( + x.unsqueeze(2), + kernel_size=ks, + stride=(1, mod.stride[0]), + padding=(0, mod.padding[0]), + dilation=(1, mod.dilation[0]), + ) + # A_tilda is bs * #locations x #parameters + A_tilda = A_tilda.permute(0, 2, 1).contiguous().view(-1, A_tilda.size(1)) + if layer.bias is not None: + A_tilda = torch.cat([A_tilda, torch.ones_like(A_tilda[:, :1])], dim=1) + # Omega_hat in KFC + buffer.add_(torch.mm(A_tilda.t(), A_tilda)) + + @classmethod + def kfac_gg(cls, buffer, mod, layer, x, gy): + spatial_locations = gy.size(2) + os = gy.size(1) + # DS_tilda in KFC + DS_tilda = gy.permute(0, 2, 1).contiguous().view(-1, os) + buffer.add_(torch.mm(DS_tilda.t(), DS_tilda) / spatial_locations) + + @classmethod + def kfe_diag(cls, buffer, mod, layer, x, gy, evecs_a, evecs_g): + ks = (1, mod.weight.size(2)) + gy_s = gy.size() + bs = gy_s[0] + # project x to kfe + x_unfold = F.unfold( + x.unsqueeze(2), + kernel_size=ks, + stride=(1, mod.stride[0]), + padding=(0, mod.padding[0]), + dilation=(1, mod.dilation[0]), + ) + x_unfold_s = x_unfold.size() + x_unfold = ( + x_unfold.view(bs, x_unfold_s[1], -1) + .permute(0, 2, 1) + .contiguous() + .view(-1, x_unfold_s[1]) + ) + if mod.bias is not None: + x_unfold = torch.cat([x_unfold, torch.ones_like(x_unfold[:, :1])], dim=1) + x_kfe = torch.mm(x_unfold, evecs_a) + + # project gy to kfe + gy = gy.view(bs, gy_s[1], -1).permute(0, 2, 1).contiguous() + gy_kfe = torch.mm(gy.view(-1, gy_s[1]), evecs_g) + gy_kfe = gy_kfe.view(bs, -1, gy_s[1]).permute(0, 2, 1).contiguous() + + indiv_gw = torch.bmm( + gy_kfe.view(bs, gy_s[1], -1), x_kfe.view(bs, -1, x_kfe.size(1)) + ) + buffer.add_((indiv_gw**2).sum(dim=0).view(-1)) + + FactoryMap = { LinearLayer: LinearJacobianFactory, + Conv1dLayer: Conv1dJacobianFactory, Conv2dLayer: Conv2dJacobianFactory, ConvTranspose2dLayer: ConvTranspose2dJacobianFactory, BatchNorm1dLayer: BatchNorm1dJacobianFactory, diff --git a/nngeometry/generator/jacobian/grads_conv.py b/nngeometry/generator/jacobian/grads_conv.py index 268c568..1ae27d7 100644 --- a/nngeometry/generator/jacobian/grads_conv.py +++ b/nngeometry/generator/jacobian/grads_conv.py @@ -83,11 +83,6 @@ def conv_backward( return weight_bgrad -def conv1d_backward(*args, **kwargs): - """Computes per-example gradients for nn.Conv1d layers.""" - return conv_backward(*args, nd=1, **kwargs) - - def conv2d_backward_using_conv(mod, x, gy): """Computes per-example gradients for nn.Conv2d layers.""" return conv_backward( @@ -119,7 +114,29 @@ def conv2d_backward_using_unfold(mod, x, gy): def conv2d_backward(*args, **kwargs): - return _conv_grad_impl.get_impl()(*args, **kwargs) + return _conv_grad_impl.get_impl2d()(*args, **kwargs) + + +def conv1d_backward_using_unfold(mod, x, gy): + """Computes per-example gradients for nn.Conv1d layers.""" + ks = (1, mod.weight.size(2)) + gy_s = gy.size() + bs = gy_s[0] + x_unfold = F.unfold( + x.unsqueeze(2), + kernel_size=ks, + stride=(1, mod.stride[0]), + padding=(0, mod.padding[0]), + dilation=(1, mod.dilation[0]), + ) + x_unfold_s = x_unfold.size() + return torch.bmm( + gy.view(bs, gy_s[1], -1), x_unfold.view(bs, x_unfold_s[1], -1).permute(0, 2, 1) + ) + + +def conv1d_backward(*args, **kwargs): + return _conv_grad_impl.get_impl1d()(*args, **kwargs) class ConvGradImplManager: @@ -129,12 +146,18 @@ def __init__(self): def use_unfold(self, choice=True): self._use_unfold = choice - def get_impl(self): + def get_impl2d(self): if self._use_unfold: return conv2d_backward_using_unfold else: return conv2d_backward_using_conv + def get_impl1d(self): + if self._use_unfold: + return conv1d_backward_using_unfold + else: + raise NotImplementedError() + _conv_grad_impl = ConvGradImplManager() diff --git a/nngeometry/generator/jacobian/jacobian.py b/nngeometry/generator/jacobian/jacobian.py index 96a5125..93b611a 100644 --- a/nngeometry/generator/jacobian/jacobian.py +++ b/nngeometry/generator/jacobian/jacobian.py @@ -41,7 +41,10 @@ def __init__( self.centering = centering if function is None: - function = lambda *x: model(x[0]) + + def function(*x): + return model(x[0]) + self.function = function if layer_collection is None: @@ -250,6 +253,9 @@ def get_kfac_blocks(self, examples): elif layer_class == "Conv2dLayer": sG = layer.out_channels sA = layer.in_channels * layer.kernel_size[0] * layer.kernel_size[1] + elif layer_class == "Conv1dLayer": + sG = layer.out_channels + sA = layer.in_channels * layer.kernel_size[0] if layer.bias is not None: sA += 1 self._blocks[layer_id] = ( @@ -459,6 +465,9 @@ def get_kfe_diag(self, kfe, examples): elif layer_class == "Conv2dLayer": sG = layer.out_channels sA = layer.in_channels * layer.kernel_size[0] * layer.kernel_size[1] + elif layer_class == "Conv1dLayer": + sG = layer.out_channels + sA = layer.in_channels * layer.kernel_size[0] if layer.bias is not None: sA += 1 self._diags[layer_id] = torch.zeros((sG * sA), device=device, dtype=dtype) @@ -761,7 +770,7 @@ def _hook_compute_kfac_blocks(self, mod, gy): layer_id = self.m_to_l[mod] layer = self.layer_collection[layer_id] block = self._blocks[layer_id] - if mod_class in ["Linear", "Conv2d"]: + if mod_class in ["Linear", "Conv2d", "Conv1d"]: FactoryMap[layer.__class__].kfac_gg(block[1], mod, layer, x, gy) if self.i_output == 0: # do this only once if n_output > 1 @@ -775,7 +784,7 @@ def _hook_compute_kfe_diag(self, mod, gy): layer = self.layer_collection[layer_id] x = self.xs[mod] evecs_a, evecs_g = self._kfe[layer_id] - if mod_class in ["Linear", "Conv2d"]: + if mod_class in ["Linear", "Conv2d", "Conv1d"]: FactoryMap[layer.__class__].kfe_diag( self._diags[layer_id], mod, layer, x, gy, evecs_a, evecs_g ) diff --git a/nngeometry/layercollection.py b/nngeometry/layercollection.py index 23176fb..e3924de 100644 --- a/nngeometry/layercollection.py +++ b/nngeometry/layercollection.py @@ -24,6 +24,7 @@ class LayerCollection: "Cosine1d", "Affine1d", "ConvTranspose2d", + "Conv1d", ] def __init__(self, layers=None): @@ -112,6 +113,13 @@ def _module_to_layer(mod): kernel_size=mod.kernel_size, bias=(mod.bias is not None), ) + elif mod_class == "Conv1d": + return Conv1dLayer( + in_channels=mod.in_channels, + out_channels=mod.out_channels, + kernel_size=mod.kernel_size, + bias=(mod.bias is not None), + ) elif mod_class == "BatchNorm1d": return BatchNorm1dLayer(num_features=mod.num_features) elif mod_class == "BatchNorm2d": @@ -231,6 +239,31 @@ def __eq__(self, other): ) +class Conv1dLayer(AbstractLayer): + def __init__(self, in_channels, out_channels, kernel_size, bias=True): + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.weight = Parameter(out_channels, in_channels, kernel_size[0]) + if bias: + self.bias = Parameter(out_channels) + else: + self.bias = None + + def numel(self): + if self.bias is not None: + return self.weight.numel() + self.bias.numel() + else: + return self.weight.numel() + + def __eq__(self, other): + return ( + self.in_channels == other.in_channels + and self.out_channels == other.out_channels + and self.kernel_size == other.kernel_size + ) + + class LinearLayer(AbstractLayer): def __init__(self, in_features, out_features, bias=True): self.in_features = in_features diff --git a/nngeometry/object/__init__.py b/nngeometry/object/__init__.py index de504b3..d8065db 100644 --- a/nngeometry/object/__init__.py +++ b/nngeometry/object/__init__.py @@ -1,7 +1,15 @@ from .fspace import FMatDense from .map import PullBackDense, PushForwardDense, PushForwardImplicit -from .pspace import (PMatBlockDiag, PMatDense, PMatDiag, PMatEKFAC, - PMatImplicit, PMatKFAC, PMatLowRank, PMatQuasiDiag) +from .pspace import ( + PMatBlockDiag, + PMatDense, + PMatDiag, + PMatEKFAC, + PMatImplicit, + PMatKFAC, + PMatLowRank, + PMatQuasiDiag, +) from .vector import FVector, PVector __all__ = [ diff --git a/nngeometry/object/pspace.py b/nngeometry/object/pspace.py index 33c9fb0..b374684 100644 --- a/nngeometry/object/pspace.py +++ b/nngeometry/object/pspace.py @@ -487,7 +487,7 @@ def get_dense_tensor(self, split_weight_bias=True): a, g = self.data[layer_id] start = self.generator.layer_collection.p_pos[layer_id] sAG = a.size(0) * g.size(0) - if split_weight_bias: + if split_weight_bias and layer.bias: reconstruct = torch.cat( [ torch.cat( @@ -517,7 +517,7 @@ def get_diag(self, split_weight_bias=True): for layer_id, layer in self.generator.layer_collection.layers.items(): a, g = self.data[layer_id] diag_of_block = torch.diag(g).view(-1, 1) * torch.diag(a).view(1, -1) - if split_weight_bias: + if split_weight_bias and layer.bias: diags.append(diag_of_block[:, :-1].contiguous().view(-1)) diags.append(diag_of_block[:, -1:].view(-1)) else: @@ -535,10 +535,10 @@ def mv(self, vs): v = torch.cat([v, vs_dict[layer_id][1].unsqueeze(1)], dim=1) a, g = self.data[layer_id] mv = torch.mm(torch.mm(g, v), a) - if layer.bias is None: - mv_tuple = (mv.view(*sw),) - else: + if layer.bias: mv_tuple = (mv[:, :-1].contiguous().view(*sw), mv[:, -1].contiguous()) + else: + mv_tuple = (mv.view(*sw),) out_dict[layer_id] = mv_tuple return PVector(layer_collection=vs.layer_collection, dict_repr=out_dict) @@ -602,7 +602,7 @@ class PMatEKFAC(PMatAbstract): """ EKFAC representation from *George, Laurent et al., Fast Approximate Natural Gradient Descent - in a Kronecker-factored Eigenbasis, NIPS 2018* + in a Kronecker-factored Eigenbasis, NeurIPS 2018* """ @@ -659,9 +659,9 @@ def get_KFE(self, split_weight_bias=True): """ evecs, _ = self.data KFE = dict() - for layer_id, _ in self.generator.layer_collection.layers.items(): + for layer_id, layer in self.generator.layer_collection.layers.items(): evecs_a, evecs_g = evecs[layer_id] - if split_weight_bias: + if split_weight_bias and layer.bias: kronecker(evecs_g, evecs_a[:-1, :]) kronecker(evecs_g, evecs_a[-1:, :].contiguous()) KFE[layer_id] = torch.cat( diff --git a/nngeometry/object/vector.py b/nngeometry/object/vector.py index 3b2bbb5..4744a67 100644 --- a/nngeometry/object/vector.py +++ b/nngeometry/object/vector.py @@ -194,7 +194,7 @@ def _dict_to_flat(self): parts = [] for layer_id, layer in self.layer_collection.layers.items(): parts.append(self.dict_repr[layer_id][0].view(-1)) - if len(self.dict_repr[layer_id]) > 1: + if layer.bias: parts.append(self.dict_repr[layer_id][1].view(-1)) return torch.cat(parts) diff --git a/setup.py b/setup.py index edb0c41..dc1b1ad 100644 --- a/setup.py +++ b/setup.py @@ -11,5 +11,5 @@ 'nngeometry.generator', 'nngeometry.generator.jacobian', 'nngeometry.object'], - install_requires=['torch>=2.0.0'], + install_requires=['torch>=2.0.0','torchvision>=0.9.1'], zip_safe=False) diff --git a/tests/tasks.py b/tests/tasks.py index 55bbd2f..6a6469b 100644 --- a/tests/tasks.py +++ b/tests/tasks.py @@ -101,7 +101,7 @@ def __init__(self, normalization="none"): self.bn1 = nn.BatchNorm2d(6) elif self.normalization == "group_norm": self.gn1 = nn.GroupNorm(2, 6) - self.conv2 = nn.Conv2d(6, 5, 4, 1) + self.conv2 = nn.Conv2d(6, 5, 4, 1, bias=False) self.conv3 = nn.Conv2d(5, 7, 3, 1, 1) if self.normalization == "weight_norm": self.wn2 = WeightNorm1d(7, 4) @@ -370,21 +370,6 @@ def get_fullyconnect_affine_task(): return get_fullyconnect_task(normalization="affine") -def get_conv_task(normalization="none"): - train_set = get_mnist() - train_set = Subset(train_set, range(70)) - train_loader = DataLoader(dataset=train_set, batch_size=30, shuffle=False) - net = ConvNet(normalization=normalization) - to_device_model(net) - net.eval() - - def output_fn(input, target): - return net(to_device(input)) - - layer_collection = LayerCollection.from_model(net) - return (train_loader, layer_collection, net.parameters(), net, output_fn, 3) - - def get_conv_bn_task(): return get_conv_task(normalization="batch_norm") @@ -482,3 +467,39 @@ def output_fn(input, target): layer_collection = LayerCollection.from_model(net) return (train_loader, layer_collection, net.parameters(), net, output_fn, 3) + + +class Conv1dNet(nn.Module): + def __init__(self, normalization="none"): + super(Conv1dNet, self).__init__() + if normalization != "none": + raise NotImplementedError + self.normalization = normalization + self.conv1 = nn.Conv1d(1, 6, 3, 3) + self.conv2 = nn.Conv1d(6, 5, 4, 8, bias=False) + self.conv3 = nn.Conv1d(5, 2, 4, 4) + self.fc1 = nn.Linear(16, 4) + + def forward(self, x): + x = x.reshape(x.size(0), x.size(1), -1) + x = tF.relu(self.conv1(x)) + x = tF.relu(self.conv2(x)) + x = tF.relu(self.conv3(x)) + x = x.view(x.size(0), -1) + x = self.fc1(x) + return x + + +def get_conv1d_task(normalization="none"): + train_set = get_mnist() + train_set = Subset(train_set, range(70)) + train_loader = DataLoader(dataset=train_set, batch_size=30, shuffle=False) + net = Conv1dNet(normalization=normalization) + to_device_model(net) + net.eval() + + def output_fn(input, target): + return net(to_device(input)) + + layer_collection = LayerCollection.from_model(net) + return (train_loader, layer_collection, net.parameters(), net, output_fn, 4) diff --git a/tests/test_jacobian.py b/tests/test_jacobian.py index e9f1348..c188364 100644 --- a/tests/test_jacobian.py +++ b/tests/test_jacobian.py @@ -1,21 +1,35 @@ import pytest import torch -from tasks import (get_batchnorm_conv_linear_task, - get_batchnorm_fc_linear_task, get_conv_gn_task, - get_conv_skip_task, get_conv_task, - get_fullyconnect_affine_task, get_fullyconnect_cosine_task, - get_fullyconnect_onlylast_task, get_fullyconnect_task, - get_fullyconnect_wn_task, get_linear_conv_task, - get_linear_fc_task, get_small_conv_transpose_task, - get_small_conv_wn_task) +from tasks import ( + get_batchnorm_conv_linear_task, + get_batchnorm_fc_linear_task, + get_conv_gn_task, + get_conv_skip_task, + get_conv_task, + get_fullyconnect_affine_task, + get_fullyconnect_cosine_task, + get_fullyconnect_onlylast_task, + get_fullyconnect_task, + get_fullyconnect_wn_task, + get_linear_conv_task, + get_linear_fc_task, + get_small_conv_transpose_task, + get_small_conv_wn_task, + get_conv1d_task, +) from utils import check_ratio, check_tensors from nngeometry.generator import Jacobian from nngeometry.object.fspace import FMatDense -from nngeometry.object.map import (PullBackDense, PushForwardDense, - PushForwardImplicit) -from nngeometry.object.pspace import (PMatBlockDiag, PMatDense, PMatDiag, - PMatImplicit, PMatLowRank, PMatQuasiDiag) +from nngeometry.object.map import PullBackDense, PushForwardDense, PushForwardImplicit +from nngeometry.object.pspace import ( + PMatBlockDiag, + PMatDense, + PMatDiag, + PMatImplicit, + PMatLowRank, + PMatQuasiDiag, +) from nngeometry.object.vector import PVector, random_fvector, random_pvector linear_tasks = [ @@ -27,6 +41,7 @@ ] nonlinear_tasks = [ + get_conv1d_task, get_small_conv_transpose_task, get_conv_task, get_fullyconnect_affine_task, diff --git a/tests/test_jacobian_ekfac.py b/tests/test_jacobian_ekfac.py index 32b65da..79be84d 100644 --- a/tests/test_jacobian_ekfac.py +++ b/tests/test_jacobian_ekfac.py @@ -1,6 +1,6 @@ import pytest import torch -from tasks import device, get_conv_task, get_fullyconnect_task +from tasks import device, get_conv_task, get_conv1d_task, get_fullyconnect_task from utils import check_ratio, check_tensors from nngeometry.generator import Jacobian @@ -20,7 +20,7 @@ def test_pspace_ekfac_vs_kfac(): sense of the Frobenius norm """ eps = 1e-4 - for get_task in [get_fullyconnect_task, get_conv_task]: + for get_task in [get_conv1d_task, get_fullyconnect_task, get_conv_task]: loader, lc, parameters, model, function, n_output = get_task() model.train() generator = Jacobian( @@ -50,7 +50,7 @@ def test_pspace_ekfac_vs_direct(): Check EKFAC basis operations against direct computation using get_dense_tensor """ - for get_task in [get_fullyconnect_task, get_conv_task]: + for get_task in [get_conv1d_task, get_fullyconnect_task, get_conv_task]: loader, lc, parameters, model, function, n_output = get_task() model.train() diff --git a/tests/test_jacobian_kfac.py b/tests/test_jacobian_kfac.py index a3f97e6..f27178c 100644 --- a/tests/test_jacobian_kfac.py +++ b/tests/test_jacobian_kfac.py @@ -3,8 +3,13 @@ import pytest import torch import torch.nn as nn -import torch.nn.functional as F -from tasks import get_conv_task, get_fullyconnect_task, get_mnist, to_device_model +from tasks import ( + get_conv_task, + get_fullyconnect_task, + get_mnist, + to_device_model, + get_conv1d_task, +) from torch.utils.data import DataLoader, Subset from utils import angle, check_ratio, check_tensors @@ -106,6 +111,32 @@ def output_fn(input, target): return (train_loader, layer_collection, net.parameters(), net, output_fn, 4) +class Conv1dNet(nn.Module): + def __init__(self): + super(Conv1dNet, self).__init__() + self.conv1 = nn.Conv1d(3, 4, 3, 1) + + def forward(self, x): + x = self.conv1(x) + return x.sum(axis=(2,)) + + +def get_conv1dnet_kfc_task(bs=5): + train_set = torch.utils.data.TensorDataset( + torch.ones(size=(10, 3, 5)), torch.randint(0, 4, size=(10, 4)) + ) + train_loader = DataLoader(dataset=train_set, batch_size=bs, shuffle=False) + net = Conv1dNet() + to_device_model(net) + net.eval() + + def output_fn(input, target): + return net(to_device(input)) + + layer_collection = LayerCollection.from_model(net) + return (train_loader, layer_collection, net.parameters(), net, output_fn, 4) + + @pytest.fixture(autouse=True) def make_test_deterministic(): torch.manual_seed(1234) @@ -118,7 +149,8 @@ def test_jacobian_kfac_vs_pblockdiag(): where they are the same """ for get_task, mult in zip( - [get_convnet_kfc_task, get_fullyconnect_kfac_task], [15.0, 1.0] + [get_conv1dnet_kfc_task, get_convnet_kfc_task, get_fullyconnect_kfac_task], + [3.0, 15.0, 1.0], ): loader, lc, parameters, model, function, n_output = get_task() @@ -134,7 +166,7 @@ def test_jacobian_kfac_vs_pblockdiag(): def test_jacobian_kfac(): - for get_task in [get_fullyconnect_task, get_conv_task]: + for get_task in [get_conv1d_task, get_fullyconnect_task, get_conv_task]: loader, lc, parameters, model, function, n_output = get_task() generator = Jacobian( diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 4bd375e..a67c6e9 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -1,9 +1,14 @@ import pytest import torch import torch.nn.functional as tF -from tasks import (device, get_conv_gn_task, get_conv_task, - get_fullyconnect_segm_task, get_fullyconnect_task, - to_device) +from tasks import ( + device, + get_conv_gn_task, + get_conv_task, + get_fullyconnect_segm_task, + get_fullyconnect_task, + to_device, +) from test_jacobian import get_output_vector, update_model from nngeometry.metrics import FIM, FIM_MonteCarlo diff --git a/tests/test_pickle.py b/tests/test_pickle.py index 5645083..0768d4e 100644 --- a/tests/test_pickle.py +++ b/tests/test_pickle.py @@ -4,8 +4,13 @@ from utils import check_tensors from nngeometry.generator import Jacobian -from nngeometry.object.pspace import (PMatBlockDiag, PMatDense, PMatDiag, - PMatLowRank, PMatQuasiDiag) +from nngeometry.object.pspace import ( + PMatBlockDiag, + PMatDense, + PMatDiag, + PMatLowRank, + PMatQuasiDiag, +) from nngeometry.object.vector import PVector diff --git a/tests/test_vector.py b/tests/test_vector.py index 3fef91e..1e34eba 100644 --- a/tests/test_vector.py +++ b/tests/test_vector.py @@ -5,8 +5,7 @@ from utils import check_ratio, check_tensors from nngeometry.layercollection import LayerCollection -from nngeometry.object.vector import (PVector, random_pvector, - random_pvector_dict) +from nngeometry.object.vector import PVector, random_pvector, random_pvector_dict class ConvNet(nn.Module):