Skip to content

Commit

Permalink
Cleanup and docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewRHermes committed Apr 20, 2024
1 parent 51ae2d6 commit 0555eee
Showing 1 changed file with 19 additions and 28 deletions.
47 changes: 19 additions & 28 deletions my_pyscf/lassi/op_o1.py
Original file line number Diff line number Diff line change
Expand Up @@ -933,6 +933,16 @@ def _get_addr_range (self, raddr, *inv):
return addr0 + np.dot (strides_inv, envaddr_inv)

def _prepare_spec_addr_ovlp_(self, rbra, rket, *inv):
'''Prepare the cache for _get_spec_addr_ovlp.
Args:
rbra: integer
Index of bra rootspace for which to prepare the current cache.
rket: integer
Index of ket rootspace for which to prepare the current cache.
*inv: integers
Indices of nonspectator fragments
'''
key = tuple ((rbra,rket)) + inv
braket_table = self.nonuniq_exc[key]
self._spec_addr_ovlp_cache = []
Expand All @@ -941,30 +951,12 @@ def _prepare_spec_addr_ovlp_(self, rbra, rket, *inv):
self._spec_addr_ovlp_cache.append ((rbra1, rket1, b, k, o))
return

def _get_cached_spec_addr_ovlp (self, bra, ket, *inv):
rbra, rket = self.rootaddr[bra], self.rootaddr[ket]
braenv = self.envaddr[bra]
ketenv = self.envaddr[ket]
key = tuple ((rbra,rket)) + inv
braket_table = self.nonuniq_exc[key]
bra_rng = []
ket_rng = []
facs = []
for (rbra1, rket1, b, k, o) in self._spec_addr_ovlp_cache:
dbra = np.dot (braenv, self.strides[rbra1])
dket = np.dot (ketenv, self.strides[rket1])
bra_rng.append (b+dbra)
ket_rng.append (k+dket)
facs.append (o)
bra_rng = np.concatenate (bra_rng)
ket_rng = np.concatenate (ket_rng)
facs = np.concatenate (facs)
return bra_rng, ket_rng, facs

def _get_spec_addr_ovlp (self, bra, ket, *inv):
'''Obtain the integer indices and overlap*permutation factors for all pairs of model states
for which a specified list of nonspectator fragments are in same state that they are in a
provided input pair bra, ket.
provided input pair bra, ket. Uses a cache that must be prepared beforehand by the function
_prepare_spec_addr_ovlp_(rbra, rket, *inv), where rbra and rket must be the rootspace
indices corresponding to this function's bra, ket arguments.
Args:
bra: integer
Expand All @@ -991,10 +983,9 @@ def _get_spec_addr_ovlp (self, bra, ket, *inv):
bra_rng = []
ket_rng = []
facs = []
for rbra1, rket1 in braket_table:
for (rbra1, rket1, b, k, o) in self._spec_addr_ovlp_cache:
dbra = np.dot (braenv, self.strides[rbra1])
dket = np.dot (ketenv, self.strides[rket1])
b, k, o = self._get_spec_addr_ovlp_1space (rbra1, rket1, *inv)
bra_rng.append (b+dbra)
ket_rng.append (k+dket)
facs.append (o)
Expand Down Expand Up @@ -1060,14 +1051,14 @@ def _get_D2_(self, bra, ket):
return self.d2

def _put_D1_(self, bra, ket, D1, *inv):
bra1, ket1, wgt = self._get_cached_spec_addr_ovlp (bra, ket, *inv)
bra1, ket1, wgt = self._get_spec_addr_ovlp (bra, ket, *inv)
self._put_SD1_(bra1, ket1, D1, wgt)

def _put_SD1_(self, bra, ket, D1, wgt):
self.tdm1s[bra,ket,:] += np.multiply.outer (wgt, D1)

def _put_D2_(self, bra, ket, D2, *inv):
bra1, ket1, wgt = self._get_cached_spec_addr_ovlp (bra, ket, *inv)
bra1, ket1, wgt = self._get_spec_addr_ovlp (bra, ket, *inv)
self._put_SD2_(bra1, ket1, D2, wgt)

def _put_SD2_(self, bra, ket, D2, wgt):
Expand Down Expand Up @@ -1404,7 +1395,7 @@ def _put_D1_(self, bra, ket, D1, *inv):
D1 = D1.sum (0)
#self.s2[bra,ket] += (np.trace (M1)/2)**2 + np.trace (D1)/2
s2 = 3*np.trace (D1)/4
bra1, ket1, wgt = self._get_cached_spec_addr_ovlp (bra, ket, *inv)
bra1, ket1, wgt = self._get_spec_addr_ovlp (bra, ket, *inv)
self._put_ham_s2_(bra1, ket1, ham, s2, wgt)

def _put_ham_s2_(self, bra, ket, ham, s2, wgt):
Expand All @@ -1416,7 +1407,7 @@ def _put_D2_(self, bra, ket, D2, *inv):
M2 = np.einsum ('sppqq->s', D2) / 4
s2 = M2[0] + M2[3] - M2[1] - M2[2]
s2 -= np.einsum ('pqqp->', D2[1] + D2[2]) / 2
bra1, ket1, wgt = self._get_cached_spec_addr_ovlp (bra, ket, *inv)
bra1, ket1, wgt = self._get_spec_addr_ovlp (bra, ket, *inv)
self._put_ham_s2_(bra1, ket1, ham, s2, wgt)

def _add_transpose_(self):
Expand Down Expand Up @@ -1639,7 +1630,7 @@ def _get_vecs_(self, bra, ket):
return hci_f_ab, excfrags

def _put_vecs_(self, bra, ket, vecs, *inv):
bras, kets, facs = self._get_cached_spec_addr_ovlp (bra, ket, *inv)
bras, kets, facs = self._get_spec_addr_ovlp (bra, ket, *inv)
for bra, ket, fac in zip (bras, kets, facs):
self._put_Svecs_(bra, ket, [fac*vec for vec in vecs])

Expand Down

0 comments on commit 0555eee

Please sign in to comment.