From 51ae2d682e23a9741f632b96a8550c7a836d944a Mon Sep 17 00:00:00 2001 From: Matthew R Hermes Date: Fri, 19 Apr 2024 18:35:53 -0500 Subject: [PATCH] Cache spectator fragment idxs and ovlps! Finally! An actual speedup! --- my_pyscf/lassi/op_o1.py | 42 ++++++++++++++++++++++++++++++++++------- 1 file changed, 35 insertions(+), 7 deletions(-) diff --git a/my_pyscf/lassi/op_o1.py b/my_pyscf/lassi/op_o1.py index 281cd4bf..3604df01 100644 --- a/my_pyscf/lassi/op_o1.py +++ b/my_pyscf/lassi/op_o1.py @@ -932,6 +932,35 @@ def _get_addr_range (self, raddr, *inv): strides_inv = self.strides[raddr][inv] return addr0 + np.dot (strides_inv, envaddr_inv) + def _prepare_spec_addr_ovlp_(self, rbra, rket, *inv): + key = tuple ((rbra,rket)) + inv + braket_table = self.nonuniq_exc[key] + self._spec_addr_ovlp_cache = [] + for rbra1, rket1 in braket_table: + b, k, o = self._get_spec_addr_ovlp_1space (rbra1, rket1, *inv) + self._spec_addr_ovlp_cache.append ((rbra1, rket1, b, k, o)) + return + + def _get_cached_spec_addr_ovlp (self, bra, ket, *inv): + rbra, rket = self.rootaddr[bra], self.rootaddr[ket] + braenv = self.envaddr[bra] + ketenv = self.envaddr[ket] + key = tuple ((rbra,rket)) + inv + braket_table = self.nonuniq_exc[key] + bra_rng = [] + ket_rng = [] + facs = [] + for (rbra1, rket1, b, k, o) in self._spec_addr_ovlp_cache: + dbra = np.dot (braenv, self.strides[rbra1]) + dket = np.dot (ketenv, self.strides[rket1]) + bra_rng.append (b+dbra) + ket_rng.append (k+dket) + facs.append (o) + bra_rng = np.concatenate (bra_rng) + ket_rng = np.concatenate (ket_rng) + facs = np.concatenate (facs) + return bra_rng, ket_rng, facs + def _get_spec_addr_ovlp (self, bra, ket, *inv): '''Obtain the integer indices and overlap*permutation factors for all pairs of model states for which a specified list of nonspectator fragments are in same state that they are in a @@ -964,9 +993,7 @@ def _get_spec_addr_ovlp (self, bra, ket, *inv): facs = [] for rbra1, rket1 in braket_table: dbra = np.dot (braenv, self.strides[rbra1]) - bra1 = self.offs_lroots[rbra1][0] + dbra dket = np.dot (ketenv, self.strides[rket1]) - ket1 = self.offs_lroots[rket1][0] + dket b, k, o = self._get_spec_addr_ovlp_1space (rbra1, rket1, *inv) bra_rng.append (b+dbra) ket_rng.append (k+dket) @@ -1033,14 +1060,14 @@ def _get_D2_(self, bra, ket): return self.d2 def _put_D1_(self, bra, ket, D1, *inv): - bra1, ket1, wgt = self._get_spec_addr_ovlp (bra, ket, *inv) + bra1, ket1, wgt = self._get_cached_spec_addr_ovlp (bra, ket, *inv) self._put_SD1_(bra1, ket1, D1, wgt) def _put_SD1_(self, bra, ket, D1, wgt): self.tdm1s[bra,ket,:] += np.multiply.outer (wgt, D1) def _put_D2_(self, bra, ket, D2, *inv): - bra1, ket1, wgt = self._get_spec_addr_ovlp (bra, ket, *inv) + bra1, ket1, wgt = self._get_cached_spec_addr_ovlp (bra, ket, *inv) self._put_SD2_(bra1, ket1, D2, wgt) def _put_SD2_(self, bra, ket, D2, wgt): @@ -1296,6 +1323,7 @@ def _loop_lroots_(self, _crunch_fn, *row): inv = row[2:-1] else: inv = row[2:] + self._prepare_spec_addr_ovlp_(row[0], row[1], *inv) bra_rng = self._get_addr_range (row[0], *inv) ket_rng = self._get_addr_range (row[1], *inv) lrow = [l for l in row] @@ -1376,7 +1404,7 @@ def _put_D1_(self, bra, ket, D1, *inv): D1 = D1.sum (0) #self.s2[bra,ket] += (np.trace (M1)/2)**2 + np.trace (D1)/2 s2 = 3*np.trace (D1)/4 - bra1, ket1, wgt = self._get_spec_addr_ovlp (bra, ket, *inv) + bra1, ket1, wgt = self._get_cached_spec_addr_ovlp (bra, ket, *inv) self._put_ham_s2_(bra1, ket1, ham, s2, wgt) def _put_ham_s2_(self, bra, ket, ham, s2, wgt): @@ -1388,7 +1416,7 @@ def _put_D2_(self, bra, ket, D2, *inv): M2 = np.einsum ('sppqq->s', D2) / 4 s2 = M2[0] + M2[3] - M2[1] - M2[2] s2 -= np.einsum ('pqqp->', D2[1] + D2[2]) / 2 - bra1, ket1, wgt = self._get_spec_addr_ovlp (bra, ket, *inv) + bra1, ket1, wgt = self._get_cached_spec_addr_ovlp (bra, ket, *inv) self._put_ham_s2_(bra1, ket1, ham, s2, wgt) def _add_transpose_(self): @@ -1611,7 +1639,7 @@ def _get_vecs_(self, bra, ket): return hci_f_ab, excfrags def _put_vecs_(self, bra, ket, vecs, *inv): - bras, kets, facs = self._get_spec_addr_ovlp (bra, ket, *inv) + bras, kets, facs = self._get_cached_spec_addr_ovlp (bra, ket, *inv) for bra, ket, fac in zip (bras, kets, facs): self._put_Svecs_(bra, ket, [fac*vec for vec in vecs])