Skip to content

Commit

Permalink
safety commit
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewRHermes committed Aug 23, 2024
1 parent baca1d8 commit 47e7c6a
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 48 deletions.
71 changes: 59 additions & 12 deletions my_pyscf/lassi/op_o1/frag.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import numpy as np
from scipy import linalg
from pyscf import lib
from pyscf.fci.direct_spin1 import trans_rdm12s, contract_1e
from pyscf.fci.direct_spin1 import trans_rdm12s, contract_1e, contract_2e, absorb_h1e
from pyscf.fci.direct_uhf import contract_1e as contract_1e_uhf
from pyscf.fci.addons import cre_a, cre_b, des_a, des_b
from pyscf.fci import cistring
from itertools import product, combinations
from mrh.my_pyscf.lassi.citools import get_lroots, get_rootaddr_fragaddr
from mrh.my_pyscf.lassi.op_o1.utilities import *

class LSTDMint1 (object):
''' LAS state transition density matrix intermediate 1: fragment-local data.
class FragTDMInt (object):
''' Fragment-local LAS state transition density matrix intermediate
Quasi-sparse-memory storage for LAS-state transition density matrix single-fragment
intermediates. Stores all local transition density matrix factors. Run the `kernel` method
Expand Down Expand Up @@ -465,7 +466,17 @@ def trans_rdm12s_loop (iroot, bra, ket):
return t0

def contract_h00 (self, h_00, h_11, h_22, ket):
raise NotImplementedError
r = self.rootaddr[ket]
n = self.fragaddr[ket]
norb, nelec = self.norb, self.nelec_r[r]
ci = self.ci[r][n]
h_uhf = (h_11[0] - h_11[1]) / 2
h_uhf = [h_uhf, -h.uhf]
h_11 = h_11.sum (0) / 2
h2eff = absorb_h1e (h_11, h_22, norb, nelec, 0.5)
hci = h_00*ci + contract_2e (h2eff, ci, norb, nelec)
hci += contract_uhf (h_uhf, ci, norb, nelec)
return hci

def contract_h10 (self, spin, h_10, h_21, ket):
r = self.rootaddr[ket]
Expand Down Expand Up @@ -497,16 +508,52 @@ def contract_h01 (self, spin, h_01, h_12, ket):
return hci

def contract_h20 (self, spin, h_20, ket):
raise NotImplementedError
r = self.rootaddr[ket]
n = self.fragaddr[ket]
norb, nelec = self.norb, self.nelec_r[r]
ci = self.ci[r][n]
# 0, 1, 2 = aa, ab, bb
cre_op1 = (cre_a, cre_b)[int (spin>1)]
cre_op2 = (cre_a, cre_b)[int (spin>0)]
hci = 0
for q in range (self.norb):
qci = cre_op2 (ci, norb, nelec, q)
for p in range (self.norb):
hci += h_20[p,q] * cre_op1 (qci, norb, nelec, p)
return hci

def contract_h02 (self, spin, h_02, ket):
raise NotImplementedError
r = self.rootaddr[ket]
n = self.fragaddr[ket]
norb, nelec = self.norb, self.nelec_r[r]
ci = self.ci[r][n]
# 0, 1, 2 = aa, ab, bb
des_op1 = (des_a, des_b)[int (spin>1)]
des_op2 = (des_a, des_b)[int (spin>0)]
hci = 0
for q in range (self.norb):
qci = des_op1 (ci, norb, nelec, q)
for p in range (self.norb):
hci += h_02[p,q] * des_op2 (qci, norb, nelec, p)
return hci

def contract_h11 (self, spin, h_11, ket):
raise NotImplementedError
r = self.rootaddr[ket]
n = self.fragaddr[ket]
norb, nelec = self.norb, self.nelec_r[r]
ci = self.ci[r][n]
# 0, 1 = ab, ba
cre_op = (cre_a, cre_b)[spin]
des_op = (des_b, des_a)[spin]
hci = 0
for q in range (self.norb):
qci = des_op (ci, norb, nelec, q)
for p in range (self.norb):
hci += h_11[p,q] * cre_op (qci, norb, nelec, p)
return hci

def make_ints (las, ci, nelec_frs, screen_linequiv=True, _LSTDMint1_class=LSTDMint1):
''' Build fragment-local intermediates (`LSTDMint1`) for LASSI o1
def make_ints (las, ci, nelec_frs, screen_linequiv=True, nlas=None, _FragTDMInt_class=FragTDMInt):
''' Build fragment-local intermediates (`FragTDMInt`) for LASSI o1
Args:
las : instance of :class:`LASCINoSymm`
Expand All @@ -525,18 +572,18 @@ def make_ints (las, ci, nelec_frs, screen_linequiv=True, _LSTDMint1_class=LSTDMi
hopping_index : ndarray of ints of shape (nfrags, 2, nroots, nroots)
element [i,j,k,l] reports the change of number of electrons of
spin j in fragment i between LAS rootspaces k and l
ints : list of length nfrags of instances of :class:`LSTDMint1`
ints : list of length nfrags of instances of :class:`FragTDMInt`
lroots: ndarray of ints of shape (nfrags, nroots)
Number of states within each fragment and rootspace
'''
nfrags, nroots = nelec_frs.shape[:2]
nlas = las.ncas_sub
if nlas is None: nlas = las.ncas_sub
lroots = get_lroots (ci)
hopping_index, zerop_index, onep_index = lst_hopping_index (nelec_frs)
rootaddr, fragaddr = get_rootaddr_fragaddr (lroots)
ints = []
for ifrag in range (nfrags):
tdmint = _LSTDMint1_class (ci[ifrag], hopping_index[ifrag], zerop_index, onep_index,
tdmint = _FragTDMInt_class (ci[ifrag], hopping_index[ifrag], zerop_index, onep_index,
nlas[ifrag], nroots, nelec_frs[ifrag], rootaddr,
fragaddr[ifrag], ifrag, screen_linequiv=screen_linequiv)
lib.logger.timer (las, 'LAS-state TDM12s fragment {} intermediate crunching'.format (
Expand Down
10 changes: 5 additions & 5 deletions my_pyscf/lassi/op_o1/hams2ovlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
# for two fragments this is (sign?)
# ((d1a_pp - d1b_pp) * (d1a_qq - d1b_qq))/4 - (sp_pp*sm_qq + sm_pp*sp_qq)/2

class HamS2ovlpint (stdm.LSTDMint2):
__doc__ = stdm.LSTDMint2.__doc__ + '''
class HamS2Ovlp (stdm.LSTDM):
__doc__ = stdm.LSTDM.__doc__ + '''
SUBCLASS: Hamiltonian, spin-squared, and overlap matrices
Expand All @@ -29,7 +29,7 @@ class HamS2ovlpint (stdm.LSTDMint2):

def __init__(self, ints, nlas, hopping_index, lroots, h1, h2, mask_bra_space=None,
mask_ket_space=None, log=None, max_memory=2000, dtype=np.float64):
stdm.LSTDMint2.__init__(self, ints, nlas, hopping_index, lroots,
stdm.LSTDM.__init__(self, ints, nlas, hopping_index, lroots,
mask_bra_space=mask_bra_space, mask_ket_space=mask_ket_space,
log=log, max_memory=max_memory, dtype=dtype)
if h1.ndim==2: h1 = np.stack ([h1,h1], axis=0)
Expand Down Expand Up @@ -397,7 +397,7 @@ def _crunch_2c_(self, bra, ket, i, j, k, l, s2lt):
self.dt_2c, self.dw_2c = self.dt_2c + dt, self.dw_2c + dw
return ham, s2, (l, j, i, k)

def ham (las, h1, h2, ci, nelec_frs, _HamS2ovlpint_class=HamS2ovlpint, **kwargs):
def ham (las, h1, h2, ci, nelec_frs, _HamS2Ovlp_class=HamS2Ovlp, **kwargs):
''' Build Hamiltonian, spin-squared, and overlap matrices in LAS product state basis
Args:
Expand Down Expand Up @@ -438,7 +438,7 @@ def ham (las, h1, h2, ci, nelec_frs, _HamS2ovlpint_class=HamS2ovlpint, **kwargs)

# Second pass: upper-triangle
t0 = (lib.logger.process_clock (), lib.logger.perf_counter ())
outerprod = _HamS2ovlpint_class (ints, nlas, hopping_index, lroots, h1, h2, dtype=dtype,
outerprod = _HamS2Ovlp_class (ints, nlas, hopping_index, lroots, h1, h2, dtype=dtype,
max_memory=max_memory, log=log)
lib.logger.timer (las, 'LASSI Hamiltonian second intermediate indexing setup', *t0)
ham, s2, ovlp, t0 = outerprod.kernel ()
Expand Down
Loading

0 comments on commit 47e7c6a

Please sign in to comment.