Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error with float64 tensors #70

Closed
Xuzzo opened this issue Oct 31, 2023 · 2 comments
Closed

Error with float64 tensors #70

Xuzzo opened this issue Oct 31, 2023 · 2 comments

Comments

@Xuzzo
Copy link

Xuzzo commented Oct 31, 2023

Hello and thanks for your work.
Ekfac seems to have issues with models that work on double precision. Here is a code to reproduce it:

from nngeometry.metrics import FIM
from nngeometry.object import PMatEKFAC
import torch as th

dtype = th.float64

class SimpleModel(th.nn.Module):

    def __init__(
        self,
        n_input: int,
        n_output: int,
    ):
        super().__init__()
        self.fc1 = th.nn.Linear(n_input, n_output, bias=True, dtype=dtype)

    def forward(self, x):
        return th.nn.Softmax(dim=-1)(self.fc1(x))
    


if __name__ == "__main__":
    model = SimpleModel(10, 3)
    dataset = th.utils.data.TensorDataset(th.randn(100, 10, dtype=dtype), th.randint(0, 3, (100,), dtype=th.long))
    loader = th.utils.data.DataLoader(dataset, batch_size=10)
    F_ekfac = FIM(model, loader, PMatEKFAC, 3, variant='classif_logits')
    F_ekfac.update_diag(loader)

I get "RuntimeError: expected scalar type Double but found Float"

@tfjgeorge
Copy link
Owner

Hi, thanks for pointing this out.

This PR: #71 should do it.

It still needs a little bit more testing before it gets merged to master but meanwhile you can use it.

@Xuzzo
Copy link
Author

Xuzzo commented Nov 3, 2023

Works now! thanks a lot

@Xuzzo Xuzzo closed this as completed Nov 3, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants