Skip to content

Commit

Permalink
safety commit
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewRHermes committed Dec 19, 2024
1 parent d7ad373 commit 2112eef
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 25 deletions.
68 changes: 68 additions & 0 deletions my_pyscf/lassi/citools.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
import numpy as np
import functools
from pyscf.scf.addons import canonical_orth_
from pyscf import __config__

LINDEP_THRESH = getattr (__config__, 'lassi_lindep_thresh', 1.0e-5)

def get_lroots (ci):
'''Generate a table showing the number of states contained in a (optionally nested) list
Expand Down Expand Up @@ -114,3 +119,66 @@ def _umat_dot_1frag (target, umat, lroots, ifrag):
old_shape2 = list (old_shape1)
old_shape2[0] = old_shape2[0] * ncol2 // ncol1
return target.reshape (*old_shape2)

def get_orth_basis (ci_fr, norb_f, nelec_frs, _get_ovlp=None):
if _get_ovlp is None:
from mrh.my_pyscf.lassi.op_o0 import get_ovlp
_get_ovlp = functools.partial (get_ovlp, ci_fr, norb_f, nelec_frs)
nfrags, nroots = nelec_frs.shape[:2]
unique, uniq_idx, inverse, cnts = np.unique (nelec_frs, axis=1, return_index=True,
return_inverse=True, return_counts=True)
if not np.count_nonzero (cnts>1):
def raw2orth (rawarr):
return rawarr
def orth2raw (ortharr):
return ortharr
return raw2orth, orth2raw
lroots_fr = np.array ([[1 if c.ndim<3 else c.shape[0]
for c in ci_r]
for ci_r in ci_fr])
nprods_r = np.prod (lroots_fr, axis=0)
offs1 = np.cumsum (nprods_r)
offs0 = offs1 - nprods_r
uniq_prod_idx = []
for i in uniq_idx[cnts==1]: uniq_prod_idx.extend (list(range(offs0[i],offs1[i])))
manifolds_prod_idx = []
manifolds_xmat = []
nuniq_prod = north = len (uniq_prod_idx)
for manifold_idx in np.where (cnts>1)[0]:
manifold = np.where (inverse==manifold_idx)[0]
manifold_prod_idx = []
for i in manifold: manifold_prod_idx.extend (list(range(offs0[i],offs1[i])))
manifolds_prod_idx.append (manifold_prod_idx)
ovlp = _get_ovlp (rootidx=manifold)
xmat = canonical_orth_(ovlp, thr=LINDEP_THRESH)
north += xmat.shape[1]
manifolds_xmat.append (xmat)

nraw = offs1[-1]
def raw2orth (rawarr):
col_shape = rawarr.shape[1:]
orth_shape = [north,] + list (col_shape)
ortharr = np.zeros (orth_shape, dtype=rawarr.dtype)
ortharr[:nuniq_prod] = rawarr[uniq_prod_idx]
i = nuniq_prod
for prod_idx, xmat in zip (manifolds_prod_idx, manifolds_xmat):
j = i + xmat.shape[1]
ortharr[i:j] = np.tensordot (xmat.T, rawarr[prod_idx], axes=1)
i = j
return ortharr

def orth2raw (ortharr):
col_shape = ortharr.shape[1:]
raw_shape = [nraw,] + list (col_shape)
rawarr = np.zeros (raw_shape, dtype=ortharr.dtype)
rawarr[uniq_prod_idx] = ortharr[:nuniq_prod]
i = nuniq_prod
for prod_idx, xmat in zip (manifolds_prod_idx, manifolds_xmat):
j = i + xmat.shape[1]
rawarr[prod_idx] = np.tensordot (xmat.conj (), ortharr[i:j], axes=1)
i = j
return rawarr

return raw2orth, orth2raw


3 changes: 2 additions & 1 deletion my_pyscf/lassi/lassi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from mrh.my_pyscf.lassi import op_o0
from mrh.my_pyscf.lassi import op_o1
from mrh.my_pyscf.lassi import chkfile
from mrh.my_pyscf.lassi import citools
from mrh.my_pyscf.lassi.citools import get_lroots
from pyscf import lib, symm, ao2mo
from pyscf.scf.addons import canonical_orth_
Expand Down Expand Up @@ -429,7 +430,7 @@ def _eig_block (las, e0, h1, h2, ci_blk, nelec_blk, rootsym, soc, orbsym, wfnsym
lc = 'checking if LASSI basis has lindeps: |ovlp| = {:.6e}'.format (ovlp_det)
lib.logger.info (las, 'Caught error %s, %s', str (err), lc)
if ovlp_det < LINDEP_THRESH:
raw2orth, orth2raw = op_o0.get_orth_basis (ci_blk, las.ncas_sub, nelec_blk)
raw2orth, orth2raw = citools.get_orth_basis (ci_blk, las.ncas_sub, nelec_blk)
xhx = raw2orth (ham_blk).conj ().T
#x = canonical_orth_(ovlp_blk, thr=LINDEP_THRESH)
lib.logger.info (las, '%d/%d linearly independent model states',
Expand Down
8 changes: 0 additions & 8 deletions my_pyscf/lassi/op_o0.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,11 @@
from pyscf.fci.direct_nosym import contract_1e as contract_1e_nosym
from pyscf.fci.direct_spin1 import _unpack_nelec
from pyscf.fci.spin_op import contract_ss, spin_square
from pyscf.scf.addons import canonical_orth_
from pyscf.data import nist
from itertools import combinations
from mrh.my_pyscf.mcscf import soc_int as soc_int
from mrh.my_pyscf.lassi import dms as lassi_dms
from mrh.my_pyscf.fci.csf import unpack_h1e_cs
from pyscf import __config__

LINDEP_THRESH = getattr (__config__, 'lassi_lindep_thresh', 1.0e-5)

def memcheck (las, ci, soc=None):
'''Check if the system has enough memory to run these functions! ONLY checks
Expand Down Expand Up @@ -440,8 +436,6 @@ def get_orth_basis (ci_fr, norb_f, nelec_frs, _get_ovlp=get_ovlp):
unique, uniq_idx, inverse, cnts = np.unique (nelec_frs, axis=1, return_index=True,
return_inverse=True, return_counts=True)
if not np.count_nonzero (cnts>1):
print ("escape")
print (cnts)
def raw2orth (rawarr):
return rawarr
def orth2raw (ortharr):
Expand Down Expand Up @@ -469,7 +463,6 @@ def orth2raw (ortharr):
manifolds_xmat.append (xmat)

nraw = offs1[-1]
print (nraw, north)
def raw2orth (rawarr):
col_shape = rawarr.shape[1:]
orth_shape = [north,] + list (col_shape)
Expand All @@ -480,7 +473,6 @@ def raw2orth (rawarr):
j = i + xmat.shape[1]
ortharr[i:j] = np.tensordot (xmat.T, rawarr[prod_idx], axes=1)
i = j
print (rawarr.shape, ortharr.shape)
return ortharr

def orth2raw (ortharr):
Expand Down
17 changes: 1 addition & 16 deletions my_pyscf/lassi/op_o1/hams2ovlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,22 +71,7 @@ def kernel (self):
self.s2 = np.zeros ([self.nstates,]*2, dtype=self.dtype)
self._crunch_all_()
t1, w1 = lib.logger.process_clock (), lib.logger.perf_counter ()
ovlp = np.zeros ([self.nstates,]*2, dtype=self.dtype)
def crunch_ovlp (bra_sp, ket_sp):
i = self.ints[-1]
b, k = i.unique_root[bra_sp], i.unique_root[ket_sp]
o = i.ovlp[b][k] / (1 + int (bra_sp==ket_sp))
for i in self.ints[-2::-1]:
b, k = i.unique_root[bra_sp], i.unique_root[ket_sp]
o = np.multiply.outer (o, i.ovlp[b][k]).transpose (0,2,1,3)
o = o.reshape (o.shape[0]*o.shape[1], o.shape[2]*o.shape[3])
o *= self.spin_shuffle[bra_sp]
o *= self.spin_shuffle[ket_sp]
i0, i1 = self.offs_lroots[bra_sp]
j0, j1 = self.offs_lroots[ket_sp]
ovlp[i0:i1,j0:j1] = o
for bra_sp, ket_sp in self.exc_null: crunch_ovlp (bra_sp, ket_sp)
ovlp += ovlp.T
ovlp = self.get_ovlp ()
dt, dw = logger.process_clock () - t1, logger.perf_counter () - w1
self.dt_o, self.dw_o = self.dt_o + dt, self.dw_o + dw
self._umat_linequiv_loop_(ovlp)
Expand Down
12 changes: 12 additions & 0 deletions my_pyscf/lassi/op_o1/stdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,18 @@ def get_ovlp (self, rootidx=None):
exc_null = self.exc_null
offs_lroots = self.offs_lroots
nstates = self.nstates
if rootidx is not None:
rootidx = np.atleast_1d (rootidx)
bra_null = np.isin (self.exc_null[:,0], rootidx)
ket_null = np.isin (self.exc_null[:,1], rootidx)
idx_null = bra_null & ket_null
exc_null = exc_null[idx_null,:]
lroots = self.lroots[:,idx_null]
nprods = np.prod (lroots, axis=0)
offs1 = np.cumsum (nprods)
offs0 = offs1 - nprods
offs_lroots = np.stack ([offs0, offs1], axis=1)
nstates = offs1[-1]
ovlp = np.zeros ([nstates,]*2, dtype=self.dtype)
for bra, ket in exc_null:
i0, i1 = offs_lroots[bra]
Expand Down

0 comments on commit 2112eef

Please sign in to comment.