From dd2f84656acf494bd06a8b9947eb85bc04c0ad76 Mon Sep 17 00:00:00 2001 From: Matthew R Hermes Date: Tue, 6 Aug 2024 23:07:48 -0500 Subject: [PATCH] lassi rdm refactor put Split into two steps and do the awkward discontiguous put only after summing over lroots --- lib/lassi/rdm.c | 30 +++++++++-- my_pyscf/lassi/op_o1.py | 112 ++++++++++++++++++++++++++++++++-------- 2 files changed, 116 insertions(+), 26 deletions(-) diff --git a/lib/lassi/rdm.c b/lib/lassi/rdm.c index 535d819e..a1e99fe2 100644 --- a/lib/lassi/rdm.c +++ b/lib/lassi/rdm.c @@ -69,7 +69,7 @@ void LASSIRDMdgetwgtfac (double * fac, double * wgt, double * sivec, } } -void LASSIRDMdputSD (double * SDdest, double * fac, double * SDsrc, +void LASSIRDMdsumSD (double * SDdest, double * fac, double * SDsrc, int nroots, int nelem_dest, int * SDdest_idx, int * SDsrc_idx, int * SDlen, int nidx) @@ -89,9 +89,31 @@ const unsigned int i_one = 1; mySDdest+iroot*nelem_dest, &i_one); } } - //for (int i = 0; i < nidx; i++){ - // daxpy_(&nroots, &(SDterm[i]), fac, &i_one, &(SDsum[idx[i]]), &nelem); - //} + +} +} + +void LASSIRDMdputSD (double * SDdest, double * SDsrc, + int nroots, int nelem_dest, int nelem_src, + int * SDdest_idx, int * SDsrc_idx, int * SDlen, + int nidx) +{ +const unsigned int i_one = 1; +const double d_one = 1.0; +#pragma omp parallel +{ + double * mySDsrc; + double * mySDdest; + #pragma omp for + for (int iidx = 0; iidx < nidx; iidx++){ + mySDsrc = SDsrc + SDsrc_idx[iidx]; + mySDdest = SDdest + SDdest_idx[iidx]; + for (int iroot = 0; iroot < nroots; iroot++){ + daxpy_(SDlen+iidx, &d_one, + mySDsrc+iroot*nelem_src, &i_one, + mySDdest+iroot*nelem_dest, &i_one); + } + } } } diff --git a/my_pyscf/lassi/op_o1.py b/my_pyscf/lassi/op_o1.py index 538f382e..34d19f8c 100644 --- a/my_pyscf/lassi/op_o1.py +++ b/my_pyscf/lassi/op_o1.py @@ -1384,8 +1384,12 @@ def _crunch_env_(self, _crunch_fn, *row): inv = row[2:] with lib.temporary_env (self, **self._orbrange_env_kwargs (inv)): self._loop_lroots_(_crunch_fn, row, inv) + self._finalize_crunch_env_(_crunch_fn, row, inv) + + def _finalize_crunch_env_(self, _crunch_fn, row, inv): pass def _orbrange_env_kwargs (self, inv): + t0, w0 = logger.process_clock (), logger.perf_counter () fragidx = np.zeros (self.nfrags, dtype=bool) _orbidx = np.zeros (self.norb, dtype=bool) for frag in inv: @@ -1404,6 +1408,8 @@ def _orbrange_env_kwargs (self, inv): d2_size = np.prod (d2_shape) d2 = self.d2.ravel ()[:d2_size].reshape (d2_shape) env_kwargs = {'nlas': nlas, 'd1': d1, 'd2': d2, '_orbidx': _orbidx} + dt, dw = logger.process_clock () - t0, logger.perf_counter () - w0 + self.dt_i, self.dw_i = self.dt_i + dt, self.dw_i + dw return env_kwargs def _loop_lroots_(self, _crunch_fn, row, inv): @@ -1551,11 +1557,14 @@ def _umat_linequiv_(self, ifrag, iroot, umat, *args): def _orbrange_env_kwargs (self, inv): env_kwargs = super()._orbrange_env_kwargs (inv) + t0, w0 = logger.process_clock (), logger.perf_counter () _orbidx = env_kwargs['_orbidx'] idx = np.ix_([True,True],_orbidx,_orbidx) env_kwargs['h1'] = np.ascontiguousarray (self.h1[idx]) idx = np.ix_(_orbidx,_orbidx,_orbidx,_orbidx) env_kwargs['h2'] = np.ascontiguousarray (self.h2[idx]) + dt, dw = logger.process_clock () - t0, logger.perf_counter () - w0 + self.dt_i, self.dw_i = self.dt_i + dt, self.dw_i + dw return env_kwargs def kernel (self): @@ -1611,10 +1620,12 @@ def get_contig_blks (mask): def split_contig_array (arrlen, nthreads): '''Divide a contiguous array into chunks to be handled by each thread''' - blklen = (arrlen + (nthreads-1)) // nthreads; + blklen, rem = divmod (arrlen, nthreads); blklen = np.array ([blklen,]*nthreads) - blklen[-1] = arrlen - blklen[:-1].sum () - blkstart = np.cumsum (np.append ([0],blklen[:-1])) + blklen[:rem] += 1 + blkstart = np.cumsum (blklen) + assert (blkstart[-1] == arrlen), '{}'.format (blklen) + blkstart -= blklen return blkstart, blklen class LRRDMint (LSTDMint2): @@ -1645,6 +1656,10 @@ def __init__(self, ints, nlas, hopping_index, lroots, si, mask_bra_space=None, self._si_c = c_arr (self.si) self._si_c_nrow = c_int (self.si.shape[0]) self._si_c_ncol = c_int (self.si.shape[1]) + self.d1buf = np.empty ((self.nroots_si,self.d1.size), dtype=self.d1.dtype) + self.d2buf = np.empty ((self.nroots_si,self.d2.size), dtype=self.d2.dtype) + self._d1buf_c = c_arr (self.d1buf) + self._d2buf_c = c_arr (self.d2buf) def get_wgt_fac (self, bra, ket, wgt): #si_dm = self.si[bra,:] * self.si[ket,:].conj () @@ -1658,28 +1673,28 @@ def get_wgt_fac (self, bra, ket, wgt): def _put_SD1_(self, bra, ket, D1, wgt): t0, w0 = logger.process_clock (), logger.perf_counter () fac = self.get_wgt_fac (bra, ket, wgt) - fn = liblassi.LASSIRDMdputSD - fn (self._rdm1s_c, c_arr (fac), c_arr (D1), - self._si_c_ncol, self._rdm1s_c_ncol, - self._rdm1s_dblk_idx, - self._rdm1s_sblk_idx, - self._rdm1s_lblk, - self._rdm1s_nblk) - #self.rdm1s[self.rdm1s_idx] += np.multiply.outer (fac, D1) + #self.d1buf += np.multiply.outer (fac, D1) + fn = liblassi.LASSIRDMdsumSD + fn (self._d1buf_c, c_arr (fac), c_arr (D1), + self._si_c_ncol, self._d1buf_ncol, + self._d1buf_dblk_idx, + self._d1buf_sblk_idx, + self._d1buf_lblk, + self._d1buf_nblk) dt, dw = logger.process_clock () - t0, logger.perf_counter () - w0 self.dt_s, self.dw_s = self.dt_s + dt, self.dw_s + dw def _put_SD2_(self, bra, ket, D2, wgt): t0, w0 = logger.process_clock (), logger.perf_counter () fac = self.get_wgt_fac (bra, ket, wgt) - fn = liblassi.LASSIRDMdputSD - fn (self._rdm2s_c, c_arr (fac), c_arr (D2), - self._si_c_ncol, self._rdm2s_c_ncol, - self._rdm2s_dblk_idx, - self._rdm2s_sblk_idx, - self._rdm2s_lblk, - self._rdm2s_nblk) - #self.rdm2s[self.rdm2s_idx] += np.multiply.outer (fac, D2) + #self.d2buf += np.multiply.outer (fac, D2) + fn = liblassi.LASSIRDMdsumSD + fn (self._d2buf_c, c_arr (fac), c_arr (D2), + self._si_c_ncol, self._d2buf_ncol, + self._d2buf_dblk_idx, + self._d2buf_sblk_idx, + self._d2buf_lblk, + self._d2buf_nblk) dt, dw = logger.process_clock () - t0, logger.perf_counter () - w0 self.dt_s, self.dw_s = self.dt_s + dt, self.dw_s + dw @@ -1691,13 +1706,64 @@ def _umat_linequiv_(self, ifrag, iroot, umat, *args): si = args[0] return umat_dot_1frag_(si, umat.conj ().T, self.lroots, ifrag, iroot, axis=0) + def _crunch_env_(self, _crunch_fn, *row): + self.d1buf[:] = 0 + self.d2buf[:] = 0 + super()._crunch_env_(_crunch_fn, *row) + + def _finalize_crunch_env_(self, _crunch_fn, row, inv): + t0, w0 = logger.process_clock (), logger.perf_counter () + fn = liblassi.LASSIRDMdputSD + if len (inv) < 3: + fn (self._rdm1s_c, self._d1buf_c, + self._si_c_ncol, self._rdm1s_c_ncol, self._d1buf_ncol, + self._rdm1s_dblk_idx, + self._rdm1s_sblk_idx, + self._rdm1s_lblk, + self._rdm1s_nblk) + fn (self._rdm2s_c, self._d2buf_c, + self._si_c_ncol, self._rdm2s_c_ncol, self._d2buf_ncol, + self._rdm2s_dblk_idx, + self._rdm2s_sblk_idx, + self._rdm2s_lblk, + self._rdm2s_nblk) + dt, dw = logger.process_clock () - t0, logger.perf_counter () - w0 + self.dt_s, self.dw_s = self.dt_s + dt, self.dw_s + dw + self.dt_p, self.dw_s = self.dt_p + dt, self.dw_s + dw + def _orbrange_env_kwargs (self, inv): env_kwargs = super()._orbrange_env_kwargs (inv) + t0, w0 = logger.process_clock (), logger.perf_counter () _orbidx = env_kwargs['_orbidx'] ndest = self.norb nsrc = np.count_nonzero (_orbidx) + nthreads = lib.num_threads () + # buffer, always contiguous arrays + if len (inv) < 3: # Otherwise this won't be touched anyway + d1_shape = [self.nroots_si, 2] + [nsrc,]*2 + d1_size = np.prod (d1_shape) + d1buf = self.d1buf.ravel ()[:d1_size].reshape (d1_shape) + d1_col = 2*(nsrc**2) + dblk1, lblk1 = split_contig_array (d1_col, nthreads) + env_kwargs['d1buf'] = d1buf + env_kwargs['_d1buf_dblk_idx'] = c_arr (dblk1.astype (np.int32)) + env_kwargs['_d1buf_sblk_idx'] = c_arr (dblk1.astype (np.int32)) + env_kwargs['_d1buf_lblk'] = c_arr (lblk1.astype (np.int32)) + env_kwargs['_d1buf_nblk'] = c_int (len (lblk1)) + env_kwargs['_d1buf_ncol'] = c_int (d1_col) + d2_shape = [self.nroots_si, 4,] + [nsrc,]*4 + d2_size = np.prod (d2_shape) + d2buf = self.d2buf.ravel ()[:d2_size].reshape (d2_shape) + d2_col = 4*(nsrc**4) + dblk2, lblk2 = split_contig_array (d2_col, nthreads) + env_kwargs['d2buf'] = d2buf + env_kwargs['_d2buf_dblk_idx'] = c_arr (dblk2.astype (np.int32)) + env_kwargs['_d2buf_sblk_idx'] = c_arr (dblk2.astype (np.int32)) + env_kwargs['_d2buf_lblk'] = c_arr (lblk2.astype (np.int32)) + env_kwargs['_d2buf_nblk'] = c_int (len (lblk2)) + env_kwargs['_d2buf_ncol'] = c_int (d2_col) + # final, generally discontiguous arrays if nsrc==ndest: - nthreads = lib.num_threads () dblk1, lblk1 = split_contig_array (2*(self.norb**2),nthreads) dblk2, lblk2 = split_contig_array (4*(self.norb**4),nthreads) sblk1, sblk2 = dblk1, dblk2 @@ -1728,6 +1794,8 @@ def _orbrange_env_kwargs (self, inv): env_kwargs['_rdm2s_sblk_idx'] = c_arr (sblk2.astype (np.int32)) env_kwargs['_rdm2s_lblk'] = c_arr (lblk2.astype (np.int32)) env_kwargs['_rdm2s_nblk'] = c_int (len (lblk2)) + dt, dw = logger.process_clock () - t0, logger.perf_counter () - w0 + self.dt_i, self.dw_i = self.dt_i + dt, self.dw_i + dw return env_kwargs def kernel (self): @@ -1746,9 +1814,9 @@ def kernel (self): self.rdm1s = np.zeros ([self.nroots_si,2] + [self.norb,]*2, dtype=self.dtype) self.rdm2s = np.zeros ([self.nroots_si,4] + [self.norb,]*4, dtype=self.dtype) self._rdm1s_c = c_arr (self.rdm1s) - self._rdm1s_c_ncol = c_int (self.rdm1s.size // self.nroots_si) + self._rdm1s_c_ncol = c_int (2*(self.norb**2)) self._rdm2s_c = c_arr (self.rdm2s) - self._rdm2s_c_ncol = c_int (self.rdm2s.size // self.nroots_si) + self._rdm2s_c_ncol = c_int (4*(self.norb**4)) self._crunch_all_() return self.rdm1s, self.rdm2s, t0