Skip to content

Commit

Permalink
lassis impure solvers as ref solvers for exc
Browse files Browse the repository at this point in the history
Necessary for spin-excitation as charge-hop ref
  • Loading branch information
MatthewRHermes committed Oct 13, 2023
1 parent 06f3c81 commit ac3097a
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 14 deletions.
6 changes: 0 additions & 6 deletions my_pyscf/lassi/lassis.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,7 @@ def single_excitations_ci (lsi, las2, las1, nmax_charge=1, sa_heff=True, deactiv
mol = lsi.mol
nfrags = lsi.nfrags
e_roots = np.append (las1.e_states, np.zeros (las2.nroots-las1.nroots))
#psrefs = []
ci = [[ci_ij for ci_ij in ci_i] for ci_i in las2.ci]
#for j in range (las1.nroots):
# solvers = [b.fcisolvers[j] for b in las1.fciboxes]
# psrefs.append (ProductStateFCISolver (solvers, stdout=mol.stdout, verbose=mol.verbose))
spaces = [SingleLASRootspace (las2, m, s, c, las2.weights[ix], ci=[c[ix] for c in ci])
for ix, (c, m, s, w) in enumerate (zip (*get_space_info (las2)))]
ncsf = las2.get_ugg ().ncsf_sub
Expand All @@ -87,8 +83,6 @@ def single_excitations_ci (lsi, las2, las1, nmax_charge=1, sa_heff=True, deactiv
dest_frag, dest_ds)
excfrags[spaces[i].excited_fragments (spaces[j])] = True
psref.append (spaces[j])
#for k in range (nfrags):
# ciref[k].append (las1.ci[k][j])
#psref = _spin_halfexcitation_products (psref, spin_halfexcs, nroots_ref=len(psref),
# frozen_frags=(~excfrags))
ciref = [[] for j in range (nfrags)]
Expand Down
31 changes: 23 additions & 8 deletions my_pyscf/lassi/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from mrh.my_pyscf.fci.spin_op import contract_sdown, contract_sup
from mrh.my_pyscf.fci.csfstring import CSFTransformer
from mrh.my_pyscf.fci.csfstring import ImpossibleSpinError
from mrh.my_pyscf.mcscf.productstate import ProductStateFCISolver
from mrh.my_pyscf.mcscf.productstate import ImpureProductStateFCISolver
import itertools

class SingleLASRootspace (object):
Expand Down Expand Up @@ -228,12 +228,17 @@ def excited_fragments (self, other):
idx_same = (dneleca==0) & (dnelecb==0) & (dsmults==0)
return ~idx_same

def get_lroots (self):
if not self.has_ci (): return None
lroots = []
for c, n in zip (self.ci, self.get_ndet ()):
c = np.asarray (c).reshape (-1, n[0], n[1])
lroots.append (c.shape[0])
return lroots


def table_printlog (self, lroots=None):
if lroots is None and self.has_ci ():
lroots = []
for c, n in zip (self.ci, self.get_ndet ()):
c = np.asarray (c).reshape (-1, n[0], n[1])
lroots.append (c.shape[0])
if lroots is None: lroots = self.get_lroots ()
log = logger.new_logger (self, self.verbose)
fmt_str = " {:4s} {:>11s} {:>4s} {:>3s}"
header = fmt_str.format ("Frag", "Nelec,Norb", "2S+1", "Ir")
Expand Down Expand Up @@ -287,9 +292,19 @@ def get_fcisolvers (self):
fcisolvers.append (solver)
return fcisolvers

def get_product_state_solver (self):
def get_product_state_solver (self, lroots=None, lweights='gs'):
fcisolvers = self.get_fcisolvers ()
return ProductStateFCISolver (fcisolvers, stdout=self.stdout, verbose=self.verbose)
if lroots is None: lroots = self.get_lroots ()
lw = [np.zeros (l) for l in lroots]
if 'gs' in lweights.lower ():
for l in lw: l[0] = 1.0
elif 'sa' in lweights.lower ():
for l in lw: l[:] = 1.0/len (l)
else:
raise RuntimeError ('valid lweights are "gs" and "sa"')
lweights=lw
return ImpureProductStateFCISolver (fcisolvers, stdout=self.stdout, lweights=lweights,
verbose=self.verbose)


def all_single_excitations (las, verbose=None):
Expand Down

0 comments on commit ac3097a

Please sign in to comment.