diff --git a/grad_dft/evaluate.py b/grad_dft/evaluate.py index 7fd9376..21a2d51 100644 --- a/grad_dft/evaluate.py +++ b/grad_dft/evaluate.py @@ -32,7 +32,7 @@ from grad_dft.utils import PyTree, Array, Scalar, Optimizer from grad_dft.functional import Functional -from grad_dft.molecule import Molecule, eig, make_rdm1, orbital_grad, general_eigh +from grad_dft.molecule import Molecule, make_rdm1, orbital_grad, general_eigh from grad_dft.train import molecule_predictor from grad_dft.utils import PyTree, Array, Scalar from grad_dft.interface.pyscf import ( @@ -44,7 +44,7 @@ from grad_dft.utils.types import Hartree2kcalmol def abs_clip(arr, threshold): - return jnp.where(jnp.abs(arr) > threshold, arr, 0) + return jnp.where(jnp.abs(arr) > threshold, arr, 0.0) ######## Test kernel ######## @@ -103,27 +103,26 @@ def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Sca *args: Arguments to be passed to predict_molecule function """ - # Needed to be able to update the chi tensor - mol = mol_from_Molecule(molecule) - _, mf = process_mol( - mol, compute_energy=False, grid_level=int(molecule.grid_level), training=False - ) - nelectron = molecule.atom_index.sum() - molecule.charge - predicted_e, fock = predict_molecule(params, molecule, *args) - fock = abs_clip(fock, 1e-20) - + # predicted_e, fock = predict_molecule(params, molecule, *args) + # fock = abs_clip(fock, 1e-20) + # fock = molecule.fock + old_e = 100000 # we should set the energy in a molecule object really for cycle in range(max_cycles): # Convergence criterion is energy difference (default 1) kcal/mol and norm of gradient of orbitals < g_conv start_time = time.time() - old_e = predicted_e - - # Diagonalize Fock matrix - overlap = abs_clip(molecule.s1e, 1e-20) - mo_energy, mo_coeff = general_eigh(fock, overlap) - molecule = molecule.replace(mo_coeff=mo_coeff) - molecule = molecule.replace(mo_energy=mo_energy) + # old_e = molecule.energy + if cycle == 0: + mo_energy = molecule.mo_energy + mo_coeff = molecule.mo_coeff + fock = molecule.fock + else: + # Diagonalize Fock matrix + overlap = abs_clip(molecule.s1e, 1e-20) + mo_energy, mo_coeff = general_eigh(fock, overlap) + molecule = molecule.replace(mo_coeff=mo_coeff) + molecule = molecule.replace(mo_energy=mo_energy) # Update the molecular occupation mo_occ = molecule.get_occ() @@ -134,14 +133,13 @@ def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Sca # Update the density matrix if cycle == 0: - rdm1 = molecule.make_rdm1() - old_rdm1 = rdm1 + old_rdm1 = molecule.make_rdm1() else: - rdm1 = (1 - mixing_factor)*old_rdm1 + mixing_factor*molecule.make_rdm1() + rdm1 = (1 - mixing_factor)*old_rdm1 + mixing_factor*abs_clip(molecule.make_rdm1(), 1e-20) + rdm1 = abs_clip(rdm1, 1e-20) + molecule = molecule.replace(rdm1=rdm1) old_rdm1 = rdm1 - rdm1 = abs_clip(rdm1, 1e-20) - molecule = molecule.replace(rdm1=rdm1) computed_charge = jnp.einsum( "r,ra,rb,sab->", molecule.grid.weights, molecule.ao, molecule.ao, molecule.rdm1 @@ -150,25 +148,8 @@ def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Sca nelectron, computed_charge, atol=1e-3 ), "Total charge is not conserved" - # Update the chi matrix - if molecule.omegas: - chi_start_time = time.time() - chi = generate_chi_tensor( - molecule.rdm1, - molecule.ao, - molecule.grid.coords, - mf.mol, - omegas=molecule.omegas, - chunk_size=chunk_size, - *args, - ) - molecule = molecule.replace(chi=chi) - if verbose > 2: - print( - f"Cycle {cycle} took {time.time() - chi_start_time:.1e} seconds to compute chi matrix" - ) - exc_start_time = time.time() + predicted_e, fock = predict_molecule(params, molecule, *args) fock = abs_clip(fock, 1e-20) @@ -184,12 +165,13 @@ def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Sca if verbose > 1: print( - f"cycle: {cycle}, energy: {predicted_e:.7e}, energy difference: {abs(predicted_e - old_e):.4e}, norm_gradient_orbitals: {norm_gorb:.2e}, seconds: {time.time() - start_time:.1e}" + f"cycle: {cycle}, energy: {predicted_e:.7e}, energy difference: {abs(predicted_e - old_e):.4e}, seconds: {time.time() - start_time:.1e}" ) if verbose > 2: print( f" relative energy difference: {abs((predicted_e - old_e)/predicted_e):.5e}" ) + old_e = predicted_e if verbose > 1: print( @@ -719,6 +701,7 @@ def loop_body(cycle, state): state = loop_body(0, state) molecule, fock, predicted_e, _, _, _ = final_state + return predicted_e, fock, molecule.rdm1 return scf_jitted_iterator diff --git a/grad_dft/functional.py b/grad_dft/functional.py index ff011dd..7bea20d 100644 --- a/grad_dft/functional.py +++ b/grad_dft/functional.py @@ -275,6 +275,7 @@ def energy(self, params: PyTree, molecule: Molecule, *args, **kwargs): energy = self.apply_and_integrate(params, molecule.grid, cinputs, densities, **kwargs) if self.is_xc: + # energy += molecule.nonXC() energy += stop_gradient(molecule.nonXC()) return energy diff --git a/grad_dft/molecule.py b/grad_dft/molecule.py index 2276a47..04f5654 100644 --- a/grad_dft/molecule.py +++ b/grad_dft/molecule.py @@ -19,6 +19,7 @@ from grad_dft.external.eigh_impl import eigh2d from jax import numpy as jnp +from jax import scipy as jsp from jax.lax import Precision from jax import vmap, grad from jax.lax import fori_loop, cond @@ -579,6 +580,7 @@ def general_eigh(A, B): C = L_inv @ A @ L_inv.T C = abs_clip(C, 1e-20) eigenvalues, eigenvectors_transformed = jnp.linalg.eigh(C) + # eigenvalues, eigenvectors_transformed = jsp.linalg.eigh(C) eigenvectors_original = L_inv.T @ eigenvectors_transformed eigenvectors_original = abs_clip(eigenvectors_original, 1e-20) eigenvalues = abs_clip(eigenvalues, 1e-20) @@ -634,6 +636,7 @@ def nonXC( @partial(jax.jit) def symmetrize_rdm1(rdm1): dm = rdm1.sum(axis=0) + dm = abs_clip(dm, 1e-20) rdm1 = jnp.stack([dm, dm], axis=0) / 2.0 return rdm1 @@ -643,12 +646,14 @@ def two_body_energy(rdm1, rep_tensor, precision=Precision.HIGHEST): v_coul = 2 * jnp.einsum( "pqrt,srt->spq", rep_tensor, rdm1, precision=precision ) # The 2 is to compensate for the /2 in the dm definition + v_coul = abs_clip(v_coul, 1e-20) coulomb2e_energy = jnp.einsum("sji,sij->", rdm1, v_coul, precision=precision) / 2.0 return coulomb2e_energy @partial(jax.jit, static_argnames=["precision"]) def one_body_energy(rdm1, h1e, precision=Precision.HIGHEST): + h1e = abs_clip(h1e, 1e-20) h1e_energy = jnp.einsum("sij,ji->", rdm1, h1e, precision=precision) return h1e_energy