Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

now infers dtype from torch Modules when populating representations #71

Merged
merged 2 commits into from
Nov 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 65 additions & 30 deletions nngeometry/generator/jacobian/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,25 +48,23 @@
self.l_to_m, self.m_to_l = \
self.layer_collection.get_layerid_module_maps(model)

def get_device(self):
return next(self.model.parameters()).device

def get_covariance_matrix(self, examples):
# add hooks
self.handles += self._add_hooks(self._hook_savex,
self._hook_compute_flat_grad,
self.l_to_m.values())

device = next(self.model.parameters()).device
device = self._check_same_device()
dtype = self._check_same_dtype()
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
n_parameters = self.layer_collection.numel()
bs = loader.batch_size
G = torch.zeros((n_parameters, n_parameters), device=device)
self.grads = torch.zeros((1, bs, n_parameters), device=device)
G = torch.zeros((n_parameters, n_parameters), device=device, dtype=dtype)
self.grads = torch.zeros((1, bs, n_parameters), device=device, dtype=dtype)
if self.centering:
grad_mean = torch.zeros((self.n_output, n_parameters),
device=device)
device=device, dtype=dtype)

self.start = 0
self.i_output = 0
Expand Down Expand Up @@ -105,11 +103,12 @@
self._hook_compute_diag,
self.l_to_m.values())

device = next(self.model.parameters()).device
device = self._check_same_device()
dtype = self._check_same_dtype()
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
n_parameters = self.layer_collection.numel()
self.diag_m = torch.zeros((n_parameters,), device=device)
self.diag_m = torch.zeros((n_parameters,), device=device, dtype=dtype)
self.start = 0
for d in loader:
inputs = d[0]
Expand Down Expand Up @@ -139,19 +138,20 @@
self._hook_compute_quasidiag,
self.l_to_m.values())

device = next(self.model.parameters()).device
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
self._blocks = dict()
for layer_id, layer in self.layer_collection.layers.items():
device = self._infer_device(layer_id)
dtype = self._infer_dtype(layer_id)
s = layer.numel()
if layer.bias is None:
self._blocks[layer_id] = (torch.zeros((s, ), device=device),
self._blocks[layer_id] = (torch.zeros((s, ), device=device, dtype=dtype),

Check warning on line 149 in nngeometry/generator/jacobian/__init__.py

View check run for this annotation

Codecov / codecov/patch

nngeometry/generator/jacobian/__init__.py#L149

Added line #L149 was not covered by tests
None)
else:
cross_s = layer.weight.size
self._blocks[layer_id] = (torch.zeros((s, ), device=device),
torch.zeros(cross_s, device=device))
self._blocks[layer_id] = (torch.zeros((s, ), device=device, dtype=dtype),
torch.zeros(cross_s, device=device, dtype=dtype))

for d in loader:
inputs = d[0]
Expand Down Expand Up @@ -186,13 +186,14 @@
self._hook_compute_layer_blocks,
self.l_to_m.values())

device = next(self.model.parameters()).device
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
self._blocks = dict()
for layer_id, layer in self.layer_collection.layers.items():
device = self._infer_device(layer_id)
dtype = self._infer_dtype(layer_id)
s = layer.numel()
self._blocks[layer_id] = torch.zeros((s, s), device=device)
self._blocks[layer_id] = torch.zeros((s, s), device=device, dtype=dtype)

for d in loader:
inputs = d[0]
Expand Down Expand Up @@ -220,11 +221,12 @@
self._hook_compute_kfac_blocks,
self.l_to_m.values())

device = next(self.model.parameters()).device
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
self._blocks = dict()
for layer_id, layer in self.layer_collection.layers.items():
device = self._infer_device(layer_id)
dtype = self._infer_dtype(layer_id)
layer_class = layer.__class__.__name__
if layer_class == 'LinearLayer':
sG = layer.out_features
Expand All @@ -235,8 +237,8 @@
layer.kernel_size[1]
if layer.bias is not None:
sA += 1
self._blocks[layer_id] = (torch.zeros((sA, sA), device=device),
torch.zeros((sG, sG), device=device))
self._blocks[layer_id] = (torch.zeros((sA, sA), device=device, dtype=dtype),
torch.zeros((sG, sG), device=device, dtype=dtype))

for d in loader:
inputs = d[0]
Expand Down Expand Up @@ -274,12 +276,13 @@
self._hook_compute_flat_grad,
self.l_to_m.values())

device = next(self.model.parameters()).device
device = self._check_same_device()
dtype = self._check_same_dtype()
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
n_parameters = self.layer_collection.numel()
self.grads = torch.zeros((self.n_output, n_examples, n_parameters),
device=device)
device=device, dtype=dtype)
self.start = 0
for d in loader:
inputs = d[0]
Expand Down Expand Up @@ -312,11 +315,12 @@
self.handles += self._add_hooks(self._hook_savex_io, self._hook_kxy,
self.l_to_m.values())

device = next(self.model.parameters()).device
device = self._check_same_device()
dtype = self._check_same_dtype()
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
self.G = torch.zeros((self.n_output, n_examples,
self.n_output, n_examples), device=device)
self.n_output, n_examples), device=device, dtype=dtype)
self.x_outer = dict()
self.x_inner = dict()
self.gy_outer = dict()
Expand Down Expand Up @@ -392,13 +396,14 @@
self._hook_compute_kfe_diag,
self.l_to_m.values())

device = next(self.model.parameters()).device
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
self._diags = dict()
self._kfe = kfe
for layer_id, layer in self.layer_collection.layers.items():
layer_class = layer.__class__.__name__
device = self._infer_device(layer_id)
dtype = self._infer_dtype(layer_id)
if layer_class == 'LinearLayer':
sG = layer.out_features
sA = layer.in_features
Expand All @@ -408,7 +413,7 @@
layer.kernel_size[1]
if layer.bias is not None:
sA += 1
self._diags[layer_id] = torch.zeros((sG * sA), device=device)
self._diags[layer_id] = torch.zeros((sG * sA), device=device, dtype=dtype)

for d in loader:
inputs = d[0]
Expand Down Expand Up @@ -453,7 +458,8 @@
parameters.append(mod.bias)
output[mod.bias] = torch.zeros_like(mod.bias)

device = next(self.model.parameters()).device
device = self._check_same_device()
dtype = self._check_same_dtype()
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)

Expand All @@ -467,7 +473,7 @@
f_output = self.function(*d).view(bs, self.n_output)
for i in range(self.n_output):
# TODO reuse instead of reallocating memory
self._Jv = torch.zeros((1, bs), device=device)
self._Jv = torch.zeros((1, bs), device=device, dtype=dtype)

self.compute_switch = True
torch.autograd.grad(f_output[:, i].sum(dim=0), [inputs],
Expand Down Expand Up @@ -510,7 +516,8 @@

self._v = v.get_dict_representation()

device = next(self.model.parameters()).device
device = self._check_same_device()
dtype = self._check_same_dtype()
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)

Expand All @@ -532,7 +539,7 @@
f_output = self.function(*d).view(bs, self.n_output).sum(dim=0)
for i in range(self.n_output):
# TODO reuse instead of reallocating memory
self._Jv = torch.zeros((1, bs), device=device)
self._Jv = torch.zeros((1, bs), device=device, dtype=dtype)

torch.autograd.grad(f_output[i], [inputs],
retain_graph=i < self.n_output - 1,
Expand Down Expand Up @@ -589,10 +596,11 @@

self._v = v.get_dict_representation()

device = next(self.model.parameters()).device
device = self._check_same_device()
dtype = self._check_same_dtype()
loader = self._get_dataloader(examples)
n_examples = len(loader.sampler)
self._Jv = torch.zeros((self.n_output, n_examples), device=device)
self._Jv = torch.zeros((self.n_output, n_examples), device=device, dtype=dtype)
self.start = 0
self.compute_switch = True
for d in loader:
Expand Down Expand Up @@ -750,3 +758,30 @@
else:
return DataLoader(TensorDataset(*examples),
batch_size=len(examples[0]))

def _infer_dtype(self, layer_id):
return self.l_to_m[layer_id].weight.dtype

def _infer_device(self, layer_id):
return self.l_to_m[layer_id].weight.device

def _check_same_device(self):
device = None
for layer_id in self.layer_collection.layers.keys():
if device is None:
device = self._infer_device(layer_id)
elif device != self._infer_device(layer_id):
raise ValueError("All modules should reside on the same device")

Check warning on line 774 in nngeometry/generator/jacobian/__init__.py

View check run for this annotation

Codecov / codecov/patch

nngeometry/generator/jacobian/__init__.py#L774

Added line #L774 was not covered by tests
return device

def _check_same_dtype(self):
dtype = None
for layer_id in self.layer_collection.layers.keys():
if dtype is None:
dtype = self._infer_dtype(layer_id)
elif dtype != self._infer_dtype(layer_id):
raise ValueError("All modules should have the same type")
return dtype

def get_device(self):
return self._check_same_device()
39 changes: 39 additions & 0 deletions tests/test_dtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import pytest
import torch as th

from nngeometry.metrics import FIM
from nngeometry.object import PMatDense, PMatDiag, PMatEKFAC, PMatKFAC


class SimpleModel(th.nn.Module):
def __init__(self, dtype1, dtype2):
super().__init__()
self.fc1 = th.nn.Linear(10, 5, bias=True, dtype=dtype1)
self.fc2 = th.nn.Linear(5, 2, bias=True, dtype=dtype2)

def forward(self, x):
return th.nn.Softmax(dim=-1)(self.fc2(self.fc1(x)))


def test_same_dtype():
model = SimpleModel(dtype1=th.float32, dtype2=th.float64)
dataset = th.utils.data.TensorDataset(
th.randn(100, 10, dtype=th.float64), th.randint(0, 2, (100,))
)
loader = th.utils.data.DataLoader(dataset, batch_size=10)

for PMatType in [PMatDense, PMatDiag]:
with pytest.raises(ValueError):
FIM(model, loader, PMatType, 2, variant="classif_logits")


def test_dtypes():
for dtype in [th.float32, th.float64]:
model = SimpleModel(dtype1=dtype, dtype2=dtype)
dataset = th.utils.data.TensorDataset(
th.randn(100, 10, dtype=dtype), th.randint(0, 2, (100,))
)
loader = th.utils.data.DataLoader(dataset, batch_size=10)

for PMatType in [PMatDense, PMatDiag, PMatKFAC, PMatEKFAC]:
FIM(model, loader, PMatType, 2, variant="classif_logits")
Loading