Skip to content

Commit

Permalink
Cache spectator fragment idxs and ovlps!
Browse files Browse the repository at this point in the history
Finally! An actual speedup!
  • Loading branch information
MatthewRHermes committed Apr 19, 2024
1 parent 64c5f06 commit 51ae2d6
Showing 1 changed file with 35 additions and 7 deletions.
42 changes: 35 additions & 7 deletions my_pyscf/lassi/op_o1.py
Original file line number Diff line number Diff line change
Expand Up @@ -932,6 +932,35 @@ def _get_addr_range (self, raddr, *inv):
strides_inv = self.strides[raddr][inv]
return addr0 + np.dot (strides_inv, envaddr_inv)

def _prepare_spec_addr_ovlp_(self, rbra, rket, *inv):
key = tuple ((rbra,rket)) + inv
braket_table = self.nonuniq_exc[key]
self._spec_addr_ovlp_cache = []
for rbra1, rket1 in braket_table:
b, k, o = self._get_spec_addr_ovlp_1space (rbra1, rket1, *inv)
self._spec_addr_ovlp_cache.append ((rbra1, rket1, b, k, o))
return

def _get_cached_spec_addr_ovlp (self, bra, ket, *inv):
rbra, rket = self.rootaddr[bra], self.rootaddr[ket]
braenv = self.envaddr[bra]
ketenv = self.envaddr[ket]
key = tuple ((rbra,rket)) + inv
braket_table = self.nonuniq_exc[key]
bra_rng = []
ket_rng = []
facs = []
for (rbra1, rket1, b, k, o) in self._spec_addr_ovlp_cache:
dbra = np.dot (braenv, self.strides[rbra1])
dket = np.dot (ketenv, self.strides[rket1])
bra_rng.append (b+dbra)
ket_rng.append (k+dket)
facs.append (o)
bra_rng = np.concatenate (bra_rng)
ket_rng = np.concatenate (ket_rng)
facs = np.concatenate (facs)
return bra_rng, ket_rng, facs

def _get_spec_addr_ovlp (self, bra, ket, *inv):
'''Obtain the integer indices and overlap*permutation factors for all pairs of model states
for which a specified list of nonspectator fragments are in same state that they are in a
Expand Down Expand Up @@ -964,9 +993,7 @@ def _get_spec_addr_ovlp (self, bra, ket, *inv):
facs = []
for rbra1, rket1 in braket_table:
dbra = np.dot (braenv, self.strides[rbra1])
bra1 = self.offs_lroots[rbra1][0] + dbra
dket = np.dot (ketenv, self.strides[rket1])
ket1 = self.offs_lroots[rket1][0] + dket
b, k, o = self._get_spec_addr_ovlp_1space (rbra1, rket1, *inv)
bra_rng.append (b+dbra)
ket_rng.append (k+dket)
Expand Down Expand Up @@ -1033,14 +1060,14 @@ def _get_D2_(self, bra, ket):
return self.d2

def _put_D1_(self, bra, ket, D1, *inv):
bra1, ket1, wgt = self._get_spec_addr_ovlp (bra, ket, *inv)
bra1, ket1, wgt = self._get_cached_spec_addr_ovlp (bra, ket, *inv)
self._put_SD1_(bra1, ket1, D1, wgt)

def _put_SD1_(self, bra, ket, D1, wgt):
self.tdm1s[bra,ket,:] += np.multiply.outer (wgt, D1)

def _put_D2_(self, bra, ket, D2, *inv):
bra1, ket1, wgt = self._get_spec_addr_ovlp (bra, ket, *inv)
bra1, ket1, wgt = self._get_cached_spec_addr_ovlp (bra, ket, *inv)
self._put_SD2_(bra1, ket1, D2, wgt)

def _put_SD2_(self, bra, ket, D2, wgt):
Expand Down Expand Up @@ -1296,6 +1323,7 @@ def _loop_lroots_(self, _crunch_fn, *row):
inv = row[2:-1]
else:
inv = row[2:]
self._prepare_spec_addr_ovlp_(row[0], row[1], *inv)
bra_rng = self._get_addr_range (row[0], *inv)
ket_rng = self._get_addr_range (row[1], *inv)
lrow = [l for l in row]
Expand Down Expand Up @@ -1376,7 +1404,7 @@ def _put_D1_(self, bra, ket, D1, *inv):
D1 = D1.sum (0)
#self.s2[bra,ket] += (np.trace (M1)/2)**2 + np.trace (D1)/2
s2 = 3*np.trace (D1)/4
bra1, ket1, wgt = self._get_spec_addr_ovlp (bra, ket, *inv)
bra1, ket1, wgt = self._get_cached_spec_addr_ovlp (bra, ket, *inv)
self._put_ham_s2_(bra1, ket1, ham, s2, wgt)

def _put_ham_s2_(self, bra, ket, ham, s2, wgt):
Expand All @@ -1388,7 +1416,7 @@ def _put_D2_(self, bra, ket, D2, *inv):
M2 = np.einsum ('sppqq->s', D2) / 4
s2 = M2[0] + M2[3] - M2[1] - M2[2]
s2 -= np.einsum ('pqqp->', D2[1] + D2[2]) / 2
bra1, ket1, wgt = self._get_spec_addr_ovlp (bra, ket, *inv)
bra1, ket1, wgt = self._get_cached_spec_addr_ovlp (bra, ket, *inv)
self._put_ham_s2_(bra1, ket1, ham, s2, wgt)

def _add_transpose_(self):
Expand Down Expand Up @@ -1611,7 +1639,7 @@ def _get_vecs_(self, bra, ket):
return hci_f_ab, excfrags

def _put_vecs_(self, bra, ket, vecs, *inv):
bras, kets, facs = self._get_spec_addr_ovlp (bra, ket, *inv)
bras, kets, facs = self._get_cached_spec_addr_ovlp (bra, ket, *inv)
for bra, ket, fac in zip (bras, kets, facs):
self._put_Svecs_(bra, ket, [fac*vec for vec in vecs])

Expand Down

0 comments on commit 51ae2d6

Please sign in to comment.