You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@abmazitov observed that float64 is faster than float32 for alchemical expansion. I observed the same with #21 The difference was more severe depending when there was no warmup of the benchmarks (using just --quick). I think it is because the NeighborlistTransformer takes the type of the given numpy arrays which is usually float64, then in the radial basis the one hot encoding is using torch.get_default_ytype() which is usally float32 https://github.com/frostedoyster/torch_spex/blob/08cfe0d296a1296b1b05596a868639df9a9ba6d1/torch_spex/radial_basis.py#L44
I assume the type conversion when the two types meet causes the #21 issue.
Temporary fix:
Use torch.set_default_dtype(torch.float64) before running the code
Real fix:
I want to take the chance and integrate asv benchmarks so we can actually track the change in performance when fixing tihs. Because the fix should be rather trivial (e.g. using dtype of r). I started this PR #21
The text was updated successfully, but these errors were encountered:
agoscinski
changed the title
RadialBasis Float64 faster the Floa32
RadialBasis Float64 faster the Float32
Jul 22, 2023
@abmazitov observed that float64 is faster than float32 for alchemical expansion. I observed the same with #21 The difference was more severe depending when there was no warmup of the benchmarks (using just
--quick
). I think it is because the NeighborlistTransformer takes the type of the given numpy arrays which is usually float64, then in the radial basis the one hot encoding is usingtorch.get_default_ytype()
which is usally float32 https://github.com/frostedoyster/torch_spex/blob/08cfe0d296a1296b1b05596a868639df9a9ba6d1/torch_spex/radial_basis.py#L44I assume the type conversion when the two types meet causes the #21 issue.
Temporary fix:
Use
torch.set_default_dtype(torch.float64)
before running the codeReal fix:
I want to take the chance and integrate asv benchmarks so we can actually track the change in performance when fixing tihs. Because the fix should be rather trivial (e.g. using dtype of
r
). I started this PR #21The text was updated successfully, but these errors were encountered: