Skip to content

Commit

Permalink
Memory efficiency lassis _spin_shuffle_ci_ (#90)
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewRHermes committed Apr 30, 2024
1 parent 52486aa commit de53375
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 6 deletions.
57 changes: 53 additions & 4 deletions my_pyscf/lassi/lassis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from mrh.my_pyscf.mcscf.productstate import ProductStateFCISolver
from mrh.my_pyscf.lassi.excitations import ExcitationPSFCISolver
from mrh.my_pyscf.lassi.spaces import spin_shuffle, spin_shuffle_ci
from mrh.my_pyscf.lassi.spaces import _spin_shuffle, _spin_shuffle_ci_
from mrh.my_pyscf.lassi.spaces import _spin_shuffle
from mrh.my_pyscf.lassi.spaces import all_single_excitations, SingleLASRootspace
from mrh.my_pyscf.lassi.spaces import orthogonal_excitations, combine_orthogonal_excitations
from mrh.my_pyscf.lassi.lassi import LASSI
Expand Down Expand Up @@ -141,6 +141,7 @@ def single_excitations_ci (lsi, las2, las1, ncharge=1, sa_heff=True, deactivate_
ciref = [[] for j in range (nfrags)]
for k in range (nfrags):
for space in psref: ciref[k].append (space.ci[k])
spaces[i].set_entmap_(psref[0])
psref = [space.get_product_state_solver () for space in psref]
psexc = ExcitationPSFCISolver (psref, ciref, las2.ncas_sub, las2.nelecas_sub,
stdout=mol.stdout, verbose=mol.verbose,
Expand Down Expand Up @@ -302,14 +303,64 @@ def _spin_flip_products (spaces, spin_flips, nroots_ref=1, frozen_frags=None):
spaces = [space for space in spaces if not ((space in seen) or seen.add (space))]
return spaces

def _spin_shuffle_ci_(spaces, spin_flips, nroots_ref, nroots_refc):
'''Memory-efficient version of the function spaces._spin_shuffle_ci_.
Based on the fact that we know there has only been one independent set
of vectors per fragment Hilbert space and that all possible individual
fragment spins must be accounted for already, so we are just recombining
them.'''
old_idx = []
new_idx = []
nfrag = spaces[0].nfrag
for ix, space in enumerate (spaces):
if space.has_ci ():
old_idx.append (ix)
else:
assert (ix >= nroots_refc)
new_idx.append (ix)
space.ci = [None for ifrag in range (space.nfrag)]
# Prepare charge-hop szrots
spaces_1c = spaces[nroots_ref:nroots_refc]
spaces_1c = [space for space in spaces_1c if len (space.entmap)==1]
ci_szrot_1c = []
for ix, space in enumerate (spaces_1c):
ifrag, jfrag = space.entmap[0] # must be a tuple of length 2
ci_szrot_1c.append (space.get_ci_szrot (ifrags=(ifrag,jfrag)))
charges0 = spaces[0].charges
for ix in new_idx:
idx = spaces[ix].excited_fragments (spaces[0])
space = spaces[ix]
for ifrag in np.where (~idx)[0]:
space.ci[ifrag] = spaces[0].ci[ifrag]
for ifrag in np.where (idx)[0]:
if space.charges[ifrag] != charges0[ifrag]: continue
sf = spin_flips[ifrag]
iflp = sf.smults == space.smults[ifrag]
iflp &= sf.spins == space.spins[ifrag]
assert (np.count_nonzero (iflp) == 1)
iflp = np.where (iflp)[0][0]
space.ci[ifrag] = sf.ci[iflp]
for (ci_i, ci_j), sp_1c in zip (ci_szrot_1c, spaces_1c):
ijfrag = sp_1c.entmap[0]
if ijfrag not in spaces[ix].entmap: continue
if np.any (sp_1c.charges[list(ijfrag)] != space.charges[list(ijfrag)]): continue
if np.any (sp_1c.smults[list(ijfrag)] != space.smults[list(ijfrag)]): continue
ifrag, jfrag = ijfrag
assert (space.ci[ifrag] is None)
assert (space.ci[jfrag] is None)
space.ci[ifrag] = ci_i[space.spins[ifrag]]
space.ci[jfrag] = ci_j[space.spins[jfrag]]
assert (space.has_ci ())
return spaces

def spin_flip_products (las, spaces, spin_flips, nroots_ref=1):
'''Inject spin-flips into las2 in all possible ways'''
log = logger.new_logger (las, las.verbose)
las2_nroots = len (spaces)
spaces = _spin_flip_products (spaces, spin_flips, nroots_ref=nroots_ref)
nfrags = spaces[0].nfrag
spaces = _spin_shuffle (spaces)
spaces = _spin_shuffle_ci_(spaces)
spaces = _spin_shuffle_ci_(spaces, spin_flips, nroots_ref, las2_nroots)
log.info ("LASSIS spin-excitation spaces: %d-%d", las2_nroots, len (spaces)-1)
for i, space in enumerate (spaces[las2_nroots:]):
if np.any (space.nelec != spaces[0].nelec):
Expand All @@ -326,8 +377,6 @@ def charge_excitation_products (lsi, spaces, las1):
nfrags = lsi.nfrags
space0 = spaces[0]
i0, j0 = i, j = las1.nroots, len (spaces)
for space1 in spaces[i:j]:
space1.set_entmap_(space0)
for product_order in range (2, (nfrags//2)+1):
seen = set ()
for i_list in itertools.combinations (range (i,j), product_order):
Expand Down
9 changes: 7 additions & 2 deletions my_pyscf/lassi/spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,18 +170,23 @@ def has_ci (self):
if self.ci is None: return False
return all ([c is not None for c in self.ci])

def get_ci_szrot (self):
def get_ci_szrot (self, ifrags=None):
'''Generate the sets of CI vectors in which each vector for each fragment
has the sz axis rotated in all possible ways.
Kwargs:
ifrags: list of integers
Optionally restrict ci_sz to particular fragments identified by ifrags
Returns:
ci_sz: list of dict of type {integer: ndarray}
dict keys are integerified "spin" quantum numbers; i.e., neleca-nelecb.
dict vals are the corresponding CI vectors
'''
ci_sz = []
ndet = self.get_ndet ()
for ifrag in range (self.nfrag):
if ifrags is None: ifrags = range (self.nfrag)
for ifrag in ifrags:
norb, sz, ci = self.nlas[ifrag], self.spins[ifrag], self.ci[ifrag]
ndeta, ndetb = ndet[ifrag]
nelec = self.neleca[ifrag], self.nelecb[ifrag]
Expand Down

0 comments on commit de53375

Please sign in to comment.