diff --git a/jitterbug/compare.py b/jitterbug/compare.py new file mode 100644 index 0000000..1338229 --- /dev/null +++ b/jitterbug/compare.py @@ -0,0 +1,85 @@ +"""Tools for assessing the quality of a Hessian compared to a true one""" +from dataclasses import dataclass + +import ase +from ase import units +import numpy as np +from ase.vibrations import VibrationsData +from pmutt.statmech import StatMech, presets + + +@dataclass +class HessianQuality: + """Measurements of the quality of a Hessian""" + + # Thermodynamics + zpe: float + """Zero point energy (kcal/mol)""" + zpe_error: float + """Different between the ZPE and the target one""" + cp: list[float] + """Heat capacity as a function of temperature (units: kcal/mol/K)""" + cp_error: list[float] + """Difference between known and approximate heat capacity as a function of temperature (units: kcal/mol/K)""" + h: list[float] + """Enthalpy as a function of temperature (units: kcal/mol)""" + h_error: list[float] + """Error between known and approximate enthalpy as a function of temperature (units: kcal/mol)""" + temps: list[float] + """Temperatures at which Cp was evaluated (units: K)""" + + # Vibrations + vib_freqs: list[float] + """Vibrational frequencies for our hessian (units: cm^-1)""" + vib_errors: list[float] + """Error between each frequency and the corresponding mode in the known hessian""" + vib_mae: float + """Mean absolute error for the vibrational modes""" + + +def compare_hessians(atoms: ase.Atoms, known_hessian: np.ndarray, approx_hessian: np.ndarray) -> HessianQuality: + """Compare two different hessians for same atomic structure + + Args: + atoms: Structure + known_hessian: 2D form of the target Hessian + approx_hessian: 2D form of an approximate Hessian + Returns: + Collection of the performance metrics + """ + + # Start by making a vibration data object + known_vibs: VibrationsData = VibrationsData.from_2d(atoms, known_hessian) + approx_vibs: VibrationsData = VibrationsData.from_2d(atoms, approx_hessian) + + # Compare the vibrational frequencies on the non-zero modes + known_freqs = known_vibs.get_frequencies() + is_real = np.isreal(known_freqs) + approx_freqs = approx_vibs.get_frequencies() + freq_error = np.subtract(approx_freqs[is_real], known_freqs[is_real]) + freq_mae = np.abs(freq_error).mean() + + # Compare the enthalpy and heat capacity + # TODO (wardlt): Might actually want to compute the symmetry number + known_harm = StatMech(vib_wavenumbers=np.real(known_freqs[is_real]), atoms=atoms, symmetrynumber=1, **presets['harmonic']) + approx_harm = StatMech(vib_wavenumbers=np.real(approx_freqs[is_real]), atoms=atoms, symmetrynumber=1, **presets['harmonic']) + + temps = np.linspace(1., 373, 128) + known_cp = np.array([known_harm.get_Cp('kcal/mol/K', T=t) for t in temps]) + approx_cp = np.array([approx_harm.get_Cp('kcal/mol/K', T=t) for t in temps]) + known_h = np.array([known_harm.get_H('kcal/mol', T=t) for t in temps]) + approx_h = np.array([approx_harm.get_H('kcal/mol', T=t) for t in temps]) + + # Assemble into a result object + return HessianQuality( + zpe=approx_vibs.get_zero_point_energy() * units.mol / units.kcal, + zpe_error=(approx_vibs.get_zero_point_energy() - known_vibs.get_zero_point_energy()) * units.mol / units.kcal, + vib_freqs=np.real(approx_freqs[is_real]).tolist(), + vib_errors=np.abs(freq_error), + vib_mae=freq_mae, + cp=approx_cp.tolist(), + cp_error=(known_cp - approx_cp).tolist(), + h=approx_h, + h_error=(known_h - approx_h).tolist(), + temps=temps.tolist() + ) diff --git a/notebooks/1_explore-sampling-methods/0_random-directions-same-distance.ipynb b/notebooks/1_explore-sampling-methods/0_random-directions-same-distance.ipynb index ce20d59..7d95bc5 100644 --- a/notebooks/1_explore-sampling-methods/0_random-directions-same-distance.ipynb +++ b/notebooks/1_explore-sampling-methods/0_random-directions-same-distance.ipynb @@ -71,6 +71,19 @@ "name, method, basis = run_name.split(\"_\")" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "1d549833-999a-4172-8bc2-689517e7c2a7", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "if not Path(starting_geometry).exists():\n", + " raise ValueError('Cannot find file')" + ] + }, { "cell_type": "markdown", "id": "cf9ff792-6b5b-46ce-9a78-78912e372912", @@ -226,7 +239,7 @@ " # Sample a perturbation\n", " disp = np.random.normal(0, 1, size=(n_atoms * 3))\n", " disp /= np.linalg.norm(disp)\n", - " disp *= step_size\n", + " disp *= step_size * len(atoms) \n", " disp = disp.reshape((-1, 3))\n", " \n", " # Subtract off any translation\n", @@ -270,7 +283,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.17" + "version": "3.9.18" } }, "nbformat": 4, diff --git a/notebooks/1_explore-sampling-methods/1_random-directions-variable-distance.ipynb b/notebooks/1_explore-sampling-methods/1_random-directions-variable-distance.ipynb index 9d8a915..4567a8c 100644 --- a/notebooks/1_explore-sampling-methods/1_random-directions-variable-distance.ipynb +++ b/notebooks/1_explore-sampling-methods/1_random-directions-variable-distance.ipynb @@ -227,7 +227,7 @@ " disp = np.random.normal(0, 1, size=(n_atoms * 3))\n", " disp /= np.linalg.norm(disp)\n", " my_step_dist = np.random.exponential(scale=step_size)\n", - " disp *= my_step_dist\n", + " disp *= my_step_dist * len(atoms)\n", " disp = disp.reshape((-1, 3))\n", " \n", " # Subtract off any translation\n", diff --git a/notebooks/1_explore-sampling-methods/2_displace-along-axes.ipynb b/notebooks/1_explore-sampling-methods/2_displace-along-axes.ipynb index 64179fa..ef96e40 100644 --- a/notebooks/1_explore-sampling-methods/2_displace-along-axes.ipynb +++ b/notebooks/1_explore-sampling-methods/2_displace-along-axes.ipynb @@ -344,7 +344,7 @@ " # Create the perturbation vector\n", " disp = np.zeros((n_coords,))\n", " for d in perturb:\n", - " disp[abs(d) - 1] = (1 if abs(d) > 0 else -1) * step_size\n", + " disp[abs(d) - 1] = (1 if abs(d) > 0 else -1) * step_size / perturbs_per_evaluation\n", " disp = disp.reshape((-1, 3))\n", " \n", " # Make the new atoms\n", @@ -387,7 +387,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.17" + "version": "3.9.18" } }, "nbformat": 4, diff --git a/notebooks/1_explore-sampling-methods/run-all-methods.sh b/notebooks/1_explore-sampling-methods/run-all-methods.sh new file mode 100644 index 0000000..311dcdc --- /dev/null +++ b/notebooks/1_explore-sampling-methods/run-all-methods.sh @@ -0,0 +1,15 @@ +#! /bin/bash + +xyz=../data/exact/caffeine_pm7_None.xyz +for step_size in 0.02; do + # Do the randomized methods + for method in 0_random-directions-same-distance.ipynb 1_random-directions-variable-distance.ipynb; do + papermill -p starting_geometry $xyz -p step_size $step_size $method last.ipynb + done + + # Test with different reductions for "along axes" + notebook=2_displace-along-axes.ipynb + for n in 2 4 8; do + papermill -p starting_geometry $xyz -p perturbs_per_evaluation $n -p step_size $step_size $notebook last.ipynb + done +done diff --git a/notebooks/2_testing-fitting-strategies/1_fit-forcefield-using-mbtr.ipynb b/notebooks/2_testing-fitting-strategies/1_fit-forcefield-using-mbtr.ipynb index c6bde80..cb29ff0 100644 --- a/notebooks/2_testing-fitting-strategies/1_fit-forcefield-using-mbtr.ipynb +++ b/notebooks/2_testing-fitting-strategies/1_fit-forcefield-using-mbtr.ipynb @@ -29,10 +29,12 @@ "from dscribe.descriptors import MBTR\n", "from ase.vibrations import VibrationsData\n", "from ase.db import connect\n", + "from random import sample\n", "from pathlib import Path\n", "from tqdm import tqdm\n", "import numpy as np\n", "import warnings\n", + "import json\n", "import ase" ] }, @@ -55,7 +57,9 @@ }, "outputs": [], "source": [ - "db_path = '../1_explore-sampling-methods/data/along-axes/caffeine_pm7_None_d=5.00e-03-N=2.db'" + "db_path: str = '../1_explore-sampling-methods/data/along-axes/caffeine_pm7_None_d=5.00e-03-N=2.db'\n", + "overwrite: bool = False\n", + "max_size: int = 10000" ] }, { @@ -78,7 +82,29 @@ "run_name, sampling_options = Path(db_path).name[:-3].rsplit(\"_\", 1)\n", "exact_path = Path('../data/exact/') / f'{run_name}-ase.json'\n", "sampling_name = Path(db_path).parent.name\n", - "out_name = '_'.join([run_name, sampling_name, sampling_options])" + "out_name = '_'.join([run_name, sampling_name, sampling_options])\n", + "out_dir = Path('data/mbtr/')" + ] + }, + { + "cell_type": "markdown", + "id": "41f31375-9cf5-412b-949c-406711358781", + "metadata": {}, + "source": [ + "Skip if done" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "75d22086-f020-40d7-8327-1154491b9821", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "if (out_dir / f'{out_name}-full.json').exists() and not overwrite:\n", + " raise ValueError('Already done!')" ] }, { @@ -104,6 +130,28 @@ "print(f'Loaded {len(data)} structures')" ] }, + { + "cell_type": "markdown", + "id": "0c8aae57-1863-4bad-a56b-31f7b8a6062b", + "metadata": {}, + "source": [ + "Downsample if desired" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e2dfe036-a173-41ff-817b-2e92349b9704", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "if max_size is not None and len(data) > max_size:\n", + " data = sample(data, max_size)\n", + " print(f'Downselected to {len(data)}')" + ] + }, { "cell_type": "markdown", "id": "cb1a8e03-b045-49a4-95fd-61636a48fbad", @@ -178,7 +226,7 @@ "source": [ "test_y = np.array([a.get_potential_energy() for a in test_data])\n", "baseline_y = np.abs(test_y - test_y.mean()).mean()\n", - "print(f'Baseline score: {baseline_y*1000:.2e} meV/atom')`" + "print(f'Baseline score: {baseline_y*1000:.2e} meV/atom')" ] }, { @@ -679,7 +727,7 @@ "source": [ "out_dir = Path('data/mbtr')\n", "out_dir.mkdir(exist_ok=True, parents=True)\n", - "with open(f'data/mbtr/{out_name}.json', 'w') as fp:\n", + "with open(f'data/mbtr/{out_name}-full.json', 'w') as fp:\n", " approx_vibs.write(fp)" ] }, @@ -715,14 +763,20 @@ "outputs": [], "source": [ "zpes = []\n", - "for count in tqdm(steps):\n", - " with warnings.catch_warnings():\n", - " warnings.simplefilter(\"ignore\")\n", - " hess_model = model.train(data[:count])\n", - " \n", - " approx_hessian = model.mean_hessian(hess_model)\n", - " approx_vibs = VibrationsData.from_2d(data[0], approx_hessian)\n", - " zpes.append(approx_vibs.get_zero_point_energy())" + "with open(f'data/mbtr/{out_name}-increment.json', 'w') as fp:\n", + " for count in tqdm(steps):\n", + " with warnings.catch_warnings():\n", + " warnings.simplefilter(\"ignore\")\n", + " hess_model = model.train(data[:count])\n", + "\n", + " approx_hessian = model.mean_hessian(hess_model)\n", + " \n", + " # Save the incremental\n", + " print(json.dumps({'count': int(count), 'hessian': approx_hessian.tolist()}), file=fp)\n", + " \n", + " # Compute the ZPE\n", + " approx_vibs = VibrationsData.from_2d(data[0], approx_hessian)\n", + " zpes.append(approx_vibs.get_zero_point_energy())" ] }, { diff --git a/notebooks/2_testing-fitting-strategies/run-all-dbs.sh b/notebooks/2_testing-fitting-strategies/run-all-dbs.sh new file mode 100644 index 0000000..26e0fdc --- /dev/null +++ b/notebooks/2_testing-fitting-strategies/run-all-dbs.sh @@ -0,0 +1,8 @@ +#! /bin/bash + +notebook=$1 +dbs=$(find ../1_explore-sampling-methods/data/ -name "caffeine_pm7_None*.db") +for db in $dbs; do + echo $db + papermill -p db_path "$db" -p max_size 5000 $notebook last.ipynb +done diff --git a/notebooks/3_consolidate-results/0_compare-sampling-strategies-with-mbtr.ipynb b/notebooks/3_consolidate-results/0_compare-sampling-strategies-with-mbtr.ipynb new file mode 100644 index 0000000..fccbb2b --- /dev/null +++ b/notebooks/3_consolidate-results/0_compare-sampling-strategies-with-mbtr.ipynb @@ -0,0 +1,903 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "b592b18c-a459-4563-99fd-2ecd66474ed5", + "metadata": {}, + "source": [ + "# Compare Sampling Strategies using an MBTR Forcefield\n", + "Here, we hold our learning strategy constant and vary the strategies used in sampling" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "4f6bfbb4-9d13-45d0-88a6-7470111204fe", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "from matplotlib import pyplot as plt\n", + "from jitterbug.compare import compare_hessians\n", + "from ase.vibrations import VibrationsData\n", + "from dataclasses import asdict\n", + "from pathlib import Path\n", + "from ase.io import read\n", + "from tqdm import tqdm\n", + "import pandas as pd\n", + "import numpy as np\n", + "import json\n", + "import re" + ] + }, + { + "cell_type": "markdown", + "id": "f4e97ce7-31c1-4542-9aff-b06c83bbefd8", + "metadata": {}, + "source": [ + "Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "02e04cef-054c-47ad-9334-d1cf6d4412e0", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "target_mol = '../data/exact/caffeine_pm7_None.xyz'\n", + "target_method = '../2_testing-fitting-strategies/data/mbtr/'" + ] + }, + { + "cell_type": "markdown", + "id": "8874ea91-b4f3-432a-bd28-0d33b50e24ee", + "metadata": {}, + "source": [ + "## Load the Exact Result\n", + "The target molecule filename determines which molecule we'll look for. The name includes both the molecule name and method used to evaluate the hessian" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "db22a33d-e70a-4e7b-aad7-39aa4e552804", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "target_mol = Path(target_mol)\n", + "mol_name = target_mol.name[:-4]\n", + "atoms = read(target_mol)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "11895c35-43e9-4880-af59-be157df37b55", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "exact_hess = VibrationsData.read(target_mol.parent / f'{mol_name}-ase.json')\n", + "exact_hess" + ] + }, + { + "cell_type": "markdown", + "id": "16c37359-0d6e-4299-8f40-8fc78708e691", + "metadata": {}, + "source": [ + "## Find All Fittings\n", + "Find the approximate hessians produced using each method" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "20b2213c-d98f-47d2-a2ed-b628f00e7cf7", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "all_hessians = list(Path(target_method).glob(f\"{mol_name}_*-increment.json\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "2056b382-8566-4f2e-a574-961a3268d3c3", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/15 [00:00 list[dict[str, object], str]:\n", + " \"\"\"Load the Hessian and parse the metadata from the filename\n", + " \n", + " Args:\n", + " path: Path to the run path\n", + " Returns:\n", + " Dictionary the includes the metadata:\n", + " \"\"\"\n", + " \n", + " # Get some of the basic information\n", + " _, sampling_method, sampling_options_str = path.name[:-15].rsplit(\"_\", 2)\n", + " sampling_options = dict(x.split(\"=\") for x in re.split(\"-([^\\d]+=.*)\", sampling_options_str) if len(x) > 0)\n", + " \n", + " # For each, load the Hessian on the full dataset and compare to exact answer\n", + " output = []\n", + " with path.open() as fp:\n", + " for line in fp:\n", + " record = json.loads(line)\n", + " compare = compare_hessians(exact_hess.get_atoms(), exact_hess.get_hessian_2d(), record['hessian'])\n", + " output.append({\n", + " 'sampling_method': sampling_method,\n", + " 'options': sampling_options_str,\n", + " 'size': record['count'],\n", + " **sampling_options,\n", + " **asdict(compare)\n", + " }) \n", + " return output\n", + "all_results = pd.DataFrame(sum([load_hessian(path) for path in tqdm(all_hessians)], []))\n", + "print(f'Loaded {len(all_results)} approximate hessians')" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "db4ec32b-2b6b-44a5-8bdb-dc95feb6e395", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sampling_methodoptionssizedNzpezpe_errorcpcp_errorhh_errortempsvib_freqsvib_errorsvib_mae
0along-axesd=1.00e-02-N=851.00e-02814.200987-95.264864[0.000397593092606875, 0.002386964975551445, 0...[0.0036833759055666425, 0.0032014502665919545,...[14.201069696784895, 14.20544123215772, 14.216...[95.26810423257896, 95.27849857026263, 95.2846...[1.0, 3.9291338582677167, 6.858267716535433, 9...[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[365.26521782818406, 345.1085529599917, 194.17...1027.154434
1along-axesd=1.00e-02-N=8911.00e-0285.163324-104.302527[0.0019348995359107907, 0.008929018823987386, ...[0.002146069462262727, -0.0033406035818439863,...[5.163950489872683, 5.179473698114745, 5.21578...[104.30522343949119, 104.3044661043056, 104.28...[1.0, 3.9291338582677167, 6.858267716535433, 9...[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[136.3752974454976, 116.1372867519512, 97.3719...1113.227006
2along-axesd=1.00e-02-N=81771.00e-0286.876866-102.588985[0.0020382429986622493, 0.008341448274375083, ...[0.0020427259995112685, -0.0027530330322316837...[6.878002515584569, 6.8926516980488675, 6.9274...[102.59117141377929, 102.59128810437149, 102.5...[1.0, 3.9291338582677167, 6.858267716535433, 9...[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[96.75265542338498, 90.41924803705203, 74.0875...1092.500033
3along-axesd=1.00e-02-N=82631.00e-02889.958157-19.507693[0.0014020270591578593, 0.0023843729911753324,...[0.0026789419390156584, 0.003204042250968067, ...[89.9587466410539, 89.96445676642922, 89.97267...[19.510427288309955, 19.51948303599113, 19.528...[1.0, 3.9291338582677167, 6.858267716535433, 9...[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.446...[111.37319602907411, 108.7601947486991, 66.937...263.944079
4along-axesd=1.00e-02-N=83501.00e-02893.839129-15.626722[1.6382310852832182e-08, 0.0004830509673205082...[0.004080952615862665, 0.005105364274822891, 0...[93.83912881523764, 93.83954779804499, 93.8423...[15.630045114126219, 15.644392004375362, 15.65...[1.0, 3.9291338582677167, 6.858267716535433, 9...[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[118.00864162563397, 110.49254554846664, 100.8...237.597413
\n", + "
" + ], + "text/plain": [ + " sampling_method options size d N zpe zpe_error \\\n", + "0 along-axes d=1.00e-02-N=8 5 1.00e-02 8 14.200987 -95.264864 \n", + "1 along-axes d=1.00e-02-N=8 91 1.00e-02 8 5.163324 -104.302527 \n", + "2 along-axes d=1.00e-02-N=8 177 1.00e-02 8 6.876866 -102.588985 \n", + "3 along-axes d=1.00e-02-N=8 263 1.00e-02 8 89.958157 -19.507693 \n", + "4 along-axes d=1.00e-02-N=8 350 1.00e-02 8 93.839129 -15.626722 \n", + "\n", + " cp \\\n", + "0 [0.000397593092606875, 0.002386964975551445, 0... \n", + "1 [0.0019348995359107907, 0.008929018823987386, ... \n", + "2 [0.0020382429986622493, 0.008341448274375083, ... \n", + "3 [0.0014020270591578593, 0.0023843729911753324,... \n", + "4 [1.6382310852832182e-08, 0.0004830509673205082... \n", + "\n", + " cp_error \\\n", + "0 [0.0036833759055666425, 0.0032014502665919545,... \n", + "1 [0.002146069462262727, -0.0033406035818439863,... \n", + "2 [0.0020427259995112685, -0.0027530330322316837... \n", + "3 [0.0026789419390156584, 0.003204042250968067, ... \n", + "4 [0.004080952615862665, 0.005105364274822891, 0... \n", + "\n", + " h \\\n", + "0 [14.201069696784895, 14.20544123215772, 14.216... \n", + "1 [5.163950489872683, 5.179473698114745, 5.21578... \n", + "2 [6.878002515584569, 6.8926516980488675, 6.9274... \n", + "3 [89.9587466410539, 89.96445676642922, 89.97267... \n", + "4 [93.83912881523764, 93.83954779804499, 93.8423... \n", + "\n", + " h_error \\\n", + "0 [95.26810423257896, 95.27849857026263, 95.2846... \n", + "1 [104.30522343949119, 104.3044661043056, 104.28... \n", + "2 [102.59117141377929, 102.59128810437149, 102.5... \n", + "3 [19.510427288309955, 19.51948303599113, 19.528... \n", + "4 [15.630045114126219, 15.644392004375362, 15.65... \n", + "\n", + " temps \\\n", + "0 [1.0, 3.9291338582677167, 6.858267716535433, 9... \n", + "1 [1.0, 3.9291338582677167, 6.858267716535433, 9... \n", + "2 [1.0, 3.9291338582677167, 6.858267716535433, 9... \n", + "3 [1.0, 3.9291338582677167, 6.858267716535433, 9... \n", + "4 [1.0, 3.9291338582677167, 6.858267716535433, 9... \n", + "\n", + " vib_freqs \\\n", + "0 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "1 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "2 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "3 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.446... \n", + "4 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "\n", + " vib_errors vib_mae \n", + "0 [365.26521782818406, 345.1085529599917, 194.17... 1027.154434 \n", + "1 [136.3752974454976, 116.1372867519512, 97.3719... 1113.227006 \n", + "2 [96.75265542338498, 90.41924803705203, 74.0875... 1092.500033 \n", + "3 [111.37319602907411, 108.7601947486991, 66.937... 263.944079 \n", + "4 [118.00864162563397, 110.49254554846664, 100.8... 237.597413 " + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "all_results.head()" + ] + }, + { + "cell_type": "markdown", + "id": "0f243a23-ed89-4f5e-aeac-f23722ef10af", + "metadata": {}, + "source": [ + "Coerce columns I know should be numeric" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "2d315d44-7564-468f-bd03-8bbedf8b424c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "for col in ['d', 'N']:\n", + " all_results[col] = pd.to_numeric(all_results[col])" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "eee03c3a-4a53-4e84-995a-c2191e6f6332", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
sampling_methodoptionssizedNzpezpe_errorcpcp_errorhh_errortempsvib_freqsvib_errorsvib_mae
0along-axesd=1.00e-02-N=850.018.014.200987-95.264864[0.000397593092606875, 0.002386964975551445, 0...[0.0036833759055666425, 0.0032014502665919545,...[14.201069696784895, 14.20544123215772, 14.216...[95.26810423257896, 95.27849857026263, 95.2846...[1.0, 3.9291338582677167, 6.858267716535433, 9...[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[365.26521782818406, 345.1085529599917, 194.17...1027.154434
1along-axesd=1.00e-02-N=8910.018.05.163324-104.302527[0.0019348995359107907, 0.008929018823987386, ...[0.002146069462262727, -0.0033406035818439863,...[5.163950489872683, 5.179473698114745, 5.21578...[104.30522343949119, 104.3044661043056, 104.28...[1.0, 3.9291338582677167, 6.858267716535433, 9...[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[136.3752974454976, 116.1372867519512, 97.3719...1113.227006
2along-axesd=1.00e-02-N=81770.018.06.876866-102.588985[0.0020382429986622493, 0.008341448274375083, ...[0.0020427259995112685, -0.0027530330322316837...[6.878002515584569, 6.8926516980488675, 6.9274...[102.59117141377929, 102.59128810437149, 102.5...[1.0, 3.9291338582677167, 6.858267716535433, 9...[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[96.75265542338498, 90.41924803705203, 74.0875...1092.500033
3along-axesd=1.00e-02-N=82630.018.089.958157-19.507693[0.0014020270591578593, 0.0023843729911753324,...[0.0026789419390156584, 0.003204042250968067, ...[89.9587466410539, 89.96445676642922, 89.97267...[19.510427288309955, 19.51948303599113, 19.528...[1.0, 3.9291338582677167, 6.858267716535433, 9...[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.446...[111.37319602907411, 108.7601947486991, 66.937...263.944079
4along-axesd=1.00e-02-N=83500.018.093.839129-15.626722[1.6382310852832182e-08, 0.0004830509673205082...[0.004080952615862665, 0.005105364274822891, 0...[93.83912881523764, 93.83954779804499, 93.8423...[15.630045114126219, 15.644392004375362, 15.65...[1.0, 3.9291338582677167, 6.858267716535433, 9...[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...[118.00864162563397, 110.49254554846664, 100.8...237.597413
\n", + "
" + ], + "text/plain": [ + " sampling_method options size d N zpe zpe_error \\\n", + "0 along-axes d=1.00e-02-N=8 5 0.01 8.0 14.200987 -95.264864 \n", + "1 along-axes d=1.00e-02-N=8 91 0.01 8.0 5.163324 -104.302527 \n", + "2 along-axes d=1.00e-02-N=8 177 0.01 8.0 6.876866 -102.588985 \n", + "3 along-axes d=1.00e-02-N=8 263 0.01 8.0 89.958157 -19.507693 \n", + "4 along-axes d=1.00e-02-N=8 350 0.01 8.0 93.839129 -15.626722 \n", + "\n", + " cp \\\n", + "0 [0.000397593092606875, 0.002386964975551445, 0... \n", + "1 [0.0019348995359107907, 0.008929018823987386, ... \n", + "2 [0.0020382429986622493, 0.008341448274375083, ... \n", + "3 [0.0014020270591578593, 0.0023843729911753324,... \n", + "4 [1.6382310852832182e-08, 0.0004830509673205082... \n", + "\n", + " cp_error \\\n", + "0 [0.0036833759055666425, 0.0032014502665919545,... \n", + "1 [0.002146069462262727, -0.0033406035818439863,... \n", + "2 [0.0020427259995112685, -0.0027530330322316837... \n", + "3 [0.0026789419390156584, 0.003204042250968067, ... \n", + "4 [0.004080952615862665, 0.005105364274822891, 0... \n", + "\n", + " h \\\n", + "0 [14.201069696784895, 14.20544123215772, 14.216... \n", + "1 [5.163950489872683, 5.179473698114745, 5.21578... \n", + "2 [6.878002515584569, 6.8926516980488675, 6.9274... \n", + "3 [89.9587466410539, 89.96445676642922, 89.97267... \n", + "4 [93.83912881523764, 93.83954779804499, 93.8423... \n", + "\n", + " h_error \\\n", + "0 [95.26810423257896, 95.27849857026263, 95.2846... \n", + "1 [104.30522343949119, 104.3044661043056, 104.28... \n", + "2 [102.59117141377929, 102.59128810437149, 102.5... \n", + "3 [19.510427288309955, 19.51948303599113, 19.528... \n", + "4 [15.630045114126219, 15.644392004375362, 15.65... \n", + "\n", + " temps \\\n", + "0 [1.0, 3.9291338582677167, 6.858267716535433, 9... \n", + "1 [1.0, 3.9291338582677167, 6.858267716535433, 9... \n", + "2 [1.0, 3.9291338582677167, 6.858267716535433, 9... \n", + "3 [1.0, 3.9291338582677167, 6.858267716535433, 9... \n", + "4 [1.0, 3.9291338582677167, 6.858267716535433, 9... \n", + "\n", + " vib_freqs \\\n", + "0 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "1 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "2 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "3 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.446... \n", + "4 [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ... \n", + "\n", + " vib_errors vib_mae \n", + "0 [365.26521782818406, 345.1085529599917, 194.17... 1027.154434 \n", + "1 [136.3752974454976, 116.1372867519512, 97.3719... 1113.227006 \n", + "2 [96.75265542338498, 90.41924803705203, 74.0875... 1092.500033 \n", + "3 [111.37319602907411, 108.7601947486991, 66.937... 263.944079 \n", + "4 [118.00864162563397, 110.49254554846664, 100.8... 237.597413 " + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "all_results.head()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "a238ef1e-eae0-4195-a254-7c34fc63cc8d", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "along-axes 133\n", + "random-dir-same-dist 48\n", + "random-dir-variable-dist 48\n", + "Name: sampling_method, dtype: int64" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "all_results['sampling_method'].value_counts()" + ] + }, + { + "cell_type": "markdown", + "id": "c04c6df3-882c-4ab8-954b-ff316ff1134c", + "metadata": {}, + "source": [ + "## Compute Performance Metrics\n", + "Get things like the error in ZPE or vibrational frequencies" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "38d7ff93-3cd6-4489-96fb-5a1f6a917465", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "subset = all_results.query('d==0.01 and (sampling_method != \"along-axes\" or N == 4)')" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "16cdf7cf-4d09-415e-8996-3222ee33307b", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(3.5, 2.5))\n", + "\n", + "for gid, group in subset.groupby('sampling_method'):\n", + " ax.plot(group['size'], group['vib_mae'], '--o', label=gid)\n", + " \n", + "ax.legend()\n", + "\n", + "ax.set_yscale('log')\n", + "\n", + "ax.set_xlabel('Training Size')\n", + "ax.set_ylabel('Vibration MAE (cm$^{-1}$)')\n", + "\n", + "fig.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "7dc273cc-6b8e-47a9-8f58-31bd385368da", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(3.5, 2.5))\n", + "\n", + "for gid, group in subset.groupby('sampling_method'):\n", + " ax.plot(group['size'], group['zpe_error'].abs(), '--o', label=gid)\n", + " \n", + "ax.legend()\n", + "\n", + "ax.set_yscale('log')\n", + "\n", + "ax.set_xlabel('Training Size')\n", + "ax.set_ylabel('ZPE Error (kcal/mol)')\n", + "\n", + "fig.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "id": "3d1cd9a6-ec4d-4dbe-9e92-188c401dd548", + "metadata": {}, + "source": [ + "It seems like random sampling is preferred, and it is especially stable if we sample random directions" + ] + }, + { + "cell_type": "markdown", + "id": "9704b823-18f4-4195-884d-12fadd602993", + "metadata": {}, + "source": [ + "## Explore Effect of Sampling Size\n", + "What is the best magnitude?" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "edbf9938-54b0-4182-b9e2-8de6ec96404e", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "best_strategy = all_results.query('sampling_method==\"random-dir-variable-dist\"')" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "b0e1309b-a6dd-4728-9583-e8ec5d8a765e", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, axs = plt.subplots(3, 1, figsize=(3.5, 3.8), sharex=True)\n", + "\n", + "for gid, group in best_strategy.groupby('d'):\n", + " axs[0].plot(group['size'], group['vib_mae'], '--o', label=f'd={gid}')\n", + " axs[1].plot(group['size'], group['zpe_error'].abs(), '--o', label=f'd={gid}')\n", + " axs[2].plot(group['size'], group['cp_error'].apply(np.array).apply(np.abs).apply(np.mean), '--o', label=f'd={gid}')\n", + " \n", + " \n", + "# Labels\n", + "axs[0].legend()\n", + "axs[0].set_ylabel('Vibration MAE\\n(cm$^{-1}$)')\n", + "axs[1].set_ylabel('ZPE Error\\n(kcal/mol)')\n", + "axs[2].set_ylabel('$C_p$ Error\\n(kcal/mol/K)')\n", + "\n", + "for ax in axs:\n", + " ax.set_yscale('log')\n", + "\n", + "axs[-1].set_xlabel('Training Size')\n", + "\n", + "fig.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "id": "0ee09462-b4e4-4a3a-9f14-dea85668da0b", + "metadata": {}, + "source": [ + "Evaluate the enthalpy differences" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "bc042000-02a6-4b07-ade7-849c74a4fadf", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "best_model = best_strategy.query('d==0.01').sort_values('size').tail().iloc[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "137b0999-98fe-414b-ae6c-c7c39f37314a", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0, 0.5, '$H$ (kcal/mol)')" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(3.5, 2.5))\n", + "\n", + "ax.plot(best_model['temps'], best_model['h'], 'r', label='Approx')\n", + "ax.plot(best_model['temps'], np.add(best_model['h'], best_model['h_error']), '--k', label='True')\n", + "\n", + "ax.legend()\n", + "ax.set_xlabel('Temp (K)')\n", + "ax.set_ylabel('$H$ (kcal/mol)')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "95680e61-587f-4cff-937b-583196fe6193", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/notebooks/3_consolidate-results/README.md b/notebooks/3_consolidate-results/README.md new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml index 0104d3c..738120b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ "colmena==0.5.*", "parsl>=2023.04", "ase>3.22", + "pmutt>=1.4.9", "tqdm" ] diff --git a/tests/files/water-hessian.json b/tests/files/water-hessian.json new file mode 100644 index 0000000..344dbfa --- /dev/null +++ b/tests/files/water-hessian.json @@ -0,0 +1 @@ +{"atoms": {"numbers": {"__ndarray__": [[3], "int64", [8, 1, 1]]}, "positions": {"__ndarray__": [[3, 3], "float64", [2.537095624894952, -0.33431088490549643, 4.820167419625937e-10, 3.2968336037162858, 0.2449868198104976, 0.0, 1.7768707898269531, 0.2443240650949987, 0.0]]}, "cell": {"__ndarray__": [[3, 3], "float64", [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]]}, "pbc": {"__ndarray__": [[3], "bool", [false, false, false]]}, "__ase_objtype__": "atoms"}, "hessian": {"__ndarray__": [[3, 3, 3, 3], "float64", [34.30044462022372, 0.007332869969590315, 1.0841025975148397e-06, -17.140700636997888, -13.067917980191837, 0.0005453036065499644, -17.15972988989207, 13.060585110222245, -0.0005463877091474738, 0.007332869969590315, 20.84404714703323, -0.0009952061845186228, -8.62325018884728, -10.433245119503493, 0.0004986871948568262, 8.615917318877688, -10.412678609126038, 0.0004976030922593114, 1.0841025975148397e-06, -0.0009952061845186228, 0.0034994831847778725, 0.0005843313000605094, 0.000497603092259306, -0.0007339374585175403, -0.0005843313000604877, 0.0004976030922593006, -0.0025866687976703858, -17.140700636997888, -8.62325018884728, 0.0005843313000605094, 21.417006210845795, 10.847734402021727, -0.0005149487338195705, -4.276315330771283, -2.224277149578323, -6.93825662409389e-05, -13.067917980191837, -10.433245119503493, 0.000497603092259306, 10.847734402021727, 10.200150051806633, -0.00047158462991895255, 2.2203299320207717, 0.23403390054630657, -2.6018462340356152e-05, 0.0005453036065499644, 0.0004986871948568262, -0.0007339374585175403, -0.0005149487338195705, -0.00047158462991895255, 0.0025324636677946442, -3.035487273041551e-05, -2.6018462340356152e-05, -0.00188742262227332, -17.15972988989207, 8.615917318877688, -0.0005843313000604877, -4.276315330771283, 2.2203299320207717, -3.035487273041551e-05, 21.43603654784257, -10.836455398597181, 0.0006146861727909033, 13.060585110222245, -10.412678609126038, 0.0004976030922593006, -2.224277149578323, 0.23403390054630657, -2.6018462340356152e-05, -10.836455398597181, 10.179582457326582, -0.0004705005273214404, -0.0005463877091474738, 0.0004976030922593114, -0.0025866687976703858, -6.93825662409389e-05, -2.6018462340356152e-05, -0.00188742262227332, 0.0006146861727909033, -0.0004705005273214404, 0.004384110904349974]]}, "indices": null, "__ase_objtype__": "vibrationsdata"} \ No newline at end of file diff --git a/tests/test_compare.py b/tests/test_compare.py new file mode 100644 index 0000000..8bde1e1 --- /dev/null +++ b/tests/test_compare.py @@ -0,0 +1,21 @@ +from pathlib import Path + +from ase.vibrations import VibrationsData +from pytest import fixture +import numpy as np + +from jitterbug.compare import compare_hessians + +_test_files = Path(__file__).parent / 'files' + + +@fixture() +def example_hess() -> VibrationsData: + return VibrationsData.read(_test_files / 'water-hessian.json') + + +def test_compare(example_hess): + comp = compare_hessians(example_hess.get_atoms(), example_hess.get_hessian_2d(), example_hess.get_hessian_2d()) + assert comp.zpe_error == 0. + assert np.ndim(comp.cp_error) == 1 + assert np.mean(comp.cp_error) == 0.