Skip to content

Commit

Permalink
Merge pull request #26 from XanaduAI/25-substitute-the-eigh-implement…
Browse files Browse the repository at this point in the history
…ation-by-something-fully-differentiable

Differentiable general eigenvalue problem for hermitian matrices
  • Loading branch information
jackbaker1001 authored Aug 21, 2023
2 parents 81444ce + 0092f17 commit f11caf9
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
8 changes: 4 additions & 4 deletions grad_dft/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from grad_dft.utils import PyTree, Array, Scalar, Optimizer
from grad_dft.functional import Functional

from grad_dft.molecule import Molecule, eig, make_rdm1, orbital_grad
from grad_dft.molecule import Molecule, eig, make_rdm1, orbital_grad, general_eigh
from grad_dft.train import molecule_predictor
from grad_dft.utils import PyTree, Array, Scalar
from grad_dft.interface.pyscf import (
Expand Down Expand Up @@ -154,7 +154,7 @@ def scf_iterator(params: PyTree, molecule: Molecule, *args) -> Tuple[Scalar, Sca
)

# Diagonalize Fock matrix
mo_energy, mo_coeff = eig(fock, molecule.s1e)
mo_energy, mo_coeff = general_eigh(fock, molecule.s1e)
molecule = molecule.replace(mo_coeff=mo_coeff)
molecule = molecule.replace(mo_energy=mo_energy)

Expand Down Expand Up @@ -251,7 +251,7 @@ def nelec_cost_fn(m, mo_es, sigma, _nelectron):
if abs(predicted_e - old_e) * Hartree2kcalmol < e_conv and norm_gorb < g_conv:
# We perform an extra diagonalization to remove the level shift
# Solve eigenvalue problem
mo_energy, mo_coeff = eig(fock, molecule.s1e)
mo_energy, mo_coeff = general_eigh(fock, molecule.s1e)
molecule = molecule.replace(mo_coeff=mo_coeff)
molecule = molecule.replace(mo_energy=mo_energy)

Expand Down Expand Up @@ -541,7 +541,7 @@ def loop_body(cycle, state):
fock, diis_data = diis.run(new_data, diis_data, cycle)

# Diagonalize Fock matrix
mo_energy, mo_coeff = eig(fock, molecule.s1e)
mo_energy, mo_coeff = general_eigh(fock, molecule.s1e)
molecule = molecule.replace(mo_coeff=mo_coeff)
molecule = molecule.replace(mo_energy=mo_energy)

Expand Down
8 changes: 8 additions & 0 deletions grad_dft/molecule.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,14 @@ def eig(h, x):
e1, c1 = eigh2d(h[1], x)
return jnp.stack((e0, e1), axis=0), jnp.stack((c0, c1), axis=0)

def general_eigh(A, B):
L = jnp.linalg.cholesky(B)
L_inv = jnp.linalg.inv(L)
C = L_inv @ A @ L_inv.T
eigenvalues, eigenvectors_transformed = jnp.linalg.eigh(C)
eigenvectors_original = L_inv.T @ eigenvectors_transformed
return eigenvalues, eigenvectors_original


######################################################################

Expand Down

0 comments on commit f11caf9

Please sign in to comment.