diff --git a/my_pyscf/lassi/lassi.py b/my_pyscf/lassi/lassi.py index 73b6df72..a32aa605 100644 --- a/my_pyscf/lassi/lassi.py +++ b/my_pyscf/lassi/lassi.py @@ -5,6 +5,7 @@ from mrh.my_pyscf.lassi import op_o1 from mrh.my_pyscf.lassi.citools import get_lroots from pyscf import lib, symm +from pyscf.scf.addons import canonical_orth_ from pyscf.lib.numpy_helper import tag_array from pyscf.fci.direct_spin1 import _unpack_nelec from itertools import combinations, product @@ -418,16 +419,18 @@ def _eig_block (las, e0, h1, h2, ci_blk, nelec_blk, rootsym, soc, orbsym, wfnsym # Error catch: linear dependencies in basis try: e, c = linalg.eigh (ham_blk, b=ovlp_blk) - except linalg.LinAlgError as e: + except linalg.LinAlgError as err: ovlp_det = linalg.det (ovlp_blk) lc = 'checking if LASSI basis has lindeps: |ovlp| = {:.6e}'.format (ovlp_det) - lib.logger.info (las, 'Caught error %s, %s', str (e), lc) + lib.logger.info (las, 'Caught error %s, %s', str (err), lc) if ovlp_det < LINDEP_THRESH: - err_str = ('LASSI basis appears to have linear dependencies; ' - 'double-check your state list.\n' - '|ovlp| = {:.6e}').format (ovlp_det) - raise RuntimeError (err_str) from e - else: raise (e) from None + x = canonical_orth_(ovlp_blk, thr=LINDEP_THRESH) + lib.logger.info (las, '%d/%d linearly independent model states', + x.shape[1], x.shape[0]) + xhx = x.conj ().T @ ham_blk @ x + e, c = linalg.eigh (xhx) + c = x @ c + else: raise (err) from None return e, c, s2_blk def make_stdm12s (las, ci=None, orbsym=None, soc=False, break_symmetry=False, opt=1): diff --git a/my_pyscf/lassi/lassis.py b/my_pyscf/lassi/lassis.py index ab642ef6..39fad336 100644 --- a/my_pyscf/lassi/lassis.py +++ b/my_pyscf/lassi/lassis.py @@ -12,6 +12,7 @@ from mrh.my_pyscf.mcscf.productstate import ProductStateFCISolver from mrh.my_pyscf.lassi.excitations import ExcitationPSFCISolver from mrh.my_pyscf.lassi.states import spin_shuffle, spin_shuffle_ci +from mrh.my_pyscf.lassi.states import _spin_shuffle, _spin_shuffle_ci_ from mrh.my_pyscf.lassi.states import all_single_excitations, SingleLASRootspace from mrh.my_pyscf.lassi.states import orthogonal_excitations, combine_orthogonal_excitations from mrh.my_pyscf.lassi.lassi import LASSI @@ -62,13 +63,24 @@ def prepare_states (lsi, ncharge=1, nspin=0, sa_heff=True, deactivate_vrv=False, ) else: converged, las2 = las1.converged, las1 + # TODO: make all_single_excitations and single_excitations_ci return spaces2 instead of + # las2, so that you can delete this + spaces2 = [SingleLASRootspace (las2, m, s, c, las2.weights[ix], ci=[c[ix] for c in las2.ci]) + for ix, (c, m, s, w) in enumerate (zip (*get_space_info (las2)))] if lsi.nfrags > 3: - las2 = charge_excitation_products (lsi, las2, las1) + spaces2 = charge_excitation_products (lsi, spaces2, las1) # 4. Spin excitations part 2 if nspin: - las3 = spin_flip_products (las2, spin_flips, nroots_ref=nroots_ref) + spaces3 = spin_flip_products (las1, spaces2, spin_flips, nroots_ref=nroots_ref) else: - las3 = las2 + spaces3 = spaces2 + weights = [space.weight for space in spaces3] + charges = [space.charges for space in spaces3] + spins = [space.spins for space in spaces3] + smults = [space.smults for space in spaces3] + ci3 = [[space.ci[ifrag] for space in spaces3] for ifrag in range (lsi.nfrags)] + las3 = las2.state_average (weights=weights, charges=charges, spins=spins, smults=smults, assert_no_dupes=False) + las3.ci = ci3 las3.lasci (_dry_run=True) log.timer ("LASSIS model space preparation", *t0) return converged, las3 @@ -306,71 +318,51 @@ def _spin_flip_products (spaces, spin_flips, nroots_ref=1, frozen_frags=None): spaces = [space for space in spaces if not ((space in seen) or seen.add (space))] return spaces -def spin_flip_products (las2, spin_flips, nroots_ref=1): +def spin_flip_products (las, spaces, spin_flips, nroots_ref=1): '''Inject spin-flips into las2 in all possible ways''' - log = logger.new_logger (las2, las2.verbose) - spaces = [SingleLASRootspace (las2, m, s, c, las2.weights[ix], ci=[c[ix] for c in las2.ci]) - for ix, (c, m, s, w) in enumerate (zip (*get_space_info (las2)))] + log = logger.new_logger (las, las.verbose) + las2_nroots = len (spaces) spaces = _spin_flip_products (spaces, spin_flips, nroots_ref=nroots_ref) nfrags = spaces[0].nfrag - weights = [space.weight for space in spaces] - charges = [space.charges for space in spaces] - spins = [space.spins for space in spaces] - smults = [space.smults for space in spaces] - ci3 = [[space.ci[ifrag] for space in spaces] for ifrag in range (nfrags)] - las3 = las2.state_average (weights=weights, charges=charges, spins=spins, smults=smults) - las3.ci = ci3 - if las3.nfrags > 2: # A second spin shuffle to get the coupled spin-charge excitations - las3 = spin_shuffle (las3) - las3.ci = spin_shuffle_ci (las3, las3.ci) - spaces = [SingleLASRootspace (las3, m, s, c, las3.weights[ix], ci=[c[ix] for c in las3.ci]) - for ix, (c, m, s, w) in enumerate (zip (*get_space_info (las3)))] - log.info ("LASSIS spin-excitation spaces: %d-%d", las2.nroots, las3.nroots-1) - for i, space in enumerate (spaces[las2.nroots:]): + spaces = _spin_shuffle (spaces) + spaces = _spin_shuffle_ci_(spaces) + log.info ("LASSIS spin-excitation spaces: %d-%d", las2_nroots, len (spaces)-1) + for i, space in enumerate (spaces[las2_nroots:]): if np.any (space.nelec != spaces[0].nelec): - log.info ("Spin/charge-excitation space %d:", i+las2.nroots) + log.info ("Spin/charge-excitation space %d:", i+las2_nroots) else: - log.info ("Spin-excitation space %d:", i+las2.nroots) + log.info ("Spin-excitation space %d:", i+las2_nroots) space.table_printlog () - return las3 + return spaces -def charge_excitation_products (lsi, las2, las1): +def charge_excitation_products (lsi, spaces, las1): t0 = (logger.process_clock (), logger.perf_counter ()) log = logger.new_logger (lsi, lsi.verbose) mol = lsi.mol nfrags = lsi.nfrags - spaces = [SingleLASRootspace (las2, m, s, c, las2.weights[ix], ci=[c[ix] for c in las2.ci]) - for ix, (c, m, s, w) in enumerate (zip (*get_space_info (las2)))] space0 = spaces[0] - i0, j0 = i, j = las1.nroots, las2.nroots - space1 = spaces[i:j] - for _ in range (1, nfrags//2): + i0, j0 = i, j = las1.nroots, len (spaces) + for space1 in spaces[i:j]: + space1.set_entmap_(space0) + for product_order in range (2, (nfrags//2)+1): seen = set () - for ip,iq in itertools.product (range (i,j), range (i0,j0)): - if ip <= iq: continue - p, q = spaces[ip], spaces[iq] - if not orthogonal_excitations (p, q, space0): continue - r = combine_orthogonal_excitations (p, q, space0) - if r in seen: - s = spaces[spaces.index (r)] - s.merge_(r, ref=space0) - continue - ir = len (spaces) - spaces.append (r) - seen.add (r) - i, j = j, len (spaces) - for ir in range (i, j): - log.info ("Electron hop product %d space %d:", _+1, ir) - spaces[ir].table_printlog () - weights = [space.weight for space in spaces] - charges = [space.charges for space in spaces] - spins = [space.spins for space in spaces] - smults = [space.smults for space in spaces] - ci3 = [[space.ci[ifrag] for space in spaces] for ifrag in range (nfrags)] - las3 = las2.state_average (weights=weights, charges=charges, spins=spins, smults=smults) - las3.ci = ci3 + for i_list in itertools.combinations (range (i,j), product_order): + p_list = [spaces[ip] for ip in i_list] + nonorth = False + for p, q in itertools.combinations (p_list, 2): + if not orthogonal_excitations (p, q, space0): + nonorth = True + break + if nonorth: continue + p = p_list[0] + for q in p_list[1:]: + p = combine_orthogonal_excitations (p, q, space0) + spaces.append (p) + log.info ("Electron hop product space %d (product of %s)", len (spaces) - 1, str (i_list)) + spaces[-1].table_printlog () + assert (len (spaces) == len (set (spaces))) log.timer ("LASSIS charge-hop product generation", *t0) - return las3 + return spaces def as_scanner(lsi): '''Generating a scanner for LASSIS PES. diff --git a/my_pyscf/lassi/op_o1.py b/my_pyscf/lassi/op_o1.py index 85d77239..61f73691 100644 --- a/my_pyscf/lassi/op_o1.py +++ b/my_pyscf/lassi/op_o1.py @@ -1,4 +1,5 @@ import numpy as np +from scipy import linalg from pyscf import lib, fci from pyscf.lib import logger from pyscf.fci.direct_spin1 import _unpack_nelec, trans_rdm12s, contract_1e @@ -371,7 +372,15 @@ def _init_crunch_(self): if not self.root_unique[j]: continue if self.nelec_r[i] != self.nelec_r[j]: continue if ci[i].shape != ci[j].shape: continue - if np.all (ci[i] == ci[j]): + isequal = False + if np.all (ci[i]==ci[j]): isequal = True + elif np.all (np.abs (ci[i]-ci[j]) < 1e-8): isequal=True + else: + ci_i = ci[i].reshape (lroots[i],-1) + ci_j = ci[j].reshape (lroots[j],-1) + ovlp = ci_i.conj () @ ci_j.T + isequal = np.allclose (ovlp.diagonal (), 1) + if isequal: self.root_unique[j] = False self.unique_root[j] = i self.onep_index[i] |= self.onep_index[j] diff --git a/my_pyscf/lassi/states.py b/my_pyscf/lassi/states.py index 066e72ae..89bbe9f8 100644 --- a/my_pyscf/lassi/states.py +++ b/my_pyscf/lassi/states.py @@ -43,6 +43,8 @@ def __init__(self, las, spins, smults, charges, weight, nlas=None, nelelas=None, self.nholeu = self.nlas - self.nelecu self.nholed = self.nlas - self.nelecd + self.entmap = tuple () + def __eq__(self, other): if self.nfrag != other.nfrag: return False return (np.all (self.spins==other.spins) and @@ -51,7 +53,7 @@ def __eq__(self, other): def __hash__(self): return hash (tuple ([self.nfrag,] + list (self.spins) + list (self.smults) - + list (self.charges))) + + list (self.charges) + list (self.entmap))) def possible_excitation (self, i, a, s): i, a, s = np.atleast_1d (i, a, s) @@ -159,8 +161,10 @@ def gen_spin_shuffles (self): idx_valid = np.all (spins_table>-self.smults[None,:], axis=1) spins_table = spins_table[idx_valid,:] for spins in spins_table: - yield SingleLASRootspace (self.las, spins, self.smults, self.charges, 0, nlas=self.nlas, - nelelas=self.nelelas, stdout=self.stdout, verbose=self.verbose) + sp = SingleLASRootspace (self.las, spins, self.smults, self.charges, 0, nlas=self.nlas, + nelelas=self.nelelas, stdout=self.stdout, verbose=self.verbose) + sp.entmap = self.entmap + yield sp def has_ci (self): if self.ci is None: return False @@ -246,6 +250,14 @@ def describe_single_excitation (self, other): lroots_s = min (self.nlas[src_frag], self.nlas[dest_frag]) return src_frag, dest_frag, e_spin, src_ds, dest_ds, lroots_s + def set_entmap_(self, ref): + idx = np.where (self.excited_fragments (ref))[0] + idx = tuple (set (idx)) + self.entmap = tuple ((idx,)) + #self.entmap[:,:] = 0 + #for i, j in itertools.combinations (idx, 2): + # self.entmap[i,j] = self.entmap[j,i] = 1 + def single_excitation_description_string (self, other): src, dest, e_spin, src_ds, dest_ds, lroots_s = self.describe_single_excitation (other) fmt_str = '{:d}({:s}) --{:s}--> {:d}({:s}) ({:d} lroots)' @@ -313,9 +325,11 @@ def single_fragment_spin_change (self, ifrag, new_smult, new_spin, ci=None): if ci is not None: ci1 = [c for c in self.ci] ci1[ifrag] = ci - return SingleLASRootspace (self.las, spins1, smults1, self.charges, 0, nlas=self.nlas, - nelelas=self.nelelas, stdout=self.stdout, verbose=self.verbose, - ci=ci1) + sp = SingleLASRootspace (self.las, spins1, smults1, self.charges, 0, nlas=self.nlas, + nelelas=self.nelelas, stdout=self.stdout, verbose=self.verbose, + ci=ci1) + sp.entmap = self.entmap + return sp def is_orthogonal_by_smult (self, other): if isinstance (other, (list, tuple)): @@ -397,6 +411,9 @@ def combine_orthogonal_excitations (exc1, exc2, ref): ref.las, spins, smults, charges, 0, ci=ci, nlas=ref.nlas, nelelas=ref.nelelas, stdout=ref.stdout, verbose=ref.verbose ) + product.entmap = tuple (set (exc1.entmap + exc2.entmap)) + #assert (np.amax (product.entmap) < 2) + assert (len (product.entmap) == len (set (product.entmap))) return product def all_single_excitations (las, verbose=None): @@ -418,16 +435,16 @@ def all_single_excitations (las, verbose=None): new_states.extend (ref_state.get_singles ()) seen = set (ref_states) all_states = ref_states + [state for state in new_states if not ((state in seen) or seen.add (state))] - weights = [state.weight for state in all_states] - charges = [state.charges for state in all_states] - spins = [state.spins for state in all_states] - smults = [state.smults for state in all_states] - #wfnsyms = [state.wfnsyms for state in all_states] log.info ('Built {} singly-excited LAS states from {} reference LAS states'.format ( len (all_states) - len (ref_states), len (ref_states))) if len (all_states) == len (ref_states): log.warn (("%d reference LAS states exhaust current active space specifications; " "no singly-excited states could be constructed"), len (ref_states)) + weights = [state.weight for state in all_states] + charges = [state.charges for state in all_states] + spins = [state.spins for state in all_states] + smults = [state.smults for state in all_states] + #wfnsyms = [state.wfnsyms for state in all_states] return las.state_average (weights=weights, charges=charges, spins=spins, smults=smults) def spin_shuffle (las, verbose=None, equal_weights=False): @@ -446,16 +463,8 @@ def spin_shuffle (las, verbose=None, equal_weights=False): raise NotImplementedError ("Point-group symmetry for LASSI state generator") ref_states = [SingleLASRootspace (las, m, s, c, 0) for c,m,s,w in zip (*get_space_info (las))] for weight, state in zip (las.weights, ref_states): state.weight = weight - seen = set (ref_states) - all_states = [state for state in ref_states] - for ref_state in ref_states: - for new_state in ref_state.gen_spin_shuffles (): - if not new_state in seen: - all_states.append (new_state) - seen.add (new_state) + all_states = _spin_shuffle (ref_states, equal_weights=equal_weights) weights = [state.weight for state in all_states] - if equal_weights: - weights = [1.0/len(all_states),]*len(all_states) charges = [state.charges for state in all_states] spins = [state.spins for state in all_states] smults = [state.smults for state in all_states] @@ -466,6 +475,19 @@ def spin_shuffle (las, verbose=None, equal_weights=False): log.warn ("no spin-shuffling options found for given LAS states") return las.state_average (weights=weights, charges=charges, spins=spins, smults=smults) +def _spin_shuffle (ref_spaces, equal_weights=False): + seen = set (ref_spaces) + all_spaces = [space for space in ref_spaces] + for ref_space in ref_spaces: + for new_space in ref_space.gen_spin_shuffles (): + if not new_space in seen: + all_spaces.append (new_space) + seen.add (new_space) + if equal_weights: + w = 1.0/len(all_spaces) + for space in all_spaces: space.weight = w + return all_spaces + def spin_shuffle_ci (las, ci): '''Fill out the CI vectors for rootspaces constructed by the spin_shuffle function. Unallocated CI vectors (None elements in ci) for rootspaces which have the same @@ -477,19 +499,26 @@ def spin_shuffle_ci (las, ci): from mrh.my_pyscf.mcscf.lasci import get_space_info spaces = [SingleLASRootspace (las, m, s, c, 0, ci=[c[ix] for c in ci]) for ix, (c, m, s, w) in enumerate (zip (*get_space_info (las)))] + spaces = _spin_shuffle_ci_(spaces) + ci = [[space.ci[ifrag] for space in spaces] for ifrag in range (las.nfrags)] + return ci + +def _spin_shuffle_ci_(spaces): old_ci_sz = [] old_idx = [] new_idx = [] - nfrag = las.nfrags + nfrag = spaces[0].nfrag for ix, space in enumerate (spaces): if space.has_ci (): old_idx.append (ix) old_ci_sz.append (space.get_ci_szrot ()) else: new_idx.append (ix) + space.ci = [None for ifrag in range (space.nfrag)] def is_spin_shuffle_ref (sp1, sp2): return (np.all (sp1.charges==sp2.charges) and - np.all (sp1.smults==sp2.smults)) + np.all (sp1.smults==sp2.smults) and + sp1.entmap==sp2.entmap) for ix in new_idx: ndet = spaces[ix].get_ndet () ci_ix = [np.zeros ((0,ndet[i][0],ndet[i][1])) @@ -502,7 +531,7 @@ def is_spin_shuffle_ref (sp1, sp2): ci_ix[ifrag] = np.append (ci_ix[ifrag], c, axis=0) for ifrag in range (nfrag): if ci_ix[ifrag].size==0: - ci[ifrag][ix] = None + spaces[ix].ci[ifrag] = None continue lroots, ndeti = ci_ix[ifrag].shape[0], ndet[ifrag] if lroots > 1: @@ -513,8 +542,8 @@ def is_spin_shuffle_ref (sp1, sp2): v = v[:,idx] / np.sqrt (w[idx])[None,:] c = (c @ v).T ci_ix[ifrag] = c.reshape (-1, ndeti[0], ndeti[1]) - ci[ifrag][ix] = ci_ix[ifrag] - return ci + spaces[ix].ci[ifrag] = ci_ix[ifrag] + return spaces def count_excitations (las0): log = logger.new_logger (las0, las0.verbose) diff --git a/pyscf-forge_version.txt b/pyscf-forge_version.txt index 4790f30f..1e196016 100644 --- a/pyscf-forge_version.txt +++ b/pyscf-forge_version.txt @@ -1 +1 @@ -git+https://github.com/pyscf/pyscf-forge.git@f817911cb0c984207d9309a1d7347f6dbb46a9c6 +git+https://github.com/pyscf/pyscf-forge.git@6fa7530498f434404a323bd7b116a04d8f7c1f12 diff --git a/pyscf_version.txt b/pyscf_version.txt index 630bfdc5..8864a398 100644 --- a/pyscf_version.txt +++ b/pyscf_version.txt @@ -1 +1 @@ -git+https://github.com/pyscf/pyscf.git@d57f1d6c89c723e11a7f0933380a6139ba372554 +git+https://github.com/pyscf/pyscf.git@6d3b24bb64e2a5edb7990b6e3304068981a33f54