Skip to content

Commit

Permalink
Clip constant set to 1e-30, scf_train_loops returns (energy, molecule)
Browse files Browse the repository at this point in the history
  • Loading branch information
PabloAMC committed Sep 14, 2023
1 parent 57e9f5a commit 0a95b34
Show file tree
Hide file tree
Showing 15 changed files with 154 additions and 115 deletions.
8 changes: 3 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@

# Grad-DFT: a software library for machine learning density functional theory

[![arXiv](http://img.shields.io/badge/arXiv-2101.10279-B31B1B.svg "Grad-DFT")](https://arxiv.org/abs/2101.10279)
[![arXiv](http://img.shields.io/badge/arXiv-2101.10279-B31B1B.svg "Grad-DFT")](https://arxiv.org/abs/2101.10279) ![License](https://img.shields.io/badge/License-Apache%202.0-9F9F9F "https://github.com/XanaduAI/DiffDFT/blob/main/LICENSE")

</div>



Grad-DFT is a JAX-based library enabling the differentiable design and experimentation of exchange-correlation functionals using machine learning techniques. This library supports a parametrization of exchange-correlation functionals based on energy densities and associated coefficient functions; the latter typically constructed using neural networks:

```math
Expand Down Expand Up @@ -72,8 +70,8 @@ from flax import linen as nn
from grad_dft.functional import NeuralFunctional

def coefficient_inputs(molecule):
rho = jnp.clip(molecule.density(), a_min = 1e-27)
kinetic = jnp.clip(molecule.kinetic_density(), a_min = 1e-27)
rho = jnp.clip(molecule.density(), a_min = 1e-30)
kinetic = jnp.clip(molecule.kinetic_density(), a_min = 1e-30)
return jnp.concatenate((rho, kinetic))

def coefficients(self, rhoinputs):
Expand Down
9 changes: 4 additions & 5 deletions examples/advanced_examples/train_scf_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@
from grad_dft.evaluate import make_jitted_scf_loop

from jax.config import config

config.update("jax_disable_jit", True)

config.update("jax_enable_x64", True)
config.update('jax_debug_nans', True)

orbax_checkpointer = PyTreeCheckpointer()

Expand Down Expand Up @@ -123,13 +122,13 @@ def coefficients(instance, rhoinputs, *_, **__):

# Here we use one of the following. We will use the second here.
molecule_predict = molecule_predictor(functional)
scf_train_loop = make_jitted_scf_loop(functional, max_cycles=1)
scf_train_loop = make_jitted_scf_loop(functional, max_cycles=50)


@partial(value_and_grad, has_aux=True)
def loss(params, molecule, ground_truth_energy):
# predicted_energy, fock = molecule_predict(params, molecule)
predicted_energy, fock, rdm1 = scf_train_loop(params, molecule)
predicted_energy, molecule = scf_train_loop(params, molecule)
cost_value = (predicted_energy - ground_truth_energy) ** 2

# We may want to add a regularization term to the cost, be it one of the
Expand Down
2 changes: 1 addition & 1 deletion examples/basic_examples/example_lda_functional_02.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
# First a features method, which takes a molecule and returns an array of features
# It computes what in the article appears as potential e_\theta(r), as well as the
# input to the neural network to compute the density.
def energy_densities(molecule: Molecule, clip_cte: float = 1e-27):
def energy_densities(molecule: Molecule, clip_cte: float = 1e-30):
r"""Auxiliary function to generate the features of LSDA."""
# Molecule can compute the density matrix.
rho = molecule.density()
Expand Down
2 changes: 1 addition & 1 deletion examples/basic_examples/example_molecule_01.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def parallelized_density(rdm1: Array, ao: Array) -> Array:
density = HF_molecule.density()
# We need to avoid dividing by zero
x = jnp.where(
density > 1e-27,
density > 1e-30,
grad_density_norm / (2 * (3 * jnp.pi**2) ** (1 / 3) * density ** (4 / 3)),
0.0,
)
Expand Down
8 changes: 4 additions & 4 deletions examples/basic_examples/example_neural_functional_03.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,11 @@
# which compute the two vectors that get dot-multiplied and then integrated over space. If the functional is a
# neural functional we also need to define coefficient_inputs, for which in this case we will reuse the densities function.
def coefficient_inputs(molecule: Molecule, *_, **__):
rho = jnp.clip(molecule.density(), a_min = 1e-27)
kinetic = jnp.clip(molecule.kinetic_density(), a_min = 1e-27)
rho = jnp.clip(molecule.density(), a_min = 1e-30)
kinetic = jnp.clip(molecule.kinetic_density(), a_min = 1e-30)
return jnp.concatenate((rho, kinetic), axis = 1)

def energy_densities(molecule: Molecule, clip_cte: float = 1e-27, *_, **__):
def energy_densities(molecule: Molecule, clip_cte: float = 1e-30, *_, **__):
r"""Auxiliary function to generate the features of LSDA."""
# Molecule can compute the density matrix.
rho = molecule.density()
Expand Down Expand Up @@ -134,7 +134,7 @@ def coefficients(instance, rhoinputs):
# We can alternatively use the jit-ed version of the scf loop
HH_molecule = molecule_from_pyscf(mf)
scf_iterator = make_jitted_scf_loop(neuralfunctional, cycles=5)
jitted_energy, _, _ = scf_iterator(params, HH_molecule)
jitted_energy, HH_molecule = scf_iterator(params, HH_molecule)
print("Energy from the jitted scf loop:", jitted_energy)

# We can even use a direct optimizer of the orbitals
Expand Down
10 changes: 5 additions & 5 deletions examples/basic_examples/example_neural_scf_training_04.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@
# Now we create our neuralfunctional. We need to define at least the following methods: densities and coefficients
# which compute the two vectors that get dot-multiplied and then integrated over space. If the functional is a
# neural functional we also need to define coefficient_inputs, for which in this case we will reuse the densities function.
def coefficient_inputs(molecule: Molecule, *_, **__):
rho = jnp.clip(molecule.density(), a_min = 1e-20)
def coefficient_inputs(molecule: Molecule, clip_cte: float = 1e-30, *_, **__):
rho = jnp.clip(molecule.density(), a_min = clip_cte)
return jnp.concatenate((rho, ), axis = 1)

def energy_densities(molecule: Molecule, clip_cte: float = 1e-27, *_, **__):
def energy_densities(molecule: Molecule, clip_cte: float = 1e-30, *_, **__):
r"""Auxiliary function to generate the features of LSDA."""
# Molecule can compute the density matrix.
rho = molecule.density()
Expand Down Expand Up @@ -127,7 +127,7 @@ def loss(params, molecule_predict, molecule, trueenergy):
it will compute the gradients with respect to params.
"""

predictedenergy, _, _ = molecule_predict(params, molecule)
predictedenergy, molecule = molecule_predict(params, molecule)
cost_value = (predictedenergy - trueenergy) ** 2

return cost_value, predictedenergy
Expand All @@ -139,7 +139,7 @@ def loss(params, molecule_predict, molecule, trueenergy):

# and implement the optimization loop
n_epochs = 20
scf_iterator = make_jitted_scf_loop(neuralfunctional, cycles=3)
scf_iterator = make_jitted_scf_loop(neuralfunctional, cycles=30)

for iteration in tqdm(range(n_epochs), desc="Training epoch"):
(cost_value, predicted_energy), grads = loss(
Expand Down
40 changes: 21 additions & 19 deletions grad_dft/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ def constraint_x4(
precision=Precision.HIGHEST,
lower_bound=1e-15,
upper_bound=1e-5,
clip_cte = 1e-30,
):
r"""
.. math::
Expand All @@ -302,8 +303,8 @@ def constraint_x4(
grad_density = molecule.grad_density().sum(axis=-1)
lapl_density = molecule.lapl_density().sum(axis=-1)

s = jnp.where(jnp.greater_equal(density, 1e-20), grad_density / density ** (4 / 3), 0)
q = jnp.where(jnp.greater_equal(density, 1e-20), lapl_density / density ** (5 / 3), 0)
s = jnp.where(jnp.greater_equal(density, clip_cte), grad_density / density ** (4 / 3), 0)
q = jnp.where(jnp.greater_equal(density, clip_cte), lapl_density / density ** (5 / 3), 0)

s2_cond = jnp.logical_and(
jnp.logical_and(
Expand All @@ -328,7 +329,7 @@ def constraint_x4(
lda_cinputs = lda_functional.compute_coefficient_inputs(molecule)
lda_coefficients = lda_functional.apply(params, lda_cinputs)
lda_e = jnp.einsum("rf,rf->r", lda_coefficients, lda_densities)
lda_e = abs_clip(lda_e, 1e-20)
lda_e = abs_clip(lda_e, clip_cte)

cinputs = functional.compute_coefficient_inputs(molecule)
densities = functional.compute_densities(molecule)
Expand Down Expand Up @@ -378,6 +379,7 @@ def constraint_x5(
precision=Precision.HIGHEST,
multiplier1=1e5,
multiplier2=1e7,
clip_cte = 1e-30,
):
r"""
.. math::
Expand All @@ -390,7 +392,7 @@ def constraint_x5(
lda_cinputs = lda_functional.compute_coefficient_inputs(molecule)
lda_coefficients = lda_functional.apply(params, lda_cinputs)
lda_e = jnp.einsum("rf,rf->r", lda_coefficients, lda_densities)
lda_e = jnp.expand_dims(abs_clip(lda_e, 1e-20), axis=-1)
lda_e = jnp.expand_dims(abs_clip(lda_e, clip_cte), axis=-1)

@struct.dataclass
class modMolecule(Molecule):
Expand Down Expand Up @@ -437,7 +439,7 @@ def grad_density(self, *args, **kwargs) -> Array:
density1 = modmolecule.density()
grad_density1 = modmolecule.grad_density().sum(axis=-1)

s1 = jnp.where(jnp.greater_equal(density1, 1e-27), grad_density1 / density1 ** (4 / 3), 0)
s1 = jnp.where(jnp.greater_equal(density1, 1e-30), grad_density1 / density1 ** (4 / 3), 0)
a = jnp.isnan(s1).any()

cinputs1 = functional.compute_coefficient_inputs(modmolecule)
Expand All @@ -447,7 +449,7 @@ def grad_density(self, *args, **kwargs) -> Array:
ex1 = jnp.expand_dims(ex1, axis=-1)

fx1 = jnp.where(
jnp.greater_equal(jnp.abs(lda_e * jnp.sqrt(s1)), 1e-27), ex1 / (lda_e * jnp.sqrt(s1)), 0
jnp.greater_equal(jnp.abs(lda_e * jnp.sqrt(s1)), 1e-30), ex1 / (lda_e * jnp.sqrt(s1)), 0
)
a = jnp.isnan(fx1).any()

Expand All @@ -456,7 +458,7 @@ def grad_density(self, *args, **kwargs) -> Array:
density2 = modmolecule.density()
grad_density2 = modmolecule.grad_density().sum(axis=-1)

s2 = jnp.where(jnp.greater_equal(density2, 1e-27), grad_density2 / density2 ** (4 / 3), 0)
s2 = jnp.where(jnp.greater_equal(density2, 1e-30), grad_density2 / density2 ** (4 / 3), 0)
a = jnp.isnan(s2).any()

cinputs2 = functional.compute_coefficient_inputs(modmolecule)
Expand All @@ -466,7 +468,7 @@ def grad_density(self, *args, **kwargs) -> Array:
ex2 = jnp.expand_dims(ex2, axis=-1)

fx2 = jnp.where(
jnp.greater_equal(jnp.abs(lda_e * jnp.sqrt(s2)), 1e-27), ex2 / (lda_e * jnp.sqrt(s2)), 0
jnp.greater_equal(jnp.abs(lda_e * jnp.sqrt(s2)), 1e-30), ex2 / (lda_e * jnp.sqrt(s2)), 0
)
a = jnp.isnan(fx2).any()

Expand All @@ -475,7 +477,7 @@ def grad_density(self, *args, **kwargs) -> Array:
density_1 = modmolecule.density()
grad_density_1 = modmolecule.grad_density().sum(axis=-1)

s_1 = jnp.where(jnp.greater_equal(density_1, 1e-27), grad_density_1 / density_1 ** (4 / 3), 0)
s_1 = jnp.where(jnp.greater_equal(density_1, 1e-30), grad_density_1 / density_1 ** (4 / 3), 0)
a = jnp.isnan(s_1).any()

cinputs_1 = functional.compute_coefficient_inputs(modmolecule)
Expand All @@ -485,7 +487,7 @@ def grad_density(self, *args, **kwargs) -> Array:
ex_1 = jnp.expand_dims(ex_1, axis=-1)

fx_1 = jnp.where(
jnp.greater_equal(jnp.abs(lda_e * jnp.sqrt(s_1)), 1e-27),
jnp.greater_equal(jnp.abs(lda_e * jnp.sqrt(s_1)), 1e-30),
ex_1 / (lda_e * jnp.sqrt(s_1)),
0,
)
Expand All @@ -496,7 +498,7 @@ def grad_density(self, *args, **kwargs) -> Array:
density_2 = modmolecule.density()
grad_density_2 = modmolecule.grad_density().sum(axis=-1)

s_2 = jnp.where(jnp.greater_equal(density_2, 1e-27), grad_density_2 / density_2 ** (4 / 3), 0)
s_2 = jnp.where(jnp.greater_equal(density_2, 1e-30), grad_density_2 / density_2 ** (4 / 3), 0)
a = jnp.isnan(s2).any()

cinputs_2 = functional.compute_coefficient_inputs(modmolecule)
Expand All @@ -506,7 +508,7 @@ def grad_density(self, *args, **kwargs) -> Array:
ex_2 = jnp.expand_dims(ex_2, axis=-1)

fx_2 = jnp.where(
jnp.greater_equal(jnp.abs(lda_e * jnp.sqrt(s_2)), 1e-27),
jnp.greater_equal(jnp.abs(lda_e * jnp.sqrt(s_2)), 1e-30),
ex_2 / (lda_e * jnp.sqrt(s_2)),
0,
)
Expand All @@ -524,7 +526,7 @@ def grad_density(self, *args, **kwargs) -> Array:


def constraint_x6(
functional: Functional, params: PyTree, molecule: Molecule, precision=Precision.HIGHEST
functional: Functional, params: PyTree, molecule: Molecule, precision=Precision.HIGHEST, clip_cte = 1e-30
):
r"""
.. math::
Expand All @@ -537,7 +539,7 @@ def constraint_x6(
lda_cinputs = lda_functional.compute_coefficient_inputs(molecule)
lda_coefficients = lda_functional.apply(params, lda_cinputs)
lda_e = jnp.einsum("rf,rf->r", lda_coefficients, lda_densities)
lda_e = abs_clip(lda_e, 1e-20)
lda_e = abs_clip(lda_e, clip_cte)

# Symmetrize the reduced density matrix
rdm1 = molecule.rdm1
Expand All @@ -562,7 +564,7 @@ def constraint_x6(


def constraint_x7(
functional: Functional, params: PyTree, molecule2e: Molecule, precision=Precision.HIGHEST
functional: Functional, params: PyTree, molecule2e: Molecule, precision=Precision.HIGHEST, clip_cte = 1e-30
):
r"""
For a two electron system:
Expand All @@ -576,7 +578,7 @@ def kinetic_density(self: Molecule, *args, **kwargs) -> Array:
r"""Weizsacker kinetic energy"""
drho = self.grad_density(*args, **kwargs)
rho = self.density(*args, **kwargs)
return jnp.where(jnp.greater_equal(rho, 1e-27), drho**2 / (8 * rho), 0)
return jnp.where(jnp.greater_equal(rho, 1e-30), drho**2 / (8 * rho), 0)

modmolecule = modMolecule(
molecule2e.grid,
Expand Down Expand Up @@ -618,11 +620,11 @@ def kinetic_density(self: Molecule, *args, **kwargs) -> Array:
lda_cinputs = lda_functional.compute_coefficient_inputs(modmolecule)
lda_coefficients = lda_functional.apply(params, lda_cinputs)
lda_e = jnp.einsum("rf,rf->r", lda_coefficients, lda_densities)
lda_e = abs_clip(lda_e, 1e-20)
lda_e = abs_clip(lda_e, clip_cte)

# return jnp.where(jnp.greater_equal(lsda_e, 1e-27), jnp.less_equal(functional_e / lsda_e, 1.174), True).all()
# return jnp.where(jnp.greater_equal(lsda_e, 1e-30), jnp.less_equal(functional_e / lsda_e, 1.174), True).all()
return functional._integrate(
jnp.where(jnp.greater_equal(lda_e, 1e-27), (relu(functional_e / lda_e - 1.174) ** 2), 0),
jnp.where(jnp.greater_equal(lda_e, 1e-30), (relu(functional_e / lda_e - 1.174) ** 2), 0),
molecule2e.grid.weights,
)

Expand Down
Loading

0 comments on commit 0a95b34

Please sign in to comment.