Skip to content

Commit

Permalink
Attempt improved C offload...
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewRHermes committed Aug 5, 2024
1 parent ff58b96 commit cd985be
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 8 deletions.
28 changes: 24 additions & 4 deletions lib/lassi/rdm.c
Original file line number Diff line number Diff line change
Expand Up @@ -69,16 +69,36 @@ void LASSIRDMdgetwgtfac (double * fac, double * wgt, double * sivec,
}
}

void LASSIRDMdputSD (double * fac, double * SDsum, int nroots, int nelem,
double * SDterm, long * idx, int nidx)
void LASSIRDMdputSD (double * fac, double * SDsum, int nroots, int nelem_sum,
int * SDsum_idx, int nSDsum_idx,
double * SDterm, int norb_term,
int * blkstart, int * blklen, int nblks)
{
const unsigned int i_one = 1;
#pragma omp parallel
{
double * mySDsum;
double * mySDterm;
int os, ot, l;
#pragma omp for schedule(static)
for (int i = 0; i < nidx; i++){
daxpy_(&nroots, &(SDterm[i]), fac, &i_one, &(SDsum[idx[i]]), &nelem);
for (int iidx = 0; iidx < nSDsum_idx; iidx++){
mySDterm = SDterm + (iidx*norb_term);
for (int iroot = 0; iroot < nroots; iroot++){
mySDsum = SDsum + (iroot*nelem_sum) + SDsum_idx[iidx];
ot = 0;
for (int iblk = 0; iblk < nblks; iblk++){
os = blkstart[iblk];
l = blklen[iblk];
daxpy_(&blklen[iblk], &(fac[iroot]),
&(mySDterm[ot]), &i_one,
&(mySDsum[os]), &i_one);
ot += l;
}
}
}
//for (int i = 0; i < nidx; i++){
// daxpy_(&nroots, &(SDterm[i]), fac, &i_one, &(SDsum[idx[i]]), &nelem);
//}

}
}
25 changes: 21 additions & 4 deletions my_pyscf/lassi/op_o1.py
Original file line number Diff line number Diff line change
Expand Up @@ -1638,7 +1638,9 @@ def _put_SD1_(self, bra, ket, D1, wgt):
fac = self.get_wgt_fac (bra, ket, wgt)
fn = liblassi.LASSIRDMdputSD
fn (c_arr (fac), self._rdm1s_c, self._si_c_ncol, self._rdm1s_c_ncol,
c_arr (D1), self._rdm1s_idx_c, self._rdm1s_idx_c_len)
self._rdm1s_idx_c, self._rdm1s_idx_c_len,
c_arr (D1), self._dm_norb_c,
self._dm_blkstart_c, self._dm_blklen_c, self._dm_nblks_c)
#self.rdm1s[self.rdm1s_idx] += np.multiply.outer (fac, D1)
dt, dw = logger.process_clock () - t0, logger.perf_counter () - w0
self.dt_s, self.dw_s = self.dt_s + dt, self.dw_s + dw
Expand All @@ -1648,7 +1650,9 @@ def _put_SD2_(self, bra, ket, D2, wgt):
fac = self.get_wgt_fac (bra, ket, wgt)
fn = liblassi.LASSIRDMdputSD
fn (c_arr (fac), self._rdm2s_c, self._si_c_ncol, self._rdm2s_c_ncol,
c_arr (D2), self._rdm2s_idx_c, self._rdm2s_idx_c_len)
self._rdm2s_idx_c, self._rdm2s_idx_c_len,
c_arr (D2), self._dm_norb_c,
self._dm_blkstart_c, self._dm_blklen_c, self._dm_nblks_c)
#self.rdm2s[self.rdm2s_idx] += np.multiply.outer (fac, D2)
dt, dw = logger.process_clock () - t0, logger.perf_counter () - w0
self.dt_s, self.dw_s = self.dt_s + dt, self.dw_s + dw
Expand All @@ -1664,16 +1668,29 @@ def _umat_linequiv_(self, ifrag, iroot, umat, *args):
def _orbrange_env_kwargs (self, inv):
env_kwargs = super()._orbrange_env_kwargs (inv)
_orbidx = env_kwargs['_orbidx']
_blkstart = [i for i in range (self.norb)
if (_orbidx[i] and (i==0 or (not _orbidx[i-1])))]
_blkend = [i+1 for i in range (self.norb)
if (_orbidx[i] and ((i+1==self.norb) or (not _orbidx[i+1])))]
_blklen = [end - start for end, start in zip (_blkend, _blkstart)]
env_kwargs['_dm_blkstart_c'] = c_arr (np.array (_blkstart, dtype=np.int32) - _blkstart[0])
env_kwargs['_dm_blklen_c'] = c_arr (np.array (_blklen, dtype=np.int32))
env_kwargs['_dm_nblks_c'] = c_int (len (_blkstart))
env_kwargs['_dm_norb_c'] = c_int (np.count_nonzero (_orbidx))
idx = np.ix_([True,]*2,_orbidx,_orbidx)
mask = np.zeros ((2,self.norb,self.norb), dtype=bool)
mask[idx] = True
idx = np.where (mask.ravel ())[0]
if _blkstart[0] > 0: mask[...,:_blkstart[0]] = False
if _blkstart[0] < self.norb-1: mask[...,_blkstart[0]+1:] = False
idx = np.array (np.where (mask.ravel ())[0], dtype=np.int32)
env_kwargs['_rdm1s_idx_c'] = c_arr (idx)
env_kwargs['_rdm1s_idx_c_len'] = c_int (len (idx))
idx = np.ix_([True,]*4,_orbidx,_orbidx,_orbidx,_orbidx)
mask = np.zeros ((4,self.norb,self.norb,self.norb,self.norb), dtype=bool)
mask[idx] = True
idx = np.where (mask.ravel ())[0]
if _blkstart[0] > 0: mask[...,:_blkstart[0]] = False
if _blkstart[0] < self.norb-1: mask[...,_blkstart[0]+1:] = False
idx = np.array (np.where (mask.ravel ())[0], dtype=np.int32)
env_kwargs['_rdm2s_idx_c'] = c_arr (idx)
env_kwargs['_rdm2s_idx_c_len'] = c_int (len (idx))
return env_kwargs
Expand Down

0 comments on commit cd985be

Please sign in to comment.