Skip to content

Commit

Permalink
lassi rdm refactor put
Browse files Browse the repository at this point in the history
Split into two steps and do the awkward discontiguous put only
after summing over lroots
  • Loading branch information
MatthewRHermes committed Aug 7, 2024
1 parent 030b607 commit dd2f846
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 26 deletions.
30 changes: 26 additions & 4 deletions lib/lassi/rdm.c
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void LASSIRDMdgetwgtfac (double * fac, double * wgt, double * sivec,
}
}

void LASSIRDMdputSD (double * SDdest, double * fac, double * SDsrc,
void LASSIRDMdsumSD (double * SDdest, double * fac, double * SDsrc,
int nroots, int nelem_dest,
int * SDdest_idx, int * SDsrc_idx, int * SDlen,
int nidx)
Expand All @@ -89,9 +89,31 @@ const unsigned int i_one = 1;
mySDdest+iroot*nelem_dest, &i_one);
}
}
//for (int i = 0; i < nidx; i++){
// daxpy_(&nroots, &(SDterm[i]), fac, &i_one, &(SDsum[idx[i]]), &nelem);
//}

}
}

void LASSIRDMdputSD (double * SDdest, double * SDsrc,
int nroots, int nelem_dest, int nelem_src,
int * SDdest_idx, int * SDsrc_idx, int * SDlen,
int nidx)
{
const unsigned int i_one = 1;
const double d_one = 1.0;
#pragma omp parallel
{
double * mySDsrc;
double * mySDdest;
#pragma omp for
for (int iidx = 0; iidx < nidx; iidx++){
mySDsrc = SDsrc + SDsrc_idx[iidx];
mySDdest = SDdest + SDdest_idx[iidx];
for (int iroot = 0; iroot < nroots; iroot++){
daxpy_(SDlen+iidx, &d_one,
mySDsrc+iroot*nelem_src, &i_one,
mySDdest+iroot*nelem_dest, &i_one);
}
}

}
}
112 changes: 90 additions & 22 deletions my_pyscf/lassi/op_o1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1384,8 +1384,12 @@ def _crunch_env_(self, _crunch_fn, *row):
inv = row[2:]
with lib.temporary_env (self, **self._orbrange_env_kwargs (inv)):
self._loop_lroots_(_crunch_fn, row, inv)
self._finalize_crunch_env_(_crunch_fn, row, inv)

def _finalize_crunch_env_(self, _crunch_fn, row, inv): pass

def _orbrange_env_kwargs (self, inv):
t0, w0 = logger.process_clock (), logger.perf_counter ()
fragidx = np.zeros (self.nfrags, dtype=bool)
_orbidx = np.zeros (self.norb, dtype=bool)
for frag in inv:
Expand All @@ -1404,6 +1408,8 @@ def _orbrange_env_kwargs (self, inv):
d2_size = np.prod (d2_shape)
d2 = self.d2.ravel ()[:d2_size].reshape (d2_shape)
env_kwargs = {'nlas': nlas, 'd1': d1, 'd2': d2, '_orbidx': _orbidx}
dt, dw = logger.process_clock () - t0, logger.perf_counter () - w0
self.dt_i, self.dw_i = self.dt_i + dt, self.dw_i + dw
return env_kwargs

def _loop_lroots_(self, _crunch_fn, row, inv):
Expand Down Expand Up @@ -1551,11 +1557,14 @@ def _umat_linequiv_(self, ifrag, iroot, umat, *args):

def _orbrange_env_kwargs (self, inv):
env_kwargs = super()._orbrange_env_kwargs (inv)
t0, w0 = logger.process_clock (), logger.perf_counter ()
_orbidx = env_kwargs['_orbidx']
idx = np.ix_([True,True],_orbidx,_orbidx)
env_kwargs['h1'] = np.ascontiguousarray (self.h1[idx])
idx = np.ix_(_orbidx,_orbidx,_orbidx,_orbidx)
env_kwargs['h2'] = np.ascontiguousarray (self.h2[idx])
dt, dw = logger.process_clock () - t0, logger.perf_counter () - w0
self.dt_i, self.dw_i = self.dt_i + dt, self.dw_i + dw
return env_kwargs

def kernel (self):
Expand Down Expand Up @@ -1611,10 +1620,12 @@ def get_contig_blks (mask):

def split_contig_array (arrlen, nthreads):
'''Divide a contiguous array into chunks to be handled by each thread'''
blklen = (arrlen + (nthreads-1)) // nthreads;
blklen, rem = divmod (arrlen, nthreads);
blklen = np.array ([blklen,]*nthreads)
blklen[-1] = arrlen - blklen[:-1].sum ()
blkstart = np.cumsum (np.append ([0],blklen[:-1]))
blklen[:rem] += 1
blkstart = np.cumsum (blklen)
assert (blkstart[-1] == arrlen), '{}'.format (blklen)
blkstart -= blklen
return blkstart, blklen

class LRRDMint (LSTDMint2):
Expand Down Expand Up @@ -1645,6 +1656,10 @@ def __init__(self, ints, nlas, hopping_index, lroots, si, mask_bra_space=None,
self._si_c = c_arr (self.si)
self._si_c_nrow = c_int (self.si.shape[0])
self._si_c_ncol = c_int (self.si.shape[1])
self.d1buf = np.empty ((self.nroots_si,self.d1.size), dtype=self.d1.dtype)
self.d2buf = np.empty ((self.nroots_si,self.d2.size), dtype=self.d2.dtype)
self._d1buf_c = c_arr (self.d1buf)
self._d2buf_c = c_arr (self.d2buf)

def get_wgt_fac (self, bra, ket, wgt):
#si_dm = self.si[bra,:] * self.si[ket,:].conj ()
Expand All @@ -1658,28 +1673,28 @@ def get_wgt_fac (self, bra, ket, wgt):
def _put_SD1_(self, bra, ket, D1, wgt):
t0, w0 = logger.process_clock (), logger.perf_counter ()
fac = self.get_wgt_fac (bra, ket, wgt)
fn = liblassi.LASSIRDMdputSD
fn (self._rdm1s_c, c_arr (fac), c_arr (D1),
self._si_c_ncol, self._rdm1s_c_ncol,
self._rdm1s_dblk_idx,
self._rdm1s_sblk_idx,
self._rdm1s_lblk,
self._rdm1s_nblk)
#self.rdm1s[self.rdm1s_idx] += np.multiply.outer (fac, D1)
#self.d1buf += np.multiply.outer (fac, D1)
fn = liblassi.LASSIRDMdsumSD
fn (self._d1buf_c, c_arr (fac), c_arr (D1),
self._si_c_ncol, self._d1buf_ncol,
self._d1buf_dblk_idx,
self._d1buf_sblk_idx,
self._d1buf_lblk,
self._d1buf_nblk)
dt, dw = logger.process_clock () - t0, logger.perf_counter () - w0
self.dt_s, self.dw_s = self.dt_s + dt, self.dw_s + dw

def _put_SD2_(self, bra, ket, D2, wgt):
t0, w0 = logger.process_clock (), logger.perf_counter ()
fac = self.get_wgt_fac (bra, ket, wgt)
fn = liblassi.LASSIRDMdputSD
fn (self._rdm2s_c, c_arr (fac), c_arr (D2),
self._si_c_ncol, self._rdm2s_c_ncol,
self._rdm2s_dblk_idx,
self._rdm2s_sblk_idx,
self._rdm2s_lblk,
self._rdm2s_nblk)
#self.rdm2s[self.rdm2s_idx] += np.multiply.outer (fac, D2)
#self.d2buf += np.multiply.outer (fac, D2)
fn = liblassi.LASSIRDMdsumSD
fn (self._d2buf_c, c_arr (fac), c_arr (D2),
self._si_c_ncol, self._d2buf_ncol,
self._d2buf_dblk_idx,
self._d2buf_sblk_idx,
self._d2buf_lblk,
self._d2buf_nblk)
dt, dw = logger.process_clock () - t0, logger.perf_counter () - w0
self.dt_s, self.dw_s = self.dt_s + dt, self.dw_s + dw

Expand All @@ -1691,13 +1706,64 @@ def _umat_linequiv_(self, ifrag, iroot, umat, *args):
si = args[0]
return umat_dot_1frag_(si, umat.conj ().T, self.lroots, ifrag, iroot, axis=0)

def _crunch_env_(self, _crunch_fn, *row):
self.d1buf[:] = 0
self.d2buf[:] = 0
super()._crunch_env_(_crunch_fn, *row)

def _finalize_crunch_env_(self, _crunch_fn, row, inv):
t0, w0 = logger.process_clock (), logger.perf_counter ()
fn = liblassi.LASSIRDMdputSD
if len (inv) < 3:
fn (self._rdm1s_c, self._d1buf_c,
self._si_c_ncol, self._rdm1s_c_ncol, self._d1buf_ncol,
self._rdm1s_dblk_idx,
self._rdm1s_sblk_idx,
self._rdm1s_lblk,
self._rdm1s_nblk)
fn (self._rdm2s_c, self._d2buf_c,
self._si_c_ncol, self._rdm2s_c_ncol, self._d2buf_ncol,
self._rdm2s_dblk_idx,
self._rdm2s_sblk_idx,
self._rdm2s_lblk,
self._rdm2s_nblk)
dt, dw = logger.process_clock () - t0, logger.perf_counter () - w0
self.dt_s, self.dw_s = self.dt_s + dt, self.dw_s + dw
self.dt_p, self.dw_s = self.dt_p + dt, self.dw_s + dw

def _orbrange_env_kwargs (self, inv):
env_kwargs = super()._orbrange_env_kwargs (inv)
t0, w0 = logger.process_clock (), logger.perf_counter ()
_orbidx = env_kwargs['_orbidx']
ndest = self.norb
nsrc = np.count_nonzero (_orbidx)
nthreads = lib.num_threads ()
# buffer, always contiguous arrays
if len (inv) < 3: # Otherwise this won't be touched anyway
d1_shape = [self.nroots_si, 2] + [nsrc,]*2
d1_size = np.prod (d1_shape)
d1buf = self.d1buf.ravel ()[:d1_size].reshape (d1_shape)
d1_col = 2*(nsrc**2)
dblk1, lblk1 = split_contig_array (d1_col, nthreads)
env_kwargs['d1buf'] = d1buf
env_kwargs['_d1buf_dblk_idx'] = c_arr (dblk1.astype (np.int32))
env_kwargs['_d1buf_sblk_idx'] = c_arr (dblk1.astype (np.int32))
env_kwargs['_d1buf_lblk'] = c_arr (lblk1.astype (np.int32))
env_kwargs['_d1buf_nblk'] = c_int (len (lblk1))
env_kwargs['_d1buf_ncol'] = c_int (d1_col)
d2_shape = [self.nroots_si, 4,] + [nsrc,]*4
d2_size = np.prod (d2_shape)
d2buf = self.d2buf.ravel ()[:d2_size].reshape (d2_shape)
d2_col = 4*(nsrc**4)
dblk2, lblk2 = split_contig_array (d2_col, nthreads)
env_kwargs['d2buf'] = d2buf
env_kwargs['_d2buf_dblk_idx'] = c_arr (dblk2.astype (np.int32))
env_kwargs['_d2buf_sblk_idx'] = c_arr (dblk2.astype (np.int32))
env_kwargs['_d2buf_lblk'] = c_arr (lblk2.astype (np.int32))
env_kwargs['_d2buf_nblk'] = c_int (len (lblk2))
env_kwargs['_d2buf_ncol'] = c_int (d2_col)
# final, generally discontiguous arrays
if nsrc==ndest:
nthreads = lib.num_threads ()
dblk1, lblk1 = split_contig_array (2*(self.norb**2),nthreads)
dblk2, lblk2 = split_contig_array (4*(self.norb**4),nthreads)
sblk1, sblk2 = dblk1, dblk2
Expand Down Expand Up @@ -1728,6 +1794,8 @@ def _orbrange_env_kwargs (self, inv):
env_kwargs['_rdm2s_sblk_idx'] = c_arr (sblk2.astype (np.int32))
env_kwargs['_rdm2s_lblk'] = c_arr (lblk2.astype (np.int32))
env_kwargs['_rdm2s_nblk'] = c_int (len (lblk2))
dt, dw = logger.process_clock () - t0, logger.perf_counter () - w0
self.dt_i, self.dw_i = self.dt_i + dt, self.dw_i + dw
return env_kwargs

def kernel (self):
Expand All @@ -1746,9 +1814,9 @@ def kernel (self):
self.rdm1s = np.zeros ([self.nroots_si,2] + [self.norb,]*2, dtype=self.dtype)
self.rdm2s = np.zeros ([self.nroots_si,4] + [self.norb,]*4, dtype=self.dtype)
self._rdm1s_c = c_arr (self.rdm1s)
self._rdm1s_c_ncol = c_int (self.rdm1s.size // self.nroots_si)
self._rdm1s_c_ncol = c_int (2*(self.norb**2))
self._rdm2s_c = c_arr (self.rdm2s)
self._rdm2s_c_ncol = c_int (self.rdm2s.size // self.nroots_si)
self._rdm2s_c_ncol = c_int (4*(self.norb**4))
self._crunch_all_()
return self.rdm1s, self.rdm2s, t0

Expand Down

0 comments on commit dd2f846

Please sign in to comment.