diff --git a/my_pyscf/lassi/citools.py b/my_pyscf/lassi/citools.py index e399e758..f9ac4acb 100644 --- a/my_pyscf/lassi/citools.py +++ b/my_pyscf/lassi/citools.py @@ -1,4 +1,9 @@ import numpy as np +import functools +from pyscf.scf.addons import canonical_orth_ +from pyscf import __config__ + +LINDEP_THRESH = getattr (__config__, 'lassi_lindep_thresh', 1.0e-5) def get_lroots (ci): '''Generate a table showing the number of states contained in a (optionally nested) list @@ -114,3 +119,66 @@ def _umat_dot_1frag (target, umat, lroots, ifrag): old_shape2 = list (old_shape1) old_shape2[0] = old_shape2[0] * ncol2 // ncol1 return target.reshape (*old_shape2) + +def get_orth_basis (ci_fr, norb_f, nelec_frs, _get_ovlp=None): + if _get_ovlp is None: + from mrh.my_pyscf.lassi.op_o0 import get_ovlp + _get_ovlp = functools.partial (get_ovlp, ci_fr, norb_f, nelec_frs) + nfrags, nroots = nelec_frs.shape[:2] + unique, uniq_idx, inverse, cnts = np.unique (nelec_frs, axis=1, return_index=True, + return_inverse=True, return_counts=True) + if not np.count_nonzero (cnts>1): + def raw2orth (rawarr): + return rawarr + def orth2raw (ortharr): + return ortharr + return raw2orth, orth2raw + lroots_fr = np.array ([[1 if c.ndim<3 else c.shape[0] + for c in ci_r] + for ci_r in ci_fr]) + nprods_r = np.prod (lroots_fr, axis=0) + offs1 = np.cumsum (nprods_r) + offs0 = offs1 - nprods_r + uniq_prod_idx = [] + for i in uniq_idx[cnts==1]: uniq_prod_idx.extend (list(range(offs0[i],offs1[i]))) + manifolds_prod_idx = [] + manifolds_xmat = [] + nuniq_prod = north = len (uniq_prod_idx) + for manifold_idx in np.where (cnts>1)[0]: + manifold = np.where (inverse==manifold_idx)[0] + manifold_prod_idx = [] + for i in manifold: manifold_prod_idx.extend (list(range(offs0[i],offs1[i]))) + manifolds_prod_idx.append (manifold_prod_idx) + ovlp = _get_ovlp (rootidx=manifold) + xmat = canonical_orth_(ovlp, thr=LINDEP_THRESH) + north += xmat.shape[1] + manifolds_xmat.append (xmat) + + nraw = offs1[-1] + def raw2orth (rawarr): + col_shape = rawarr.shape[1:] + orth_shape = [north,] + list (col_shape) + ortharr = np.zeros (orth_shape, dtype=rawarr.dtype) + ortharr[:nuniq_prod] = rawarr[uniq_prod_idx] + i = nuniq_prod + for prod_idx, xmat in zip (manifolds_prod_idx, manifolds_xmat): + j = i + xmat.shape[1] + ortharr[i:j] = np.tensordot (xmat.T, rawarr[prod_idx], axes=1) + i = j + return ortharr + + def orth2raw (ortharr): + col_shape = ortharr.shape[1:] + raw_shape = [nraw,] + list (col_shape) + rawarr = np.zeros (raw_shape, dtype=ortharr.dtype) + rawarr[uniq_prod_idx] = ortharr[:nuniq_prod] + i = nuniq_prod + for prod_idx, xmat in zip (manifolds_prod_idx, manifolds_xmat): + j = i + xmat.shape[1] + rawarr[prod_idx] = np.tensordot (xmat.conj (), ortharr[i:j], axes=1) + i = j + return rawarr + + return raw2orth, orth2raw + + diff --git a/my_pyscf/lassi/lassi.py b/my_pyscf/lassi/lassi.py index bc628a7e..60632a5c 100644 --- a/my_pyscf/lassi/lassi.py +++ b/my_pyscf/lassi/lassi.py @@ -4,6 +4,7 @@ from mrh.my_pyscf.lassi import op_o0 from mrh.my_pyscf.lassi import op_o1 from mrh.my_pyscf.lassi import chkfile +from mrh.my_pyscf.lassi import citools from mrh.my_pyscf.lassi.citools import get_lroots from pyscf import lib, symm, ao2mo from pyscf.scf.addons import canonical_orth_ @@ -429,7 +430,7 @@ def _eig_block (las, e0, h1, h2, ci_blk, nelec_blk, rootsym, soc, orbsym, wfnsym lc = 'checking if LASSI basis has lindeps: |ovlp| = {:.6e}'.format (ovlp_det) lib.logger.info (las, 'Caught error %s, %s', str (err), lc) if ovlp_det < LINDEP_THRESH: - raw2orth, orth2raw = op_o0.get_orth_basis (ci_blk, las.ncas_sub, nelec_blk) + raw2orth, orth2raw = citools.get_orth_basis (ci_blk, las.ncas_sub, nelec_blk) xhx = raw2orth (ham_blk).conj ().T #x = canonical_orth_(ovlp_blk, thr=LINDEP_THRESH) lib.logger.info (las, '%d/%d linearly independent model states', diff --git a/my_pyscf/lassi/op_o0.py b/my_pyscf/lassi/op_o0.py index 7cd0c434..c3ae5f51 100644 --- a/my_pyscf/lassi/op_o0.py +++ b/my_pyscf/lassi/op_o0.py @@ -6,15 +6,11 @@ from pyscf.fci.direct_nosym import contract_1e as contract_1e_nosym from pyscf.fci.direct_spin1 import _unpack_nelec from pyscf.fci.spin_op import contract_ss, spin_square -from pyscf.scf.addons import canonical_orth_ from pyscf.data import nist from itertools import combinations from mrh.my_pyscf.mcscf import soc_int as soc_int from mrh.my_pyscf.lassi import dms as lassi_dms from mrh.my_pyscf.fci.csf import unpack_h1e_cs -from pyscf import __config__ - -LINDEP_THRESH = getattr (__config__, 'lassi_lindep_thresh', 1.0e-5) def memcheck (las, ci, soc=None): '''Check if the system has enough memory to run these functions! ONLY checks @@ -440,8 +436,6 @@ def get_orth_basis (ci_fr, norb_f, nelec_frs, _get_ovlp=get_ovlp): unique, uniq_idx, inverse, cnts = np.unique (nelec_frs, axis=1, return_index=True, return_inverse=True, return_counts=True) if not np.count_nonzero (cnts>1): - print ("escape") - print (cnts) def raw2orth (rawarr): return rawarr def orth2raw (ortharr): @@ -469,7 +463,6 @@ def orth2raw (ortharr): manifolds_xmat.append (xmat) nraw = offs1[-1] - print (nraw, north) def raw2orth (rawarr): col_shape = rawarr.shape[1:] orth_shape = [north,] + list (col_shape) @@ -480,7 +473,6 @@ def raw2orth (rawarr): j = i + xmat.shape[1] ortharr[i:j] = np.tensordot (xmat.T, rawarr[prod_idx], axes=1) i = j - print (rawarr.shape, ortharr.shape) return ortharr def orth2raw (ortharr): diff --git a/my_pyscf/lassi/op_o1/hams2ovlp.py b/my_pyscf/lassi/op_o1/hams2ovlp.py index 9395a611..b3547219 100644 --- a/my_pyscf/lassi/op_o1/hams2ovlp.py +++ b/my_pyscf/lassi/op_o1/hams2ovlp.py @@ -71,22 +71,7 @@ def kernel (self): self.s2 = np.zeros ([self.nstates,]*2, dtype=self.dtype) self._crunch_all_() t1, w1 = lib.logger.process_clock (), lib.logger.perf_counter () - ovlp = np.zeros ([self.nstates,]*2, dtype=self.dtype) - def crunch_ovlp (bra_sp, ket_sp): - i = self.ints[-1] - b, k = i.unique_root[bra_sp], i.unique_root[ket_sp] - o = i.ovlp[b][k] / (1 + int (bra_sp==ket_sp)) - for i in self.ints[-2::-1]: - b, k = i.unique_root[bra_sp], i.unique_root[ket_sp] - o = np.multiply.outer (o, i.ovlp[b][k]).transpose (0,2,1,3) - o = o.reshape (o.shape[0]*o.shape[1], o.shape[2]*o.shape[3]) - o *= self.spin_shuffle[bra_sp] - o *= self.spin_shuffle[ket_sp] - i0, i1 = self.offs_lroots[bra_sp] - j0, j1 = self.offs_lroots[ket_sp] - ovlp[i0:i1,j0:j1] = o - for bra_sp, ket_sp in self.exc_null: crunch_ovlp (bra_sp, ket_sp) - ovlp += ovlp.T + ovlp = self.get_ovlp () dt, dw = logger.process_clock () - t1, logger.perf_counter () - w1 self.dt_o, self.dw_o = self.dt_o + dt, self.dw_o + dw self._umat_linequiv_loop_(ovlp) diff --git a/my_pyscf/lassi/op_o1/stdm.py b/my_pyscf/lassi/op_o1/stdm.py index e6c87f8c..633cbd4a 100644 --- a/my_pyscf/lassi/op_o1/stdm.py +++ b/my_pyscf/lassi/op_o1/stdm.py @@ -427,6 +427,18 @@ def get_ovlp (self, rootidx=None): exc_null = self.exc_null offs_lroots = self.offs_lroots nstates = self.nstates + if rootidx is not None: + rootidx = np.atleast_1d (rootidx) + bra_null = np.isin (self.exc_null[:,0], rootidx) + ket_null = np.isin (self.exc_null[:,1], rootidx) + idx_null = bra_null & ket_null + exc_null = exc_null[idx_null,:] + lroots = self.lroots[:,idx_null] + nprods = np.prod (lroots, axis=0) + offs1 = np.cumsum (nprods) + offs0 = offs1 - nprods + offs_lroots = np.stack ([offs0, offs1], axis=1) + nstates = offs1[-1] ovlp = np.zeros ([nstates,]*2, dtype=self.dtype) for bra, ket in exc_null: i0, i1 = offs_lroots[bra]