Skip to content

Commit

Permalink
lassi op_o1 contract_hci syntax safety commit
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewRHermes committed Nov 10, 2023
1 parent 98fd871 commit e3a1a76
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 47 deletions.
3 changes: 1 addition & 2 deletions my_pyscf/lassi/citools.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_lroots (ci):
lroots.append (get_lroots (c))
return np.asarray (lroots)

def envaddr2fragaddr (lroots):
def get_rootaddr_fragaddr (lroots):
'''Generate an index array into a compressed fragment basis for a state in the LASSI model
space
Expand Down Expand Up @@ -65,4 +65,3 @@ def envaddr2fragaddr (lroots):
return rootaddr, fragaddr



100 changes: 57 additions & 43 deletions my_pyscf/lassi/op_o1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pyscf.fci.addons import cre_a, cre_b, des_a, des_b
from pyscf.fci import cistring
from itertools import product, combinations
from mrh.my_pyscf.lassi.citools import get_lroots, envaddr2fragaddr
from mrh.my_pyscf.lassi.citools import get_lroots, get_rootaddr_fragaddr
import time

# NOTE: PySCF has a strange convention where
Expand Down Expand Up @@ -225,6 +225,12 @@ def __init__(self, ci, hopping_index, zerop_index, onep_index, norb, nroots, nel
self.rootaddr = rootaddr
self.fragaddr = fragaddr
self.idx_frag = idx_frag

# Consistent array shape
self.ndeta_r = [cistring.num_strings (norb, nelec[0]) for nelec in self.nelec_r]
self.ndetb_r = [cistring.num_strings (norb, nelec[1]) for nelec in self.nelec_r]
self.ci = [c.reshape (-1,na,nb) for c, na, nb in zip (self.ci, self.ndeta_r, self.ndetb_r)]

self.time_crunch = self._init_crunch_()

# Exception catching
Expand Down Expand Up @@ -346,17 +352,14 @@ def _init_crunch_(self):
timestamp of entry into this function, for profiling by caller
'''
ci = self.ci
ndeta, ndetb = self.ndeta_r, self.ndetb_r
hopping_index = self.hopping_index
zerop_index = self.zerop_index
onep_index = self.onep_index

nroots, norb = self.nroots, self.norb
t0 = (lib.logger.process_clock (), lib.logger.perf_counter ())

# Consistent array shape
ndeta = [cistring.num_strings (norb, nelec[0]) for nelec in self.nelec_r]
ndetb = [cistring.num_strings (norb, nelec[1]) for nelec in self.nelec_r]
ci = [c.reshape (-1,na,nb) for c, na, nb in zip (ci, ndeta, ndetb)]
lroots = [c.shape[0] for c in ci]

# Overlap matrix
Expand Down Expand Up @@ -503,6 +506,9 @@ def contract_h10 (self, spin, h_10, h_21, ket):
cre_op = (cre_a, cre_b)[spin]
ci = self.ci[r][n]
hci = 0
nelecp = list (nelec)
nelecp[spin] = nelecp[spin] + 1
nelecp = tuple (nelecp)
for p in range (self.norb):
hci += h_10[p] * cre_op (ci, norb, nelec, p)
hci += cre_op (contract_1e (h_21[p], ci, norb, nelec),
Expand All @@ -516,10 +522,18 @@ def contract_h01 (self, spin, h_01, h_12, ket):
des_op = (des_a, des_b)[spin]
ci = self.ci[r][n]
hci = 0
nelecp = list (nelec)
nelecp[spin] = nelecp[spin] - 1
nelecp = tuple (nelecp)
for p in range (self.norb):
hci += h_01[p] * des_op (ci, norb, nelec, p)
try:
hci += h_01[p] * des_op (ci, norb, nelec, p)
except ValueError as err:
print (ci.shape, norb, nelec, p)
print (type (self.ci[r]))
raise (err)
hci += contract_1e (h_12[p], des_op (ci, norb, nelec, p),
norb, nelec)
norb, nelecp)
return hci

def contract_h20 (self, spin, h_20, ket):
Expand Down Expand Up @@ -598,7 +612,7 @@ def __init__(self, ints, nlas, hopping_index, lroots, mask_bra_space=None, mask_
self.nlas = nlas
self.norb = sum (nlas)
self.lroots = lroots
self.rootaddr = envaddr2fragaddr (lroots)[0]
self.rootaddr = get_rootaddr_fragaddr (lroots)[0]
nprods = np.prod (lroots, axis=0)
offs1 = np.cumsum (nprods)
offs0 = offs1 - nprods
Expand Down Expand Up @@ -748,8 +762,8 @@ def make_exc_tables (self, hopping_index):

def mask_exc_table (self, exc, mask_bra_space=None, mask_ket_space=None):
# TODO: PROBLEM: this transposes "bra" and "ket"
exc = mask_exc_table (exc, col=1, mask_space=mask_bra_space)
exc = mask_exc_table (exc, col=0, mask_space=mask_ket_space)
exc = mask_exc_table (exc, col=0, mask_space=mask_bra_space)
exc = mask_exc_table (exc, col=1, mask_space=mask_ket_space)
return exc

def get_range (self, i):
Expand Down Expand Up @@ -1185,32 +1199,32 @@ class ContractHamCI (LSTDMint2):
def __init__(self, ints, nlas, hopping_index, lroots, h1, h2, nbra=1, dtype=np.float64):
nfrags, _, nroots, _ = hopping_index.shape
if nfrags > 2: raise NotImplementedError ("Spectator fragments in _crunch_1c_")
nket = nroots - nbra
HamS2ovlpint.__init__(self, ints, nlas, hopping_index, lroots, h1, h2,
mask_bra_space = list (range (nbra)),
mask_ket_space = list (range (nbra, nroots)),
mask_bra_space = list (range (nket, nroots)),
mask_ket_space = list (range (nket)),
dtype=dtype)
self.hci_fr_pabq, self.bra_offsets, self.ket_offset = self._init_vecs (nbra=nbra)
self.nbra = nbra
self.hci_fr_pabq = self._init_vecs ()

def _init_vecs (self, nbra=1):
def _init_vecs (self):
hci_fr_pabq = []
nfrags, nroots = self.nfrags, self.nroots
nprods_ket = np.sum (np.prod (self.lroots[:,nbra:], axis=0))
bra_offsets = np.prod (self.lroots[:,:nbra], axis=0)
ket_offset = np.sum (bra_offsets)
nfrags, nroots, nbra = self.nfrags, self.nroots, self.nbra
nprods_ket = np.sum (np.prod (self.lroots[:,:-nbra], axis=0))
for i in range (nfrags):
lroots_bra = self.lroots.copy ()[:,:nbra]
lroots_bra = self.lroots.copy ()[:,-nbra:]
lroots_bra[i,:] = 1
nprods_bra = np.prod (lroots_bra, axis=0)
hci_r_pabq = []
norb = self.ints[i].norb
for r in range (nbra):
nelec = self.ints[i].nelec_r[r]
for r in range (self.nbra):
nelec = self.ints[i].nelec_r[r+self.nroots-self.nbra]
ndeta = cistring.num_strings (norb, nelec[0])
ndetb = cistring.num_strings (norb, nelec[1])
hci_r_pabq.append (np.zeros ((nprods_ket, nprods_bra[r], ndeta, ndetb),
dtype=self.dtype).transpose (1,2,3,0))
hci_fr_pabq.append (hci_r_pabq)
return hci_fr_pabq, bra_offsets, ket_offset
return hci_fr_pabq

def _crunch_null_(self, bra, ket):
raise NotImplementedError
Expand Down Expand Up @@ -1241,16 +1255,16 @@ def _crunch_1c_(self, bra, ket, i, j, s1):
h2_iiij = self.h2[p:q,p:q,p:q,r:s]
if i in excfrags:
D_j = self.ints[j].get_h (bra, ket, s1)
h_10 = np.dot (h1_ij, D_j) + np.tensordot (
h2_ijjj, self.ints[j].get_phh (bra, ket, s1),
D_jjj = self.ints[j].get_phh (bra, ket, s1).sum (0)
h_10 = np.dot (h1_ij, D_j) + np.tensordot (h2_ijjj, D_jjj,
axes=((1,2,3),(2,0,1)))
h_21 = np.dot (h2_iiij, D_j)
hci_f_ab[i] += self.ints[i].contract_h10 (s1, h_10, h_21, ket)
if j in excfrags:
D_i = self.ints[i].get_p (bra, ket, s1)
h_01 = np.dot (D_i, h1_ij) + np.tensordot (
self.ints[i].get_pph (bra, ket, s1), h_iiij,
axes=((0,2,1),(0,1,2)))
D_iii = self.ints[i].get_pph (bra, ket, s1).sum (0)
h_01 = np.dot (D_i, h1_ij) + np.tensordot (D_iii, h2_iiij,
axes=((0,1,2),(0,1,2)))
h_12 = np.dot (D_i, h2_ijjj)
hci_f_ab[j] += self.ints[j].contract_h01 (s1, h_01, h_12, ket)
self._put_vecs_(bra, ket, hci_f_ab)
Expand All @@ -1275,20 +1289,20 @@ def env_addr_fragpop (self, bra, i, r):
return bra + rem

def _get_vecs_(self, bra, ket):
ket = ket - self.ket_offset
bra_r = self.rootaddr[bra]
bra_env = np.array ([inti.fragaddr[bra] for inti in self.ints])
lroots_bra_r = self.lroots[:,bra_r]
bra_r = bra_r + self.nbra - self.nroots
hci_f_ab = []
excfrags = set ()
for i, hci_r_pabq in enumerate (self.hci_fr_pabq):
if self.ints[i].fragaddr[bra] != 0:
hci_fr_ab.append (None)
continue
excfrags.add (i)
for r, bra_offset in enumerate (self.bra_offsets):
if bra < bra_offset:
break
bra = bra - bra_offset
env_i = self.env_addr_fragpop (bra, i, r)
hci_f_ab.append (hci_r_pabq[r][env_i,:,:,ket])
excfrags = set (np.where (bra_env==0)[0])
hci_f_ab = [None for i in range (self.nfrags)]
for i in excfrags:
hci_r_pabq = self.hci_fr_pabq[i]
lroots_i = lroots_bra_r.copy ()
lroots_i[i] = 1
strides = np.append ([1], np.cumprod (lroots_i[:-1]))
bra_envaddr = np.dot (strides, bra_env)
hci_f_ab[i] = hci_r_pabq[bra_r][bra_envaddr,:,:,ket]
return hci_f_ab, set (excfrags)

def _put_vecs_(self, bra, ket, vecs):
Expand Down Expand Up @@ -1334,7 +1348,7 @@ def make_ints (las, ci, nelec_frs):
nlas = las.ncas_sub
lroots = get_lroots (ci)
hopping_index, zerop_index, onep_index = lst_hopping_index (nelec_frs)
rootaddr, fragaddr = envaddr2fragaddr (lroots)
rootaddr, fragaddr = get_rootaddr_fragaddr (lroots)
ints = []
for ifrag in range (nfrags):
tdmint = LSTDMint1 (ci[ifrag], hopping_index[ifrag], zerop_index, onep_index, nlas[ifrag],
Expand Down Expand Up @@ -1496,8 +1510,8 @@ def contract_ham_ci (las, h1, h2, ci_fr_ket, nelec_frs_ket, ci_fr_bra, nelec_frs
nlas = las.ncas_sub
nfrags, nbra = nelec_frs_bra.shape[:2]
nket = nelec_frs_ket.shape[1]
ci = [ci_r_bra + ci_r_ket for ci_r_bra, ci_r_ket in zip (ci_fr_bra, ci_fr_ket)]
nelec_frs = np.append (nelec_frs_bra, nelec_frs_ket, axis=1)
ci = [ci_r_ket + ci_r_bra for ci_r_bra, ci_r_ket in zip (ci_fr_bra, ci_fr_ket)]
nelec_frs = np.append (nelec_frs_ket, nelec_frs_bra, axis=1)

# First pass: single-fragment intermediates
hopping_index, ints, lroots = make_ints (las, ci, nelec_frs)
Expand Down
4 changes: 2 additions & 2 deletions tests/lassi/test_opt57_slow.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from mrh.my_pyscf.mcscf.lasscf_o0 import LASSCF
from mrh.my_pyscf.mcscf.lasci import get_space_info
from mrh.my_pyscf.lassi.lassi import roots_make_rdm12s, make_stdm12s, ham_2q
from mrh.my_pyscf.lassi.citools import get_lroots, envaddr2fragaddr
from mrh.my_pyscf.lassi.citools import get_lroots, get_rootaddr_fragaddr
from mrh.my_pyscf.lassi import op_o0
from mrh.my_pyscf.lassi import op_o1

Expand Down Expand Up @@ -135,7 +135,7 @@ def test_stdm12s (self):
t2, w2 = lib.logger.process_clock (), lib.logger.perf_counter ()
#print (t1-t0, t2-t1)
#print (w1-w0, w2-w1)
rootaddr, fragaddr = envaddr2fragaddr (get_lroots (las.ci))
rootaddr, fragaddr = get_rootaddr_fragaddr (get_lroots (las.ci))
for r in range (2):
for i, j in product (range (nstates), repeat=2):
with self.subTest (rank=r+1, idx=(i,j), spaces=(rootaddr[i], rootaddr[j]),
Expand Down

0 comments on commit e3a1a76

Please sign in to comment.