We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hello and thanks for your work. Ekfac seems to have issues with models that work on double precision. Here is a code to reproduce it:
from nngeometry.metrics import FIM from nngeometry.object import PMatEKFAC import torch as th dtype = th.float64 class SimpleModel(th.nn.Module): def __init__( self, n_input: int, n_output: int, ): super().__init__() self.fc1 = th.nn.Linear(n_input, n_output, bias=True, dtype=dtype) def forward(self, x): return th.nn.Softmax(dim=-1)(self.fc1(x)) if __name__ == "__main__": model = SimpleModel(10, 3) dataset = th.utils.data.TensorDataset(th.randn(100, 10, dtype=dtype), th.randint(0, 3, (100,), dtype=th.long)) loader = th.utils.data.DataLoader(dataset, batch_size=10) F_ekfac = FIM(model, loader, PMatEKFAC, 3, variant='classif_logits') F_ekfac.update_diag(loader)
I get "RuntimeError: expected scalar type Double but found Float"
The text was updated successfully, but these errors were encountered:
Hi, thanks for pointing this out.
This PR: #71 should do it.
It still needs a little bit more testing before it gets merged to master but meanwhile you can use it.
Sorry, something went wrong.
Works now! thanks a lot
No branches or pull requests
Hello and thanks for your work.
Ekfac seems to have issues with models that work on double precision. Here is a code to reproduce it:
I get "RuntimeError: expected scalar type Double but found Float"
The text was updated successfully, but these errors were encountered: