Skip to content

Commit

Permalink
Index down the orbital range in lassi op_o1 crunch
Browse files Browse the repository at this point in the history
So that we stop accidentally having N^8 scaling by making each of
N^4 interactions address all N^4 orbitals.
  • Loading branch information
MatthewRHermes committed Aug 2, 2024
1 parent d83e528 commit 054a204
Showing 1 changed file with 81 additions and 33 deletions.
114 changes: 81 additions & 33 deletions my_pyscf/lassi/op_o1.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ def __init__(self, ints, nlas, hopping_index, lroots, mask_bra_space=None, mask_
# buffer
self.d1 = np.zeros ([2,]+[self.norb,]*2, dtype=self.dtype)
self.d2 = np.zeros ([4,]+[self.norb,]*4, dtype=self.dtype)
self._orbidx = np.ones (self.norb, dtype=bool)

def init_profiling (self):
self.dt_1d, self.dw_1d = 0.0, 0.0
Expand Down Expand Up @@ -1107,7 +1108,9 @@ def _put_D1_(self, bra, ket, D1, *inv):

def _put_SD1_(self, bra, ket, D1, wgt):
t0, w0 = logger.process_clock (), logger.perf_counter ()
self.tdm1s[bra,ket,:] += np.multiply.outer (wgt, D1)
idx = self._orbidx
idx = np.ix_(bra,ket,[True,]*2,idx,idx)
self.tdm1s[idx] += np.multiply.outer (wgt, D1)
dt, dw = logger.process_clock () - t0, logger.perf_counter () - w0
self.dt_s, self.dw_s = self.dt_s + dt, self.dw_s + dw

Expand All @@ -1120,7 +1123,9 @@ def _put_D2_(self, bra, ket, D2, *inv):

def _put_SD2_(self, bra, ket, D2, wgt):
t0, w0 = logger.process_clock (), logger.perf_counter ()
self.tdm2s[bra,ket,:] += np.multiply.outer (wgt, D2)
idx = self._orbidx
idx = np.ix_(bra,ket,[True,]*4,idx,idx,idx,idx)
self.tdm2s[idx] += np.multiply.outer (wgt, D2)
dt, dw = logger.process_clock () - t0, logger.perf_counter () - w0
self.dt_s, self.dw_s = self.dt_s + dt, self.dw_s + dw

Expand Down Expand Up @@ -1369,11 +1374,34 @@ def _crunch_2c_(self, bra, ket, i, j, k, l, s2lt):
self.dt_2c, self.dw_2c = self.dt_2c + dt, self.dw_2c + dw
self._put_D2_(bra, ket, d2, i, j, k, l)

def _loop_lroots_(self, _crunch_fn, *row):
def _crunch_env_(self, _crunch_fn, *row):
if _crunch_fn.__name__ in ('_crunch_1c_', '_crunch_1c1d_', '_crunch_2c_'):
inv = row[2:-1]
else:
inv = row[2:]
with lib.temporary_env (self, **self._orbrange_env_kwargs (inv)):
self._loop_lroots_(_crunch_fn, row, inv)

def _orbrange_env_kwargs (self, inv):
fragidx = np.zeros (self.nfrags, dtype=bool)
_orbidx = np.zeros (self.norb, dtype=bool)
for frag in inv:
fragidx[frag] = True
p, q = self.get_range (frag)
_orbidx[p:q] = True
nlas = np.array (self.nlas)
nlas[~fragidx] = 0
norb = sum (nlas)
d1_shape = [2,] + [norb,]*2
d1_size = np.prod (d1_shape)
d1 = self.d1.ravel ()[:d1_size].reshape (d1_shape)
d2_shape = [4,] + [norb,]*4
d2_size = np.prod (d2_shape)
d2 = self.d2.ravel ()[:d2_size].reshape (d2_shape)
env_kwargs = {'nlas': nlas, 'd1': d1, 'd2': d2, '_orbidx': _orbidx}
return env_kwargs

def _loop_lroots_(self, _crunch_fn, row, inv):
self._prepare_spec_addr_ovlp_(row[0], row[1], *inv)
bra_rng = self._get_addr_range (row[0], *inv)
ket_rng = self._get_addr_range (row[1], *inv)
Expand All @@ -1382,13 +1410,13 @@ def _loop_lroots_(self, _crunch_fn, *row):
_crunch_fn (*lrow)

def _crunch_all_(self):
for row in self.exc_1d: self._loop_lroots_(self._crunch_1d_, *row)
for row in self.exc_2d: self._loop_lroots_(self._crunch_2d_, *row)
for row in self.exc_1c: self._loop_lroots_(self._crunch_1c_, *row)
for row in self.exc_1c1d: self._loop_lroots_(self._crunch_1c1d_, *row)
for row in self.exc_1s: self._loop_lroots_(self._crunch_1s_, *row)
for row in self.exc_1s1c: self._loop_lroots_(self._crunch_1s1c_, *row)
for row in self.exc_2c: self._loop_lroots_(self._crunch_2c_, *row)
for row in self.exc_1d: self._crunch_env_(self._crunch_1d_, *row)
for row in self.exc_2d: self._crunch_env_(self._crunch_2d_, *row)
for row in self.exc_1c: self._crunch_env_(self._crunch_1c_, *row)
for row in self.exc_1c1d: self._crunch_env_(self._crunch_1c1d_, *row)
for row in self.exc_1s: self._crunch_env_(self._crunch_1s_, *row)
for row in self.exc_1s1c: self._crunch_env_(self._crunch_1s1c_, *row)
for row in self.exc_2c: self._crunch_env_(self._crunch_2c_, *row)
self._add_transpose_()

def _add_transpose_(self):
Expand Down Expand Up @@ -1516,6 +1544,15 @@ def _umat_linequiv_(self, ifrag, iroot, umat, *args):
ovlp = umat_dot_1frag_(ovlp, umat, self.lroots, ifrag, iroot, axis=1)
return ovlp

def _orbrange_env_kwargs (self, inv):
env_kwargs = super()._orbrange_env_kwargs (inv)
_orbidx = env_kwargs['_orbidx']
idx = np.ix_([True,True],_orbidx,_orbidx)
env_kwargs['h1'] = np.ascontiguousarray (self.h1[idx])
idx = np.ix_(_orbidx,_orbidx,_orbidx,_orbidx)
env_kwargs['h2'] = np.ascontiguousarray (self.h2[idx])
return env_kwargs

def kernel (self):
''' Main driver method of class.
Expand Down Expand Up @@ -1584,31 +1621,31 @@ def __init__(self, ints, nlas, hopping_index, lroots, si, mask_bra_space=None,

def _put_SD1_(self, bra, ket, D1, wgt):
t0, w0 = logger.process_clock (), logger.perf_counter ()
#si_dm = self.si[bra,:] * self.si[ket,:].conj ()
#fac = np.dot (wgt, si_dm)
#self.rdm1s[:] += np.multiply.outer (fac, D1)
fn = liblassi.LASSIRDMdputSD
si_nrow, si_ncol = self.si.shape
fn (c_arr(self.rdm1s), c_arr(D1), c_int(D1.size),
c_arr(self.si), c_int(si_nrow), c_int(si_ncol),
c_arr(bra), c_arr(ket), c_arr (wgt),
c_int(len(wgt)))
dt, dw = logger.process_clock () - t0, logger.perf_counter () - w0
self.dt_s, self.dw_s = self.dt_s + dt, self.dw_s + dw
si_dm = self.si[bra,:] * self.si[ket,:].conj ()
fac = np.dot (wgt, si_dm)
self.rdm1s[self.rdm1s_idx] += np.multiply.outer (fac, D1)
#fn = liblassi.LASSIRDMdputSD
#si_nrow, si_ncol = self.si.shape
#fn (c_arr(self.rdm1s), c_arr(D1), c_int(D1.size),
# c_arr(self.si), c_int(si_nrow), c_int(si_ncol),
# c_arr(bra), c_arr(ket), c_arr (wgt),
# c_int(len(wgt)))
#dt, dw = logger.process_clock () - t0, logger.perf_counter () - w0
#self.dt_s, self.dw_s = self.dt_s + dt, self.dw_s + dw

def _put_SD2_(self, bra, ket, D2, wgt):
t0, w0 = logger.process_clock (), logger.perf_counter ()
#si_dm = self.si[bra,:] * self.si[ket,:].conj ()
#fac = np.dot (wgt, si_dm)
#self.rdm2s[:] += np.multiply.outer (fac, D2)
fn = liblassi.LASSIRDMdputSD
si_nrow, si_ncol = self.si.shape
fn (c_arr(self.rdm2s), c_arr(D2), c_int(D2.size),
c_arr(self.si), c_int(si_nrow), c_int(si_ncol),
c_arr(bra), c_arr(ket), c_arr (wgt),
c_int(len(wgt)))
dt, dw = logger.process_clock () - t0, logger.perf_counter () - w0
self.dt_s, self.dw_s = self.dt_s + dt, self.dw_s + dw
si_dm = self.si[bra,:] * self.si[ket,:].conj ()
fac = np.dot (wgt, si_dm)
self.rdm2s[self.rdm2s_idx] += np.multiply.outer (fac, D2)
#fn = liblassi.LASSIRDMdputSD
#si_nrow, si_ncol = self.si.shape
#fn (c_arr(self.rdm2s), c_arr(D2), c_int(D2.size),
# c_arr(self.si), c_int(si_nrow), c_int(si_ncol),
# c_arr(bra), c_arr(ket), c_arr (wgt),
# c_int(len(wgt)))
#dt, dw = logger.process_clock () - t0, logger.perf_counter () - w0
#self.dt_s, self.dw_s = self.dt_s + dt, self.dw_s + dw

def _add_transpose_(self):
self.rdm1s += self.rdm1s.conj ().transpose (0,1,3,2)
Expand All @@ -1618,6 +1655,15 @@ def _umat_linequiv_(self, ifrag, iroot, umat, *args):
si = args[0]
return umat_dot_1frag_(si, umat.conj ().T, self.lroots, ifrag, iroot, axis=0)

def _orbrange_env_kwargs (self, inv):
env_kwargs = super()._orbrange_env_kwargs (inv)
_orbidx = env_kwargs['_orbidx']
idx = np.ix_([True,]*self.nroots_si,[True,]*2,_orbidx,_orbidx)
env_kwargs['rdm1s_idx'] = idx
idx = np.ix_([True,]*self.nroots_si,[True,]*4,_orbidx,_orbidx,_orbidx,_orbidx)
env_kwargs['rdm2s_idx'] = idx
return env_kwargs

def kernel (self):
''' Main driver method of class.
Expand Down Expand Up @@ -1775,7 +1821,9 @@ def _put_Svecs_(self, bra, ket, vecs):
self.hci_fr_pabq[i][bra_r][addr,:,:,ket] += vecs[i]

def _crunch_all_(self):
for row in self.exc_1c: self._loop_lroots_(self._crunch_1c_, *row)
for row in self.exc_1c: self._crunch_env_(self._crunch_1c_, *row)

def _orbrange_env_kwargs (self, inv): return {}

def _umat_linequiv_(self, ifrag, iroot, umat, *args):
# TODO: is this even possible?
Expand Down

0 comments on commit 054a204

Please sign in to comment.