Skip to content

Commit

Permalink
lassi op_o1 hci constant part
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewRHermes committed Aug 23, 2024
1 parent 47e7c6a commit 9c95ffc
Showing 1 changed file with 18 additions and 14 deletions.
32 changes: 18 additions & 14 deletions my_pyscf/lassi/op_o1/hci.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from scipy import linalg
from pyscf import lib
from pyscf.lib import logger
from pyscf.fci import cistring
Expand Down Expand Up @@ -310,18 +311,19 @@ def contract_ham_ci (las, h1, h2, ci_fr_ket, nelec_frs_ket, ci_fr_bra, nelec_frs
hket_fr_pabq, t0 = contracter.kernel ()
lib.logger.timer (las, 'LASSI Hamiltonian contraction second intermediate crunching', *t0)

#for ifrag in range (nfrags):
# gen_hket = gen_contract_ham_ci_const (ifrag, nbra, las, h1, h2, ci, nelec_frs, soc=soc,
# orbsym=orbsym, wfnsym=wfnsym)
# for ibra, hket_pabq in enumerate (gen_hket):
# hket_fr_pabq[ifrag][ibra][:] += hket_pabq[:]
# Third pass: multiplicative part
for ifrag in range (nfrags):
gen_hket = gen_contract_ham_ci_const (ifrag, nbra, las, h1, h2, ci, nelec_frs, soc=soc,
orbsym=orbsym, wfnsym=wfnsym)
for ibra, hket_pabq in enumerate (gen_hket):
hket_fr_pabq[ifrag][ibra][:] += hket_pabq[:]
return hket_fr_pabq

def gen_contract_ham_ci_const (ifrag, nbra, las, h1, h2, ci, nelec_frs, soc=0, orbsym=None,
wfnsym=None):
'''Constant-term parts of contract_ham_ci for fragment ifrag'''
log = lib.logger.new_logger (las, las.verbose)
nlas = las.ncas_sub
nlas = np.asarray (las.ncas_sub)
nfrags, nroots = nelec_frs.shape[:2]
nket = nroots - nbra
dtype = ci[0][0].dtype
Expand All @@ -334,7 +336,7 @@ def gen_contract_ham_ci_const (ifrag, nbra, las, h1, h2, ci, nelec_frs, soc=0, o
nelec_i_rs = nelec_frs[ifrag]

# index down to omit fragment
idx = np.ones (las.nfrags, dtype=bool)
idx = np.ones (nfrags, dtype=bool)
idx[ifrag] = False
nelec_frs = nelec_frs[idx]
ci_jfrag = [c for i,c in enumerate (ci) if i != ifrag]
Expand Down Expand Up @@ -370,19 +372,21 @@ def gen_contract_ham_ci_const (ifrag, nbra, las, h1, h2, ci, nelec_frs, soc=0, o
ndetb = cistring.num_strings (norb_i, nelec_i[1])
hket_pabq = np.zeros ((nprods_ket, j-i, ndeta, ndetb),
dtype=outerprod.dtype).transpose (1,2,3,0)
m = 0
n = 0
for iket in range (nket):
if tuple (nelec_i_rs[iket]) != tuple (nelec_i): continue
m = n
ci_i_iket = ci_i[iket]
if ci_i_iket.ndim == 2: ci_i_iket = ci_i_iket[None,...]
nq1, na, nb = ci_i_iket.shape
k, l = outerprod.offs_lroots[iket]
hket = np.multiply.outer (ci_i[iket], ham[i:j,k:l]) # qabpq
if ci_i[iket].ndim == 2: hket = hket[None,...]
nq1, na, nb, np1, nq2 = hket.shape
np1, nq2 = ham[i:j,k:l].shape
n = m + nq1*nq2
if tuple (nelec_i_rs[iket]) != tuple (nelec_i): continue
hket = np.multiply.outer (ci_i_iket, ham[i:j,k:l]) # qabpq
new_shape = [nq1,na,nb,np1] + list (outerprod.lroots[::-1,iket])
hket = np.moveaxis (hket.reshape (new_shape), 0, -ifrag)
new_shape = [na,nb,np1,nq1*nq2]
hket = hket.reshape (new_shape).transpose (2,0,1,3)
n = m + nq1*nq2
hket_pabq[:,:,:,m:n] = hket[:]
m = n
yield hket_pabq

0 comments on commit 9c95ffc

Please sign in to comment.