Skip to content

Commit

Permalink
lassis "fullauto" bugfix counting spin states
Browse files Browse the repository at this point in the history
count permitted spin states based on m=s coordinate frame, not
whatever "spaces[0]" happens to be
  • Loading branch information
MatthewRHermes committed Oct 27, 2023
1 parent cf58b76 commit 0b6cb1f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
10 changes: 7 additions & 3 deletions my_pyscf/lassi/lassis.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,8 @@ def all_spin_halfexcitations (lsi, las, nspin=1):
f1 = f1[None,:,:] - np.tensordot (casdm1s, h2, axes=((1,2),(2,1)))
i = 0
auto_singles = isinstance (nspin, str) and 's' in nspin.lower ()
nup0 = np.minimum (spaces[0].nelecb, spaces[0].nholea)
ndn0 = np.minimum (spaces[0].neleca, spaces[0].nholeb)
nup0 = np.minimum (spaces[0].nelecd, spaces[0].nholeu)
ndn0 = np.minimum (spaces[0].nelecu, spaces[0].nholed)
if not auto_singles: # integer supplied by caller
nup0[:] = nspin
ndn0[:] = nspin
Expand Down Expand Up @@ -214,7 +214,11 @@ def cisolve (sm, nroots):
ifrag, nelec, norb, smult-2)
smults1_i.extend ([smult-2,]*(smult-2))
spins1_i.extend (list (range (smult-3, -(smult-3)-1, -2)))
ci1_i.extend (cisolve (smult-2, ndn0[ifrag]))
try:
ci1_i.extend (cisolve (smult-2, ndn0[ifrag]))
except ValueError as err:
print (ndn0[ifrag], nelec, norb, smult)
raise (err)
min_npair = max (0, nelec-norb)
max_smult = (nelec - 2*min_npair) + 1
if smult < max_smult: # spin-raised
Expand Down
6 changes: 6 additions & 0 deletions my_pyscf/lassi/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@ def __init__(self, las, spins, smults, charges, weight, nlas=None, nelelas=None,
self.nholea = self.nlas - self.neleca
self.nholeb = self.nlas - self.nelecb

# "u", "d": like "a", "b", but presuming spins+1==smults everywhere
self.nelecu = (self.nelec + (self.smults-1)) // 2
self.nelecd = (self.nelec + (self.smults-1)) // 2
self.nholeu = self.nlas - self.nelecu
self.nholed = self.nlas - self.nelecd

def __eq__(self, other):
if self.nfrag != other.nfrag: return False
return (np.all (self.spins==other.spins) and
Expand Down

0 comments on commit 0b6cb1f

Please sign in to comment.