Skip to content

Commit

Permalink
LASSI spin_shuffle_ci fn
Browse files Browse the repository at this point in the history
Generates sz-rotated CI vectors directly from origin CI vectors.
  • Loading branch information
MatthewRHermes committed Sep 21, 2023
1 parent abeebfe commit 1a735e3
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 4 deletions.
84 changes: 82 additions & 2 deletions my_pyscf/lassi/states.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import numpy as np
from pyscf.fci.direct_spin1 import _unpack_nelec
from pyscf.fci import cistring
from pyscf.lib import logger
from pyscf.lo.orth import vec_lowdin
from mrh.my_pyscf.fci.spin_op import contract_sdown, contract_sup
from mrh.my_pyscf.fci.csfstring import CSFTransformer
from mrh.my_pyscf.fci.csfstring import ImpossibleSpinError
import itertools

class SingleLASRootspace (object):
def __init__(self, las, spins, smults, charges, weight, nlas=None, nelelas=None, stdout=None,
verbose=None):
verbose=None, ci=None):
if nlas is None: nlas = las.ncas_sub
if nelelas is None: nelelas = [sum (_unpack_nelec (x)) for x in las.nelecas_sub]
if stdout is None: stdout = las.stdout
Expand All @@ -19,7 +22,8 @@ def __init__(self, las, spins, smults, charges, weight, nlas=None, nelelas=None,
self.charges = np.asarray (charges)
self.weight = weight
self.stdout, self.verbose = stdout, verbose

self.ci = ci

self.nelec = self.nelelas - self.charges
self.neleca = (self.nelec + self.spins) // 2
self.nelecb = (self.nelec - self.spins) // 2
Expand Down Expand Up @@ -146,7 +150,45 @@ def gen_spin_shuffles (self):
yield SingleLASRootspace (self.las, spins, self.smults, self.charges, 0, nlas=self.nlas,
nelelas=self.nelelas, stdout=self.stdout, verbose=self.verbose)

def has_ci (self):
if self.ci is None: return False
return all ([c is not None for c in self.ci])

def get_ci_szrot (self):
'''Generate the sets of CI vectors in which each vector for each fragment
has the sz axis rotated in all possible ways.
Returns:
ci_sz: list of dict of type {integer: ndarray}
dict keys are integerified "spin" quantum numbers; i.e., neleca-nelecb.
dict vals are the corresponding CI vectors
'''
ci_sz = []
for ifrag in range (self.nfrag):
norb, sz, ci = self.nlas[ifrag], self.spins[ifrag], self.ci[ifrag]
nelec = self.neleca[ifrag], self.nelecb[ifrag]
smult = self.smults[ifrag]
ci_sz_ = {sz: ci}
ci1 = ci
nelec1 = nelec
for sz1 in range (sz-2, -(1+smult), -2):
ci1 = contract_sdown (ci1, norb, nelec1)
nelec1 = nelec1[0]-1, nelec1[1]+1
ci_sz_[sz1] = ci1
ci1 = ci
nelec1 = nelec
for sz1 in range (sz+2, (1+smult), 2):
ci1 = contract_sup (ci1, norb, nelec1)
nelec1 = nelec1[0]+1, nelec1[1]-1
ci_sz_[sz1] = ci1
ci_sz.append (ci_sz_)
return ci_sz

def get_ndet (self):
return [(cistring.num_strings (self.nlas[i], self.neleca[i]),
cistring.num_strings (self.nlas[i], self.nelecb[i]))
for i in range (self.nfrag)]

def all_single_excitations (las, verbose=None):
'''Add states characterized by one electron hopping from one fragment to another fragment
in all possible ways. Uses all states already present as reference states, so that calling
Expand Down Expand Up @@ -212,6 +254,44 @@ def spin_shuffle (las, verbose=None):
log.warn ("no spin-shuffling options found for given LAS states")
return las.state_average (weights=weights, charges=charges, spins=spins, smults=smults)

def spin_shuffle_ci (las, ci):
from mrh.my_pyscf.mcscf.lasci import get_space_info
spaces = [SingleLASRootspace (las, m, s, c, 0, ci=[c[ix] for c in ci])
for ix, (c, m, s, w) in enumerate (zip (*get_space_info (las)))]
old_ci_sz = []
old_idx = []
new_idx = []
nfrag = las.nfrags
for ix, space in enumerate (spaces):
if space.has_ci ():
old_idx.append (ix)
old_ci_sz.append (space.get_ci_szrot ())
else:
new_idx.append (ix)
def is_spin_shuffle_ref (sp1, sp2):
return (np.all (sp1.charges==sp2.charges) and
np.all (sp1.smults==sp2.smults))
for ix in new_idx:
ndet = spaces[ix].get_ndet ()
ci_ix = [np.zeros ((0,ndet[i][0],ndet[i][1]))
for i in range (nfrag)]
for ci_sz, jx in zip (old_ci_sz, old_idx):
if not is_spin_shuffle_ref (spaces[ix], spaces[jx]): continue
for ifrag in range (nfrag):
c = ci_sz[ifrag][spaces[ix].spins[ifrag]]
if c.ndim < 3: c = c[None,:,:]
ci_ix[ifrag] = np.append (ci_ix[ifrag], c, axis=0)
for ifrag in range (nfrag):
if ci_ix[ifrag].size==0:
ci[ifrag][ix] = None
continue
lroots, ndeti = ci_ix[ifrag].shape[0], ndet[ifrag]
if lroots > 1:
c = vec_lowdin (ci_ix[ifrag].reshape (lroots, ndeti[0]*ndeti[1]))
ci_ix[ifrag] = c.reshape (lroots, ndeti[0], ndeti[1])
ci[ifrag][ix] = ci_ix[ifrag]
return ci

def count_excitations (las0):
log = logger.new_logger (las0, las0.verbose)
t = (logger.process_clock(), logger.perf_counter ())
Expand Down
11 changes: 9 additions & 2 deletions tests/lassi/test_c2h4n4.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,10 @@ def test_singles_constructor (self):
self.assertEqual (las2.nroots, 33)

def test_spin_shuffle (self):
from mrh.my_pyscf.lassi.states import spin_shuffle
from mrh.my_pyscf.lassi.states import spin_shuffle, spin_shuffle_ci
mf = lsi._las._scf
las3 = LASSCF (mf, (4,2,4), (4,2,4), spin_sub=(5,3,5))
las3.lasci ()
las3 = spin_shuffle (las3)
las3.check_sanity ()
# The number of states is the number of graphs connecting one number
Expand All @@ -145,7 +146,13 @@ def test_spin_shuffle (self):
# and three paths each sum to -1, 0, +1. Each partial sum then has one
# remaining option to complete the path, so
# 2 + 3 + 3 + 3 + 2 = 13
self.assertEqual (las3.nroots, 13)
with self.subTest ("state construction"):
self.assertEqual (las3.nroots, 13)
las3.ci = spin_shuffle_ci (las3, las3.ci)
lsi2 = LASSI (las3).run ()
errvec = lsi2.s2 - np.around (lsi2.s2)
with self.subTest ("CI vector rotation"):
self.assertLess (np.amax (np.abs (errvec)), 1e-8)

if __name__ == "__main__":
print("Full Tests for SA-LASSI of c2h4n4 molecule")
Expand Down

0 comments on commit 1a735e3

Please sign in to comment.