diff --git a/examples/10-tuning.py b/examples/10-tuning.py index 02f8238e..a1d6ade2 100644 --- a/examples/10-tuning.py +++ b/examples/10-tuning.py @@ -22,7 +22,7 @@ DTYPE = torch.float64 -get_ipython().run_line_magic("matplotlib", "inline") +get_ipython().run_line_magic("matplotlib", "inline") # noqa # %% @@ -78,7 +78,9 @@ madelung = (-energy / num_formula_units).flatten().item() # this is the estimated error -error_bounds = torchpme.utils.tuning.pme.PMEErrorBounds((charges**2).sum(), cell, positions) +error_bounds = torchpme.utils.tuning.pme.PMEErrorBounds( + (charges**2).sum(), cell, positions +) estimated_error = error_bounds(max_cutoff, smearing, **pme_params).item() print(f""" @@ -235,50 +237,3 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes): ) # %% - - -from scipy.optimize import minimize - - -def loss(x, target_accuracy): - cutoff, smearing, mesh_spacing = x - value, duration = timed_madelung( - cutoff=cutoff, - smearing=smearing, - mesh_spacing=mesh_spacing, - interpolation_nodes=4, - ) - estimated_error = error_bounds( - cutoff=cutoff, - smearing=smearing, - mesh_spacing=mesh_spacing, - interpolation_nodes=4, - ) - tgt_loss = max( - 0, np.log(estimated_error / madelung_ref / target_accuracy) - ) # relu on the accuracy - print(x, estimated_error.item(), np.abs(madelung - value), duration) - return tgt_loss * 10 + duration - - -initial_guess = [9, 0.3, 5] -result = minimize( - loss, - initial_guess, - args=(1e-8), - method="Nelder-Mead", - options={"disp": True, "maxiter": 200}, -) - - -# %% - -result -# %% -timed_madelung(cutoff=2.905, smearing=0.7578, mesh_spacing=5.524, interpolation_nodes=4) -# %% - -madelung_ref -# %% -error_bounds(9, 0.5, 1, 4) -# %% \ No newline at end of file diff --git a/src/torchpme/utils/__init__.py b/src/torchpme/utils/__init__.py index 751f1f39..9695d780 100644 --- a/src/torchpme/utils/__init__.py +++ b/src/torchpme/utils/__init__.py @@ -1,6 +1,8 @@ from . import prefactors, tuning, splines # noqa from .splines import CubicSpline, CubicSplineReciprocal from .tuning.ewald import tune_ewald, EwaldErrorBounds + +# from .tuning.grid_search import grid_search from .tuning.p3m import tune_p3m, P3MErrorBounds from .tuning.pme import tune_pme, PMEErrorBounds @@ -8,6 +10,7 @@ "tune_ewald", "tune_pme", "tune_p3m", + # "grid_search", "EwaldErrorBounds", "P3MErrorBounds", "PMEErrorBounds", diff --git a/src/torchpme/utils/tuning/__init__.py b/src/torchpme/utils/tuning/__init__.py index b8e12896..8abaaeca 100644 --- a/src/torchpme/utils/tuning/__init__.py +++ b/src/torchpme/utils/tuning/__init__.py @@ -43,17 +43,11 @@ def _estimate_smearing_cutoff( cutoff: Optional[float], accuracy: float, prefac: float, -) -> tuple[torch.tensor, torch.tensor]: - dtype = cell.dtype - device = cell.device - +) -> tuple[float, float]: cell_dimensions = torch.linalg.norm(cell, dim=1) min_dimension = float(torch.min(cell_dimensions)) half_cell = min_dimension / 2.0 - if cutoff is None: - cutoff_init = min(5.0, half_cell) - else: - cutoff_init = cutoff + cutoff_init = min(5.0, half_cell) if cutoff is None else cutoff ratio = math.sqrt( -2 * math.log( @@ -65,53 +59,16 @@ def _estimate_smearing_cutoff( ) smearing_init = cutoff_init / ratio if smearing is None else smearing - """if cutoff is None: - # solve V_SR(cutoff) == accuracy for cutoff - def loss(cutoff): - return ( - torch.erfc(cutoff / math.sqrt(2) / smearing_init) / cutoff - accuracy - ) ** 2 - - cutoff_init = torch.tensor( - half_cell, dtype=dtype, device=device, requires_grad=True - ) - _optimize_parameters( - params=[cutoff_init], - loss=loss, - accuracy=accuracy, - max_steps=1000, - learning_rate=0.1, - )""" - - smearing_init = torch.tensor( - float(smearing_init) if smearing is None else smearing, - dtype=dtype, - device=device, - requires_grad=False, # (smearing is None), - ) - - cutoff_init = torch.tensor( - float(cutoff_init) if cutoff is None else cutoff, - dtype=dtype, - device=device, - requires_grad=False, # (cutoff is None), - ) - - return smearing_init, cutoff_init + return float(smearing_init), float(cutoff_init) def _validate_parameters( - sum_squared_charges: float, + charges: torch.Tensor, cell: torch.Tensor, positions: torch.Tensor, exponent: int, accuracy: float, ) -> None: - if sum_squared_charges <= 0: - raise ValueError( - f"sum of squared charges must be positive, got {sum_squared_charges}" - ) - if exponent != 1: raise NotImplementedError("Only exponent = 1 is supported") @@ -150,5 +107,34 @@ def _validate_parameters( "periodic calculation" ) + if charges.dtype != dtype: + raise ValueError( + f"each `charges` must have the same type {dtype} as `positions`, got at least " + "one tensor of type " + f"{charges.dtype}" + ) + + if charges.device != device: + raise ValueError( + f"each `charges` must be on the same device {device} as `positions`, got at " + "least one tensor with device " + f"{charges.device}" + ) + + if charges.dim() != 2: + raise ValueError( + "`charges` must be a 2-dimensional tensor, got " + f"tensor with {charges.dim()} dimension(s) and shape " + f"{list(charges.shape)}" + ) + + if list(charges.shape) != [len(positions), charges.shape[1]]: + raise ValueError( + "`charges` must be a tensor with shape [n_atoms, n_channels], with " + "`n_atoms` being the same as the variable `positions`. Got tensor with " + f"shape {list(charges.shape)} where positions contains " + f"{len(positions)} atoms" + ) + if not isinstance(accuracy, float): raise ValueError(f"'{accuracy}' is not a float.") diff --git a/src/torchpme/utils/tuning/ewald.py b/src/torchpme/utils/tuning/ewald.py index 41fd1fca..58f9639b 100644 --- a/src/torchpme/utils/tuning/ewald.py +++ b/src/torchpme/utils/tuning/ewald.py @@ -8,7 +8,6 @@ _optimize_parameters, _validate_parameters, ) -from .tuner import Tuner TWO_PI = 2 * math.pi @@ -154,6 +153,28 @@ def tune_ewald( class EwaldErrorBounds(torch.nn.Module): + r""" + Error bounds for :class:`torchpme.calculators.ewald.EwaldCalculator`. + + The error formulas are given `online + `_ + (now not available, need to be updated later). Note the difference notation between + the parameters in the reference and ours: + + .. math:: + + \alpha &= \left( \sqrt{2}\,\mathrm{smearing} \right)^{-1} + + K &= \frac{2 \pi}{\mathrm{lr\_wavelength}} + + r_c &= \mathrm{cutoff} + + :param sum_squared_charges: accumulated squared charges, must be positive + :param cell: single tensor of shape (3, 3), describing the bounding + :param positions: single tensor of shape (``len(charges), 3``) containing the + Cartesian positions of all point charges in the system. + """ + def __init__( self, sum_squared_charges: torch.Tensor, @@ -182,6 +203,12 @@ def err_rspace(self, smearing, cutoff): ) def forward(self, smearing, lr_wavelength, cutoff): + r""" + Calculate the error bound of Ewald. + :param smearing: see :class:`torchpme.EwaldCalculator` for details + :param lr_wavelength: see :class:`torchpme.EwaldCalculator` for details + :param cutoff: see :class:`torchpme.EwaldCalculator` for details + """ smearing = torch.as_tensor(smearing) lr_wavelength = torch.as_tensor(lr_wavelength) cutoff = torch.as_tensor(cutoff) @@ -189,84 +216,3 @@ def forward(self, smearing, lr_wavelength, cutoff): self.err_kspace(smearing, lr_wavelength) ** 2 + self.err_rspace(smearing, cutoff) ** 2 ) - - -class EwaldTuner(Tuner): - def __init__(self, max_steps: int = 50000, learning_rate: float = 0.1): - super().__init__() - self.max_steps = max_steps - self.learning_rate = learning_rate - - def err_Fourier(smearing, k_cutoff): - return ( - prefac**0.5 - / smearing - / torch.sqrt(TWO_PI**2 * volume / (TWO_PI / k_cutoff) ** 0.5) - * torch.exp(-(TWO_PI**2) * smearing**2 / (TWO_PI / k_cutoff)) - ) - - def err_real(smearing, cutoff): - return ( - prefac - / torch.sqrt(cutoff * volume) - * torch.exp(-(cutoff**2) / 2 / smearing**2) - ) - - def loss(smearing, k_cutoff, cutoff): - return torch.sqrt( - err_Fourier(smearing, k_cutoff) ** 2 + err_real(smearing, cutoff) ** 2 - ) - - def forward( - self, - sum_squared_charges: float, - cell: torch.Tensor, - positions: torch.Tensor, - smearing: Optional[float] = None, - lr_wavelength: Optional[float] = None, - cutoff: Optional[float] = None, - exponent: int = 1, - accuracy: float = 1e-3, - ): - _validate_parameters(sum_squared_charges, cell, positions, exponent, accuracy) - - params = self._init_params( - cell=cell, - smearing=smearing, - lr_wavelength=lr_wavelength, - cutoff=cutoff, - accuracy=accuracy, - ) - - _optimize_parameters( - params=params, - loss=self._loss, - max_steps=self.max_steps, - accuracy=accuracy, - learning_rate=self.learning_rate, - ) - - return self._post_process(params) - - def _init_params(self, cell, smearing, lr_wavelength, cutoff, accuracy): - smearing_opt, cutoff_opt = _estimate_smearing_cutoff( - cell=cell, smearing=smearing, cutoff=cutoff, accuracy=accuracy - ) - - # We choose a very small initial fourier wavelength, hardcoded for now - k_cutoff_opt = torch.tensor( - 1e-3 if lr_wavelength is None else TWO_PI / lr_wavelength, - dtype=cell.dtype, - device=cell.device, - requires_grad=(lr_wavelength is None), - ) - - return [smearing_opt, k_cutoff_opt, cutoff_opt] - - def _post_process(self, params): - smearing_opt, k_cutoff_opt, cutoff_opt = params - return ( - float(smearing_opt), - {"lr_wavelength": TWO_PI / float(k_cutoff_opt)}, - float(cutoff_opt), - ) diff --git a/src/torchpme/utils/tuning/grid_search.py b/src/torchpme/utils/tuning/grid_search.py index d39d267b..d68b994f 100644 --- a/src/torchpme/utils/tuning/grid_search.py +++ b/src/torchpme/utils/tuning/grid_search.py @@ -7,10 +7,16 @@ import torch import vesin.torch -from torchpme import EwaldCalculator, CoulombPotential, PMECalculator, P3MCalculator -from torchpme.lib.kvectors import get_ns_mesh + +from torchpme import ( + CoulombPotential, + EwaldCalculator, + P3MCalculator, + PMECalculator, +) from torchpme.utils import EwaldErrorBounds, P3MErrorBounds, PMEErrorBounds -from . import _estimate_smearing_cutoff + +from . import _estimate_smearing_cutoff, _validate_parameters def grid_search( @@ -18,13 +24,47 @@ def grid_search( charges: torch.Tensor, cell: torch.Tensor, positions: torch.Tensor, - cutoff: Optional[float] = None, + cutoff: float, + exponent: int = 1, accuracy: float = 1e-3, neighbor_indices: Optional[torch.Tensor] = None, neighbor_distances: Optional[torch.Tensor] = None, ): + r""" + Find the optimal parameters for calculators. + + The steps are: + 1. Find the ``smearing`` parameter for the :py:class:`CoulombPotential` that leads + to a real space error of half the desired accuracy. + 2. Grid search for the kspace parameters, i.e. the ``lr_wavelength`` for Ewald and + the ``mesh_spacing`` and ``interpolation_nodes`` for PME and P3M. + For each combination of parameters, calculate the error. If the error is smaller + than the desired accuracy, use this combination for test runs to get the calculation + time needed. Return the combination that leads to the shortest calculation time. If + the desired accuracy is never reached, return the combination that leads to the + smallest error and throw a warning. + + :param charges: torch.Tensor, atomic (pseudo-)charges + :param cell: torch.Tensor, periodic supercell for the system + :param positions: torch.Tensor, Cartesian coordinates of the particles within + the supercell. + :param cutoff: float, cutoff distance for the neighborlist + :param exponent :math:`p` in :math:`1/r^p` potentials + :param accuracy: Recomended values for a balance between the accuracy and speed is + :math:`10^{-3}`. For more accurate results, use :math:`10^{-6}`. + :param neighbor_indices: torch.Tensor with the ``i,j`` indices of neighbors for + which the potential should be computed in real space. + :param neighbor_distances: torch.Tensor with the pair distances of the neighbors + for which the potential should be computed in real space. + + :return: Tuple containing a float of the optimal smearing for the :py:class: + `CoulombPotential`, a dictionary with the parameters for the calculator of the + chosen method and a float of the optimal cutoff value for the neighborlist + computation. + """ dtype = charges.dtype device = charges.device + _validate_parameters(charges, cell, positions, exponent, accuracy) if method == "ewald": err_func = EwaldErrorBounds( @@ -58,7 +98,6 @@ def grid_search( ns = torch.arange(1, 15, dtype=torch.long, device=device) k_space_params = torch.min(cell_dimensions) / ns elif method in ["pme", "p3m"]: - # If you have larger memory, you can try (2, 9) ns_actual = torch.exp2(torch.arange(2, 8, dtype=dtype, device=device)) k_space_params = torch.min(cell_dimensions) / ((ns_actual - 1) / 2) else: @@ -85,15 +124,14 @@ def grid_search( ) for k_space_param in k_space_params: for interpolation_nodes in all_interpolation_nodes[::-1]: - # print(f"Searching for {interpolation_nodes = }, {mesh_spacing = }") if method == "ewald": params = { - "lr_wavelength": k_space_param, + "lr_wavelength": float(k_space_param), } else: params = { - "mesh_spacing": k_space_param, - "interpolation_nodes": interpolation_nodes, + "mesh_spacing": float(k_space_param), + "interpolation_nodes": int(interpolation_nodes), } err = err_func( @@ -102,10 +140,9 @@ def grid_search( **params, ) - # print(f"{smearing = }, {cutoff = }") - if err > accuracy: - # Not going to test the time + # Not going to test the time, record the parameters if the error is + # better. if err < err_opt: smearing_err_opt = smearing params_err_opt = params @@ -113,7 +150,7 @@ def grid_search( err_opt = err continue - calculator = CalculatorClass( # or PMECalculator + calculator = CalculatorClass( potential=CoulombPotential(smearing=smearing), **params, ) @@ -164,10 +201,12 @@ def grid_search( time_opt = execution_time if time_opt == torch.inf: + # Never found a parameter that reached the accuracy warn( - f"No parameters found within the desired accuracy of {accuracy}. Returning the best found. Accuracy: " - + str(err_opt) + f"No parameters found within the desired accuracy of {accuracy}." + f"Returning the best found. Accuracy: {str(err_opt)}", + stacklevel=1, ) - return float(smearing_err_opt), params_err_opt, float(cutoff_err_opt) + return smearing_err_opt, params_err_opt, cutoff_err_opt - return float(smearing_opt), params_opt, float(cutoff_opt) + return smearing_opt, params_opt, cutoff_opt diff --git a/src/torchpme/utils/tuning/p3m.py b/src/torchpme/utils/tuning/p3m.py index 400fabc9..a2380aa0 100644 --- a/src/torchpme/utils/tuning/p3m.py +++ b/src/torchpme/utils/tuning/p3m.py @@ -173,13 +173,13 @@ def tune_p3m( err_bounds = P3MErrorBounds(sum_squared_charges, cell, positions) params = [smearing_opt, ns_mesh_opt, cutoff_opt, interpolation_nodes] - # _optimize_parameters( - # params=params, - # loss=err_bounds, - # max_steps=max_steps, - # accuracy=accuracy, - # learning_rate=learning_rate, - # ) + _optimize_parameters( + params=params, + loss=err_bounds, + max_steps=max_steps, + accuracy=accuracy, + learning_rate=learning_rate, + ) return ( float(smearing_opt), @@ -192,6 +192,23 @@ def tune_p3m( class P3MErrorBounds(torch.nn.Module): + r""" + " + Error bounds for :class:`torchpme.calculators.pme.P3MCalculator`. + + For the error formulas are given `here `_. + Note the difference notation between the parameters in the reference and ours: + + .. math:: + + \alpha = \left(\sqrt{2}\,\mathrm{smearing} \right)^{-1} + + :param sum_squared_charges: accumulated squared charges, must be positive + :param cell: single tensor of shape (3, 3), describing the bounding + :param positions: single tensor of shape (``len(charges), 3``) containing the + Cartesian positions of all point charges in the system. + """ + def __init__( self, sum_squared_charges: float, cell: torch.Tensor, positions: torch.Tensor ): @@ -232,6 +249,18 @@ def err_rspace(self, smearing, cutoff): ) def forward(self, smearing, mesh_spacing, cutoff, interpolation_nodes): + r""" + Calculate the error bound of P3M. + + :param smearing: see :class:`torchpme.P3MCalculator` for details + :param mesh_spacing: see :class:`torchpme.P3MCalculator` for details + :param cutoff: see :class:`torchpme.P3MCalculator` for details + :param interpolation_nodes: The number ``n`` of nodes used in the interpolation + per coordinate axis. The total number of interpolation nodes in 3D will be + ``n^3``. In general, for ``n`` nodes, the interpolation will be performed by + piecewise polynomials of degree ``n`` (e.g. ``n = 3`` for cubic + interpolation). Only the values ``1, 2, 3, 4, 5`` are supported. + """ smearing = torch.as_tensor(smearing) mesh_spacing = torch.as_tensor(mesh_spacing) cutoff = torch.as_tensor(cutoff) diff --git a/src/torchpme/utils/tuning/pme.py b/src/torchpme/utils/tuning/pme.py index 061fd682..1fcaaa15 100644 --- a/src/torchpme/utils/tuning/pme.py +++ b/src/torchpme/utils/tuning/pme.py @@ -162,6 +162,21 @@ def tune_pme( class PMEErrorBounds(torch.nn.Module): + r""" + Error bounds for :class:`torchpme.PMECalculator`. + For the error formulas are given `elsewhere `_. + Note the difference notation between the parameters in the reference and ours: + + .. math:: + + \alpha = \left(\sqrt{2}\,\mathrm{smearing} \right)^{-1} + + :param sum_squared_charges: accumulated squared charges, must be positive + :param cell: single tensor of shape (3, 3), describing the bounding + :param positions: single tensor of shape (``len(charges), 3``) containing the + Cartesian positions of all point charges in the system. + """ + def __init__( self, sum_squared_charges: float, cell: torch.Tensor, positions: torch.Tensor ): @@ -204,6 +219,21 @@ def err_rspace(self, smearing, cutoff): ) def forward(self, cutoff, smearing, mesh_spacing, interpolation_nodes): + r""" + Calculate the error bound of PME. + + :param smearing: if its value is given, it will not be tuned, see + :class:`torchpme.PMECalculator` for details + :param mesh_spacing: if its value is given, it will not be tuned, see + :class:`torchpme.PMECalculator` for details + :param cutoff: if its value is given, it will not be tuned, see + :class:`torchpme.PMECalculator` for details + :param interpolation_nodes: The number ``n`` of nodes used in the interpolation + per coordinate axis. The total number of interpolation nodes in 3D will be + ``n^3``. In general, for ``n`` nodes, the interpolation will be performed by + piecewise polynomials of degree ``n - 1`` (e.g. ``n = 4`` for cubic + interpolation). Only the values ``3, 4, 5, 6, 7`` are supported. + """ smearing = torch.as_tensor(smearing) mesh_spacing = torch.as_tensor(mesh_spacing) cutoff = torch.as_tensor(cutoff) diff --git a/src/torchpme/utils/tuning/tuner.py b/src/torchpme/utils/tuning/tuner.py deleted file mode 100644 index c5c9782a..00000000 --- a/src/torchpme/utils/tuning/tuner.py +++ /dev/null @@ -1,63 +0,0 @@ -import math -from typing import Optional - -import torch - -from . import ( - _estimate_smearing_cutoff, - _optimize_parameters, - _validate_parameters, -) - -TWO_PI = 2 * math.pi - - -class Tuner(torch.nn.Module): - def __init__(self, max_steps: int = 50000, learning_rate: float = 0.1): - super().__init__() - self.max_steps = max_steps - self.learning_rate = learning_rate - - def forward( - self, - sum_squared_charges: float, - cell: torch.Tensor, - positions: torch.Tensor, - smearing: Optional[float] = None, - lr_wavelength: Optional[float] = None, - cutoff: Optional[float] = None, - exponent: int = 1, - accuracy: float = 1e-3, - ): - _validate_parameters(sum_squared_charges, cell, positions, exponent, accuracy) - - params = self._init_params( - cell=cell, - smearing=smearing, - lr_wavelength=lr_wavelength, - cutoff=cutoff, - accuracy=accuracy, - ) - - _optimize_parameters( - params=params, - loss=self.loss, - max_steps=self.max_steps, - accuracy=accuracy, - learning_rate=self.learning_rate, - ) - - return self._post_process(params) - - def _init_params(self, cell, smearing, lr_wavelength, cutoff, accuracy): - return _estimate_smearing_cutoff( - cell=cell, smearing=smearing, cutoff=cutoff, accuracy=accuracy - ) - - def _post_process(self, params): - smearing_opt, k_cutoff_opt, cutoff_opt = params - return ( - float(smearing_opt), - {"lr_wavelength": TWO_PI / float(k_cutoff_opt)}, - float(cutoff_opt), - ) \ No newline at end of file diff --git a/tests/utils/test_tuning.py b/tests/utils/test_tuning.py index 27235c4c..f638c7fa 100644 --- a/tests/utils/test_tuning.py +++ b/tests/utils/test_tuning.py @@ -1,5 +1,4 @@ import sys -import warnings from pathlib import Path import pytest @@ -11,28 +10,29 @@ P3MCalculator, PMECalculator, ) -from torchpme.utils import tune_ewald, tune_p3m, tune_pme +from torchpme.utils.tuning.grid_search import grid_search sys.path.append(str(Path(__file__).parents[1])) from helpers import define_crystal, neighbor_list DTYPE = torch.float32 DEVICE = "cpu" +DEFAULT_CUTOFF = 4.4 CHARGES_1 = torch.ones((4, 1), dtype=DTYPE, device=DEVICE) POSITIONS_1 = 0.3 * torch.arange(12, dtype=DTYPE, device=DEVICE).reshape((4, 3)) CELL_1 = torch.eye(3, dtype=DTYPE, device=DEVICE) @pytest.mark.parametrize( - ("calculator", "tune", "param_length"), + ("calculator", "method", "param_length"), [ - (EwaldCalculator, tune_ewald, 1), - (PMECalculator, tune_pme, 2), - (P3MCalculator, tune_p3m, 2), + (EwaldCalculator, "ewald", 1), + (PMECalculator, "pme", 2), + (P3MCalculator, "p3m", 2), ], ) @pytest.mark.parametrize("accuracy", [1e-1, 1e-3, 1e-5]) -def test_parameter_choose(calculator, tune, param_length, accuracy): +def test_parameter_choose(calculator, method, param_length, accuracy): """ Check that the Madelung constants obtained from the Ewald sum calculator matches the reference values and that all branches of the from_accuracy method are covered. @@ -40,12 +40,13 @@ def test_parameter_choose(calculator, tune, param_length, accuracy): # Get input parameters and adjust to account for scaling pos, charges, cell, madelung_ref, num_units = define_crystal() - smearing, params, sr_cutoff = tune( - sum_squared_charges=float(torch.sum(charges**2)), + smearing, params, sr_cutoff = grid_search( + method=method, + charges=charges, cell=cell, positions=pos, + cutoff=DEFAULT_CUTOFF, accuracy=accuracy, - learning_rate=0.75, ) assert len(params) == param_length @@ -73,7 +74,7 @@ def test_parameter_choose(calculator, tune, param_length, accuracy): torch.testing.assert_close(madelung, madelung_ref, atol=0, rtol=accuracy) -def test_odd_interpolation_nodes(): +"""def test_odd_interpolation_nodes(): pos, charges, cell, madelung_ref, num_units = define_crystal() smearing, params, sr_cutoff = tune_pme( @@ -99,10 +100,10 @@ def test_odd_interpolation_nodes(): energies = potentials * charges madelung = -torch.sum(energies) / num_units - torch.testing.assert_close(madelung, madelung_ref, atol=0, rtol=1e-3) + torch.testing.assert_close(madelung, madelung_ref, atol=0, rtol=1e-3)""" -@pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) +'''@pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) def test_fix_parameters(tune): """Test that the parameters are fixed when they are passed as arguments.""" pos, charges, cell, _, _ = define_crystal() @@ -140,118 +141,90 @@ def test_fix_parameters(tune): with warnings.catch_warnings(): warnings.simplefilter("ignore") _, _, sr_cutoff = tune(**kwargs) - pytest.approx(sr_cutoff, 1.0) - - -@pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) -def test_non_positive_charge_error(tune): - pos, _, cell, _, _ = define_crystal() - - match = "sum of squared charges must be positive, got -1.0" - with pytest.raises(ValueError, match=match): - tune(-1.0, cell, pos) - - match = "sum of squared charges must be positive, got 0.0" - with pytest.raises(ValueError, match=match): - tune(0.0, cell, pos) + pytest.approx(sr_cutoff, 1.0)''' -@pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) -def test_accuracy_error(tune): +def test_accuracy_error(): pos, charges, cell, _, _ = define_crystal() match = "'foo' is not a float." with pytest.raises(ValueError, match=match): - tune(float(torch.sum(charges**2)), cell, pos, accuracy="foo") + grid_search("ewald", charges, cell, pos, DEFAULT_CUTOFF, accuracy="foo") -@pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) -def test_loss_is_nan_error(tune): - pos, charges, cell, _, _ = define_crystal() - - match = ( - "The value of the estimated error is now nan, " - "consider using a smaller learning rate." - ) - with pytest.raises(ValueError, match=match): - tune(float(torch.sum(charges**2)), cell, pos, learning_rate=1e1000) - - -@pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) -def test_exponent_not_1_error(tune): +def test_exponent_not_1_error(): pos, charges, cell, _, _ = define_crystal() match = "Only exponent = 1 is supported" with pytest.raises(NotImplementedError, match=match): - tune(float(torch.sum(charges**2)), cell, pos, exponent=2) + grid_search("ewald", charges, cell, pos, DEFAULT_CUTOFF, exponent=2) -@pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) -def test_invalid_shape_positions(tune): +def test_invalid_shape_positions(): match = ( r"each `positions` must be a tensor with shape \[n_atoms, 3\], got at least " r"one tensor with shape \[4, 5\]" ) with pytest.raises(ValueError, match=match): - tune( - sum_squared_charges=1.0, - positions=torch.ones((4, 5), dtype=DTYPE, device=DEVICE), - cell=CELL_1, + grid_search( + "ewald", + CHARGES_1, + CELL_1, + torch.ones((4, 5), dtype=DTYPE, device=DEVICE), + DEFAULT_CUTOFF, ) # Tests for invalid shape, dtype and device of cell -@pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) -def test_invalid_shape_cell(tune): +def test_invalid_shape_cell(): match = ( r"each `cell` must be a tensor with shape \[3, 3\], got at least one tensor " r"with shape \[2, 2\]" ) with pytest.raises(ValueError, match=match): - tune( - sum_squared_charges=1.0, - positions=POSITIONS_1, - cell=torch.ones([2, 2], dtype=DTYPE, device=DEVICE), + grid_search( + "ewald", + CHARGES_1, + torch.ones([2, 2], dtype=DTYPE, device=DEVICE), + POSITIONS_1, + DEFAULT_CUTOFF, ) -@pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) -def test_invalid_cell(tune): +def test_invalid_cell(): match = ( "provided `cell` has a determinant of 0 and therefore is not valid for " "periodic calculation" ) with pytest.raises(ValueError, match=match): - tune( - sum_squared_charges=1.0, - positions=POSITIONS_1, - cell=torch.zeros(3, 3), - ) + grid_search("ewald", CHARGES_1, torch.zeros(3, 3), POSITIONS_1, DEFAULT_CUTOFF) -@pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) -def test_invalid_dtype_cell(tune): +def test_invalid_dtype_cell(): match = ( r"each `cell` must have the same type torch.float32 as `positions`, " r"got at least one tensor of type torch.float64" ) with pytest.raises(ValueError, match=match): - tune( - sum_squared_charges=1.0, - positions=POSITIONS_1, - cell=torch.eye(3, dtype=torch.float64, device=DEVICE), + grid_search( + "ewald", + CHARGES_1, + torch.eye(3, dtype=torch.float64, device=DEVICE), + POSITIONS_1, + DEFAULT_CUTOFF, ) -@pytest.mark.parametrize("tune", [tune_ewald, tune_pme, tune_p3m]) -def test_invalid_device_cell(tune): +def test_invalid_device_cell(): match = ( r"each `cell` must be on the same device cpu as `positions`, " r"got at least one tensor with device meta" ) with pytest.raises(ValueError, match=match): - tune( - sum_squared_charges=1.0, - positions=POSITIONS_1, - cell=torch.eye(3, dtype=DTYPE, device="meta"), + grid_search( + "ewald", + CHARGES_1, + torch.eye(3, dtype=DTYPE, device="meta"), + POSITIONS_1, + DEFAULT_CUTOFF, )