From 35fa53f1bdcb27c3767885f4987dc593a54dfc72 Mon Sep 17 00:00:00 2001 From: Matthew R Hermes Date: Thu, 29 Aug 2024 12:47:42 -0500 Subject: [PATCH] Simplify lassi hci unittesting --- tests/lassi/test_1frag.py | 36 ++--------------------------- tests/lassi/test_22.py | 38 +++---------------------------- tests/lassi/test_opt57_slow.py | 41 ++++------------------------------ 3 files changed, 9 insertions(+), 106 deletions(-) diff --git a/tests/lassi/test_1frag.py b/tests/lassi/test_1frag.py index 0dedc31f..7776b395 100644 --- a/tests/lassi/test_1frag.py +++ b/tests/lassi/test_1frag.py @@ -11,6 +11,7 @@ from mrh.my_pyscf.lassi.lassi import roots_make_rdm12s from mrh.my_pyscf.lassi.op_o1 import get_fdm1_maker from mrh.my_pyscf.lassi import op_o0, op_o1 +from mrh.tests.lassi.addons import case_contract_hlas_ci op = (op_o0, op_o1) @@ -88,40 +89,7 @@ def test_rdms (self): def test_contract_hlas_ci (self): e_roots, si, las = lsi.e_roots, lsi.si, lsi._las h0, h1, h2 = lsi.ham_2q () - nelec = lsi.get_nelec_frs () - print ("huh?", nelec) - ci_fr = las.ci - ham = (si * (e_roots[None,:]-h0)) @ si.conj ().T - ndim = len (e_roots) - - lroots = lsi.get_lroots () - lroots_prod = np.prod (lroots, axis=0) - nj = np.cumsum (lroots_prod) - ni = nj - lroots_prod - for opt in range (2): - hket_fr_pabq = op[opt].contract_ham_ci (las, h1, h2, ci_fr, nelec, ci_fr, nelec) - for f, (ci_r, hket_r_pabq) in enumerate (zip (ci_fr, hket_fr_pabq)): - current_order = list (range (las.nfrags-1, -1, -1)) + [las.nfrags] - current_order.insert (0, current_order.pop (f)) - for r, (ci, hket_pabq) in enumerate (zip (ci_r, hket_r_pabq)): - if ci.ndim < 3: ci = ci[None,:,:] - proper_shape = np.append (lroots[:,r], ndim) - current_shape = proper_shape[current_order] - to_proper_order = list (np.argsort (current_order)) - hket_pq = lib.einsum ('rab,pabq->rpq', ci.conj (), hket_pabq) - hket_pq = hket_pq.reshape (current_shape) - hket_pq = hket_pq.transpose (*to_proper_order) - hket_pq = hket_pq.reshape ((lroots_prod[r], ndim)) - hket_ref = ham[ni[r]:nj[r]] - for s, (k, l) in enumerate (zip (ni, nj)): - hket_pq_s = hket_pq[:,k:l] - hket_ref_s = hket_ref[:,k:l] - with self.subTest (opt=opt, frag=f, bra_space=r, ket_space=s): - print (opt, f, r, s) - print (hket_pq_s) - print (hket_ref_s) - self.assertAlmostEqual (lib.fp (hket_pq_s), lib.fp (hket_ref_s), 8) - + case_contract_hlas_ci (self, las, h0, h1, h2, las.ci, lsi.get_nelec_frs ()) if __name__ == "__main__": print("Full Tests for LASSI single-fragment edge case") diff --git a/tests/lassi/test_22.py b/tests/lassi/test_22.py index 1595a61e..d53f7706 100644 --- a/tests/lassi/test_22.py +++ b/tests/lassi/test_22.py @@ -20,11 +20,12 @@ from mrh.my_pyscf.mcscf.lasscf_o0 import LASSCF from mrh.my_pyscf.lassi import LASSI, LASSIrq, LASSIrqCT from mrh.my_pyscf.lassi.lassi import root_make_rdm12s, make_stdm12s -from mrh.my_pyscf.lassi.spaces import all_single_excitations, SingleLASRootspace +from mrh.my_pyscf.lassi.spaces import all_single_excitations from mrh.my_pyscf.mcscf.lasci import get_space_info from mrh.my_pyscf.lassi import op_o0, op_o1, lassis from mrh.my_pyscf.lassi.op_o1 import get_fdm1_maker from mrh.my_pyscf.lassi.sitools import make_sdm1 +from mrh.tests.lassi.addons import case_contract_hlas_ci def setUpModule (): global mol, mf, lsi, las, mc, op @@ -114,40 +115,7 @@ def test_lassirqct (self): def test_contract_hlas_ci (self): e_roots, si, las = lsi.e_roots, lsi.si, lsi._las h0, h1, h2 = lsi.ham_2q () - nelec = lsi.get_nelec_frs () - ci_fr = las.ci - ham = (si * (e_roots[None,:]-h0)) @ si.conj ().T - ndim = len (e_roots) - - spaces = [SingleLASRootspace (las, m, s, c, 0) for c,m,s,w in zip (*get_space_info (las))] - - lroots = lsi.get_lroots () - lroots_prod = np.prod (lroots, axis=0) - nj = np.cumsum (lroots_prod) - ni = nj - lroots_prod - for opt in range (2): - hket_fr_pabq = op[opt].contract_ham_ci (las, h1, h2, ci_fr, nelec, ci_fr, nelec) - for f, (ci_r, hket_r_pabq) in enumerate (zip (ci_fr, hket_fr_pabq)): - current_order = list (range (las.nfrags-1, -1, -1)) + [las.nfrags] - current_order.insert (0, current_order.pop (f)) - for r, (ci, hket_pabq) in enumerate (zip (ci_r, hket_r_pabq)): - if ci.ndim < 3: ci = ci[None,:,:] - proper_shape = np.append (lroots[:,r], ndim) - current_shape = proper_shape[current_order] - to_proper_order = list (np.argsort (current_order)) - hket_pq = lib.einsum ('rab,pabq->rpq', ci.conj (), hket_pabq) - hket_pq = hket_pq.reshape (current_shape) - hket_pq = hket_pq.transpose (*to_proper_order) - hket_pq = hket_pq.reshape ((lroots_prod[r], ndim)) - hket_ref = ham[ni[r]:nj[r]] - for s, (k, l) in enumerate (zip (ni, nj)): - hket_pq_s = hket_pq[:,k:l] - hket_ref_s = hket_ref[:,k:l] - # TODO: opt>0 for things other than single excitation - #if opt>0 and not spaces[r].is_single_excitation_of (spaces[s]): continue - #elif opt==1: print (r,s, round (lib.fp (hket_pq_s)-lib.fp (hket_ref_s),3)) - with self.subTest (opt=opt, frag=f, bra_space=r, ket_space=s): - self.assertAlmostEqual (lib.fp (hket_pq_s), lib.fp (hket_ref_s), 8) + case_contract_hlas_ci (self, las, h0, h1, h2, las.ci, lsi.get_nelec_frs ()) def test_lassis (self): for opt in (0,1): diff --git a/tests/lassi/test_opt57_slow.py b/tests/lassi/test_opt57_slow.py index cbdafbde..46e9a6ba 100644 --- a/tests/lassi/test_opt57_slow.py +++ b/tests/lassi/test_opt57_slow.py @@ -30,7 +30,7 @@ from mrh.my_pyscf.lassi.citools import get_lroots, get_rootaddr_fragaddr from mrh.my_pyscf.lassi import op_o0 from mrh.my_pyscf.lassi import op_o1 -from mrh.my_pyscf.lassi.spaces import SingleLASRootspace +from mrh.tests.lassi.addons import case_contract_hlas_ci op = (op_o0, op_o1) @@ -183,42 +183,9 @@ def test_rdm12s (self): self.assertAlmostEqual (lib.fp (d12_o0[r][i]), lib.fp (d12_o1[r][i]), 9) - #def test_contract_hlas_ci (self): - # h0, h1, h2 = ham_2q (las, las.mo_coeff) - # nelec = nelec_frs - # ci_fr = las.ci - - # spaces = [SingleLASRootspace (las, m, s, c, 0) for c,m,s,w in zip (*get_space_info (las))] - - # lroots = get_lroots (ci_fr) - # lroots_prod = np.prod (lroots, axis=0) - # nj = np.cumsum (lroots_prod) - # ni = nj - lroots_prod - # ndim = nj[-1] - # for opt in range (2): - # ham = op[opt].ham (las, h1, h2, ci_fr, nelec)[0] - # hket_fr_pabq = op[opt].contract_ham_ci (las, h1, h2, ci_fr, nelec, ci_fr, nelec) - # for f, (ci_r, hket_r_pabq) in enumerate (zip (ci_fr, hket_fr_pabq)): - # current_order = list (range (las.nfrags-1, -1, -1)) + [las.nfrags] - # current_order.insert (0, current_order.pop (f)) - # for r, (ci, hket_pabq) in enumerate (zip (ci_r, hket_r_pabq)): - # if ci.ndim < 3: ci = ci[None,:,:] - # proper_shape = np.append (lroots[:,r], ndim) - # current_shape = proper_shape[current_order] - # to_proper_order = list (np.argsort (current_order)) - # hket_pq = lib.einsum ('rab,pabq->rpq', ci.conj (), hket_pabq) - # hket_pq = hket_pq.reshape (current_shape) - # hket_pq = hket_pq.transpose (*to_proper_order) - # hket_pq = hket_pq.reshape ((lroots_prod[r], ndim)) - # hket_ref = ham[ni[r]:nj[r]] - # for s, (k, l) in enumerate (zip (ni, nj)): - # hket_pq_s = hket_pq[:,k:l] - # hket_ref_s = hket_ref[:,k:l] - # # TODO: opt>0 for things other than single excitation - # #if opt>0 and not spaces[r].is_single_excitation_of (spaces[s]): continue - # #elif opt==1: print (r,s, round (lib.fp (hket_pq_s)-lib.fp (hket_ref_s),3)) - # with self.subTest (opt=opt, frag=f, bra_space=r, ket_space=s): - # self.assertAlmostEqual (lib.fp (hket_pq_s), lib.fp (hket_ref_s), 8) + def test_contract_hlas_ci (self): + h0, h1, h2 = ham_2q (las, las.mo_coeff) + case_contract_hlas_ci (self, las, h0, h1, h2, las.ci, nelec_frs)