Skip to content

Commit

Permalink
lasscf_rdm microiteration divergence catching
Browse files Browse the repository at this point in the history
Copy the logic from lasci_sync.py over to lasscf_rdm.py
  • Loading branch information
MatthewRHermes committed Sep 27, 2023
1 parent 4261a60 commit 7791a05
Showing 1 changed file with 52 additions and 10 deletions.
62 changes: 52 additions & 10 deletions my_pyscf/mcscf/lasscf_rdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np
from scipy import linalg, sparse
from mrh.my_pyscf.mcscf import lasscf_sync_o0, lasci, lasci_sync, _DFLASCI
from mrh.my_pyscf.mcscf.lasci_sync import MicroIterInstabilityException
from mrh.my_pyscf.fci import csf_solver
from pyscf import lib, gto, ao2mo
from pyscf.fci.direct_spin1 import _unpack_nelec
Expand Down Expand Up @@ -182,9 +183,15 @@ def kernel (las, mo_coeff=None, casdm1frs=None, casdm2fr=None, conv_tol_grad=1e-
# if I'm already converged I don't want to waste the cycles
t1 = log.timer ('LASSCF Hessian constructor', *t1)
microit = [0]
last_x = [0]
first_norm_x = [None]
def my_callback (x):
microit[0] += 1
norm_xorb = linalg.norm (x) if x.size else 0.0
addr_max = np.argmax (np.abs (x))
id_max = ugg.addr2idstr (addr_max)
x_max = x[addr_max]/np.pi
log.debug ('Maximum step vector element x[{}] = {}*pi ({})'.format (addr_max, x_max, id_max))
if las.verbose > lib.logger.INFO:
Hx = H_op._matvec (x) # This doubles the price of each iteration!!
resid = g_vec + Hx
Expand All @@ -194,17 +201,52 @@ def my_callback (x):
microit[0], Ecall, norm_gorb, norm_xorb)
else:
log.info ('LASSCF micro %d : |x_orb| = %.15g', microit[0], norm_xorb)

if abs(x_max)>.5: # Nonphysical step vector element
if last_x[0] is 0:
x[np.abs (x)>.5*np.pi] = 0
last_x[0] = x
raise MicroIterInstabilityException ("|x[i]| > pi/2")
norm_x = linalg.norm (x)
if first_norm_x[0] is None:
first_norm_x[0] = norm_x
elif norm_x > 10*first_norm_x[0]:
raise MicroIterInstabilityException ("||x(n)|| > 10*||x(0)||")
last_x[0] = x.copy ()

my_tol = max (conv_tol_grad, norm_gx/10)
x, info_int = sparse.linalg.cg (H_op, -g_vec, x0=x0, atol=my_tol, maxiter=las.max_cycle_micro,
callback=my_callback, M=prec_op)
t1 = log.timer ('LASSCF {} microcycles'.format (microit[0]), *t1)
mo_coeff, h2eff_sub = H_op.update_mo_eri (x, h2eff_sub)
t1 = log.timer ('LASSCF Hessian update', *t1)

veff = las.get_veff (dm1s = las.make_rdm1 (mo_coeff=mo_coeff, casdm1s_sub=casdm1fs))
veff = las.split_veff (veff, h2eff_sub, mo_coeff=mo_coeff, casdm1s_sub=casdm1fs)
t1 = log.timer ('LASSCF get_veff after secondorder', *t1)
try:
x, info_int = sparse.linalg.cg (H_op, -g_vec, x0=x0, atol=my_tol,
maxiter=las.max_cycle_micro,
callback=my_callback, M=prec_op)
t1 = log.timer ('LASSCF {} microcycles'.format (microit[0]), *t1)
mo_coeff, h2eff_sub = H_op.update_mo_eri (x, h2eff_sub)
t1 = log.timer ('LASSCF Hessian update', *t1)

veff = las.get_veff (dm1s = las.make_rdm1 (mo_coeff=mo_coeff, casdm1s_sub=casdm1fs))
veff = las.split_veff (veff, h2eff_sub, mo_coeff=mo_coeff, casdm1s_sub=casdm1fs)
t1 = log.timer ('LASSCF get_veff after secondorder', *t1)
except MicroIterInstabilityException as e:
log.info ('Unstable microiteration aborted: %s', str (e))
t1 = log.timer ('LASSCF {} microcycles'.format (microit[0]), *t1)
x = last_x[0]
for i in range (3): # Make up to 3 attempts to scale-down x if necessary
mo2, h2eff_sub2 = H_op.update_mo_eri (x, h2eff_sub)
t1 = log.timer ('LASCF Hessian update', *t1)
veff2 = las.get_veff (dm1s = las.make_rdm1 (mo_coeff=mo2, casdm1s_sub=casdm1fs))
veff2 = las.split_veff (veff2, h2eff_sub2, mo_coeff=mo2, casdm1s_sub=casdm1fs)
t1 = log.timer ('LASSCF get_veff after secondorder', *t1)
e2 = las.energy_nuc () + las.energy_elec (mo_coeff=mo2, h2eff=h2eff_sub2,
casdm1frs=casdm1frs,
casdm2fr=casdm2fr,
veff=veff2)
if e2 < H_op.e_tot:
break
log.info ('New energy ({}) is higher than keyframe energy ({})'.format (
e2, H_op.e_tot))
log.info ('Attempt {} of 3 to scale down trial step vector'.format (i+1))
x *= .5
mo_coeff, h2eff_sub, veff = mo2, h2eff_sub2, veff2


t2 = log.timer ('LASSCF {} macrocycles'.format (it), *t2)

Expand Down

0 comments on commit 7791a05

Please sign in to comment.