Skip to content

Commit

Permalink
Merge branch 'dev' into lassis_ncharge_norb
Browse files Browse the repository at this point in the history
  • Loading branch information
MatthewRHermes committed Apr 24, 2024
2 parents 522d9b5 + fec6f8a commit c5c013a
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 89 deletions.
17 changes: 10 additions & 7 deletions my_pyscf/lassi/lassi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from mrh.my_pyscf.lassi import op_o1
from mrh.my_pyscf.lassi.citools import get_lroots
from pyscf import lib, symm
from pyscf.scf.addons import canonical_orth_
from pyscf.lib.numpy_helper import tag_array
from pyscf.fci.direct_spin1 import _unpack_nelec
from itertools import combinations, product
Expand Down Expand Up @@ -418,16 +419,18 @@ def _eig_block (las, e0, h1, h2, ci_blk, nelec_blk, rootsym, soc, orbsym, wfnsym
# Error catch: linear dependencies in basis
try:
e, c = linalg.eigh (ham_blk, b=ovlp_blk)
except linalg.LinAlgError as e:
except linalg.LinAlgError as err:
ovlp_det = linalg.det (ovlp_blk)
lc = 'checking if LASSI basis has lindeps: |ovlp| = {:.6e}'.format (ovlp_det)
lib.logger.info (las, 'Caught error %s, %s', str (e), lc)
lib.logger.info (las, 'Caught error %s, %s', str (err), lc)
if ovlp_det < LINDEP_THRESH:
err_str = ('LASSI basis appears to have linear dependencies; '
'double-check your state list.\n'
'|ovlp| = {:.6e}').format (ovlp_det)
raise RuntimeError (err_str) from e
else: raise (e) from None
x = canonical_orth_(ovlp_blk, thr=LINDEP_THRESH)
lib.logger.info (las, '%d/%d linearly independent model states',
x.shape[1], x.shape[0])
xhx = x.conj ().T @ ham_blk @ x
e, c = linalg.eigh (xhx)
c = x @ c
else: raise (err) from None
return e, c, s2_blk

def make_stdm12s (las, ci=None, orbsym=None, soc=False, break_symmetry=False, opt=1):
Expand Down
100 changes: 46 additions & 54 deletions my_pyscf/lassi/lassis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from mrh.my_pyscf.mcscf.productstate import ProductStateFCISolver
from mrh.my_pyscf.lassi.excitations import ExcitationPSFCISolver
from mrh.my_pyscf.lassi.states import spin_shuffle, spin_shuffle_ci
from mrh.my_pyscf.lassi.states import _spin_shuffle, _spin_shuffle_ci_
from mrh.my_pyscf.lassi.states import all_single_excitations, SingleLASRootspace
from mrh.my_pyscf.lassi.states import orthogonal_excitations, combine_orthogonal_excitations
from mrh.my_pyscf.lassi.lassi import LASSI
Expand Down Expand Up @@ -62,13 +63,24 @@ def prepare_states (lsi, ncharge=1, nspin=0, sa_heff=True, deactivate_vrv=False,
)
else:
converged, las2 = las1.converged, las1
# TODO: make all_single_excitations and single_excitations_ci return spaces2 instead of
# las2, so that you can delete this
spaces2 = [SingleLASRootspace (las2, m, s, c, las2.weights[ix], ci=[c[ix] for c in las2.ci])
for ix, (c, m, s, w) in enumerate (zip (*get_space_info (las2)))]
if lsi.nfrags > 3:
las2 = charge_excitation_products (lsi, las2, las1)
spaces2 = charge_excitation_products (lsi, spaces2, las1)
# 4. Spin excitations part 2
if nspin:
las3 = spin_flip_products (las2, spin_flips, nroots_ref=nroots_ref)
spaces3 = spin_flip_products (las1, spaces2, spin_flips, nroots_ref=nroots_ref)
else:
las3 = las2
spaces3 = spaces2
weights = [space.weight for space in spaces3]
charges = [space.charges for space in spaces3]
spins = [space.spins for space in spaces3]
smults = [space.smults for space in spaces3]
ci3 = [[space.ci[ifrag] for space in spaces3] for ifrag in range (lsi.nfrags)]
las3 = las2.state_average (weights=weights, charges=charges, spins=spins, smults=smults, assert_no_dupes=False)
las3.ci = ci3
las3.lasci (_dry_run=True)
log.timer ("LASSIS model space preparation", *t0)
return converged, las3
Expand Down Expand Up @@ -306,71 +318,51 @@ def _spin_flip_products (spaces, spin_flips, nroots_ref=1, frozen_frags=None):
spaces = [space for space in spaces if not ((space in seen) or seen.add (space))]
return spaces

def spin_flip_products (las2, spin_flips, nroots_ref=1):
def spin_flip_products (las, spaces, spin_flips, nroots_ref=1):
'''Inject spin-flips into las2 in all possible ways'''
log = logger.new_logger (las2, las2.verbose)
spaces = [SingleLASRootspace (las2, m, s, c, las2.weights[ix], ci=[c[ix] for c in las2.ci])
for ix, (c, m, s, w) in enumerate (zip (*get_space_info (las2)))]
log = logger.new_logger (las, las.verbose)
las2_nroots = len (spaces)
spaces = _spin_flip_products (spaces, spin_flips, nroots_ref=nroots_ref)
nfrags = spaces[0].nfrag
weights = [space.weight for space in spaces]
charges = [space.charges for space in spaces]
spins = [space.spins for space in spaces]
smults = [space.smults for space in spaces]
ci3 = [[space.ci[ifrag] for space in spaces] for ifrag in range (nfrags)]
las3 = las2.state_average (weights=weights, charges=charges, spins=spins, smults=smults)
las3.ci = ci3
if las3.nfrags > 2: # A second spin shuffle to get the coupled spin-charge excitations
las3 = spin_shuffle (las3)
las3.ci = spin_shuffle_ci (las3, las3.ci)
spaces = [SingleLASRootspace (las3, m, s, c, las3.weights[ix], ci=[c[ix] for c in las3.ci])
for ix, (c, m, s, w) in enumerate (zip (*get_space_info (las3)))]
log.info ("LASSIS spin-excitation spaces: %d-%d", las2.nroots, las3.nroots-1)
for i, space in enumerate (spaces[las2.nroots:]):
spaces = _spin_shuffle (spaces)
spaces = _spin_shuffle_ci_(spaces)
log.info ("LASSIS spin-excitation spaces: %d-%d", las2_nroots, len (spaces)-1)
for i, space in enumerate (spaces[las2_nroots:]):
if np.any (space.nelec != spaces[0].nelec):
log.info ("Spin/charge-excitation space %d:", i+las2.nroots)
log.info ("Spin/charge-excitation space %d:", i+las2_nroots)
else:
log.info ("Spin-excitation space %d:", i+las2.nroots)
log.info ("Spin-excitation space %d:", i+las2_nroots)
space.table_printlog ()
return las3
return spaces

def charge_excitation_products (lsi, las2, las1):
def charge_excitation_products (lsi, spaces, las1):
t0 = (logger.process_clock (), logger.perf_counter ())
log = logger.new_logger (lsi, lsi.verbose)
mol = lsi.mol
nfrags = lsi.nfrags
spaces = [SingleLASRootspace (las2, m, s, c, las2.weights[ix], ci=[c[ix] for c in las2.ci])
for ix, (c, m, s, w) in enumerate (zip (*get_space_info (las2)))]
space0 = spaces[0]
i0, j0 = i, j = las1.nroots, las2.nroots
space1 = spaces[i:j]
for _ in range (1, nfrags//2):
i0, j0 = i, j = las1.nroots, len (spaces)
for space1 in spaces[i:j]:
space1.set_entmap_(space0)
for product_order in range (2, (nfrags//2)+1):
seen = set ()
for ip,iq in itertools.product (range (i,j), range (i0,j0)):
if ip <= iq: continue
p, q = spaces[ip], spaces[iq]
if not orthogonal_excitations (p, q, space0): continue
r = combine_orthogonal_excitations (p, q, space0)
if r in seen:
s = spaces[spaces.index (r)]
s.merge_(r, ref=space0)
continue
ir = len (spaces)
spaces.append (r)
seen.add (r)
i, j = j, len (spaces)
for ir in range (i, j):
log.info ("Electron hop product %d space %d:", _+1, ir)
spaces[ir].table_printlog ()
weights = [space.weight for space in spaces]
charges = [space.charges for space in spaces]
spins = [space.spins for space in spaces]
smults = [space.smults for space in spaces]
ci3 = [[space.ci[ifrag] for space in spaces] for ifrag in range (nfrags)]
las3 = las2.state_average (weights=weights, charges=charges, spins=spins, smults=smults)
las3.ci = ci3
for i_list in itertools.combinations (range (i,j), product_order):
p_list = [spaces[ip] for ip in i_list]
nonorth = False
for p, q in itertools.combinations (p_list, 2):
if not orthogonal_excitations (p, q, space0):
nonorth = True
break
if nonorth: continue
p = p_list[0]
for q in p_list[1:]:
p = combine_orthogonal_excitations (p, q, space0)
spaces.append (p)
log.info ("Electron hop product space %d (product of %s)", len (spaces) - 1, str (i_list))
spaces[-1].table_printlog ()
assert (len (spaces) == len (set (spaces)))
log.timer ("LASSIS charge-hop product generation", *t0)
return las3
return spaces

def as_scanner(lsi):
'''Generating a scanner for LASSIS PES.
Expand Down
11 changes: 10 additions & 1 deletion my_pyscf/lassi/op_o1.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
from scipy import linalg
from pyscf import lib, fci
from pyscf.lib import logger
from pyscf.fci.direct_spin1 import _unpack_nelec, trans_rdm12s, contract_1e
Expand Down Expand Up @@ -371,7 +372,15 @@ def _init_crunch_(self):
if not self.root_unique[j]: continue
if self.nelec_r[i] != self.nelec_r[j]: continue
if ci[i].shape != ci[j].shape: continue
if np.all (ci[i] == ci[j]):
isequal = False
if np.all (ci[i]==ci[j]): isequal = True
elif np.all (np.abs (ci[i]-ci[j]) < 1e-8): isequal=True
else:
ci_i = ci[i].reshape (lroots[i],-1)
ci_j = ci[j].reshape (lroots[j],-1)
ovlp = ci_i.conj () @ ci_j.T
isequal = np.allclose (ovlp.diagonal (), 1)
if isequal:
self.root_unique[j] = False
self.unique_root[j] = i
self.onep_index[i] |= self.onep_index[j]
Expand Down
79 changes: 54 additions & 25 deletions my_pyscf/lassi/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ def __init__(self, las, spins, smults, charges, weight, nlas=None, nelelas=None,
self.nholeu = self.nlas - self.nelecu
self.nholed = self.nlas - self.nelecd

self.entmap = tuple ()

def __eq__(self, other):
if self.nfrag != other.nfrag: return False
return (np.all (self.spins==other.spins) and
Expand All @@ -51,7 +53,7 @@ def __eq__(self, other):

def __hash__(self):
return hash (tuple ([self.nfrag,] + list (self.spins) + list (self.smults)
+ list (self.charges)))
+ list (self.charges) + list (self.entmap)))

def possible_excitation (self, i, a, s):
i, a, s = np.atleast_1d (i, a, s)
Expand Down Expand Up @@ -159,8 +161,10 @@ def gen_spin_shuffles (self):
idx_valid = np.all (spins_table>-self.smults[None,:], axis=1)
spins_table = spins_table[idx_valid,:]
for spins in spins_table:
yield SingleLASRootspace (self.las, spins, self.smults, self.charges, 0, nlas=self.nlas,
nelelas=self.nelelas, stdout=self.stdout, verbose=self.verbose)
sp = SingleLASRootspace (self.las, spins, self.smults, self.charges, 0, nlas=self.nlas,
nelelas=self.nelelas, stdout=self.stdout, verbose=self.verbose)
sp.entmap = self.entmap
yield sp

def has_ci (self):
if self.ci is None: return False
Expand Down Expand Up @@ -246,6 +250,14 @@ def describe_single_excitation (self, other):
lroots_s = min (self.nlas[src_frag], self.nlas[dest_frag])
return src_frag, dest_frag, e_spin, src_ds, dest_ds, lroots_s

def set_entmap_(self, ref):
idx = np.where (self.excited_fragments (ref))[0]
idx = tuple (set (idx))
self.entmap = tuple ((idx,))
#self.entmap[:,:] = 0
#for i, j in itertools.combinations (idx, 2):
# self.entmap[i,j] = self.entmap[j,i] = 1

def single_excitation_description_string (self, other):
src, dest, e_spin, src_ds, dest_ds, lroots_s = self.describe_single_excitation (other)
fmt_str = '{:d}({:s}) --{:s}--> {:d}({:s}) ({:d} lroots)'
Expand Down Expand Up @@ -313,9 +325,11 @@ def single_fragment_spin_change (self, ifrag, new_smult, new_spin, ci=None):
if ci is not None:
ci1 = [c for c in self.ci]
ci1[ifrag] = ci
return SingleLASRootspace (self.las, spins1, smults1, self.charges, 0, nlas=self.nlas,
nelelas=self.nelelas, stdout=self.stdout, verbose=self.verbose,
ci=ci1)
sp = SingleLASRootspace (self.las, spins1, smults1, self.charges, 0, nlas=self.nlas,
nelelas=self.nelelas, stdout=self.stdout, verbose=self.verbose,
ci=ci1)
sp.entmap = self.entmap
return sp

def is_orthogonal_by_smult (self, other):
if isinstance (other, (list, tuple)):
Expand Down Expand Up @@ -397,6 +411,9 @@ def combine_orthogonal_excitations (exc1, exc2, ref):
ref.las, spins, smults, charges, 0, ci=ci,
nlas=ref.nlas, nelelas=ref.nelelas, stdout=ref.stdout, verbose=ref.verbose
)
product.entmap = tuple (set (exc1.entmap + exc2.entmap))
#assert (np.amax (product.entmap) < 2)
assert (len (product.entmap) == len (set (product.entmap)))
return product

def all_single_excitations (las, verbose=None):
Expand All @@ -418,16 +435,16 @@ def all_single_excitations (las, verbose=None):
new_states.extend (ref_state.get_singles ())
seen = set (ref_states)
all_states = ref_states + [state for state in new_states if not ((state in seen) or seen.add (state))]
weights = [state.weight for state in all_states]
charges = [state.charges for state in all_states]
spins = [state.spins for state in all_states]
smults = [state.smults for state in all_states]
#wfnsyms = [state.wfnsyms for state in all_states]
log.info ('Built {} singly-excited LAS states from {} reference LAS states'.format (
len (all_states) - len (ref_states), len (ref_states)))
if len (all_states) == len (ref_states):
log.warn (("%d reference LAS states exhaust current active space specifications; "
"no singly-excited states could be constructed"), len (ref_states))
weights = [state.weight for state in all_states]
charges = [state.charges for state in all_states]
spins = [state.spins for state in all_states]
smults = [state.smults for state in all_states]
#wfnsyms = [state.wfnsyms for state in all_states]
return las.state_average (weights=weights, charges=charges, spins=spins, smults=smults)

def spin_shuffle (las, verbose=None, equal_weights=False):
Expand All @@ -446,16 +463,8 @@ def spin_shuffle (las, verbose=None, equal_weights=False):
raise NotImplementedError ("Point-group symmetry for LASSI state generator")
ref_states = [SingleLASRootspace (las, m, s, c, 0) for c,m,s,w in zip (*get_space_info (las))]
for weight, state in zip (las.weights, ref_states): state.weight = weight
seen = set (ref_states)
all_states = [state for state in ref_states]
for ref_state in ref_states:
for new_state in ref_state.gen_spin_shuffles ():
if not new_state in seen:
all_states.append (new_state)
seen.add (new_state)
all_states = _spin_shuffle (ref_states, equal_weights=equal_weights)
weights = [state.weight for state in all_states]
if equal_weights:
weights = [1.0/len(all_states),]*len(all_states)
charges = [state.charges for state in all_states]
spins = [state.spins for state in all_states]
smults = [state.smults for state in all_states]
Expand All @@ -466,6 +475,19 @@ def spin_shuffle (las, verbose=None, equal_weights=False):
log.warn ("no spin-shuffling options found for given LAS states")
return las.state_average (weights=weights, charges=charges, spins=spins, smults=smults)

def _spin_shuffle (ref_spaces, equal_weights=False):
seen = set (ref_spaces)
all_spaces = [space for space in ref_spaces]
for ref_space in ref_spaces:
for new_space in ref_space.gen_spin_shuffles ():
if not new_space in seen:
all_spaces.append (new_space)
seen.add (new_space)
if equal_weights:
w = 1.0/len(all_spaces)
for space in all_spaces: space.weight = w
return all_spaces

def spin_shuffle_ci (las, ci):
'''Fill out the CI vectors for rootspaces constructed by the spin_shuffle function.
Unallocated CI vectors (None elements in ci) for rootspaces which have the same
Expand All @@ -477,19 +499,26 @@ def spin_shuffle_ci (las, ci):
from mrh.my_pyscf.mcscf.lasci import get_space_info
spaces = [SingleLASRootspace (las, m, s, c, 0, ci=[c[ix] for c in ci])
for ix, (c, m, s, w) in enumerate (zip (*get_space_info (las)))]
spaces = _spin_shuffle_ci_(spaces)
ci = [[space.ci[ifrag] for space in spaces] for ifrag in range (las.nfrags)]
return ci

def _spin_shuffle_ci_(spaces):
old_ci_sz = []
old_idx = []
new_idx = []
nfrag = las.nfrags
nfrag = spaces[0].nfrag
for ix, space in enumerate (spaces):
if space.has_ci ():
old_idx.append (ix)
old_ci_sz.append (space.get_ci_szrot ())
else:
new_idx.append (ix)
space.ci = [None for ifrag in range (space.nfrag)]
def is_spin_shuffle_ref (sp1, sp2):
return (np.all (sp1.charges==sp2.charges) and
np.all (sp1.smults==sp2.smults))
np.all (sp1.smults==sp2.smults) and
sp1.entmap==sp2.entmap)
for ix in new_idx:
ndet = spaces[ix].get_ndet ()
ci_ix = [np.zeros ((0,ndet[i][0],ndet[i][1]))
Expand All @@ -502,7 +531,7 @@ def is_spin_shuffle_ref (sp1, sp2):
ci_ix[ifrag] = np.append (ci_ix[ifrag], c, axis=0)
for ifrag in range (nfrag):
if ci_ix[ifrag].size==0:
ci[ifrag][ix] = None
spaces[ix].ci[ifrag] = None
continue
lroots, ndeti = ci_ix[ifrag].shape[0], ndet[ifrag]
if lroots > 1:
Expand All @@ -513,8 +542,8 @@ def is_spin_shuffle_ref (sp1, sp2):
v = v[:,idx] / np.sqrt (w[idx])[None,:]
c = (c @ v).T
ci_ix[ifrag] = c.reshape (-1, ndeti[0], ndeti[1])
ci[ifrag][ix] = ci_ix[ifrag]
return ci
spaces[ix].ci[ifrag] = ci_ix[ifrag]
return spaces

def count_excitations (las0):
log = logger.new_logger (las0, las0.verbose)
Expand Down
2 changes: 1 addition & 1 deletion pyscf-forge_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
git+https://github.com/pyscf/pyscf-forge.git@f817911cb0c984207d9309a1d7347f6dbb46a9c6
git+https://github.com/pyscf/pyscf-forge.git@6fa7530498f434404a323bd7b116a04d8f7c1f12
2 changes: 1 addition & 1 deletion pyscf_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
git+https://github.com/pyscf/pyscf.git@d57f1d6c89c723e11a7f0933380a6139ba372554
git+https://github.com/pyscf/pyscf.git@6d3b24bb64e2a5edb7990b6e3304068981a33f54

0 comments on commit c5c013a

Please sign in to comment.