Skip to content

Commit

Permalink
Documentations and pytests update
Browse files Browse the repository at this point in the history
  • Loading branch information
GardevoirX authored and ceriottm committed Dec 20, 2024
1 parent 82a32fb commit f554b70
Show file tree
Hide file tree
Showing 9 changed files with 239 additions and 341 deletions.
53 changes: 4 additions & 49 deletions examples/10-tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

DTYPE = torch.float64

get_ipython().run_line_magic("matplotlib", "inline")
get_ipython().run_line_magic("matplotlib", "inline") # noqa

# %%

Expand Down Expand Up @@ -78,7 +78,9 @@
madelung = (-energy / num_formula_units).flatten().item()

# this is the estimated error
error_bounds = torchpme.utils.tuning.pme.PMEErrorBounds((charges**2).sum(), cell, positions)
error_bounds = torchpme.utils.tuning.pme.PMEErrorBounds(
(charges**2).sum(), cell, positions
)

estimated_error = error_bounds(max_cutoff, smearing, **pme_params).item()
print(f"""
Expand Down Expand Up @@ -235,50 +237,3 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes):
)

# %%


from scipy.optimize import minimize


def loss(x, target_accuracy):
cutoff, smearing, mesh_spacing = x
value, duration = timed_madelung(
cutoff=cutoff,
smearing=smearing,
mesh_spacing=mesh_spacing,
interpolation_nodes=4,
)
estimated_error = error_bounds(
cutoff=cutoff,
smearing=smearing,
mesh_spacing=mesh_spacing,
interpolation_nodes=4,
)
tgt_loss = max(
0, np.log(estimated_error / madelung_ref / target_accuracy)
) # relu on the accuracy
print(x, estimated_error.item(), np.abs(madelung - value), duration)
return tgt_loss * 10 + duration


initial_guess = [9, 0.3, 5]
result = minimize(
loss,
initial_guess,
args=(1e-8),
method="Nelder-Mead",
options={"disp": True, "maxiter": 200},
)


# %%

result
# %%
timed_madelung(cutoff=2.905, smearing=0.7578, mesh_spacing=5.524, interpolation_nodes=4)
# %%

madelung_ref
# %%
error_bounds(9, 0.5, 1, 4)
# %%
3 changes: 3 additions & 0 deletions src/torchpme/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from . import prefactors, tuning, splines # noqa
from .splines import CubicSpline, CubicSplineReciprocal
from .tuning.ewald import tune_ewald, EwaldErrorBounds

# from .tuning.grid_search import grid_search
from .tuning.p3m import tune_p3m, P3MErrorBounds
from .tuning.pme import tune_pme, PMEErrorBounds

__all__ = [
"tune_ewald",
"tune_pme",
"tune_p3m",
# "grid_search",
"EwaldErrorBounds",
"P3MErrorBounds",
"PMEErrorBounds",
Expand Down
80 changes: 33 additions & 47 deletions src/torchpme/utils/tuning/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,11 @@ def _estimate_smearing_cutoff(
cutoff: Optional[float],
accuracy: float,
prefac: float,
) -> tuple[torch.tensor, torch.tensor]:
dtype = cell.dtype
device = cell.device

) -> tuple[float, float]:
cell_dimensions = torch.linalg.norm(cell, dim=1)
min_dimension = float(torch.min(cell_dimensions))
half_cell = min_dimension / 2.0
if cutoff is None:
cutoff_init = min(5.0, half_cell)
else:
cutoff_init = cutoff
cutoff_init = min(5.0, half_cell) if cutoff is None else cutoff
ratio = math.sqrt(
-2
* math.log(
Expand All @@ -65,53 +59,16 @@ def _estimate_smearing_cutoff(
)
smearing_init = cutoff_init / ratio if smearing is None else smearing

"""if cutoff is None:
# solve V_SR(cutoff) == accuracy for cutoff
def loss(cutoff):
return (
torch.erfc(cutoff / math.sqrt(2) / smearing_init) / cutoff - accuracy
) ** 2
cutoff_init = torch.tensor(
half_cell, dtype=dtype, device=device, requires_grad=True
)
_optimize_parameters(
params=[cutoff_init],
loss=loss,
accuracy=accuracy,
max_steps=1000,
learning_rate=0.1,
)"""

smearing_init = torch.tensor(
float(smearing_init) if smearing is None else smearing,
dtype=dtype,
device=device,
requires_grad=False, # (smearing is None),
)

cutoff_init = torch.tensor(
float(cutoff_init) if cutoff is None else cutoff,
dtype=dtype,
device=device,
requires_grad=False, # (cutoff is None),
)

return smearing_init, cutoff_init
return float(smearing_init), float(cutoff_init)


def _validate_parameters(
sum_squared_charges: float,
charges: torch.Tensor,
cell: torch.Tensor,
positions: torch.Tensor,
exponent: int,
accuracy: float,
) -> None:
if sum_squared_charges <= 0:
raise ValueError(
f"sum of squared charges must be positive, got {sum_squared_charges}"
)

if exponent != 1:
raise NotImplementedError("Only exponent = 1 is supported")

Expand Down Expand Up @@ -150,5 +107,34 @@ def _validate_parameters(
"periodic calculation"
)

if charges.dtype != dtype:
raise ValueError(
f"each `charges` must have the same type {dtype} as `positions`, got at least "
"one tensor of type "
f"{charges.dtype}"
)

if charges.device != device:
raise ValueError(
f"each `charges` must be on the same device {device} as `positions`, got at "
"least one tensor with device "
f"{charges.device}"
)

if charges.dim() != 2:
raise ValueError(
"`charges` must be a 2-dimensional tensor, got "
f"tensor with {charges.dim()} dimension(s) and shape "
f"{list(charges.shape)}"
)

if list(charges.shape) != [len(positions), charges.shape[1]]:
raise ValueError(
"`charges` must be a tensor with shape [n_atoms, n_channels], with "
"`n_atoms` being the same as the variable `positions`. Got tensor with "
f"shape {list(charges.shape)} where positions contains "
f"{len(positions)} atoms"
)

if not isinstance(accuracy, float):
raise ValueError(f"'{accuracy}' is not a float.")
110 changes: 28 additions & 82 deletions src/torchpme/utils/tuning/ewald.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
_optimize_parameters,
_validate_parameters,
)
from .tuner import Tuner

TWO_PI = 2 * math.pi

Expand Down Expand Up @@ -154,6 +153,28 @@ def tune_ewald(


class EwaldErrorBounds(torch.nn.Module):
r"""
Error bounds for :class:`torchpme.calculators.ewald.EwaldCalculator`.
The error formulas are given `online
<https://www2.icp.uni-stuttgart.de/~icp/mediawiki/images/4/4d/Script_Longrange_Interactions.pdf>`_
(now not available, need to be updated later). Note the difference notation between
the parameters in the reference and ours:
.. math::
\alpha &= \left( \sqrt{2}\,\mathrm{smearing} \right)^{-1}
K &= \frac{2 \pi}{\mathrm{lr\_wavelength}}
r_c &= \mathrm{cutoff}
:param sum_squared_charges: accumulated squared charges, must be positive
:param cell: single tensor of shape (3, 3), describing the bounding
:param positions: single tensor of shape (``len(charges), 3``) containing the
Cartesian positions of all point charges in the system.
"""

def __init__(
self,
sum_squared_charges: torch.Tensor,
Expand Down Expand Up @@ -182,91 +203,16 @@ def err_rspace(self, smearing, cutoff):
)

def forward(self, smearing, lr_wavelength, cutoff):
r"""
Calculate the error bound of Ewald.
:param smearing: see :class:`torchpme.EwaldCalculator` for details
:param lr_wavelength: see :class:`torchpme.EwaldCalculator` for details
:param cutoff: see :class:`torchpme.EwaldCalculator` for details
"""
smearing = torch.as_tensor(smearing)
lr_wavelength = torch.as_tensor(lr_wavelength)
cutoff = torch.as_tensor(cutoff)
return torch.sqrt(
self.err_kspace(smearing, lr_wavelength) ** 2
+ self.err_rspace(smearing, cutoff) ** 2
)


class EwaldTuner(Tuner):
def __init__(self, max_steps: int = 50000, learning_rate: float = 0.1):
super().__init__()
self.max_steps = max_steps
self.learning_rate = learning_rate

def err_Fourier(smearing, k_cutoff):
return (
prefac**0.5
/ smearing
/ torch.sqrt(TWO_PI**2 * volume / (TWO_PI / k_cutoff) ** 0.5)
* torch.exp(-(TWO_PI**2) * smearing**2 / (TWO_PI / k_cutoff))
)

def err_real(smearing, cutoff):
return (
prefac
/ torch.sqrt(cutoff * volume)
* torch.exp(-(cutoff**2) / 2 / smearing**2)
)

def loss(smearing, k_cutoff, cutoff):
return torch.sqrt(
err_Fourier(smearing, k_cutoff) ** 2 + err_real(smearing, cutoff) ** 2
)

def forward(
self,
sum_squared_charges: float,
cell: torch.Tensor,
positions: torch.Tensor,
smearing: Optional[float] = None,
lr_wavelength: Optional[float] = None,
cutoff: Optional[float] = None,
exponent: int = 1,
accuracy: float = 1e-3,
):
_validate_parameters(sum_squared_charges, cell, positions, exponent, accuracy)

params = self._init_params(
cell=cell,
smearing=smearing,
lr_wavelength=lr_wavelength,
cutoff=cutoff,
accuracy=accuracy,
)

_optimize_parameters(
params=params,
loss=self._loss,
max_steps=self.max_steps,
accuracy=accuracy,
learning_rate=self.learning_rate,
)

return self._post_process(params)

def _init_params(self, cell, smearing, lr_wavelength, cutoff, accuracy):
smearing_opt, cutoff_opt = _estimate_smearing_cutoff(
cell=cell, smearing=smearing, cutoff=cutoff, accuracy=accuracy
)

# We choose a very small initial fourier wavelength, hardcoded for now
k_cutoff_opt = torch.tensor(
1e-3 if lr_wavelength is None else TWO_PI / lr_wavelength,
dtype=cell.dtype,
device=cell.device,
requires_grad=(lr_wavelength is None),
)

return [smearing_opt, k_cutoff_opt, cutoff_opt]

def _post_process(self, params):
smearing_opt, k_cutoff_opt, cutoff_opt = params
return (
float(smearing_opt),
{"lr_wavelength": TWO_PI / float(k_cutoff_opt)},
float(cutoff_opt),
)
Loading

0 comments on commit f554b70

Please sign in to comment.