Skip to content

Commit

Permalink
Add scaling for vibrational frequencies
Browse files Browse the repository at this point in the history
  • Loading branch information
WardLT committed Dec 5, 2023
1 parent 938be43 commit 8f84e55
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
24 changes: 21 additions & 3 deletions jitterbug/compare.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tools for assessing the quality of a Hessian compared to a true one"""
from dataclasses import dataclass
from typing import Optional

import ase
from ase import units
Expand All @@ -12,6 +13,10 @@
class HessianQuality:
"""Measurements of the quality of a Hessian"""

# Metadata
scale_factor: float
"""Scaling factor used for frequencies"""

# Thermodynamics
zpe: float
"""Zero point energy (kcal/mol)"""
Expand All @@ -37,13 +42,15 @@ class HessianQuality:
"""Mean absolute error for the vibrational modes"""


def compare_hessians(atoms: ase.Atoms, known_hessian: np.ndarray, approx_hessian: np.ndarray) -> HessianQuality:
def compare_hessians(atoms: ase.Atoms, known_hessian: np.ndarray, approx_hessian: np.ndarray, scale_factor: Optional[float] = 1.) -> HessianQuality:
"""Compare two different hessians for same atomic structure
Args:
atoms: Structure
known_hessian: 2D form of the target Hessian
approx_hessian: 2D form of an approximate Hessian
scale_factor: Factor by which to scale frequencies from approximate Hessian before comparison.
Set to ``None`` to use the median ratio between the approximate and known frequency from each mode.
Returns:
Collection of the performance metrics
"""
Expand All @@ -56,6 +63,12 @@ def compare_hessians(atoms: ase.Atoms, known_hessian: np.ndarray, approx_hessian
known_freqs = known_vibs.get_frequencies()
is_real = np.isreal(known_freqs)
approx_freqs = approx_vibs.get_frequencies()

# Scale, if desired
if scale_factor is None:
scale_factor = np.median(np.divide(known_freqs, approx_freqs))
approx_freqs *= scale_factor

freq_error = np.subtract(approx_freqs[is_real], known_freqs[is_real])
freq_mae = np.abs(freq_error).mean()

Expand All @@ -64,6 +77,10 @@ def compare_hessians(atoms: ase.Atoms, known_hessian: np.ndarray, approx_hessian
known_harm = StatMech(vib_wavenumbers=np.real(known_freqs[is_real]), atoms=atoms, symmetrynumber=1, **presets['harmonic'])
approx_harm = StatMech(vib_wavenumbers=np.real(approx_freqs[is_real]), atoms=atoms, symmetrynumber=1, **presets['harmonic'])

approx_zpe = approx_harm.vib_model.get_ZPE() * units.mol / units.kcal
known_zpe = known_harm.vib_model.get_ZPE() * units.mol / units.kcal
zpe_error = approx_zpe - known_zpe

temps = np.linspace(1., 373, 128)
known_cp = np.array([known_harm.get_Cp('kcal/mol/K', T=t) for t in temps])
approx_cp = np.array([approx_harm.get_Cp('kcal/mol/K', T=t) for t in temps])
Expand All @@ -72,8 +89,9 @@ def compare_hessians(atoms: ase.Atoms, known_hessian: np.ndarray, approx_hessian

# Assemble into a result object
return HessianQuality(
zpe=approx_vibs.get_zero_point_energy() * units.mol / units.kcal,
zpe_error=(approx_vibs.get_zero_point_energy() - known_vibs.get_zero_point_energy()) * units.mol / units.kcal,
scale_factor=scale_factor,
zpe=approx_zpe,
zpe_error=zpe_error,
vib_freqs=np.real(approx_freqs[is_real]).tolist(),
vib_errors=np.abs(freq_error),
vib_mae=freq_mae,
Expand Down
15 changes: 15 additions & 0 deletions tests/test_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,18 @@ def test_compare(example_hess):
assert comp.zpe_error == 0.
assert np.ndim(comp.cp_error) == 1
assert np.mean(comp.cp_error) == 0.


def test_scaling(example_hess):
# Make sure scaling the Hessian has the target effect
comp = compare_hessians(example_hess.get_atoms(), example_hess.get_hessian_2d(), example_hess.get_hessian_2d() * 1.1)
assert comp.zpe_error > 0

# ... and that it can be repaired
comp = compare_hessians(example_hess.get_atoms(), example_hess.get_hessian_2d(), example_hess.get_hessian_2d() * 1.1, scale_factor=1. / np.sqrt(1.1))
assert np.isclose(comp.zpe_error, 0)

# ... and that it can be repaired, automatically
comp = compare_hessians(example_hess.get_atoms(), example_hess.get_hessian_2d(), example_hess.get_hessian_2d() * 1.1, scale_factor=None)
assert np.isclose(comp.zpe_error, 0)
assert np.isclose(comp.scale_factor, 1. / np.sqrt(1.1))

0 comments on commit 8f84e55

Please sign in to comment.