diff --git a/my_pyscf/lassi/op_o1.py b/my_pyscf/lassi/op_o1.py index 3604df01..85d77239 100644 --- a/my_pyscf/lassi/op_o1.py +++ b/my_pyscf/lassi/op_o1.py @@ -933,6 +933,16 @@ def _get_addr_range (self, raddr, *inv): return addr0 + np.dot (strides_inv, envaddr_inv) def _prepare_spec_addr_ovlp_(self, rbra, rket, *inv): + '''Prepare the cache for _get_spec_addr_ovlp. + + Args: + rbra: integer + Index of bra rootspace for which to prepare the current cache. + rket: integer + Index of ket rootspace for which to prepare the current cache. + *inv: integers + Indices of nonspectator fragments + ''' key = tuple ((rbra,rket)) + inv braket_table = self.nonuniq_exc[key] self._spec_addr_ovlp_cache = [] @@ -941,30 +951,12 @@ def _prepare_spec_addr_ovlp_(self, rbra, rket, *inv): self._spec_addr_ovlp_cache.append ((rbra1, rket1, b, k, o)) return - def _get_cached_spec_addr_ovlp (self, bra, ket, *inv): - rbra, rket = self.rootaddr[bra], self.rootaddr[ket] - braenv = self.envaddr[bra] - ketenv = self.envaddr[ket] - key = tuple ((rbra,rket)) + inv - braket_table = self.nonuniq_exc[key] - bra_rng = [] - ket_rng = [] - facs = [] - for (rbra1, rket1, b, k, o) in self._spec_addr_ovlp_cache: - dbra = np.dot (braenv, self.strides[rbra1]) - dket = np.dot (ketenv, self.strides[rket1]) - bra_rng.append (b+dbra) - ket_rng.append (k+dket) - facs.append (o) - bra_rng = np.concatenate (bra_rng) - ket_rng = np.concatenate (ket_rng) - facs = np.concatenate (facs) - return bra_rng, ket_rng, facs - def _get_spec_addr_ovlp (self, bra, ket, *inv): '''Obtain the integer indices and overlap*permutation factors for all pairs of model states for which a specified list of nonspectator fragments are in same state that they are in a - provided input pair bra, ket. + provided input pair bra, ket. Uses a cache that must be prepared beforehand by the function + _prepare_spec_addr_ovlp_(rbra, rket, *inv), where rbra and rket must be the rootspace + indices corresponding to this function's bra, ket arguments. Args: bra: integer @@ -991,10 +983,9 @@ def _get_spec_addr_ovlp (self, bra, ket, *inv): bra_rng = [] ket_rng = [] facs = [] - for rbra1, rket1 in braket_table: + for (rbra1, rket1, b, k, o) in self._spec_addr_ovlp_cache: dbra = np.dot (braenv, self.strides[rbra1]) dket = np.dot (ketenv, self.strides[rket1]) - b, k, o = self._get_spec_addr_ovlp_1space (rbra1, rket1, *inv) bra_rng.append (b+dbra) ket_rng.append (k+dket) facs.append (o) @@ -1060,14 +1051,14 @@ def _get_D2_(self, bra, ket): return self.d2 def _put_D1_(self, bra, ket, D1, *inv): - bra1, ket1, wgt = self._get_cached_spec_addr_ovlp (bra, ket, *inv) + bra1, ket1, wgt = self._get_spec_addr_ovlp (bra, ket, *inv) self._put_SD1_(bra1, ket1, D1, wgt) def _put_SD1_(self, bra, ket, D1, wgt): self.tdm1s[bra,ket,:] += np.multiply.outer (wgt, D1) def _put_D2_(self, bra, ket, D2, *inv): - bra1, ket1, wgt = self._get_cached_spec_addr_ovlp (bra, ket, *inv) + bra1, ket1, wgt = self._get_spec_addr_ovlp (bra, ket, *inv) self._put_SD2_(bra1, ket1, D2, wgt) def _put_SD2_(self, bra, ket, D2, wgt): @@ -1404,7 +1395,7 @@ def _put_D1_(self, bra, ket, D1, *inv): D1 = D1.sum (0) #self.s2[bra,ket] += (np.trace (M1)/2)**2 + np.trace (D1)/2 s2 = 3*np.trace (D1)/4 - bra1, ket1, wgt = self._get_cached_spec_addr_ovlp (bra, ket, *inv) + bra1, ket1, wgt = self._get_spec_addr_ovlp (bra, ket, *inv) self._put_ham_s2_(bra1, ket1, ham, s2, wgt) def _put_ham_s2_(self, bra, ket, ham, s2, wgt): @@ -1416,7 +1407,7 @@ def _put_D2_(self, bra, ket, D2, *inv): M2 = np.einsum ('sppqq->s', D2) / 4 s2 = M2[0] + M2[3] - M2[1] - M2[2] s2 -= np.einsum ('pqqp->', D2[1] + D2[2]) / 2 - bra1, ket1, wgt = self._get_cached_spec_addr_ovlp (bra, ket, *inv) + bra1, ket1, wgt = self._get_spec_addr_ovlp (bra, ket, *inv) self._put_ham_s2_(bra1, ket1, ham, s2, wgt) def _add_transpose_(self): @@ -1639,7 +1630,7 @@ def _get_vecs_(self, bra, ket): return hci_f_ab, excfrags def _put_vecs_(self, bra, ket, vecs, *inv): - bras, kets, facs = self._get_cached_spec_addr_ovlp (bra, ket, *inv) + bras, kets, facs = self._get_spec_addr_ovlp (bra, ket, *inv) for bra, ket, fac in zip (bras, kets, facs): self._put_Svecs_(bra, ket, [fac*vec for vec in vecs])