From 054a204056bd31c1866ddfed6a2c605c3555ba27 Mon Sep 17 00:00:00 2001 From: Matthew R Hermes Date: Fri, 2 Aug 2024 17:02:04 -0500 Subject: [PATCH] Index down the orbital range in lassi op_o1 crunch So that we stop accidentally having N^8 scaling by making each of N^4 interactions address all N^4 orbitals. --- my_pyscf/lassi/op_o1.py | 114 ++++++++++++++++++++++++++++------------ 1 file changed, 81 insertions(+), 33 deletions(-) diff --git a/my_pyscf/lassi/op_o1.py b/my_pyscf/lassi/op_o1.py index 7ae06525..6a5ff6bf 100644 --- a/my_pyscf/lassi/op_o1.py +++ b/my_pyscf/lassi/op_o1.py @@ -703,6 +703,7 @@ def __init__(self, ints, nlas, hopping_index, lroots, mask_bra_space=None, mask_ # buffer self.d1 = np.zeros ([2,]+[self.norb,]*2, dtype=self.dtype) self.d2 = np.zeros ([4,]+[self.norb,]*4, dtype=self.dtype) + self._orbidx = np.ones (self.norb, dtype=bool) def init_profiling (self): self.dt_1d, self.dw_1d = 0.0, 0.0 @@ -1107,7 +1108,9 @@ def _put_D1_(self, bra, ket, D1, *inv): def _put_SD1_(self, bra, ket, D1, wgt): t0, w0 = logger.process_clock (), logger.perf_counter () - self.tdm1s[bra,ket,:] += np.multiply.outer (wgt, D1) + idx = self._orbidx + idx = np.ix_(bra,ket,[True,]*2,idx,idx) + self.tdm1s[idx] += np.multiply.outer (wgt, D1) dt, dw = logger.process_clock () - t0, logger.perf_counter () - w0 self.dt_s, self.dw_s = self.dt_s + dt, self.dw_s + dw @@ -1120,7 +1123,9 @@ def _put_D2_(self, bra, ket, D2, *inv): def _put_SD2_(self, bra, ket, D2, wgt): t0, w0 = logger.process_clock (), logger.perf_counter () - self.tdm2s[bra,ket,:] += np.multiply.outer (wgt, D2) + idx = self._orbidx + idx = np.ix_(bra,ket,[True,]*4,idx,idx,idx,idx) + self.tdm2s[idx] += np.multiply.outer (wgt, D2) dt, dw = logger.process_clock () - t0, logger.perf_counter () - w0 self.dt_s, self.dw_s = self.dt_s + dt, self.dw_s + dw @@ -1369,11 +1374,34 @@ def _crunch_2c_(self, bra, ket, i, j, k, l, s2lt): self.dt_2c, self.dw_2c = self.dt_2c + dt, self.dw_2c + dw self._put_D2_(bra, ket, d2, i, j, k, l) - def _loop_lroots_(self, _crunch_fn, *row): + def _crunch_env_(self, _crunch_fn, *row): if _crunch_fn.__name__ in ('_crunch_1c_', '_crunch_1c1d_', '_crunch_2c_'): inv = row[2:-1] else: inv = row[2:] + with lib.temporary_env (self, **self._orbrange_env_kwargs (inv)): + self._loop_lroots_(_crunch_fn, row, inv) + + def _orbrange_env_kwargs (self, inv): + fragidx = np.zeros (self.nfrags, dtype=bool) + _orbidx = np.zeros (self.norb, dtype=bool) + for frag in inv: + fragidx[frag] = True + p, q = self.get_range (frag) + _orbidx[p:q] = True + nlas = np.array (self.nlas) + nlas[~fragidx] = 0 + norb = sum (nlas) + d1_shape = [2,] + [norb,]*2 + d1_size = np.prod (d1_shape) + d1 = self.d1.ravel ()[:d1_size].reshape (d1_shape) + d2_shape = [4,] + [norb,]*4 + d2_size = np.prod (d2_shape) + d2 = self.d2.ravel ()[:d2_size].reshape (d2_shape) + env_kwargs = {'nlas': nlas, 'd1': d1, 'd2': d2, '_orbidx': _orbidx} + return env_kwargs + + def _loop_lroots_(self, _crunch_fn, row, inv): self._prepare_spec_addr_ovlp_(row[0], row[1], *inv) bra_rng = self._get_addr_range (row[0], *inv) ket_rng = self._get_addr_range (row[1], *inv) @@ -1382,13 +1410,13 @@ def _loop_lroots_(self, _crunch_fn, *row): _crunch_fn (*lrow) def _crunch_all_(self): - for row in self.exc_1d: self._loop_lroots_(self._crunch_1d_, *row) - for row in self.exc_2d: self._loop_lroots_(self._crunch_2d_, *row) - for row in self.exc_1c: self._loop_lroots_(self._crunch_1c_, *row) - for row in self.exc_1c1d: self._loop_lroots_(self._crunch_1c1d_, *row) - for row in self.exc_1s: self._loop_lroots_(self._crunch_1s_, *row) - for row in self.exc_1s1c: self._loop_lroots_(self._crunch_1s1c_, *row) - for row in self.exc_2c: self._loop_lroots_(self._crunch_2c_, *row) + for row in self.exc_1d: self._crunch_env_(self._crunch_1d_, *row) + for row in self.exc_2d: self._crunch_env_(self._crunch_2d_, *row) + for row in self.exc_1c: self._crunch_env_(self._crunch_1c_, *row) + for row in self.exc_1c1d: self._crunch_env_(self._crunch_1c1d_, *row) + for row in self.exc_1s: self._crunch_env_(self._crunch_1s_, *row) + for row in self.exc_1s1c: self._crunch_env_(self._crunch_1s1c_, *row) + for row in self.exc_2c: self._crunch_env_(self._crunch_2c_, *row) self._add_transpose_() def _add_transpose_(self): @@ -1516,6 +1544,15 @@ def _umat_linequiv_(self, ifrag, iroot, umat, *args): ovlp = umat_dot_1frag_(ovlp, umat, self.lroots, ifrag, iroot, axis=1) return ovlp + def _orbrange_env_kwargs (self, inv): + env_kwargs = super()._orbrange_env_kwargs (inv) + _orbidx = env_kwargs['_orbidx'] + idx = np.ix_([True,True],_orbidx,_orbidx) + env_kwargs['h1'] = np.ascontiguousarray (self.h1[idx]) + idx = np.ix_(_orbidx,_orbidx,_orbidx,_orbidx) + env_kwargs['h2'] = np.ascontiguousarray (self.h2[idx]) + return env_kwargs + def kernel (self): ''' Main driver method of class. @@ -1584,31 +1621,31 @@ def __init__(self, ints, nlas, hopping_index, lroots, si, mask_bra_space=None, def _put_SD1_(self, bra, ket, D1, wgt): t0, w0 = logger.process_clock (), logger.perf_counter () - #si_dm = self.si[bra,:] * self.si[ket,:].conj () - #fac = np.dot (wgt, si_dm) - #self.rdm1s[:] += np.multiply.outer (fac, D1) - fn = liblassi.LASSIRDMdputSD - si_nrow, si_ncol = self.si.shape - fn (c_arr(self.rdm1s), c_arr(D1), c_int(D1.size), - c_arr(self.si), c_int(si_nrow), c_int(si_ncol), - c_arr(bra), c_arr(ket), c_arr (wgt), - c_int(len(wgt))) - dt, dw = logger.process_clock () - t0, logger.perf_counter () - w0 - self.dt_s, self.dw_s = self.dt_s + dt, self.dw_s + dw + si_dm = self.si[bra,:] * self.si[ket,:].conj () + fac = np.dot (wgt, si_dm) + self.rdm1s[self.rdm1s_idx] += np.multiply.outer (fac, D1) + #fn = liblassi.LASSIRDMdputSD + #si_nrow, si_ncol = self.si.shape + #fn (c_arr(self.rdm1s), c_arr(D1), c_int(D1.size), + # c_arr(self.si), c_int(si_nrow), c_int(si_ncol), + # c_arr(bra), c_arr(ket), c_arr (wgt), + # c_int(len(wgt))) + #dt, dw = logger.process_clock () - t0, logger.perf_counter () - w0 + #self.dt_s, self.dw_s = self.dt_s + dt, self.dw_s + dw def _put_SD2_(self, bra, ket, D2, wgt): t0, w0 = logger.process_clock (), logger.perf_counter () - #si_dm = self.si[bra,:] * self.si[ket,:].conj () - #fac = np.dot (wgt, si_dm) - #self.rdm2s[:] += np.multiply.outer (fac, D2) - fn = liblassi.LASSIRDMdputSD - si_nrow, si_ncol = self.si.shape - fn (c_arr(self.rdm2s), c_arr(D2), c_int(D2.size), - c_arr(self.si), c_int(si_nrow), c_int(si_ncol), - c_arr(bra), c_arr(ket), c_arr (wgt), - c_int(len(wgt))) - dt, dw = logger.process_clock () - t0, logger.perf_counter () - w0 - self.dt_s, self.dw_s = self.dt_s + dt, self.dw_s + dw + si_dm = self.si[bra,:] * self.si[ket,:].conj () + fac = np.dot (wgt, si_dm) + self.rdm2s[self.rdm2s_idx] += np.multiply.outer (fac, D2) + #fn = liblassi.LASSIRDMdputSD + #si_nrow, si_ncol = self.si.shape + #fn (c_arr(self.rdm2s), c_arr(D2), c_int(D2.size), + # c_arr(self.si), c_int(si_nrow), c_int(si_ncol), + # c_arr(bra), c_arr(ket), c_arr (wgt), + # c_int(len(wgt))) + #dt, dw = logger.process_clock () - t0, logger.perf_counter () - w0 + #self.dt_s, self.dw_s = self.dt_s + dt, self.dw_s + dw def _add_transpose_(self): self.rdm1s += self.rdm1s.conj ().transpose (0,1,3,2) @@ -1618,6 +1655,15 @@ def _umat_linequiv_(self, ifrag, iroot, umat, *args): si = args[0] return umat_dot_1frag_(si, umat.conj ().T, self.lroots, ifrag, iroot, axis=0) + def _orbrange_env_kwargs (self, inv): + env_kwargs = super()._orbrange_env_kwargs (inv) + _orbidx = env_kwargs['_orbidx'] + idx = np.ix_([True,]*self.nroots_si,[True,]*2,_orbidx,_orbidx) + env_kwargs['rdm1s_idx'] = idx + idx = np.ix_([True,]*self.nroots_si,[True,]*4,_orbidx,_orbidx,_orbidx,_orbidx) + env_kwargs['rdm2s_idx'] = idx + return env_kwargs + def kernel (self): ''' Main driver method of class. @@ -1775,7 +1821,9 @@ def _put_Svecs_(self, bra, ket, vecs): self.hci_fr_pabq[i][bra_r][addr,:,:,ket] += vecs[i] def _crunch_all_(self): - for row in self.exc_1c: self._loop_lroots_(self._crunch_1c_, *row) + for row in self.exc_1c: self._crunch_env_(self._crunch_1c_, *row) + + def _orbrange_env_kwargs (self, inv): return {} def _umat_linequiv_(self, ifrag, iroot, umat, *args): # TODO: is this even possible?