diff --git a/src/torchpme/calculators/ewald.py b/src/torchpme/calculators/ewald.py index e8bffb5c..5f627950 100644 --- a/src/torchpme/calculators/ewald.py +++ b/src/torchpme/calculators/ewald.py @@ -129,7 +129,6 @@ def _compute_kspace( ivolume = torch.abs(cell.det()).pow(-1) charge_tot = torch.sum(charges, dim=0) prefac = self.potential.background_correction() - energy -= 2 * prefac * charge_tot * ivolume - + energy -= 2 * prefac * charge_tot * ivolume if charge_tot != 0 else 0 # Compensate for double counting of pairs (i,j) and (j,i) return energy / 2 diff --git a/src/torchpme/potentials/integerspline.py b/src/torchpme/potentials/integerspline.py new file mode 100644 index 00000000..8f214c44 --- /dev/null +++ b/src/torchpme/potentials/integerspline.py @@ -0,0 +1,137 @@ +from typing import Optional + +import torch +from torch.special import gammaln, gammainc + +from .potential import Potential +from .spline import SplinePotential + + +def gamma(x: torch.Tensor) -> torch.Tensor: + """ + (Complete) Gamma function. + + pytorch has not implemented the commonly used (complete) Gamma function. We define + it in a custom way to make autograd work as in + https://discuss.pytorch.org/t/is-there-a-gamma-function-in-pytorch/17122 + """ + return torch.exp(gammaln(x)) + +class InversePowerLawPotentialSpline(Potential): + """ + Inverse power-law potentials of the form :math:`1/r^p`. + + Here :math:`r` is a distance parameter and :math:`p` an exponent. + + It can be used to compute: + + 1. the full :math:`1/r^p` potential + 2. its short-range (SR) and long-range (LR) parts, the split being determined by a + length-scale parameter (called "smearing" in the code) + 3. the Fourier transform of the LR part + + :param exponent: the exponent :math:`p` in :math:`1/r^p` potentials + :param smearing: float or torch.Tensor containing the parameter often called "sigma" + in publications, which determines the length-scale at which the short-range and + long-range parts of the naive :math:`1/r^p` potential are separated. For the + Coulomb potential (:math:`p=1`), this potential can be interpreted as the + effective potential generated by a Gaussian charge density, in which case this + smearing parameter corresponds to the "width" of the Gaussian. + :param: exclusion_radius: float or torch.Tensor containing the length scale + corresponding to a local environment. See also + :class:`Potential`. + :param dtype: type used for the internal buffers and parameters + :param device: device used for the internal buffers and parameters + """ + + def __init__( + self, + exponent: int, + r_grid: torch.Tensor, + smearing: Optional[float] = None, + exclusion_radius: Optional[float] = None, + dtype: Optional[torch.dtype] = None, + device: Optional[torch.device] = None, + ): + super().__init__(smearing, exclusion_radius, dtype, device) + if dtype is None: + dtype = torch.get_default_dtype() + if device is None: + device = torch.device("cpu") + self.r_grid = r_grid + self.register_buffer( + "exponent", torch.tensor(exponent, dtype=dtype, device=device) + ) + + @torch.jit.export + def from_dist(self, dist: torch.Tensor) -> torch.Tensor: + """ + Full :math:`1/r^p` potential as a function of :math:`r`. + + :param dist: torch.tensor containing the distances at which the potential is to + be evaluated. + """ + return torch.pow(dist, -self.exponent) + + @torch.jit.export + def lr_from_dist(self, dist: torch.Tensor) -> torch.Tensor: + """ + Long range of the range-separated :math:`1/r^p` potential. + + Used to subtract out the interior contributions after computing the LR part in + reciprocal (Fourier) space. + + For the Coulomb potential, this would return (note that the only change between + the SR and LR parts is the fact that erfc changes to erf) + + .. code-block:: python + + potential = erf(dist / sqrt(2) / smearing) / dist + + :param dist: torch.tensor containing the distances at which the potential is to + be evaluated. + """ + if self.smearing is None: + raise ValueError( + "Cannot compute long-range contribution without specifying `smearing`." + ) + + exponent = self.exponent + smearing = self.smearing + + x = 0.5 * dist**2 / smearing**2 + peff = exponent / 2 + prefac = 1.0 / (2 * smearing**2) ** peff + return prefac * gammainc(peff, x) / x ** peff + + @torch.jit.export + def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor: + r""" + TODO: Fourier transform of the LR part potential in terms of :math:`\mathbf{k^2}`. + """ + spline = SplinePotential( + self.r_grid, self.lr_from_dist(self.r_grid) + ) + return spline.lr_from_k_sq(k_sq) + + def self_contribution(self) -> torch.Tensor: + # self-correction for 1/r^p potential + if self.smearing is None: + raise ValueError( + "Cannot compute self contribution without specifying `smearing`." + ) + phalf = self.exponent / 2 + return 1 / gamma(phalf + 1) / (2 * self.smearing**2) ** phalf + + def background_correction(self) -> torch.Tensor: + # "charge neutrality" correction for 1/r^p potential + if self.smearing is None: + raise ValueError( + "Cannot compute background correction without specifying `smearing`." + ) + prefac = torch.pi**1.5 * (2 * self.smearing**2) ** ((3 - self.exponent) / 2) + prefac /= (3 - self.exponent) * gamma(self.exponent / 2) + return prefac + + self_contribution.__doc__ = Potential.self_contribution.__doc__ + background_correction.__doc__ = Potential.background_correction.__doc__ diff --git a/src/torchpme/potentials/inversepowerlaw.py b/src/torchpme/potentials/inversepowerlaw.py index bd44236e..c3b2ccb7 100644 --- a/src/torchpme/potentials/inversepowerlaw.py +++ b/src/torchpme/potentials/inversepowerlaw.py @@ -1,7 +1,8 @@ from typing import Optional import torch -from torch.special import gammainc, gammaincc, gammaln +from torch.special import gammaln, gammainc +from scipy.special import exp1 from .potential import Potential @@ -17,6 +18,30 @@ def gamma(x: torch.Tensor) -> torch.Tensor: return torch.exp(gammaln(x)) +# Auxilary function for stable Fourier transform implementation +def gammaincc_over_powerlaw(exponent: torch.Tensor, z: torch.Tensor) -> torch.Tensor: + if exponent not in [1, 2, 3, 4, 5, 6]: + raise ValueError(f"Unsupported exponent: {exponent}") + + if exponent == 1: + return torch.exp(-z) / z + if exponent == 2: + return torch.sqrt(torch.pi / z) * torch.erfc(torch.sqrt(z)) + if exponent == 3: + return exp1(z) + if exponent == 4: + return 2 * ( + torch.exp(-z) - torch.sqrt(torch.pi * z) * torch.erfc(torch.sqrt(z)) + ) + if exponent == 5: + return torch.exp(-z) - z * exp1(z) + if exponent == 6: + return ( + (2 - 4 * z) * torch.exp(-z) + + 4 * torch.sqrt(torch.pi) * z**1.5 * torch.erfc(torch.sqrt(z)) + ) / 3 + + class InversePowerLawPotential(Potential): """ Inverse power-law potentials of the form :math:`1/r^p`. @@ -46,7 +71,7 @@ class InversePowerLawPotential(Potential): def __init__( self, - exponent: float, + exponent: int, smearing: Optional[float] = None, exclusion_radius: Optional[float] = None, dtype: Optional[torch.dtype] = None, @@ -58,8 +83,8 @@ def __init__( if device is None: device = torch.device("cpu") - if exponent <= 0 or exponent > 3: - raise ValueError(f"`exponent` p={exponent} has to satisfy 0 < p <= 3") + # function call to check the validity of the exponent + gammaincc_over_powerlaw(exponent, torch.tensor(1.0, dtype=dtype, device=device)) self.register_buffer( "exponent", torch.tensor(exponent, dtype=dtype, device=device) ) @@ -103,7 +128,7 @@ def lr_from_dist(self, dist: torch.Tensor) -> torch.Tensor: x = 0.5 * dist**2 / smearing**2 peff = exponent / 2 prefac = 1.0 / (2 * smearing**2) ** peff - return prefac * gammainc(peff, x) / x**peff + return prefac * gammainc(peff, x) / x ** peff @torch.jit.export def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor: @@ -136,7 +161,7 @@ def lr_from_k_sq(self, k_sq: torch.Tensor) -> torch.Tensor: return torch.where( k_sq == 0, 0.0, - prefac * gammaincc(peff, masked) / masked**peff * gamma(peff), + prefac * gammaincc_over_powerlaw(exponent,masked) ) def self_contribution(self) -> torch.Tensor: diff --git a/tests/calculators/test_values_ewald.py b/tests/calculators/test_values_ewald.py index 208d937d..6b405011 100644 --- a/tests/calculators/test_values_ewald.py +++ b/tests/calculators/test_values_ewald.py @@ -100,7 +100,7 @@ def test_madelung(crystal_name, scaling_factor, calc_name): lr_wavelength = 0.5 * smearing calc = EwaldCalculator( InversePowerLawPotential( - exponent=1.0, + exponent=1, smearing=smearing, ), lr_wavelength=lr_wavelength, @@ -111,7 +111,7 @@ def test_madelung(crystal_name, scaling_factor, calc_name): smearing = sr_cutoff / 5.0 calc = PMECalculator( InversePowerLawPotential( - exponent=1.0, + exponent=1, smearing=smearing, ), mesh_spacing=smearing / 8, @@ -198,7 +198,7 @@ def test_wigner(crystal_name, scaling_factor): # Compute potential and compare against reference calc = EwaldCalculator( InversePowerLawPotential( - exponent=1.0, + exponent=1, smearing=smeareff, ), lr_wavelength=smeareff / 2, diff --git a/tests/helpers.py b/tests/helpers.py index a4d14a86..6322f0ee 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -257,7 +257,7 @@ def neighbor_list( nl = NeighborList(cutoff=cutoff, full_list=full_neighbor_list) neighbor_indices, d, S = nl.compute( - points=positions, box=box, periodic=periodic, quantities="PdS" + points=positions, box=box, periodic=periodic, quantities="pdS" ) neighbor_indices = torch.from_numpy(neighbor_indices.astype(int)).to(