diff --git a/grad_dft/molecule.py b/grad_dft/molecule.py index 04f5654..dfafa09 100644 --- a/grad_dft/molecule.py +++ b/grad_dft/molecule.py @@ -628,13 +628,24 @@ def nonXC( """ rdm1 = symmetrize_rdm1(rdm1) h1e_energy = one_body_energy(rdm1, h1e, precision) - coulomb2e_energy = two_body_energy(rdm1, rep_tensor, precision) + coulomb2e_energy = coulomb_energy(rdm1, rep_tensor, precision) return nuclear_repulsion + h1e_energy + coulomb2e_energy @partial(jax.jit) def symmetrize_rdm1(rdm1): + r"""A function that symmetrizes and clips the reduced density matrix. + + Parameters + ---------- + rdm1 : Array + The 1-Reduced Density Matrix. + + Returns + ------- + Array + """ dm = rdm1.sum(axis=0) dm = abs_clip(dm, 1e-20) rdm1 = jnp.stack([dm, dm], axis=0) / 2.0 @@ -642,10 +653,23 @@ def symmetrize_rdm1(rdm1): @partial(jax.jit, static_argnames=["precision"]) -def two_body_energy(rdm1, rep_tensor, precision=Precision.HIGHEST): - v_coul = 2 * jnp.einsum( - "pqrt,srt->spq", rep_tensor, rdm1, precision=precision - ) # The 2 is to compensate for the /2 in the dm definition +def coulomb_energy(rdm1, rep_tensor, precision=Precision.HIGHEST): + r"""A function that computes the Coulomb two-body energy of a DFT functional. + + Parameters + ---------- + rdm1 : Array + The 1-Reduced Density Matrix. + shape: (n_spin, n_orb, n_orb) + rep_tensor : Array + The repulsion tensor. + shape: (n_orb, n_orb, n_orb, n_orb) + + Returns + ------- + Scalar + """ + v_coul = coulomb_potential(rdm1, rep_tensor, precision) v_coul = abs_clip(v_coul, 1e-20) coulomb2e_energy = jnp.einsum("sji,sij->", rdm1, v_coul, precision=precision) / 2.0 return coulomb2e_energy @@ -653,6 +677,21 @@ def two_body_energy(rdm1, rep_tensor, precision=Precision.HIGHEST): @partial(jax.jit, static_argnames=["precision"]) def one_body_energy(rdm1, h1e, precision=Precision.HIGHEST): + r"""A function that computes the one-body energy of a DFT functional. + + Parameters + ---------- + rdm1 : Array + The 1-Reduced Density Matrix. + shape: (n_spin, n_orb, n_orb) + h1e : Array + The 1-electron Hamiltonian. + shape: (n_orb, n_orb) + + Returns + ------- + Scalar + """ h1e = abs_clip(h1e, 1e-20) h1e_energy = jnp.einsum("sij,ji->", rdm1, h1e, precision=precision) return h1e_energy @@ -680,6 +719,7 @@ def coulomb_potential(rdm1, rep_tensor, precision=Precision.HIGHEST): Scalar Coulomb potential matrix. """ + # The 2 is to compensate for the /2 in the dm definition return 2 * jnp.einsum("pqrt,srt->spq", rep_tensor, rdm1, precision=precision)