Skip to content

Commit

Permalink
Merge pull request #71 from tfjgeorge/float64
Browse files Browse the repository at this point in the history
now infers dtype from torch Modules when populating representations
  • Loading branch information
tfjgeorge authored Nov 2, 2023
2 parents 78ba46c + 556986d commit a250f27
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 30 deletions.
95 changes: 65 additions & 30 deletions nngeometry/generator/jacobian/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,25 +48,23 @@ def __init__(self, model, function=None, n_output=1,
self.l_to_m, self.m_to_l = \
self.layer_collection.get_layerid_module_maps(model)

def get_device(self):
return next(self.model.parameters()).device

def get_covariance_matrix(self, examples):
# add hooks
self.handles += self._add_hooks(self._hook_savex,
self._hook_compute_flat_grad,
self.l_to_m.values())

device = next(self.model.parameters()).device
device = self._check_same_device()
dtype = self._check_same_dtype()
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
n_parameters = self.layer_collection.numel()
bs = loader.batch_size
G = torch.zeros((n_parameters, n_parameters), device=device)
self.grads = torch.zeros((1, bs, n_parameters), device=device)
G = torch.zeros((n_parameters, n_parameters), device=device, dtype=dtype)
self.grads = torch.zeros((1, bs, n_parameters), device=device, dtype=dtype)
if self.centering:
grad_mean = torch.zeros((self.n_output, n_parameters),
device=device)
device=device, dtype=dtype)

self.start = 0
self.i_output = 0
Expand Down Expand Up @@ -105,11 +103,12 @@ def get_covariance_diag(self, examples):
self._hook_compute_diag,
self.l_to_m.values())

device = next(self.model.parameters()).device
device = self._check_same_device()
dtype = self._check_same_dtype()
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
n_parameters = self.layer_collection.numel()
self.diag_m = torch.zeros((n_parameters,), device=device)
self.diag_m = torch.zeros((n_parameters,), device=device, dtype=dtype)
self.start = 0
for d in loader:
inputs = d[0]
Expand Down Expand Up @@ -139,19 +138,20 @@ def get_covariance_quasidiag(self, examples):
self._hook_compute_quasidiag,
self.l_to_m.values())

device = next(self.model.parameters()).device
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
self._blocks = dict()
for layer_id, layer in self.layer_collection.layers.items():
device = self._infer_device(layer_id)
dtype = self._infer_dtype(layer_id)
s = layer.numel()
if layer.bias is None:
self._blocks[layer_id] = (torch.zeros((s, ), device=device),
self._blocks[layer_id] = (torch.zeros((s, ), device=device, dtype=dtype),
None)
else:
cross_s = layer.weight.size
self._blocks[layer_id] = (torch.zeros((s, ), device=device),
torch.zeros(cross_s, device=device))
self._blocks[layer_id] = (torch.zeros((s, ), device=device, dtype=dtype),
torch.zeros(cross_s, device=device, dtype=dtype))

for d in loader:
inputs = d[0]
Expand Down Expand Up @@ -186,13 +186,14 @@ def get_covariance_layer_blocks(self, examples):
self._hook_compute_layer_blocks,
self.l_to_m.values())

device = next(self.model.parameters()).device
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
self._blocks = dict()
for layer_id, layer in self.layer_collection.layers.items():
device = self._infer_device(layer_id)
dtype = self._infer_dtype(layer_id)
s = layer.numel()
self._blocks[layer_id] = torch.zeros((s, s), device=device)
self._blocks[layer_id] = torch.zeros((s, s), device=device, dtype=dtype)

for d in loader:
inputs = d[0]
Expand Down Expand Up @@ -220,11 +221,12 @@ def get_kfac_blocks(self, examples):
self._hook_compute_kfac_blocks,
self.l_to_m.values())

device = next(self.model.parameters()).device
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
self._blocks = dict()
for layer_id, layer in self.layer_collection.layers.items():
device = self._infer_device(layer_id)
dtype = self._infer_dtype(layer_id)
layer_class = layer.__class__.__name__
if layer_class == 'LinearLayer':
sG = layer.out_features
Expand All @@ -235,8 +237,8 @@ def get_kfac_blocks(self, examples):
layer.kernel_size[1]
if layer.bias is not None:
sA += 1
self._blocks[layer_id] = (torch.zeros((sA, sA), device=device),
torch.zeros((sG, sG), device=device))
self._blocks[layer_id] = (torch.zeros((sA, sA), device=device, dtype=dtype),
torch.zeros((sG, sG), device=device, dtype=dtype))

for d in loader:
inputs = d[0]
Expand Down Expand Up @@ -274,12 +276,13 @@ def get_jacobian(self, examples):
self._hook_compute_flat_grad,
self.l_to_m.values())

device = next(self.model.parameters()).device
device = self._check_same_device()
dtype = self._check_same_dtype()
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
n_parameters = self.layer_collection.numel()
self.grads = torch.zeros((self.n_output, n_examples, n_parameters),
device=device)
device=device, dtype=dtype)
self.start = 0
for d in loader:
inputs = d[0]
Expand Down Expand Up @@ -312,11 +315,12 @@ def get_gram_matrix(self, examples):
self.handles += self._add_hooks(self._hook_savex_io, self._hook_kxy,
self.l_to_m.values())

device = next(self.model.parameters()).device
device = self._check_same_device()
dtype = self._check_same_dtype()
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
self.G = torch.zeros((self.n_output, n_examples,
self.n_output, n_examples), device=device)
self.n_output, n_examples), device=device, dtype=dtype)
self.x_outer = dict()
self.x_inner = dict()
self.gy_outer = dict()
Expand Down Expand Up @@ -392,13 +396,14 @@ def get_kfe_diag(self, kfe, examples):
self._hook_compute_kfe_diag,
self.l_to_m.values())

device = next(self.model.parameters()).device
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
self._diags = dict()
self._kfe = kfe
for layer_id, layer in self.layer_collection.layers.items():
layer_class = layer.__class__.__name__
device = self._infer_device(layer_id)
dtype = self._infer_dtype(layer_id)
if layer_class == 'LinearLayer':
sG = layer.out_features
sA = layer.in_features
Expand All @@ -408,7 +413,7 @@ def get_kfe_diag(self, kfe, examples):
layer.kernel_size[1]
if layer.bias is not None:
sA += 1
self._diags[layer_id] = torch.zeros((sG * sA), device=device)
self._diags[layer_id] = torch.zeros((sG * sA), device=device, dtype=dtype)

for d in loader:
inputs = d[0]
Expand Down Expand Up @@ -453,7 +458,8 @@ def implicit_mv(self, v, examples):
parameters.append(mod.bias)
output[mod.bias] = torch.zeros_like(mod.bias)

device = next(self.model.parameters()).device
device = self._check_same_device()
dtype = self._check_same_dtype()
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)

Expand All @@ -467,7 +473,7 @@ def implicit_mv(self, v, examples):
f_output = self.function(*d).view(bs, self.n_output)
for i in range(self.n_output):
# TODO reuse instead of reallocating memory
self._Jv = torch.zeros((1, bs), device=device)
self._Jv = torch.zeros((1, bs), device=device, dtype=dtype)

self.compute_switch = True
torch.autograd.grad(f_output[:, i].sum(dim=0), [inputs],
Expand Down Expand Up @@ -510,7 +516,8 @@ def implicit_vTMv(self, v, examples):

self._v = v.get_dict_representation()

device = next(self.model.parameters()).device
device = self._check_same_device()
dtype = self._check_same_dtype()
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)

Expand All @@ -532,7 +539,7 @@ def implicit_vTMv(self, v, examples):
f_output = self.function(*d).view(bs, self.n_output).sum(dim=0)
for i in range(self.n_output):
# TODO reuse instead of reallocating memory
self._Jv = torch.zeros((1, bs), device=device)
self._Jv = torch.zeros((1, bs), device=device, dtype=dtype)

torch.autograd.grad(f_output[i], [inputs],
retain_graph=i < self.n_output - 1,
Expand Down Expand Up @@ -589,10 +596,11 @@ def implicit_Jv(self, v, examples):

self._v = v.get_dict_representation()

device = next(self.model.parameters()).device
device = self._check_same_device()
dtype = self._check_same_dtype()
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
self._Jv = torch.zeros((self.n_output, n_examples), device=device)
self._Jv = torch.zeros((self.n_output, n_examples), device=device, dtype=dtype)
self.start = 0
self.compute_switch = True
for d in loader:
Expand Down Expand Up @@ -750,3 +758,30 @@ def _get_dataloader(self, examples):
else:
return DataLoader(TensorDataset(*examples),
batch_size=len(examples[0]))

def _infer_dtype(self, layer_id):
return self.l_to_m[layer_id].weight.dtype

def _infer_device(self, layer_id):
return self.l_to_m[layer_id].weight.device

def _check_same_device(self):
device = None
for layer_id in self.layer_collection.layers.keys():
if device is None:
device = self._infer_device(layer_id)
elif device != self._infer_device(layer_id):
raise ValueError("All modules should reside on the same device")
return device

def _check_same_dtype(self):
dtype = None
for layer_id in self.layer_collection.layers.keys():
if dtype is None:
dtype = self._infer_dtype(layer_id)
elif dtype != self._infer_dtype(layer_id):
raise ValueError("All modules should have the same type")
return dtype

def get_device(self):
return self._check_same_device()
39 changes: 39 additions & 0 deletions tests/test_dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
import torch as th

from nngeometry.metrics import FIM
from nngeometry.object import PMatDense, PMatDiag, PMatEKFAC, PMatKFAC


class SimpleModel(th.nn.Module):
def __init__(self, dtype1, dtype2):
super().__init__()
self.fc1 = th.nn.Linear(10, 5, bias=True, dtype=dtype1)
self.fc2 = th.nn.Linear(5, 2, bias=True, dtype=dtype2)

def forward(self, x):
return th.nn.Softmax(dim=-1)(self.fc2(self.fc1(x)))


def test_same_dtype():
model = SimpleModel(dtype1=th.float32, dtype2=th.float64)
dataset = th.utils.data.TensorDataset(
th.randn(100, 10, dtype=th.float64), th.randint(0, 2, (100,))
)
loader = th.utils.data.DataLoader(dataset, batch_size=10)

for PMatType in [PMatDense, PMatDiag]:
with pytest.raises(ValueError):
FIM(model, loader, PMatType, 2, variant="classif_logits")


def test_dtypes():
for dtype in [th.float32, th.float64]:
model = SimpleModel(dtype1=dtype, dtype2=dtype)
dataset = th.utils.data.TensorDataset(
th.randn(100, 10, dtype=dtype), th.randint(0, 2, (100,))
)
loader = th.utils.data.DataLoader(dataset, batch_size=10)

for PMatType in [PMatDense, PMatDiag, PMatKFAC, PMatEKFAC]:
FIM(model, loader, PMatType, 2, variant="classif_logits")

0 comments on commit a250f27

Please sign in to comment.