Skip to content

Commit

Permalink
lassi op_o1 spinless mapping
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewRHermes committed Aug 29, 2024
1 parent 062a580 commit 67a0d06
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 12 deletions.
33 changes: 29 additions & 4 deletions my_pyscf/lassi/op_o1/hams2ovlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,8 @@ def _crunch_2c_(self, bra, ket, i, j, k, l, s2lt):
self.dt_2c, self.dw_2c = self.dt_2c + dt, self.dw_2c + dw
return ham, s2, (l, j, i, k)

def ham (las, h1, h2, ci, nelec_frs, _HamS2Ovlp_class=HamS2Ovlp, _do_kernel=True, **kwargs):
def ham (las, h1, h2, ci, nelec_frs, soc=0, nlas=None, _HamS2Ovlp_class=HamS2Ovlp, _do_kernel=True,
**kwargs):
''' Build Hamiltonian, spin-squared, and overlap matrices in LAS product state basis
Args:
Expand All @@ -413,6 +414,10 @@ def ham (las, h1, h2, ci, nelec_frs, _HamS2Ovlp_class=HamS2Ovlp, _do_kernel=True
fragment
Kwargs:
soc : integer
Order of spin-orbit coupling included in the Hamiltonian
nlas : sequence of length (nfrags)
Number of orbitals in each fragment
_HamS2Ovlp_class : class
The main intermediate class
_do_kernel : logical
Expand All @@ -428,12 +433,32 @@ def ham (las, h1, h2, ci, nelec_frs, _HamS2Ovlp_class=HamS2Ovlp, _do_kernel=True
Overlap matrix of LAS product states
'''
log = lib.logger.new_logger (las, las.verbose)
nlas = las.ncas_sub
if nlas is None: nlas = las.ncas_sub
max_memory = getattr (las, 'max_memory', las.mol.max_memory)
dtype = ci[0][0].dtype
dtype = h1.dtype
if soc>1: raise NotImplementedError ("Spin-orbit coupling of second order")

# Handle possible SOC
n = sum (nlas)
nelec_rs = [tuple (x) for x in nelec_frs.sum (0)]
spin_pure = len (set (nelec_rs))
if soc and spin_pure: # In this scenario, the off-diagonal sector of h1 is pointless
h1 = np.stack ([h1[:n,:n], h1[n:,n:]], axis=0)
if not spin_pure: # Engage the ``spinless mapping''
if not soc: h1 = linalg.block_diag (h1, h1)
h2_ = np.zeros ([2*n,]*4, dtype=h2.dtype)
h2_[:n,:n,:n,:n] = h2[:]
h2_[:n,:n,n:,n:] = h2[:]
h2_[n:,n:,:n,:n] = h2[:]
h2_[n:,n:,n:,n:] = h2[:]
h2 = h2_
ci = ci_map2spinless (ci, nlas, nelec_frs)
nlas = [2*x for x in nlas]
nelec_frs[:,:,0] += nelec_frs[:,:,1]
nelec_frs[:,:,1] = 0

# First pass: single-fragment intermediates
hopping_index, ints, lroots = frag.make_ints (las, ci, nelec_frs)
hopping_index, ints, lroots = frag.make_ints (las, ci, nelec_frs, nlas=nlas)
nstates = np.sum (np.prod (lroots, axis=0))

# Memory check
Expand Down
8 changes: 4 additions & 4 deletions my_pyscf/lassi/op_o1/hci.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,8 @@ def contract_ham_ci (las, h1, h2, ci_fr_ket, nelec_frs_ket, ci_fr_bra, nelec_frs
nelec_frs = np.append (nelec_frs_ket, nelec_frs_bra, axis=1)

# First pass: single-fragment intermediates
hopping_index, ints, lroots = frag.make_ints (las, ci, nelec_frs, screen_linequiv=False)
hopping_index, ints, lroots = frag.make_ints (las, ci, nelec_frs, nlas=nlas,
screen_linequiv=False)

# Second pass: upper-triangle
t0 = (lib.logger.process_clock (), lib.logger.perf_counter ())
Expand Down Expand Up @@ -516,9 +517,8 @@ def fermion_frag_shuffle (self, iroot, frags):
# Get the intermediate object, rather than just the ham matrix, so that I can use the members
# of the intermediate to keep track of the difference between the full-system indices and the
# nfrag-1--system indices
with lib.temporary_env (las, ncas_sub=nlas_j):
outerprod = hams2ovlp.ham (las, h1, h2, ci_jfrag, nelec_frs_j, _HamS2Ovlp_class=HamS2Ovlp,
_do_kernel=False)
outerprod = hams2ovlp.ham (las, h1, h2, ci_jfrag, nelec_frs_j, nlas=nlas_j,
_HamS2Ovlp_class=HamS2Ovlp, _do_kernel=False)
ham = outerprod.kernel ()[0]

for ibra in range (nket, nroots):
Expand Down
29 changes: 26 additions & 3 deletions my_pyscf/lassi/op_o1/rdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ def get_fdm1_maker (las, ci, nelec_frs, si, **kwargs):
dtype = ci[0][0].dtype

# First pass: single-fragment intermediates
hopping_index, ints, lroots = frag.make_ints (las, ci, nelec_frs)
hopping_index, ints, lroots = frag.make_ints (las, ci, nelec_frs, nlas=nlas)
nstates = np.sum (np.prod (lroots, axis=0))

# Second pass: upper-triangle
Expand Down Expand Up @@ -624,10 +624,20 @@ def roots_make_rdm12s (las, ci, nelec_frs, si, **kwargs):
ncas = las.ncas
nroots_si = si.shape[-1]
max_memory = getattr (las, 'max_memory', las.mol.max_memory)
dtype = ci[0][0].dtype
dtype = si.dtype

# Handle possible SOC
nelec_rs = [tuple (x) for x in nelec_frs.sum (0)]
spin_pure = len (set (nelec_rs))
if not spin_pure: # Engage the ``spinless mapping''
ci = ci_map2spinless (ci, nlas, nelec_frs)
nlas = [2*x for x in nlas]
nelec_frs[:,:,0] += nelec_frs[:,:,1]
nelec_frs[:,:,1] = 0

# First pass: single-fragment intermediates
hopping_index, ints, lroots = frag.make_ints (las, ci, nelec_frs, _FragTDMInt_class=FragTDMInt)
hopping_index, ints, lroots = frag.make_ints (las, ci, nelec_frs, nlas=nlas,
_FragTDMInt_class=FragTDMInt)
nstates = np.sum (np.prod (lroots, axis=0))

# Memory check
Expand All @@ -650,6 +660,19 @@ def roots_make_rdm12s (las, ci, nelec_frs, si, **kwargs):
# Put rdm1s in PySCF convention: [p,q] -> q'p
rdm1s = rdm1s.transpose (0,1,3,2)
rdm2s = rdm2s.reshape (nroots_si, 2, 2, ncas, ncas, ncas, ncas).transpose (0,1,3,4,2,5,6)

# Clean up the ``spinless mapping''
if not spin_pure:
rdm1s = rdm1s[0,:,:]
# TODO: 2e- SOC
n = sum (nlas) // 2
rdm2s_ = np.zeros ((2, n, n, 2, n, n), dtype=rdm2s.dtype)
rdm2s_[0,:,:,0,:,:] = rdm2s[0,:n,:n,0,:n,:n]
rdm2s_[0,:,:,1,:,:] = rdm2s[0,:n,:n,0,n:,n:]
rdm2s_[1,:,:,0,:,:] = rdm2s[0,n:,n:,0,:n,:n]
rdm2s_[1,:,:,1,:,:] = rdm2s[0,n:,n:,0,n:,n:]
rdm2s = rdm2s_

return rdm1s, rdm2s


Expand Down
24 changes: 23 additions & 1 deletion my_pyscf/lassi/op_o1/stdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1001,8 +1001,17 @@ def make_stdm12s (las, ci, nelec_frs, **kwargs):
dtype = ci[0][0].dtype
max_memory = getattr (las, 'max_memory', las.mol.max_memory)

# Handle possible SOC
nelec_rs = [tuple (x) for x in nelec_frs.sum (0)]
spin_pure = len (set (nelec_rs))
if not spin_pure: # Engage the ``spinless mapping''
ci = ci_map2spinless (ci, nlas, nelec_frs)
nlas = [2*x for x in nlas]
nelec_frs[:,:,0] += nelec_frs[:,:,1]
nelec_frs[:,:,1] = 0

# First pass: single-fragment intermediates
hopping_index, ints, lroots = frag.make_ints (las, ci, nelec_frs)
hopping_index, ints, lroots = frag.make_ints (las, ci, nelec_frs, nlas=nlas)
nstates = np.sum (np.prod (lroots, axis=0))

# Memory check
Expand All @@ -1025,6 +1034,19 @@ def make_stdm12s (las, ci, nelec_frs, **kwargs):
# Put tdm1s in PySCF convention: [p,q] -> q'p
tdm1s = tdm1s.transpose (0,2,4,3,1)
tdm2s = tdm2s.reshape (nstates,nstates,2,2,ncas,ncas,ncas,ncas).transpose (0,2,4,5,3,6,7,1)

# Clean up the ``spinless mapping''
if not spin_pure:
tdm1s = tdm1s[:,0,:,:,:]
n = sum (nlas) // 2
tdm2s_ = np.zeros ((nroots, nroots, 2, n, n, 2, n, n),
dtype=tdm2s.dtype).transpose (0,2,3,4,5,6,7,1)
tdm2s_[:,0,:,:,0,:,:,:] = tdm2s[:,0,:n,:n,0,:n,:n,:]
tdm2s_[:,0,:,:,1,:,:,:] = tdm2s[:,0,:n,:n,0,n:,n:,:]
tdm2s_[:,1,:,:,0,:,:,:] = tdm2s[:,0,n:,n:,0,:n,:n,:]
tdm2s_[:,1,:,:,1,:,:,:] = tdm2s[:,0,n:,n:,0,n:,n:,:]
tdm2s = tdm2s_

return tdm1s, tdm2s


22 changes: 22 additions & 0 deletions my_pyscf/lassi/op_o1/utilities.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from mrh.my_pyscf.lassi.citools import umat_dot_1frag_
from mrh.my_pyscf.lassi.op_o0 import civec_spinless_repr

def fermion_spin_shuffle (na_list, nb_list):
''' Compute the sign factor corresponding to the convention
Expand Down Expand Up @@ -140,4 +141,25 @@ def split_contig_array (arrlen, nthreads):
blkstart -= blklen
return blkstart, blklen

def ci_map2spinless (ci0_fr, norb_f, nelec_frs):
nfrags, nroots = nelec_frs.shape[:2]

# Only transform unique CI vectors
ci_ptrs = np.asarray ([[c.__array_interface__['data'][0] for c in ci_r] for ci_r in ci0_fr])
_, idx, inv = np.unique (ci_ptrs, return_index=True, return_inverse=True)
inv = inv.reshape (ci_ptrs.shape)
ci1 = []
for ix in idx:
i, j = divmod (ix, nroots)
if isinstance (ci0_fr[i][j], (list,tuple)) or ci0_fr[i][j].ndim>2:
ci0 = ci0_fr[i][j]
else:
ci0 = [ci0_fr[i][j],]
nelec = [nelec_frs[i][j],]*len(ci0)
ci1.append (civec_spinless_repr (ci0, norb_f[i], nelec))

return [[ci1[inv[i,j]] for j in range (nroots)] for i in range (nfrags)]




0 comments on commit 67a0d06

Please sign in to comment.