Skip to content

Commit

Permalink
Compare sampling methods for the same molecule (#22)
Browse files Browse the repository at this point in the history
* Add script for running all methods for a single XYZ

* Script for running all sampling strategies

* Add comparison function

* Compute heat capacities too

* First notebook for comparing sampling methods

* Scale perturbations by number of atoms involved

* Measure the enthalpy too

* Add enthalpy to the comaprison
  • Loading branch information
WardLT authored Nov 16, 2023
1 parent ea49fda commit d5f72c1
Show file tree
Hide file tree
Showing 12 changed files with 1,118 additions and 17 deletions.
85 changes: 85 additions & 0 deletions jitterbug/compare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""Tools for assessing the quality of a Hessian compared to a true one"""
from dataclasses import dataclass

import ase
from ase import units
import numpy as np
from ase.vibrations import VibrationsData
from pmutt.statmech import StatMech, presets


@dataclass
class HessianQuality:
"""Measurements of the quality of a Hessian"""

# Thermodynamics
zpe: float
"""Zero point energy (kcal/mol)"""
zpe_error: float
"""Different between the ZPE and the target one"""
cp: list[float]
"""Heat capacity as a function of temperature (units: kcal/mol/K)"""
cp_error: list[float]
"""Difference between known and approximate heat capacity as a function of temperature (units: kcal/mol/K)"""
h: list[float]
"""Enthalpy as a function of temperature (units: kcal/mol)"""
h_error: list[float]
"""Error between known and approximate enthalpy as a function of temperature (units: kcal/mol)"""
temps: list[float]
"""Temperatures at which Cp was evaluated (units: K)"""

# Vibrations
vib_freqs: list[float]
"""Vibrational frequencies for our hessian (units: cm^-1)"""
vib_errors: list[float]
"""Error between each frequency and the corresponding mode in the known hessian"""
vib_mae: float
"""Mean absolute error for the vibrational modes"""


def compare_hessians(atoms: ase.Atoms, known_hessian: np.ndarray, approx_hessian: np.ndarray) -> HessianQuality:
"""Compare two different hessians for same atomic structure
Args:
atoms: Structure
known_hessian: 2D form of the target Hessian
approx_hessian: 2D form of an approximate Hessian
Returns:
Collection of the performance metrics
"""

# Start by making a vibration data object
known_vibs: VibrationsData = VibrationsData.from_2d(atoms, known_hessian)
approx_vibs: VibrationsData = VibrationsData.from_2d(atoms, approx_hessian)

# Compare the vibrational frequencies on the non-zero modes
known_freqs = known_vibs.get_frequencies()
is_real = np.isreal(known_freqs)
approx_freqs = approx_vibs.get_frequencies()
freq_error = np.subtract(approx_freqs[is_real], known_freqs[is_real])
freq_mae = np.abs(freq_error).mean()

# Compare the enthalpy and heat capacity
# TODO (wardlt): Might actually want to compute the symmetry number
known_harm = StatMech(vib_wavenumbers=np.real(known_freqs[is_real]), atoms=atoms, symmetrynumber=1, **presets['harmonic'])
approx_harm = StatMech(vib_wavenumbers=np.real(approx_freqs[is_real]), atoms=atoms, symmetrynumber=1, **presets['harmonic'])

temps = np.linspace(1., 373, 128)
known_cp = np.array([known_harm.get_Cp('kcal/mol/K', T=t) for t in temps])
approx_cp = np.array([approx_harm.get_Cp('kcal/mol/K', T=t) for t in temps])
known_h = np.array([known_harm.get_H('kcal/mol', T=t) for t in temps])
approx_h = np.array([approx_harm.get_H('kcal/mol', T=t) for t in temps])

# Assemble into a result object
return HessianQuality(
zpe=approx_vibs.get_zero_point_energy() * units.mol / units.kcal,
zpe_error=(approx_vibs.get_zero_point_energy() - known_vibs.get_zero_point_energy()) * units.mol / units.kcal,
vib_freqs=np.real(approx_freqs[is_real]).tolist(),
vib_errors=np.abs(freq_error),
vib_mae=freq_mae,
cp=approx_cp.tolist(),
cp_error=(known_cp - approx_cp).tolist(),
h=approx_h,
h_error=(known_h - approx_h).tolist(),
temps=temps.tolist()
)
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,19 @@
"name, method, basis = run_name.split(\"_\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1d549833-999a-4172-8bc2-689517e7c2a7",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"if not Path(starting_geometry).exists():\n",
" raise ValueError('Cannot find file')"
]
},
{
"cell_type": "markdown",
"id": "cf9ff792-6b5b-46ce-9a78-78912e372912",
Expand Down Expand Up @@ -226,7 +239,7 @@
" # Sample a perturbation\n",
" disp = np.random.normal(0, 1, size=(n_atoms * 3))\n",
" disp /= np.linalg.norm(disp)\n",
" disp *= step_size\n",
" disp *= step_size * len(atoms) \n",
" disp = disp.reshape((-1, 3))\n",
" \n",
" # Subtract off any translation\n",
Expand Down Expand Up @@ -270,7 +283,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.17"
"version": "3.9.18"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@
" disp = np.random.normal(0, 1, size=(n_atoms * 3))\n",
" disp /= np.linalg.norm(disp)\n",
" my_step_dist = np.random.exponential(scale=step_size)\n",
" disp *= my_step_dist\n",
" disp *= my_step_dist * len(atoms)\n",
" disp = disp.reshape((-1, 3))\n",
" \n",
" # Subtract off any translation\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@
" # Create the perturbation vector\n",
" disp = np.zeros((n_coords,))\n",
" for d in perturb:\n",
" disp[abs(d) - 1] = (1 if abs(d) > 0 else -1) * step_size\n",
" disp[abs(d) - 1] = (1 if abs(d) > 0 else -1) * step_size / perturbs_per_evaluation\n",
" disp = disp.reshape((-1, 3))\n",
" \n",
" # Make the new atoms\n",
Expand Down Expand Up @@ -387,7 +387,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.17"
"version": "3.9.18"
}
},
"nbformat": 4,
Expand Down
15 changes: 15 additions & 0 deletions notebooks/1_explore-sampling-methods/run-all-methods.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#! /bin/bash

xyz=../data/exact/caffeine_pm7_None.xyz
for step_size in 0.02; do
# Do the randomized methods
for method in 0_random-directions-same-distance.ipynb 1_random-directions-variable-distance.ipynb; do
papermill -p starting_geometry $xyz -p step_size $step_size $method last.ipynb
done

# Test with different reductions for "along axes"
notebook=2_displace-along-axes.ipynb
for n in 2 4 8; do
papermill -p starting_geometry $xyz -p perturbs_per_evaluation $n -p step_size $step_size $notebook last.ipynb
done
done
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,12 @@
"from dscribe.descriptors import MBTR\n",
"from ase.vibrations import VibrationsData\n",
"from ase.db import connect\n",
"from random import sample\n",
"from pathlib import Path\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"import warnings\n",
"import json\n",
"import ase"
]
},
Expand All @@ -55,7 +57,9 @@
},
"outputs": [],
"source": [
"db_path = '../1_explore-sampling-methods/data/along-axes/caffeine_pm7_None_d=5.00e-03-N=2.db'"
"db_path: str = '../1_explore-sampling-methods/data/along-axes/caffeine_pm7_None_d=5.00e-03-N=2.db'\n",
"overwrite: bool = False\n",
"max_size: int = 10000"
]
},
{
Expand All @@ -78,7 +82,29 @@
"run_name, sampling_options = Path(db_path).name[:-3].rsplit(\"_\", 1)\n",
"exact_path = Path('../data/exact/') / f'{run_name}-ase.json'\n",
"sampling_name = Path(db_path).parent.name\n",
"out_name = '_'.join([run_name, sampling_name, sampling_options])"
"out_name = '_'.join([run_name, sampling_name, sampling_options])\n",
"out_dir = Path('data/mbtr/')"
]
},
{
"cell_type": "markdown",
"id": "41f31375-9cf5-412b-949c-406711358781",
"metadata": {},
"source": [
"Skip if done"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "75d22086-f020-40d7-8327-1154491b9821",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"if (out_dir / f'{out_name}-full.json').exists() and not overwrite:\n",
" raise ValueError('Already done!')"
]
},
{
Expand All @@ -104,6 +130,28 @@
"print(f'Loaded {len(data)} structures')"
]
},
{
"cell_type": "markdown",
"id": "0c8aae57-1863-4bad-a56b-31f7b8a6062b",
"metadata": {},
"source": [
"Downsample if desired"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e2dfe036-a173-41ff-817b-2e92349b9704",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"if max_size is not None and len(data) > max_size:\n",
" data = sample(data, max_size)\n",
" print(f'Downselected to {len(data)}')"
]
},
{
"cell_type": "markdown",
"id": "cb1a8e03-b045-49a4-95fd-61636a48fbad",
Expand Down Expand Up @@ -178,7 +226,7 @@
"source": [
"test_y = np.array([a.get_potential_energy() for a in test_data])\n",
"baseline_y = np.abs(test_y - test_y.mean()).mean()\n",
"print(f'Baseline score: {baseline_y*1000:.2e} meV/atom')`"
"print(f'Baseline score: {baseline_y*1000:.2e} meV/atom')"
]
},
{
Expand Down Expand Up @@ -679,7 +727,7 @@
"source": [
"out_dir = Path('data/mbtr')\n",
"out_dir.mkdir(exist_ok=True, parents=True)\n",
"with open(f'data/mbtr/{out_name}.json', 'w') as fp:\n",
"with open(f'data/mbtr/{out_name}-full.json', 'w') as fp:\n",
" approx_vibs.write(fp)"
]
},
Expand Down Expand Up @@ -715,14 +763,20 @@
"outputs": [],
"source": [
"zpes = []\n",
"for count in tqdm(steps):\n",
" with warnings.catch_warnings():\n",
" warnings.simplefilter(\"ignore\")\n",
" hess_model = model.train(data[:count])\n",
" \n",
" approx_hessian = model.mean_hessian(hess_model)\n",
" approx_vibs = VibrationsData.from_2d(data[0], approx_hessian)\n",
" zpes.append(approx_vibs.get_zero_point_energy())"
"with open(f'data/mbtr/{out_name}-increment.json', 'w') as fp:\n",
" for count in tqdm(steps):\n",
" with warnings.catch_warnings():\n",
" warnings.simplefilter(\"ignore\")\n",
" hess_model = model.train(data[:count])\n",
"\n",
" approx_hessian = model.mean_hessian(hess_model)\n",
" \n",
" # Save the incremental\n",
" print(json.dumps({'count': int(count), 'hessian': approx_hessian.tolist()}), file=fp)\n",
" \n",
" # Compute the ZPE\n",
" approx_vibs = VibrationsData.from_2d(data[0], approx_hessian)\n",
" zpes.append(approx_vibs.get_zero_point_energy())"
]
},
{
Expand Down
8 changes: 8 additions & 0 deletions notebooks/2_testing-fitting-strategies/run-all-dbs.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#! /bin/bash

notebook=$1
dbs=$(find ../1_explore-sampling-methods/data/ -name "caffeine_pm7_None*.db")
for db in $dbs; do
echo $db
papermill -p db_path "$db" -p max_size 5000 $notebook last.ipynb
done
Loading

0 comments on commit d5f72c1

Please sign in to comment.