diff --git a/my_pyscf/lassi/lassis.py b/my_pyscf/lassi/lassis.py index 1674530e..7f202c63 100644 --- a/my_pyscf/lassi/lassis.py +++ b/my_pyscf/lassi/lassis.py @@ -56,11 +56,7 @@ def single_excitations_ci (lsi, las2, las1, nmax_charge=1, sa_heff=True, deactiv mol = lsi.mol nfrags = lsi.nfrags e_roots = np.append (las1.e_states, np.zeros (las2.nroots-las1.nroots)) - #psrefs = [] ci = [[ci_ij for ci_ij in ci_i] for ci_i in las2.ci] - #for j in range (las1.nroots): - # solvers = [b.fcisolvers[j] for b in las1.fciboxes] - # psrefs.append (ProductStateFCISolver (solvers, stdout=mol.stdout, verbose=mol.verbose)) spaces = [SingleLASRootspace (las2, m, s, c, las2.weights[ix], ci=[c[ix] for c in ci]) for ix, (c, m, s, w) in enumerate (zip (*get_space_info (las2)))] ncsf = las2.get_ugg ().ncsf_sub @@ -87,8 +83,6 @@ def single_excitations_ci (lsi, las2, las1, nmax_charge=1, sa_heff=True, deactiv dest_frag, dest_ds) excfrags[spaces[i].excited_fragments (spaces[j])] = True psref.append (spaces[j]) - #for k in range (nfrags): - # ciref[k].append (las1.ci[k][j]) #psref = _spin_halfexcitation_products (psref, spin_halfexcs, nroots_ref=len(psref), # frozen_frags=(~excfrags)) ciref = [[] for j in range (nfrags)] diff --git a/my_pyscf/lassi/states.py b/my_pyscf/lassi/states.py index ac077fb6..35dbcb15 100644 --- a/my_pyscf/lassi/states.py +++ b/my_pyscf/lassi/states.py @@ -9,7 +9,7 @@ from mrh.my_pyscf.fci.spin_op import contract_sdown, contract_sup from mrh.my_pyscf.fci.csfstring import CSFTransformer from mrh.my_pyscf.fci.csfstring import ImpossibleSpinError -from mrh.my_pyscf.mcscf.productstate import ProductStateFCISolver +from mrh.my_pyscf.mcscf.productstate import ImpureProductStateFCISolver import itertools class SingleLASRootspace (object): @@ -228,12 +228,17 @@ def excited_fragments (self, other): idx_same = (dneleca==0) & (dnelecb==0) & (dsmults==0) return ~idx_same + def get_lroots (self): + if not self.has_ci (): return None + lroots = [] + for c, n in zip (self.ci, self.get_ndet ()): + c = np.asarray (c).reshape (-1, n[0], n[1]) + lroots.append (c.shape[0]) + return lroots + + def table_printlog (self, lroots=None): - if lroots is None and self.has_ci (): - lroots = [] - for c, n in zip (self.ci, self.get_ndet ()): - c = np.asarray (c).reshape (-1, n[0], n[1]) - lroots.append (c.shape[0]) + if lroots is None: lroots = self.get_lroots () log = logger.new_logger (self, self.verbose) fmt_str = " {:4s} {:>11s} {:>4s} {:>3s}" header = fmt_str.format ("Frag", "Nelec,Norb", "2S+1", "Ir") @@ -287,9 +292,19 @@ def get_fcisolvers (self): fcisolvers.append (solver) return fcisolvers - def get_product_state_solver (self): + def get_product_state_solver (self, lroots=None, lweights='gs'): fcisolvers = self.get_fcisolvers () - return ProductStateFCISolver (fcisolvers, stdout=self.stdout, verbose=self.verbose) + if lroots is None: lroots = self.get_lroots () + lw = [np.zeros (l) for l in lroots] + if 'gs' in lweights.lower (): + for l in lw: l[0] = 1.0 + elif 'sa' in lweights.lower (): + for l in lw: l[:] = 1.0/len (l) + else: + raise RuntimeError ('valid lweights are "gs" and "sa"') + lweights=lw + return ImpureProductStateFCISolver (fcisolvers, stdout=self.stdout, lweights=lweights, + verbose=self.verbose) def all_single_excitations (las, verbose=None):