Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revamped tuning #130

Draft
wants to merge 31 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9edac83
Initial version of `grid_search`
GardevoirX Nov 26, 2024
6cc617e
Remove error
GardevoirX Nov 26, 2024
5e99892
Allow a precomputed nl
GardevoirX Nov 26, 2024
3121c92
Renamed examples, and added a tuning playground
ceriottm Nov 23, 2024
d65e48c
Nelder mead (doesn't work because actual error is not a good target)
ceriottm Nov 24, 2024
a662759
Added a tuning class
ceriottm Nov 24, 2024
ff5c523
I'm not a morning person it seems
ceriottm Nov 24, 2024
6210fdd
Examples
ceriottm Nov 24, 2024
41579cd
Better plotting
ceriottm Nov 24, 2024
e1d568f
Fixes on `H` and `RMS_phi`
GardevoirX Nov 25, 2024
cf9c1bb
Some cleaning and test fix
GardevoirX Nov 25, 2024
4f67ee5
Further clean
GardevoirX Nov 26, 2024
078f36b
Replace `loss` in tuning with `ErrorBounds` and draft for `Tuner`
GardevoirX Nov 27, 2024
19b7b61
Supress output
GardevoirX Nov 27, 2024
b9be34b
Update `grid_search`
GardevoirX Nov 28, 2024
ef7e651
Return something when is cannot reach desired accuracy
GardevoirX Nov 28, 2024
e6af9ad
Supress output
GardevoirX Nov 28, 2024
ae38063
Repair some errors of the example
GardevoirX Nov 28, 2024
91f3909
Add a warning for the case that no parameter can meet the accuracy re…
GardevoirX Dec 5, 2024
82a32fb
Update warning
GardevoirX Dec 5, 2024
f554b70
Documentations and pytests update
GardevoirX Dec 18, 2024
2b3d081
Added a TIP4P example
ceriottm Dec 20, 2024
3eebd91
Started to change the API to use full charges rather than the sum of …
ceriottm Dec 20, 2024
c5e78d5
Move from `sum_squared_charges` to `charges`
GardevoirX Dec 28, 2024
5ed22a5
Refactor the tuning methods with a base class
GardevoirX Dec 28, 2024
5e2029d
Fix pytests and make linter happy
GardevoirX Dec 28, 2024
deb2ef8
Mini cleanups
ceriottm Dec 29, 2024
9150e72
Docs fix
GardevoirX Dec 29, 2024
757add8
Separate timings calculator
ceriottm Dec 29, 2024
3eaf7bb
Linting
ceriottm Dec 29, 2024
cabb742
Try fix github action failures
GardevoirX Dec 29, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions docs/src/references/utils/tuning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@ than the given accuracy. Because these methods are gradient-based, be sure to pa
attention to the ``learning_rate`` and ``max_steps`` parameter. A good choice of these
two parameters can enhance the optimization speed and performance.

.. autoclass:: torchpme.utils.tune_ewald
.. autoclass:: torchpme.utils.tuning.ewald.EwaldTuner
:members:

.. autoclass:: torchpme.utils.tune_pme
.. autoclass:: torchpme.utils.tuning.pme.PMETuner
:members:

.. autoclass:: torchpme.utils.tuning.p3m.P3MTuner
:members:
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@
from metatensor.torch.atomistic import NeighborListOptions, System

import torchpme
from torchpme.utils.tuning.pme import PMETuner

# %%
#
# Create the properties CsCl unit cell

symbols = ("Cs", "Cl")
types = torch.tensor([55, 17])
charges = torch.tensor([[1.0], [-1.0]], dtype=torch.float64)
positions = torch.tensor([(0, 0, 0), (0.5, 0.5, 0.5)], dtype=torch.float64)
cell = torch.eye(3, dtype=torch.float64)
pbc = torch.tensor([True, True, True])
Expand All @@ -55,9 +57,9 @@
# The ``sum_squared_charges`` is equal to ``2.0`` becaue each atom either has a charge
# of 1 or -1 in units of elementary charges.

smearing, pme_params, cutoff = torchpme.utils.tune_pme(
sum_squared_charges=2.0, cell=cell, positions=positions
)
smearing, pme_params, cutoff = PMETuner(
charges=charges, cell=cell, positions=positions, cutoff=4.4
).tune()

# %%
#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import vesin.torch

import torchpme
from torchpme.utils.tuning.pme import PMETuner

# %%
#
Expand Down Expand Up @@ -93,9 +94,9 @@

sum_squared_charges = float(torch.sum(charges**2))

smearing, pme_params, cutoff = torchpme.utils.tune_pme(
sum_squared_charges=sum_squared_charges, cell=cell, positions=positions
)
smearing, pme_params, cutoff = PMETuner(
charges=charges, cell=cell, positions=positions, cutoff=4.4
).tune()

# %%
#
Expand Down
File renamed without changes.
File renamed without changes.
86 changes: 84 additions & 2 deletions examples/5-autograd-demo.py → examples/05-autograd-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
exercise to the reader.
"""

# %%

from time import time

import ase
Expand Down Expand Up @@ -477,10 +479,11 @@ def forward(self, positions, cell, charges):
)

# %%
# We can also time the difference in execution
# We can also evaluate the difference in execution
# time between the Pytorch and scripted versions of the
# module (depending on the system, the relative efficiency
# of the two evaluations could go either way!)
# of the two evaluations could go either way, as this is
# a too small system to make a difference!)

duration = 0.0
for _i in range(20):
Expand Down Expand Up @@ -515,3 +518,82 @@ def forward(self, positions, cell, charges):
print(f"Evaluation time:\nPytorch: {time_python}ms\nJitted: {time_jit}ms")

# %%
# Other auto-differentiation ideas
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO opinion I wouldn't put this example here - even though I think it is good to have it. The tutorial is already 500 lines and with this super long. I rather vote for smaller examples tackling one specific tasks. Finding solutions is much easier if they are shorter. See also the beloved matplotlib examples.

# --------------------------------
#
# There are many other ways the auto-differentiation engine of
# ``torch`` can be used to facilitate the evaluation of atomistic
# models.

# %%
# 4-site water models
# ~~~~~~~~~~~~~~~~~~~
#
# Several water models (starting from the venerable TIP4P model of
# `Abascal and C. Vega, JCP (2005) <http://doi.org/10.1063/1.2121687>`_)
# use a center of negative charge that is displaced from the O position.
# This is easily implemented, yielding the forces on the O and H positions
# generated by the displaced charge.

structure = ase.Atoms(
positions=[
[0, 0, 0],
[0, 1, 0],
[1, -0.2, 0],
],
cell=[6, 6, 6],
symbols="OHH",
)

cell = torch.from_numpy(structure.cell.array).to(device=device, dtype=dtype)
positions = torch.from_numpy(structure.positions).to(device=device, dtype=dtype)

# %%
# The key step is to create a "fourth site" based on the O positions
# and use it in the ``interpolate`` step.

charges = torch.tensor(
[[-1.0], [0.5], [0.5]],
dtype=dtype,
device=device,
)

positions.requires_grad_(True)
charges.requires_grad_(True)
cell.requires_grad_(True)

positions_4site = torch.vstack(
[
((positions[1::3] + positions[2::3]) * 0.5 + positions[0::3] * 3) / 4,
positions[1::3],
positions[2::3],
]
)

ns = torch.tensor([5, 5, 5])
interpolator = torchpme.lib.MeshInterpolator(
cell=cell, ns_mesh=ns, interpolation_nodes=3, method="Lagrange"
)
interpolator.compute_weights(positions_4site)
mesh = interpolator.points_to_mesh(charges)

value = (mesh**2).sum()

# %%
# The gradients can be computed by just running `backward` on the
# end result. Gradients are computed on the H and O positions.

value.backward()

print(
f"""
Position gradients:
{positions.grad.T}

Cell gradients:
{cell.grad}

Charges gradients:
{charges.grad.T}
"""
)
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading
Loading