Skip to content

Commit

Permalink
get_pair_lasci safety commit
Browse files Browse the repository at this point in the history
Some syntax debugging
  • Loading branch information
MatthewRHermes committed Jul 23, 2024
1 parent ad712de commit e496a2b
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
6 changes: 5 additions & 1 deletion my_pyscf/mcscf/lasscf_async/combine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pyscf.lo import orth
from pyscf.scf.rohf import get_roothaan_fock
from mrh.my_pyscf.mcscf import lasci, _DFLASCI
from mrh.my_pyscf.mcscf.lasscf_async import keyframe
from mrh.my_pyscf.mcscf.lasscf_async import keyframe, crunch

# TODO: symmetry
def orth_orb (las, kf2_list, kf_ref=None):
Expand Down Expand Up @@ -222,6 +222,10 @@ def combine_pair (las, kf1, kf2, kf_ref=None):
kf3 = orth_orb (las, [kf1, kf2], kf_ref=kf_ref)
i, j = select_aa_block (las, kf1.frags, kf2.frags, kf3.fock1)
kf3 = relax (las, kf3, freeze_inactive=True, unfrozen_frags=(i,j))
#pair = crunch.get_pair_lasci (las, (i,j))
#pair._pull_keyframe_(kf3)
#pair.kernel ()
#kf3 = pair._push_keyframe (kf3)
kf3.frags = kf1.frags.union (kf2.frags)
return kf3

Expand Down
7 changes: 4 additions & 3 deletions my_pyscf/mcscf/lasscf_async/crunch.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ def orbital_response (self, kappa1, odm1s, ocm2, tdm1rs, tcm2, veff_prime):
kappa2 += w * (fock1 - fock1.T)
return kappa2

class ImpurityLASCI (lasci.LASCINoSymm):
class ImpurityLASCI (lasci.LASCINoSymm, ImpuritySolver):
_hop = ImpurityLASCI_HessianOperator

def get_grad_orb (las, mo_coeff=None, ci=None, h2eff_sub=None, veff=None, dm1s=None, hermi=-1):
Expand Down Expand Up @@ -980,7 +980,7 @@ def get_impurity_casscf (las, ifrag, imporb_builder=None):

def get_pair_lasci (las, frags):
stdout = getattr (las, '_flas_stdout', None)
if stdout is not None: stdout = stdout.get (unfrozen_frags, None)
if stdout is not None: stdout = stdout.get (frags, None)
output = getattr (las.mol, 'output', None)
if not ((output is None) or (output=='/dev/null')):
output = output + '.' + '.'.join ([str (s) for s in frags])
Expand All @@ -995,13 +995,14 @@ def get_pair_lasci (las, frags):
def imporb_builder (mo_coeff, dm1s, veff, fock1, **kwargs):
idx = np.zeros (mo_coeff.shape[1], dtype=bool)
for ix in frags:
i = ncore + sum (las.ncas_sub[:ix])
i = las.ncore + sum (las.ncas_sub[:ix])
j = i + las.ncas_sub[ix]
idx[i:j] = True
fo_coeff = mo_coeff[:,idx]
nelec_f = sum ([sum (n) for n in nelecas_sub])
return fo_coeff, nelec_f
ilas._imporb_builder = imporb_builder
ilas._ifrags = frags
params = getattr (las, 'relax_params', {})
glob = {key: val for key, val in params.items () if isinstance (key, str)}
glob = {key: val for key, val in glob.items () if key not in ('frozen', 'frozen_ci')}
Expand Down

0 comments on commit e496a2b

Please sign in to comment.