Skip to content

Commit

Permalink
Simplify lassi hci unittesting
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewRHermes committed Aug 29, 2024
1 parent f57b23d commit 35fa53f
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 106 deletions.
36 changes: 2 additions & 34 deletions tests/lassi/test_1frag.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from mrh.my_pyscf.lassi.lassi import roots_make_rdm12s
from mrh.my_pyscf.lassi.op_o1 import get_fdm1_maker
from mrh.my_pyscf.lassi import op_o0, op_o1
from mrh.tests.lassi.addons import case_contract_hlas_ci

op = (op_o0, op_o1)

Expand Down Expand Up @@ -88,40 +89,7 @@ def test_rdms (self):
def test_contract_hlas_ci (self):
e_roots, si, las = lsi.e_roots, lsi.si, lsi._las
h0, h1, h2 = lsi.ham_2q ()
nelec = lsi.get_nelec_frs ()
print ("huh?", nelec)
ci_fr = las.ci
ham = (si * (e_roots[None,:]-h0)) @ si.conj ().T
ndim = len (e_roots)

lroots = lsi.get_lroots ()
lroots_prod = np.prod (lroots, axis=0)
nj = np.cumsum (lroots_prod)
ni = nj - lroots_prod
for opt in range (2):
hket_fr_pabq = op[opt].contract_ham_ci (las, h1, h2, ci_fr, nelec, ci_fr, nelec)
for f, (ci_r, hket_r_pabq) in enumerate (zip (ci_fr, hket_fr_pabq)):
current_order = list (range (las.nfrags-1, -1, -1)) + [las.nfrags]
current_order.insert (0, current_order.pop (f))
for r, (ci, hket_pabq) in enumerate (zip (ci_r, hket_r_pabq)):
if ci.ndim < 3: ci = ci[None,:,:]
proper_shape = np.append (lroots[:,r], ndim)
current_shape = proper_shape[current_order]
to_proper_order = list (np.argsort (current_order))
hket_pq = lib.einsum ('rab,pabq->rpq', ci.conj (), hket_pabq)
hket_pq = hket_pq.reshape (current_shape)
hket_pq = hket_pq.transpose (*to_proper_order)
hket_pq = hket_pq.reshape ((lroots_prod[r], ndim))
hket_ref = ham[ni[r]:nj[r]]
for s, (k, l) in enumerate (zip (ni, nj)):
hket_pq_s = hket_pq[:,k:l]
hket_ref_s = hket_ref[:,k:l]
with self.subTest (opt=opt, frag=f, bra_space=r, ket_space=s):
print (opt, f, r, s)
print (hket_pq_s)
print (hket_ref_s)
self.assertAlmostEqual (lib.fp (hket_pq_s), lib.fp (hket_ref_s), 8)

case_contract_hlas_ci (self, las, h0, h1, h2, las.ci, lsi.get_nelec_frs ())

if __name__ == "__main__":
print("Full Tests for LASSI single-fragment edge case")
Expand Down
38 changes: 3 additions & 35 deletions tests/lassi/test_22.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,12 @@
from mrh.my_pyscf.mcscf.lasscf_o0 import LASSCF
from mrh.my_pyscf.lassi import LASSI, LASSIrq, LASSIrqCT
from mrh.my_pyscf.lassi.lassi import root_make_rdm12s, make_stdm12s
from mrh.my_pyscf.lassi.spaces import all_single_excitations, SingleLASRootspace
from mrh.my_pyscf.lassi.spaces import all_single_excitations
from mrh.my_pyscf.mcscf.lasci import get_space_info
from mrh.my_pyscf.lassi import op_o0, op_o1, lassis
from mrh.my_pyscf.lassi.op_o1 import get_fdm1_maker
from mrh.my_pyscf.lassi.sitools import make_sdm1
from mrh.tests.lassi.addons import case_contract_hlas_ci

def setUpModule ():
global mol, mf, lsi, las, mc, op
Expand Down Expand Up @@ -114,40 +115,7 @@ def test_lassirqct (self):
def test_contract_hlas_ci (self):
e_roots, si, las = lsi.e_roots, lsi.si, lsi._las
h0, h1, h2 = lsi.ham_2q ()
nelec = lsi.get_nelec_frs ()
ci_fr = las.ci
ham = (si * (e_roots[None,:]-h0)) @ si.conj ().T
ndim = len (e_roots)

spaces = [SingleLASRootspace (las, m, s, c, 0) for c,m,s,w in zip (*get_space_info (las))]

lroots = lsi.get_lroots ()
lroots_prod = np.prod (lroots, axis=0)
nj = np.cumsum (lroots_prod)
ni = nj - lroots_prod
for opt in range (2):
hket_fr_pabq = op[opt].contract_ham_ci (las, h1, h2, ci_fr, nelec, ci_fr, nelec)
for f, (ci_r, hket_r_pabq) in enumerate (zip (ci_fr, hket_fr_pabq)):
current_order = list (range (las.nfrags-1, -1, -1)) + [las.nfrags]
current_order.insert (0, current_order.pop (f))
for r, (ci, hket_pabq) in enumerate (zip (ci_r, hket_r_pabq)):
if ci.ndim < 3: ci = ci[None,:,:]
proper_shape = np.append (lroots[:,r], ndim)
current_shape = proper_shape[current_order]
to_proper_order = list (np.argsort (current_order))
hket_pq = lib.einsum ('rab,pabq->rpq', ci.conj (), hket_pabq)
hket_pq = hket_pq.reshape (current_shape)
hket_pq = hket_pq.transpose (*to_proper_order)
hket_pq = hket_pq.reshape ((lroots_prod[r], ndim))
hket_ref = ham[ni[r]:nj[r]]
for s, (k, l) in enumerate (zip (ni, nj)):
hket_pq_s = hket_pq[:,k:l]
hket_ref_s = hket_ref[:,k:l]
# TODO: opt>0 for things other than single excitation
#if opt>0 and not spaces[r].is_single_excitation_of (spaces[s]): continue
#elif opt==1: print (r,s, round (lib.fp (hket_pq_s)-lib.fp (hket_ref_s),3))
with self.subTest (opt=opt, frag=f, bra_space=r, ket_space=s):
self.assertAlmostEqual (lib.fp (hket_pq_s), lib.fp (hket_ref_s), 8)
case_contract_hlas_ci (self, las, h0, h1, h2, las.ci, lsi.get_nelec_frs ())

def test_lassis (self):
for opt in (0,1):
Expand Down
41 changes: 4 additions & 37 deletions tests/lassi/test_opt57_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from mrh.my_pyscf.lassi.citools import get_lroots, get_rootaddr_fragaddr
from mrh.my_pyscf.lassi import op_o0
from mrh.my_pyscf.lassi import op_o1
from mrh.my_pyscf.lassi.spaces import SingleLASRootspace
from mrh.tests.lassi.addons import case_contract_hlas_ci

op = (op_o0, op_o1)

Expand Down Expand Up @@ -183,42 +183,9 @@ def test_rdm12s (self):
self.assertAlmostEqual (lib.fp (d12_o0[r][i]),
lib.fp (d12_o1[r][i]), 9)

#def test_contract_hlas_ci (self):
# h0, h1, h2 = ham_2q (las, las.mo_coeff)
# nelec = nelec_frs
# ci_fr = las.ci

# spaces = [SingleLASRootspace (las, m, s, c, 0) for c,m,s,w in zip (*get_space_info (las))]

# lroots = get_lroots (ci_fr)
# lroots_prod = np.prod (lroots, axis=0)
# nj = np.cumsum (lroots_prod)
# ni = nj - lroots_prod
# ndim = nj[-1]
# for opt in range (2):
# ham = op[opt].ham (las, h1, h2, ci_fr, nelec)[0]
# hket_fr_pabq = op[opt].contract_ham_ci (las, h1, h2, ci_fr, nelec, ci_fr, nelec)
# for f, (ci_r, hket_r_pabq) in enumerate (zip (ci_fr, hket_fr_pabq)):
# current_order = list (range (las.nfrags-1, -1, -1)) + [las.nfrags]
# current_order.insert (0, current_order.pop (f))
# for r, (ci, hket_pabq) in enumerate (zip (ci_r, hket_r_pabq)):
# if ci.ndim < 3: ci = ci[None,:,:]
# proper_shape = np.append (lroots[:,r], ndim)
# current_shape = proper_shape[current_order]
# to_proper_order = list (np.argsort (current_order))
# hket_pq = lib.einsum ('rab,pabq->rpq', ci.conj (), hket_pabq)
# hket_pq = hket_pq.reshape (current_shape)
# hket_pq = hket_pq.transpose (*to_proper_order)
# hket_pq = hket_pq.reshape ((lroots_prod[r], ndim))
# hket_ref = ham[ni[r]:nj[r]]
# for s, (k, l) in enumerate (zip (ni, nj)):
# hket_pq_s = hket_pq[:,k:l]
# hket_ref_s = hket_ref[:,k:l]
# # TODO: opt>0 for things other than single excitation
# #if opt>0 and not spaces[r].is_single_excitation_of (spaces[s]): continue
# #elif opt==1: print (r,s, round (lib.fp (hket_pq_s)-lib.fp (hket_ref_s),3))
# with self.subTest (opt=opt, frag=f, bra_space=r, ket_space=s):
# self.assertAlmostEqual (lib.fp (hket_pq_s), lib.fp (hket_ref_s), 8)
def test_contract_hlas_ci (self):
h0, h1, h2 = ham_2q (las, las.mo_coeff)
case_contract_hlas_ci (self, las, h0, h1, h2, las.ci, nelec_frs)



Expand Down

0 comments on commit 35fa53f

Please sign in to comment.