diff --git a/my_pyscf/lassi/lassis.py b/my_pyscf/lassi/lassis.py index 32f084c1..b5aae0a0 100644 --- a/my_pyscf/lassi/lassis.py +++ b/my_pyscf/lassi/lassis.py @@ -12,7 +12,7 @@ from mrh.my_pyscf.mcscf.productstate import ProductStateFCISolver from mrh.my_pyscf.lassi.excitations import ExcitationPSFCISolver from mrh.my_pyscf.lassi.spaces import spin_shuffle, spin_shuffle_ci -from mrh.my_pyscf.lassi.spaces import _spin_shuffle, _spin_shuffle_ci_ +from mrh.my_pyscf.lassi.spaces import _spin_shuffle from mrh.my_pyscf.lassi.spaces import all_single_excitations, SingleLASRootspace from mrh.my_pyscf.lassi.spaces import orthogonal_excitations, combine_orthogonal_excitations from mrh.my_pyscf.lassi.lassi import LASSI @@ -141,6 +141,7 @@ def single_excitations_ci (lsi, las2, las1, ncharge=1, sa_heff=True, deactivate_ ciref = [[] for j in range (nfrags)] for k in range (nfrags): for space in psref: ciref[k].append (space.ci[k]) + spaces[i].set_entmap_(psref[0]) psref = [space.get_product_state_solver () for space in psref] psexc = ExcitationPSFCISolver (psref, ciref, las2.ncas_sub, las2.nelecas_sub, stdout=mol.stdout, verbose=mol.verbose, @@ -302,6 +303,56 @@ def _spin_flip_products (spaces, spin_flips, nroots_ref=1, frozen_frags=None): spaces = [space for space in spaces if not ((space in seen) or seen.add (space))] return spaces +def _spin_shuffle_ci_(spaces, spin_flips, nroots_ref, nroots_refc): + '''Memory-efficient version of the function spaces._spin_shuffle_ci_. + Based on the fact that we know there has only been one independent set + of vectors per fragment Hilbert space and that all possible individual + fragment spins must be accounted for already, so we are just recombining + them.''' + old_idx = [] + new_idx = [] + nfrag = spaces[0].nfrag + for ix, space in enumerate (spaces): + if space.has_ci (): + old_idx.append (ix) + else: + assert (ix >= nroots_refc) + new_idx.append (ix) + space.ci = [None for ifrag in range (space.nfrag)] + # Prepare charge-hop szrots + spaces_1c = spaces[nroots_ref:nroots_refc] + spaces_1c = [space for space in spaces_1c if len (space.entmap)==1] + ci_szrot_1c = [] + for ix, space in enumerate (spaces_1c): + ifrag, jfrag = space.entmap[0] # must be a tuple of length 2 + ci_szrot_1c.append (space.get_ci_szrot (ifrags=(ifrag,jfrag))) + charges0 = spaces[0].charges + for ix in new_idx: + idx = spaces[ix].excited_fragments (spaces[0]) + space = spaces[ix] + for ifrag in np.where (~idx)[0]: + space.ci[ifrag] = spaces[0].ci[ifrag] + for ifrag in np.where (idx)[0]: + if space.charges[ifrag] != charges0[ifrag]: continue + sf = spin_flips[ifrag] + iflp = sf.smults == space.smults[ifrag] + iflp &= sf.spins == space.spins[ifrag] + assert (np.count_nonzero (iflp) == 1) + iflp = np.where (iflp)[0][0] + space.ci[ifrag] = sf.ci[iflp] + for (ci_i, ci_j), sp_1c in zip (ci_szrot_1c, spaces_1c): + ijfrag = sp_1c.entmap[0] + if ijfrag not in spaces[ix].entmap: continue + if np.any (sp_1c.charges[list(ijfrag)] != space.charges[list(ijfrag)]): continue + if np.any (sp_1c.smults[list(ijfrag)] != space.smults[list(ijfrag)]): continue + ifrag, jfrag = ijfrag + assert (space.ci[ifrag] is None) + assert (space.ci[jfrag] is None) + space.ci[ifrag] = ci_i[space.spins[ifrag]] + space.ci[jfrag] = ci_j[space.spins[jfrag]] + assert (space.has_ci ()) + return spaces + def spin_flip_products (las, spaces, spin_flips, nroots_ref=1): '''Inject spin-flips into las2 in all possible ways''' log = logger.new_logger (las, las.verbose) @@ -309,7 +360,7 @@ def spin_flip_products (las, spaces, spin_flips, nroots_ref=1): spaces = _spin_flip_products (spaces, spin_flips, nroots_ref=nroots_ref) nfrags = spaces[0].nfrag spaces = _spin_shuffle (spaces) - spaces = _spin_shuffle_ci_(spaces) + spaces = _spin_shuffle_ci_(spaces, spin_flips, nroots_ref, las2_nroots) log.info ("LASSIS spin-excitation spaces: %d-%d", las2_nroots, len (spaces)-1) for i, space in enumerate (spaces[las2_nroots:]): if np.any (space.nelec != spaces[0].nelec): @@ -326,8 +377,6 @@ def charge_excitation_products (lsi, spaces, las1): nfrags = lsi.nfrags space0 = spaces[0] i0, j0 = i, j = las1.nroots, len (spaces) - for space1 in spaces[i:j]: - space1.set_entmap_(space0) for product_order in range (2, (nfrags//2)+1): seen = set () for i_list in itertools.combinations (range (i,j), product_order): diff --git a/my_pyscf/lassi/spaces.py b/my_pyscf/lassi/spaces.py index aaef2810..7915a2a1 100644 --- a/my_pyscf/lassi/spaces.py +++ b/my_pyscf/lassi/spaces.py @@ -170,10 +170,14 @@ def has_ci (self): if self.ci is None: return False return all ([c is not None for c in self.ci]) - def get_ci_szrot (self): + def get_ci_szrot (self, ifrags=None): '''Generate the sets of CI vectors in which each vector for each fragment has the sz axis rotated in all possible ways. + Kwargs: + ifrags: list of integers + Optionally restrict ci_sz to particular fragments identified by ifrags + Returns: ci_sz: list of dict of type {integer: ndarray} dict keys are integerified "spin" quantum numbers; i.e., neleca-nelecb. @@ -181,7 +185,8 @@ def get_ci_szrot (self): ''' ci_sz = [] ndet = self.get_ndet () - for ifrag in range (self.nfrag): + if ifrags is None: ifrags = range (self.nfrag) + for ifrag in ifrags: norb, sz, ci = self.nlas[ifrag], self.spins[ifrag], self.ci[ifrag] ndeta, ndetb = ndet[ifrag] nelec = self.neleca[ifrag], self.nelecb[ifrag]