Skip to content

Commit

Permalink
A few more clips. We will likely remove these later on, but leave the…
Browse files Browse the repository at this point in the history
…se for stability purposes
  • Loading branch information
jackbaker1001 committed Aug 25, 2023
1 parent b1163c1 commit 2fe5526
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 42 deletions.
67 changes: 25 additions & 42 deletions grad_dft/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from grad_dft.utils import PyTree, Array, Scalar, Optimizer
from grad_dft.functional import Functional

from grad_dft.molecule import Molecule, eig, make_rdm1, orbital_grad, general_eigh
from grad_dft.molecule import Molecule, make_rdm1, orbital_grad, general_eigh
from grad_dft.train import molecule_predictor
from grad_dft.utils import PyTree, Array, Scalar
from grad_dft.interface.pyscf import (
Expand All @@ -44,7 +44,7 @@
from grad_dft.utils.types import Hartree2kcalmol

def abs_clip(arr, threshold):
return jnp.where(jnp.abs(arr) > threshold, arr, 0)
return jnp.where(jnp.abs(arr) > threshold, arr, 0.0)


######## Test kernel ########
Expand Down Expand Up @@ -103,27 +103,26 @@ def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Sca
*args: Arguments to be passed to predict_molecule function
"""

# Needed to be able to update the chi tensor
mol = mol_from_Molecule(molecule)
_, mf = process_mol(
mol, compute_energy=False, grid_level=int(molecule.grid_level), training=False
)

nelectron = molecule.atom_index.sum() - molecule.charge

predicted_e, fock = predict_molecule(params, molecule, *args)
fock = abs_clip(fock, 1e-20)

# predicted_e, fock = predict_molecule(params, molecule, *args)
# fock = abs_clip(fock, 1e-20)
# fock = molecule.fock
old_e = 100000 # we should set the energy in a molecule object really
for cycle in range(max_cycles):
# Convergence criterion is energy difference (default 1) kcal/mol and norm of gradient of orbitals < g_conv
start_time = time.time()
old_e = predicted_e

# Diagonalize Fock matrix
overlap = abs_clip(molecule.s1e, 1e-20)
mo_energy, mo_coeff = general_eigh(fock, overlap)
molecule = molecule.replace(mo_coeff=mo_coeff)
molecule = molecule.replace(mo_energy=mo_energy)
# old_e = molecule.energy
if cycle == 0:
mo_energy = molecule.mo_energy
mo_coeff = molecule.mo_coeff
fock = molecule.fock
else:
# Diagonalize Fock matrix
overlap = abs_clip(molecule.s1e, 1e-20)
mo_energy, mo_coeff = general_eigh(fock, overlap)
molecule = molecule.replace(mo_coeff=mo_coeff)
molecule = molecule.replace(mo_energy=mo_energy)

# Update the molecular occupation
mo_occ = molecule.get_occ()
Expand All @@ -134,14 +133,13 @@ def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Sca

# Update the density matrix
if cycle == 0:
rdm1 = molecule.make_rdm1()
old_rdm1 = rdm1
old_rdm1 = molecule.make_rdm1()
else:
rdm1 = (1 - mixing_factor)*old_rdm1 + mixing_factor*molecule.make_rdm1()
rdm1 = (1 - mixing_factor)*old_rdm1 + mixing_factor*abs_clip(molecule.make_rdm1(), 1e-20)
rdm1 = abs_clip(rdm1, 1e-20)
molecule = molecule.replace(rdm1=rdm1)
old_rdm1 = rdm1

rdm1 = abs_clip(rdm1, 1e-20)
molecule = molecule.replace(rdm1=rdm1)

computed_charge = jnp.einsum(
"r,ra,rb,sab->", molecule.grid.weights, molecule.ao, molecule.ao, molecule.rdm1
Expand All @@ -150,25 +148,8 @@ def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Sca
nelectron, computed_charge, atol=1e-3
), "Total charge is not conserved"

# Update the chi matrix
if molecule.omegas:
chi_start_time = time.time()
chi = generate_chi_tensor(
molecule.rdm1,
molecule.ao,
molecule.grid.coords,
mf.mol,
omegas=molecule.omegas,
chunk_size=chunk_size,
*args,
)
molecule = molecule.replace(chi=chi)
if verbose > 2:
print(
f"Cycle {cycle} took {time.time() - chi_start_time:.1e} seconds to compute chi matrix"
)

exc_start_time = time.time()

predicted_e, fock = predict_molecule(params, molecule, *args)
fock = abs_clip(fock, 1e-20)

Expand All @@ -184,12 +165,13 @@ def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Sca

if verbose > 1:
print(
f"cycle: {cycle}, energy: {predicted_e:.7e}, energy difference: {abs(predicted_e - old_e):.4e}, norm_gradient_orbitals: {norm_gorb:.2e}, seconds: {time.time() - start_time:.1e}"
f"cycle: {cycle}, energy: {predicted_e:.7e}, energy difference: {abs(predicted_e - old_e):.4e}, seconds: {time.time() - start_time:.1e}"
)
if verbose > 2:
print(
f" relative energy difference: {abs((predicted_e - old_e)/predicted_e):.5e}"
)
old_e = predicted_e

if verbose > 1:
print(
Expand Down Expand Up @@ -719,6 +701,7 @@ def loop_body(cycle, state):
state = loop_body(0, state)
molecule, fock, predicted_e, _, _, _ = final_state


return predicted_e, fock, molecule.rdm1

return scf_jitted_iterator
Expand Down
1 change: 1 addition & 0 deletions grad_dft/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ def energy(self, params: PyTree, molecule: Molecule, *args, **kwargs):
energy = self.apply_and_integrate(params, molecule.grid, cinputs, densities, **kwargs)

if self.is_xc:
# energy += molecule.nonXC()
energy += stop_gradient(molecule.nonXC())

return energy
Expand Down
5 changes: 5 additions & 0 deletions grad_dft/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from grad_dft.external.eigh_impl import eigh2d

from jax import numpy as jnp
from jax import scipy as jsp
from jax.lax import Precision
from jax import vmap, grad
from jax.lax import fori_loop, cond
Expand Down Expand Up @@ -579,6 +580,7 @@ def general_eigh(A, B):
C = L_inv @ A @ L_inv.T
C = abs_clip(C, 1e-20)
eigenvalues, eigenvectors_transformed = jnp.linalg.eigh(C)
# eigenvalues, eigenvectors_transformed = jsp.linalg.eigh(C)
eigenvectors_original = L_inv.T @ eigenvectors_transformed
eigenvectors_original = abs_clip(eigenvectors_original, 1e-20)
eigenvalues = abs_clip(eigenvalues, 1e-20)
Expand Down Expand Up @@ -634,6 +636,7 @@ def nonXC(
@partial(jax.jit)
def symmetrize_rdm1(rdm1):
dm = rdm1.sum(axis=0)
dm = abs_clip(dm, 1e-20)
rdm1 = jnp.stack([dm, dm], axis=0) / 2.0
return rdm1

Expand All @@ -643,12 +646,14 @@ def two_body_energy(rdm1, rep_tensor, precision=Precision.HIGHEST):
v_coul = 2 * jnp.einsum(
"pqrt,srt->spq", rep_tensor, rdm1, precision=precision
) # The 2 is to compensate for the /2 in the dm definition
v_coul = abs_clip(v_coul, 1e-20)
coulomb2e_energy = jnp.einsum("sji,sij->", rdm1, v_coul, precision=precision) / 2.0
return coulomb2e_energy


@partial(jax.jit, static_argnames=["precision"])
def one_body_energy(rdm1, h1e, precision=Precision.HIGHEST):
h1e = abs_clip(h1e, 1e-20)
h1e_energy = jnp.einsum("sij,ji->", rdm1, h1e, precision=precision)
return h1e_energy

Expand Down

0 comments on commit 2fe5526

Please sign in to comment.