Skip to content

Commit

Permalink
Merge pull request #47 from XanaduAI/46-bug-in-the-calculation-of-the…
Browse files Browse the repository at this point in the history
…-correlation-energy

Fixing VWN and PW92
  • Loading branch information
PabloAMC authored Sep 13, 2023
2 parents 071a974 + 126efa5 commit 57e9f5a
Show file tree
Hide file tree
Showing 8 changed files with 239 additions and 55 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/install_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
- name: Run integration tests
run: |
pytest -v tests/integration/test_non_xc_energy.py
pytest -v tests/integration/test_classical_functionals.py
pytest -v tests/integration/test_functional_implementations.py
pytest -v tests/integration/test_Harris.py
pytest -v tests/integration/test_predict_B88.py
pytest -v tests/integration/test_predict_B3LYP.py
Expand Down
133 changes: 133 additions & 0 deletions examples/test_constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright 2023 Xanadu Quantum Technologies Inc.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from flax.core import freeze
from grad_dft.popular_functionals import B3LYP, LYP

from grad_dft.interface.pyscf import molecule_from_pyscf

# This file aims to test some of the constraints implemented in constraints.py.

from jax import config

config.update("jax_enable_x64", True)

# First we define a molecule:
from pyscf import gto, dft

mol = gto.M(atom="H 0 0 0; F 0 0 1.1")

grids = dft.gen_grid.Grids(mol)
grids.level = 2
grids.build()

mf = dft.UKS(mol)
mf.grids = grids
mf.xc = "b3lyp"
ground_truth_energy = mf.kernel()

molecule = molecule_from_pyscf(mf, omegas=[0.0])

# H atom:
molH = gto.M(atom="H 0 0 0", spin=1, basis="cc-pvqz")
grids = dft.gen_grid.Grids(molH)
grids.level = 3
grids.build()
mf = dft.UKS(molH)
mf.grids = grids
mf.xc = "b3lyp"
ground_truth_energy = mf.kernel()
molecule1e = molecule_from_pyscf(mf, omegas=[0.0])

# Negatively charged H atom
molHp = gto.M(atom="H 0 0 0", charge=-1, spin=0, basis="cc-pvqz")
grids = dft.gen_grid.Grids(molHp)
grids.level = 3
grids.build()
mf = dft.UKS(molHp)
mf.grids = grids
mf.xc = "b3lyp"
ground_truth_energy = mf.kernel()
molecule2e = molecule_from_pyscf(mf, omegas=[0.0])

params = freeze({"params": {}})

from grad_dft.constraints import (
constraint_c6,
constraint_x4,
constraint_x6,
constraint_x7,
constraint_xc2,
constraint_xc4,
constraints_x1_c1,
constraint_x2,
constraint_c2,
constraints_fractional_charge_spin,
constraints_x3_c3_c4,
constraint_x5,
)

#### Constraint x1 ####
x1, c1 = constraints_x1_c1(B3LYP, params, molecule)
print(f"Quadratic loss of the functional B3LYP from constraint x1?", x1)
print(f"Quadratic loss of the functional B3LYP from constraint c1?", c1)

#### Constraint x2 ####
x2 = constraint_x2(B3LYP, params, molecule)
print(f"Quadratic loss of the functional B3LYP from constraint x2?", x2)

#### Constraint c2 ####
c2 = constraint_c2(LYP, params, molecule)
print(f"Quadratic loss of the functional LYP from constraint c2?", c2)

#### Constraint x3, c3, c4 ####
x3, (c3, c4) = constraints_x3_c3_c4(B3LYP, params, molecule, gamma=2.0)
print(f"Quadratic loss of the functional B3LYP from constraint x3?", x3)
print(f"Quadratic loss of the functional B3LYP from constraint c3?", c3)
print(f"Quadratic loss of the functional B3LYP from constraint c4?", c4)

#### Constraint x4 #### This requires masks for the appropriate functional
# x4s2, x4q2, x4qs2, x4s4 = constraint_x4(B3LYP, params, molecule, s2_mask, q2_mask, qs2_mask, s4_mask)
# print(f'Quadratic loss of the functional B3LYP from constraints x4?', x4s2, x4q2, x4qs2, x4s4)

#### Constraint x5 ####
x5inf, x50 = constraint_x5(B3LYP, params, molecule)
print(f"Quadratic loss of the functional B3LYP from constraint x5?", x5inf, x50)

#### Constraint x6 ####
x61, x62 = constraint_x6(B3LYP, params, molecule)
print(f"Quadratic loss of the functional B3LYP from constraint x6?", x61, x62)

#### Constraint x7 ####
x7 = constraint_x7(B3LYP, params, molecule2e)
print(f"Quadratic loss of the functional B3LYP from constraint x7?", x7)

#### Constraint c6 ####
c6 = constraint_c6(B3LYP, params, molecule)
print(f"Quadratic loss of the functional B3LYP from constraint c6?", c6)

#### Constraint xc2 ####
xc2 = constraint_xc2(B3LYP, params, molecule)
print(f"Quadratic loss of the functional B3LYP from constraint xc2?", xc2)

#### Constraint xc4 ####
xc4 = constraint_xc4(B3LYP, params, molecule2e)
print(f"Quadratic loss of the functional B3LYP from constraint xc4?", xc4)

#### Constraint fractional charge & spin ####
fcs = constraints_fractional_charge_spin(B3LYP, params, molecule1e, molecule2e, gamma=0.5, mol=molH)
print(
f"Quadratic loss of the functional B3LYP from the fractional charge & spin constrain (xc1)?",
fcs,
)
38 changes: 21 additions & 17 deletions grad_dft/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def compute_densities(self, molecule: Molecule, *args, **kwargs):

elif self.nograd_densities:
densities = stop_gradient(self.nograd_densities(molecule, *args, **kwargs))
densities = abs_clip(densities, 1e-20)
densities = abs_clip(densities, 1e-20) #todo: investigate if we can lower this
return densities

def compute_coefficient_inputs(self, molecule: Molecule, *args, **kwargs):
Expand Down Expand Up @@ -309,6 +309,7 @@ def _integrate(
Scalar
"""

#todo: study if we can lower this clipping constants
return jnp.einsum("r,r->", abs_clip(gridweights, 1e-20), abs_clip(energy_density, 1e-20), precision=precision)


Expand Down Expand Up @@ -684,7 +685,8 @@ def dm21_hfgrads_densities(
)
return vxc_hf.sum(axis=0) # Sum over omega


@jaxtyped
@typechecked
def dm21_hfgrads_cinputs(
functional: nn.Module,
params: PyTree,
Expand Down Expand Up @@ -955,24 +957,27 @@ def _canonicalize_fxc(fxc: Functional) -> Callable:
################ Spin polarization correction functions ################


def exchange_polarization_correction(e_PF, rho):
def exchange_polarization_correction(
e_PF: Float[Array, "spin grid"],
rho: Float[Array, "spin grid"]
) -> Float[Array, "grid"]:
r"""Spin polarization correction to an exchange functional using eq 2.71 from
Carsten A. Ullrich, "Time-Dependent Density-Functional Theory".
Parameters
----------
e_PF:
Array, shape (2, n_grid)
Float[Array, "spin grid"]
The paramagnetic/ferromagnetic energy contributions on the grid, to be combined.
rho:
Array, shape (2, n_grid)
Float[Array, "spin grid"]
The electronic density of each spin polarization at each grid point.
Returns
----------
e_tilde
Array, shape (n_grid)
Float[Array, "grid"]
The ready to be integrated electronic energy density.
"""
zeta = (rho[:, 0] - rho[:, 1]) / rho.sum(axis=1)
Expand All @@ -984,18 +989,20 @@ def fzeta(z):
return e_PF[:, 0] + (e_PF[:, 1] - e_PF[:, 0]) * fzeta(zeta)


def correlation_polarization_correction(e_PF: Array, rho: Array, clip_cte: float = 1e-27):
def correlation_polarization_correction(
e_tilde_PF: Float[Array, "spin grid"],
rho: Float[Array, "spin grid"],
clip_cte: float = 1e-27
) -> Float[Array, "grid"]:
r"""Spin polarization correction to a correlation functional using eq 2.75 from
Carsten A. Ullrich, "Time-Dependent Density-Functional Theory".
Parameters
----------
e_PF:
Array, shape (2, n_grid)
e_tilde_PF: Float[Array, "spin grid"]
The paramagnetic/ferromagnetic energy contributions on the grid, to be combined.
rho:
Array, shape (2, n_grid)
rho: Float[Array, "spin grid"]
The electronic density of each spin polarization at each grid point.
clip_cte:
Expand All @@ -1004,13 +1011,10 @@ def correlation_polarization_correction(e_PF: Array, rho: Array, clip_cte: float
Returns
----------
e_tilde
Array, shape (n_grid)
e_tilde: Float[Array, "grid"]
The ready to be integrated electronic energy density.
"""

e_tilde_PF = jnp.einsum("rs,r->rs", e_PF, rho.sum(axis=1))

log_rho = jnp.log2(jnp.clip(rho.sum(axis=1), a_min=clip_cte))
# assert not jnp.isnan(log_rho).any() and not jnp.isinf(log_rho).any()
log_rs = jnp.log2((3 / (4 * jnp.pi)) ** (1 / 3)) - log_rho / 3.0
Expand Down Expand Up @@ -1038,8 +1042,8 @@ def fzeta(z):
alphac = 2 * A_ * (1 + ars) * jnp.log(1 + (1 / (2 * A_)) / (brs_1_2 + brs + brs_3_2 + brs2))
# assert not jnp.isnan(alphac).any() and not jnp.isinf(alphac).any()

fz = jnp.round(fzeta(zeta), int(math.log10(clip_cte)))
z4 = jnp.round(2 ** (4 * jnp.log2(jnp.clip(zeta, a_min=clip_cte))), int(math.log10(clip_cte)))
fz = fzeta(zeta) #jnp.round(fzeta(zeta), int(math.log10(clip_cte)))
z4 = zeta**4 #jnp.round(2 ** (4 * jnp.log2(jnp.clip(zeta, a_min=clip_cte))), int(math.log10(clip_cte)))

e_tilde = (
e_tilde_PF[:, 0]
Expand Down
31 changes: 21 additions & 10 deletions grad_dft/popular_functionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def pw92_c_e(rho: Float[Array, "grid spin"], clip_cte: float = 1e-27) -> Float[A

e_tilde = correlation_polarization_correction(e_PF, rho, clip_cte)

return e_tilde
return e_tilde * rho.sum(axis = 1)

def vwn_c_e(rho: Float[Array, "grid spin"], clip_cte: float = 1e-27) -> Float[Array, "grid"]:
r"""
Expand Down Expand Up @@ -191,8 +191,8 @@ def vwn_c_e(rho: Float[Array, "grid spin"], clip_cte: float = 1e-27) -> Float[Ar

e_tilde = correlation_polarization_correction(e_PF, rho, clip_cte)

# We have to integrate e_tilde = e * n as per eq 2.1 in original LYP article
return e_tilde
# We have to integrate e = e_tilde * n as per eq 2.1 in original VWN article
return e_tilde * rho.sum(axis = 1)

def lyp_c_e(rho: Float[Array, "grid spin"], grad_rho: Float[Array, "grid spin 3"], grad2rho: Float[Array, "grid spin"], clip_cte=1e-27) -> Float[Array, "grid"]:
r"""
Expand All @@ -213,6 +213,17 @@ def lyp_c_e(rho: Float[Array, "grid spin"], grad_rho: Float[Array, "grid spin 3"
Returns
-------
Float[Array, "grid"]
Notes:
------
Libxc implementation:
https://github.com/ElectronicStructureLibrary/libxc/blob/master/maple/gga_exc/gga_c_lyp.mpl
Important: This implementation uses the original LYP functional definition
in C. Lee, W. Yang, and R. G. Parr., Phys. Rev. B 37, 785 (1988) (doi: 10.1103/PhysRevB.37.785)
instead of the one in libxc: B. Miehlich, A. Savin, H. Stoll, and H. Preuss., Chem. Phys. Lett. 157, 200 (1989) (doi: 10.1016/0009-2614(89)87234-3)
This sometimes gives rise to <1 kcal/mol differences in spin-polarized systems.
"""

a = 0.04918
Expand Down Expand Up @@ -242,19 +253,19 @@ def lyp_c_e(rho: Float[Array, "grid spin"], grad_rho: Float[Array, "grid spin 3"
rho_grad2rho = (rho * grad2rho).sum(axis=1)
# assert not jnp.isnan(rho_grad2rho).any() and not jnp.isinf(rho_grad2rho).any()

exp_factor = jnp.where(rho.sum(axis=1) > 0, jnp.exp(-c * rho.sum(axis=1) ** (-1 / 3)), 0)
# assert not jnp.isnan(exp_factor).any() and not jnp.isinf(exp_factor).any()

rhom1_3 = (rho.sum(axis=1)) ** (-1 / 3)
rho8_3 = (rho ** (8 / 3.0)).sum(axis=1)
rho8_3 = (rho ** (8 / 3)).sum(axis=1)
rhom5_3 = (rho.sum(axis=1)) ** (-5 / 3)

par = 2 ** (2 / 3) * CF * (rho8_3) - rhos_ts + rho_t / 9 + rho_grad2rho / 18
exp_factor = jnp.where(rho.sum(axis=1) > 0, jnp.exp(-c * rhom1_3), 0)
# assert not jnp.isnan(exp_factor).any() and not jnp.isinf(exp_factor).any()

parenthesis = 2 ** (2 / 3) * CF * (rho8_3) - rhos_ts + rho_t / 9 + rho_grad2rho / 18

sum_ = jnp.where(rho.sum(axis=1) > clip_cte, 2 * b * rhom5_3 * par * exp_factor, 0.0)
braket_m_rho = jnp.where(rho.sum(axis=1) > clip_cte, 2 * b * rhom5_3 * parenthesis * exp_factor, 0.0)

return -a * jnp.where(
rho.sum(axis=1) > clip_cte, gamma / (1 + d * rhom1_3) * (rho.sum(axis=1) + sum_), 0.0
rho.sum(axis=1) > clip_cte, gamma / (1 + d * rhom1_3) * (rho.sum(axis=1) + braket_m_rho), 0.0
)

def lsda_density(molecule: Molecule, clip_cte: float = 1e-27, *_, **__) -> Float[Array, "grid densities"]:
Expand Down
Loading

0 comments on commit 57e9f5a

Please sign in to comment.