From b7d6b6ab311e4b99c0fdcd8392de348d7bfc3b4e Mon Sep 17 00:00:00 2001 From: Logan Ward Date: Wed, 13 Sep 2023 17:42:16 -0400 Subject: [PATCH] Add computing Hessians using MBTR forcefield (#8) * Add a ASE calculator for MBTR * Move back to Py3.9 for dscribe I get an error "from collections import Iterable" stemming from the sparse module used by dscribe * Add force computation * Add ability to compute Hessian from MBTR * Remove unused import * Add an example notebook with MBTR * Adjust the scale to be a unit normal Should reduce the likelihood of numerical issues * Update proof-of-concept with more data, better KRR --- envs/environment-cpu.yml | 5 +- jitterbug/model/base.py | 6 +- jitterbug/model/mbtr.py | 96 +++ .../1_compute-random-offsets.ipynb | 95 ++- .../2_approximate-hessians.ipynb | 73 +- .../3_approximate-hessians-with-mbtr.ipynb | 745 ++++++++++++++++++ pyproject.toml | 2 +- tests/models/test_mbtr.py | 31 + 8 files changed, 1001 insertions(+), 52 deletions(-) create mode 100644 jitterbug/model/mbtr.py create mode 100644 notebooks/proof-of-concept/3_approximate-hessians-with-mbtr.ipynb create mode 100644 tests/models/test_mbtr.py diff --git a/envs/environment-cpu.yml b/envs/environment-cpu.yml index f3719b2..7dc4eae 100644 --- a/envs/environment-cpu.yml +++ b/envs/environment-cpu.yml @@ -7,7 +7,7 @@ channels: - pytorch - conda-forge/label/libint_dev dependencies: - - python==3.10.* + - python==3.9.* # Standard data analysis tools - pandas==1.* @@ -20,6 +20,9 @@ dependencies: # Quantum chemistry - psi4==1.8.* + # Interatomic forcefields + - dscribe==2.1.* + # Use Conda PyTorch to avoid OpenMP disagreement with other libraries - pytorch==2.0.* - cpuonly diff --git a/jitterbug/model/base.py b/jitterbug/model/base.py index dc00870..48cd9b7 100644 --- a/jitterbug/model/base.py +++ b/jitterbug/model/base.py @@ -16,11 +16,13 @@ def train(self, data: list[Atoms]) -> object: """ raise NotImplementedError() - def mean_hessian(self, model: object) -> list[np.ndarray]: + def mean_hessian(self, model: object) -> np.ndarray: """Produce the most-likely Hessian given the model Args: - model: Model trained by this + model: Model trained by this class + Returns: + The most-likely Hessian given the model """ def sample_hessians(self, model: object, num_samples: int) -> list[np.ndarray]: diff --git a/jitterbug/model/mbtr.py b/jitterbug/model/mbtr.py new file mode 100644 index 0000000..f1e940c --- /dev/null +++ b/jitterbug/model/mbtr.py @@ -0,0 +1,96 @@ +"""Learn a potential energy surface with the MBTR representation + +MBTR is an easy route for learning a forcefield because it represents +a molecule as a single vector, which means we can case the learning +problem as a simple "molecule->energy" learning problem. Other methods, +such as SOAP, provided atomic-level features that must require an +extra step "molecule->atoms->energy/atom->energy". +""" +from shutil import rmtree + +from ase.calculators.calculator import Calculator, all_changes +from ase.vibrations import Vibrations +from ase import Atoms +from sklearn.linear_model import LinearRegression +from dscribe.descriptors import MBTR +import numpy as np + +from jitterbug.model.base import EnergyModel + + +class MBTRCalculator(Calculator): + """A learnable forcefield based on GPR and fingerprints computed using DScribe""" + + implemented_properties = ['energy', 'forces'] + default_parameters = { + 'descriptor': MBTR( + species=["H", "C", "N", "O"], + geometry={"function": "inverse_distance"}, + grid={"min": 0, "max": 1, "n": 100, "sigma": 0.1}, + weighting={"function": "exp", "scale": 0.5, "threshold": 1e-3}, + periodic=False, + normalization="l2", + ), + 'model': LinearRegression(), + 'intercept': 0., # Normalizing parameters + 'scale': 0. + } + + def calculate(self, atoms=None, properties=('energy', 'forces'), system_changes=all_changes): + # Compute the energy using the learned model + desc = self.parameters['descriptor'].create_single(atoms) + energy_no_int = self.parameters['model'].predict(desc[None, :]) + self.results['energy'] = energy_no_int[0] * self.parameters['scale'] + self.parameters['intercept'] + + # If desired, compute forces numerically + if 'forces' in properties: + # calculate_numerical_forces use that the calculation of the input atoms, + # even though it is a method of a calculator and not of the input atoms :shrug: + temp_atoms: Atoms = atoms.copy() + temp_atoms.calc = self + self.results['forces'] = self.calculate_numerical_forces(temp_atoms) + + def train(self, train_set: list[Atoms]): + """Train the embedded forcefield object + + Args: + train_set: List of Atoms objects containing at least the energy + """ + + # Determine the mean energy and subtract it off + energies = np.array([atoms.get_potential_energy() for atoms in train_set]) + self.parameters['intercept'] = energies.mean() + energies -= self.parameters['intercept'] + self.parameters['scale'] = energies.std() + energies /= self.parameters['scale'] + + # Compute the descriptors and use them to fit the model + desc = self.parameters['descriptor'].create(train_set) + self.parameters['model'].fit(desc, energies) + + +class MBTREnergyModel(EnergyModel): + """Use the MBTR representation to model the potential energy surface + + Args: + calc: Calculator used to fit the potential energy surface + reference: Reference structure at which we compute the Hessian + """ + + def __init__(self, calc: MBTRCalculator, reference: Atoms): + super().__init__() + self.calc = calc + self.reference = reference + + def train(self, data: list[Atoms]) -> MBTRCalculator: + self.calc.train(data) + return self.calc + + def mean_hessian(self, model: MBTRCalculator) -> np.ndarray: + self.reference.calc = model + try: + vib = Vibrations(self.reference, name='mbtr-temp') + vib.run() + return vib.get_vibrations().get_hessian_2d() + finally: + rmtree('mbtr-temp', ignore_errors=True) diff --git a/notebooks/proof-of-concept/1_compute-random-offsets.ipynb b/notebooks/proof-of-concept/1_compute-random-offsets.ipynb index 8bde676..10ae373 100644 --- a/notebooks/proof-of-concept/1_compute-random-offsets.ipynb +++ b/notebooks/proof-of-concept/1_compute-random-offsets.ipynb @@ -11,7 +11,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "id": "c6a28419-6831-4197-8973-00c5591e19cb", "metadata": { "tags": [] @@ -38,7 +38,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "id": "c6be56c5-a460-4acd-9b89-8c3d9c812f5f", "metadata": { "tags": [ @@ -51,7 +51,7 @@ "method = 'hf'\n", "basis = 'def2-svpd'\n", "threads = min(os.cpu_count(), 12)\n", - "step_size: float = 0.005 # Perturbation amount, used as maximum L2 norm" + "step_size: float = 0.01 # Perturbation amount, used as maximum L2 norm" ] }, { @@ -64,7 +64,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "id": "0b6794cd-477f-45a1-b96f-2332804ddb20", "metadata": {}, "outputs": [], @@ -83,12 +83,23 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "id": "ad9fd725-b1ba-4fec-ae41-959be0e540b3", "metadata": { "tags": [] }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "Atoms(symbols='O2N4C8H10', pbc=False, forces=..., calculator=SinglePointCalculator(...))" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "atoms = read(Path('data') / 'exact' / f'{run_name}.xyz')\n", "atoms" @@ -113,7 +124,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "id": "23502eea-0974-4248-8f19-e85447069c61", "metadata": { "tags": [] @@ -126,7 +137,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "bf1366fc-d9a7-4a98-b9c9-cb3a0209b406", "metadata": { "tags": [] @@ -146,7 +157,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 7, "id": "d4f21e81-5ec3-4877-a4d1-402077be2ee8", "metadata": { "tags": [] @@ -168,12 +179,22 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 8, "id": "0915595d-133a-43df-84fc-4ff6a3b538ea", "metadata": { "tags": [] }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + " Memory set to 3.815 GiB by Python driver.\n", + " Threads set to 12 by Python driver.\n" + ] + } + ], "source": [ "calc = Psi4(method=method, basis=basis, num_threads=threads, memory='4096MB')" ] @@ -188,12 +209,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "id": "e2a28593-2634-4bb7-ae5b-8f557937bda1", "metadata": { "tags": [] }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Need to run 2701 calculations for full accuracy.\n" + ] + } + ], "source": [ "n_atoms = len(atoms)\n", "to_compute = 3 * n_atoms + 3 * n_atoms * (3 * n_atoms + 1) // 2 + 1\n", @@ -202,12 +231,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "id": "8bf40523-dcaa-4046-a9c6-74e35178e87f", "metadata": { "tags": [] }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Already done 427. 2274 left to do.\n" + ] + } + ], "source": [ "with connect(db_path) as db:\n", " done = len(db)\n", @@ -219,7 +256,31 @@ "execution_count": null, "id": "a6fa1b33-defc-4b35-895d-052eb64453fb", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/2701 [00:00" ] @@ -280,7 +280,7 @@ }, { "cell_type": "code", - "execution_count": 124, + "execution_count": 13, "id": "00a10907-667a-413c-851d-d47f0eff092b", "metadata": {}, "outputs": [], @@ -298,7 +298,7 @@ }, { "cell_type": "code", - "execution_count": 125, + "execution_count": 14, "id": "d48893fd-df0d-4fa8-bfbe-0d04b71fbf1a", "metadata": { "tags": [] @@ -312,7 +312,7 @@ " [1.08009177e-03, 3.94902961e-03, 4.15881408e+00]])" ] }, - "execution_count": 125, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -323,7 +323,7 @@ }, { "cell_type": "code", - "execution_count": 126, + "execution_count": 15, "id": "9b311dea-5744-4211-81cb-40aa1183301e", "metadata": { "tags": [] @@ -337,7 +337,7 @@ " [8.59920621e-03, 6.10182117e-01, 2.73150623e+01]])" ] }, - "execution_count": 126, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -348,7 +348,7 @@ }, { "cell_type": "code", - "execution_count": 127, + "execution_count": 16, "id": "addd7bef-854a-4b9f-96e9-2aa01b652495", "metadata": { "tags": [] @@ -356,7 +356,7 @@ "outputs": [ { "data": { - "image/png": "", + "image/png": "", "text/plain": [ "
" ] @@ -389,7 +389,7 @@ }, { "cell_type": "code", - "execution_count": 128, + "execution_count": 17, "id": "abbbbfd6-7d17-4b93-880a-3352903b56c4", "metadata": { "tags": [] @@ -401,7 +401,7 @@ }, { "cell_type": "code", - "execution_count": 129, + "execution_count": 18, "id": "fdd80af3-8c18-40d8-b971-4a473bc91498", "metadata": {}, "outputs": [ @@ -411,7 +411,7 @@ "7.268532902551082" ] }, - "execution_count": 129, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -422,7 +422,7 @@ }, { "cell_type": "code", - "execution_count": 130, + "execution_count": 19, "id": "6b1af348-4bc9-4ced-9a12-44b3e49abe9c", "metadata": { "tags": [] @@ -434,7 +434,7 @@ "5.5067174465850215" ] }, - "execution_count": 130, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } @@ -462,7 +462,7 @@ }, { "cell_type": "code", - "execution_count": 131, + "execution_count": 20, "id": "bce41a81-6c88-4b0c-9d8d-0891d1832fd6", "metadata": { "tags": [] @@ -483,7 +483,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "id": "fe39ce86-1806-4367-8c86-e3ef58f81f84", "metadata": { "tags": [] @@ -493,7 +493,7 @@ "name": "stderr", "output_type": "stream", "text": [ - " 62%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 10/16 [00:00<00:00, 16.03it/s]" + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [00:04<00:00, 3.98it/s]\n" ] } ], @@ -519,12 +519,23 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "id": "1c6706a9-a27f-448f-81d4-957939bb2ca8", "metadata": { "tags": [] }, - "outputs": [], + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "fig, ax = plt.subplots(figsize=(3.5, 2))\n", "\n", @@ -572,7 +583,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.9.17" } }, "nbformat": 4, diff --git a/notebooks/proof-of-concept/3_approximate-hessians-with-mbtr.ipynb b/notebooks/proof-of-concept/3_approximate-hessians-with-mbtr.ipynb new file mode 100644 index 0000000..5fd8f8c --- /dev/null +++ b/notebooks/proof-of-concept/3_approximate-hessians-with-mbtr.ipynb @@ -0,0 +1,745 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "201346ec-3f5a-4235-b8ef-4a0051373865", + "metadata": {}, + "source": [ + "# Generate Approximate Hessians\n", + "Like the previous notebook, fit an approximate model and use that to compute the Hessian. Instead of treating the Hessian parameters as separate, we try here to fit a forcefield using the data." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "ebbbc7f5-3007-420f-861a-9f65f84436be", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "from matplotlib import pyplot as plt\n", + "from jitterbug.model.mbtr import MBTREnergyModel, MBTRCalculator\n", + "from sklearn.linear_model import LinearRegression, ElasticNetCV\n", + "from sklearn.model_selection import GridSearchCV\n", + "from sklearn.kernel_ridge import KernelRidge\n", + "from dscribe.descriptors import MBTR\n", + "from ase.vibrations import VibrationsData\n", + "from ase.db import connect\n", + "from pathlib import Path\n", + "from tqdm import tqdm\n", + "import numpy as np\n", + "import warnings" + ] + }, + { + "cell_type": "markdown", + "id": "85a147c1-2758-465b-bc54-dc373d73a0f3", + "metadata": {}, + "source": [ + "Configuration" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "99bd4c92-9a7b-4e88-ac45-dbf30fbfc9e0", + "metadata": { + "tags": [ + "parameters" + ] + }, + "outputs": [], + "source": [ + "molecule_name = 'caffeine'\n", + "method = 'hf'\n", + "basis = 'def2-svpd'\n", + "step_size: float = 0.01 # Perturbation amount, used as maximum L2 norm" + ] + }, + { + "cell_type": "markdown", + "id": "8505d400-8427-45b9-b626-3f9ca557d0c8", + "metadata": {}, + "source": [ + "Derived" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a8be3c37-bf1f-4ba4-ba8f-afff6d6bed7d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "run_name = f'{molecule_name}_{method}_{basis}'\n", + "out_dir = Path('data') / 'approx'\n", + "db_path = out_dir / f'{run_name}-random-d={step_size:.2e}.db'" + ] + }, + { + "cell_type": "markdown", + "id": "de1f6aac-b93e-45a7-98e6-ffd5205916a6", + "metadata": {}, + "source": [ + "## Read in the Data\n", + "Get all computations for the desired calculation and the exact solution" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "797b96d8-050c-4bdf-9815-586cfb5bc311", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded 1457 structures\n" + ] + } + ], + "source": [ + "with connect(db_path) as db:\n", + " data = [a.toatoms() for a in db.select('')]\n", + "print(f'Loaded {len(data)} structures')" + ] + }, + { + "cell_type": "markdown", + "id": "cb1a8e03-b045-49a4-95fd-61636a48fbad", + "metadata": {}, + "source": [ + "Read in the exact Hessian" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "7389208d-9323-492c-8fc5-d05a372206c6", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "with open(f'data/exact/{run_name}-ase.json') as fp:\n", + " exact_vibs = VibrationsData.read(fp)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "a9965595-532c-4067-ba24-7620bd977007", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "exact_hess = exact_vibs.get_hessian_2d()\n", + "exact_zpe = exact_vibs.get_zero_point_energy()" + ] + }, + { + "cell_type": "markdown", + "id": "04c60da8-4a1d-4ae3-b45d-b77e71fd598f", + "metadata": {}, + "source": [ + "## Fit a Hessian with All Data\n", + "Fit a model which explains the energy data by fitting a Hessian matrix using compressed sensing (i.e., Lasso)." + ] + }, + { + "cell_type": "markdown", + "id": "fe72ad76-2772-4094-a9b7-065be9a356d4", + "metadata": {}, + "source": [ + "Make the MBTR calculator using half the available data" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "5a5b4a37-bd58-4855-bc3e-85d4258a25c8", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 4min 28s, sys: 7.75 s, total: 4min 35s\n", + "Wall time: 23.4 s\n" + ] + } + ], + "source": [ + "%%time\n", + "mbtr = MBTRCalculator(\n", + " model=GridSearchCV(\n", + " KernelRidge(kernel='rbf', alpha=1e-6), {\n", + " 'alpha': np.logspace(-10, -7, 8),\n", + " 'gamma': np.logspace(-5, 5, 32)\n", + " }),\n", + " descriptor=MBTR(\n", + " species=[\"H\", \"C\", \"N\", \"O\"],\n", + " geometry={\"function\": \"distance\"},\n", + " grid={\"min\": 0, \"max\": 6, \"n\": 64, \"sigma\": 0.05},\n", + " weighting={\"function\": \"exp\", \"scale\": 0.1, \"threshold\": 1e-3},\n", + " periodic=False,\n", + " )\n", + ")\n", + "with warnings.catch_warnings():\n", + " warnings.simplefilter('ignore')\n", + " mbtr.train(data[:len(data) // 2])" + ] + }, + { + "cell_type": "markdown", + "id": "503240dd-b52c-4111-a024-ec44766940e5", + "metadata": {}, + "source": [ + "Plot the model performance" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "c0038a85-5a70-4a4e-b830-c3c54e5a8efc", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "pred_energy= [mbtr.get_potential_energy(x) * 1000 for x in data[len(data) // 2:]]\n", + "true_energy = [x.get_potential_energy() * 1000 for x in data[len(data) // 2:]]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "fba09717-7b2b-40a7-a6d3-543c40080d02", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0, 0.5, '$E-E_0$, True (meV)')" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(3.5, 3.5))\n", + "\n", + "base_energy = data[0].get_potential_energy() * 1000 # in meV\n", + "ax.scatter(np.subtract(pred_energy, base_energy), np.subtract(true_energy, base_energy), s=2)\n", + "\n", + "ax.set_xlim(ax.get_ylim())\n", + "ax.set_ylim(ax.get_ylim())\n", + "ax.plot(ax.get_xlim(), ax.get_xlim(), 'k--')\n", + "\n", + "ax.set_xlabel('$E-E_0$, ML (meV)')\n", + "ax.set_ylabel('$E-E_0$, True (meV)')" + ] + }, + { + "cell_type": "markdown", + "id": "71e8e883-2c4b-4929-a000-297233dabe96", + "metadata": {}, + "source": [ + "Build the ASE-compatible calculator" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "d6a7d756-37d3-44e0-b5e2-348d07c9d296", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "model = MBTREnergyModel(reference=data[0], calc=mbtr)\n", + "with warnings.catch_warnings():\n", + " warnings.simplefilter(\"ignore\")\n", + " hess_model = model.train(data)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "372a4089-88cb-47ae-a917-190bc26287ff", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'alpha': 1e-10, 'gamma': 2.1017480113324872e-05}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "hess_model.parameters['model'].best_params_" + ] + }, + { + "cell_type": "markdown", + "id": "aa509659-701d-4001-8cc7-980c9d999976", + "metadata": {}, + "source": [ + "Compare the forces estimated at a zero displacement to the true value" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "70d80f87-9983-4bd5-a6ae-b9c966b0d838", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "actual_forces = data[0].get_forces()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "f548b145-0aa8-47f7-802b-6b7232a74bd3", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "pred_forces = hess_model.get_forces(data[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "d7cd7762-6e12-4dcd-b564-67a33b18d9e0", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Maximum force: 8.96e-03 eV/Angstrom\n" + ] + } + ], + "source": [ + "print(f'Maximum force: {np.abs(pred_forces).max():.2e} eV/Angstrom')" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "425b77a9-7fd7-40da-af6f-eaed197c9ab6", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYUAAAC+CAYAAADa6ROSAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8pXeV/AAAACXBIWXMAAA9hAAAPYQGoP6dpAAATIklEQVR4nO3df2wb9f3H8ZftJBcnsV2a0IY2TqL1F999VX5MlbpWmgCxjShQoB1TmZgoHZoEncY0aUNCmmjFxqJ1K402aWzS+gPQtK3bAAFTq0E7GGNjP9pVbNpUCt/v2pS2pC0l56atE9uf7x9dPt+Zpuol9vWcu+dDOkW+2Of3x/mcXz7fJ5+LGWOMAACQFA+6AABA7SAUAAAWoQAAsAgFAIBFKAAALEIBAGARCgAAi1AAAFiEAgDAIhRCLBaL6dlnnw26DETY1q1bNW3atKDLmJCpWHM1EQpV8vvf/16JREI9PT0Telx3d7f6+/v9KQqYoHvuuUexWOy8xUu/Hq8vr1y5Um+++aZP1f6/qL+RV1Nd0AWExebNm/XFL35RP/rRj3Tw4EF1dnYGXRIwKT09PdqyZUvZOsdxJrWtZDKpZDJZjbJwiXCkUAXDw8Patm2b7r//ft1yyy3aunVr2e+fe+45LVq0SI2NjWpra9OKFSskSddff70OHDigL3/5y/YTmSStW7dO11xzTdk2+vv71d3dbW//+c9/1ic+8Qm1tbUpk8nouuuu0549e/xsJiLCcRy1t7eXLZdddpmkc32zs7NTjuNo1qxZeuCBByRduC9/8BP8WN/evHmzOjs71dLSovvvv1/FYlHr169Xe3u7ZsyYoUcffbSspscee0wLFy5Uc3Ozstms1qxZo1OnTkmSXn75Za1evVpDQ0P2udetWydJGhkZ0YMPPqjZs2erublZixcv1ssvv1y27a1bt6qzs1NNTU1avny5Tpw44cOrOnUQClXws5/9TAsWLNCCBQv02c9+Vlu2bNHY5LO/+tWvtGLFCt18883661//qp07d2rRokWSpKefflodHR165JFHdOTIER05csTzc+ZyOa1atUqvvvqqXn/9dc2bN0+9vb3K5XK+tBH4xS9+oY0bN+qHP/yh9u/fr2effVYLFy6UNLG+/Pbbb2v79u3asWOHfvKTn2jz5s26+eabdejQIb3yyiv61re+pa997Wt6/fXX7WPi8bi++93v6u9//7ueeOIJ7dq1Sw8++KAkaenSperv71c6nbbP/ZWvfEWStHr1ar322mv66U9/qjfeeEOf/vSn1dPTo/3790uS/vjHP+pzn/uc1qxZo7179+qGG27QN77xDb9ewqnBoGJLly41/f39xhhjRkdHTVtbm3nxxReNMcYsWbLE3HXXXRd8bFdXl9m4cWPZurVr15qrr766bN3GjRtNV1fXBbdTKBRMKpUyzz//vF0nyTzzzDMTaguibdWqVSaRSJjm5uay5ZFHHjEbNmww8+fPNyMjI+M+dry+vGXLFpPJZOzttWvXmqamJuO6rl130003me7ublMsFu26BQsWmL6+vgvWuW3bNtPa2nrB5zHGmLfeesvEYjHzzjvvlK2/8cYbzUMPPWSMMeYzn/mM6enpKfv9ypUrz9tWlHBOoUL79u3Tn/70Jz399NOSpLq6Oq1cuVKbN2/Wxz/+ce3du1ef//znq/68g4ODevjhh7Vr1y69++67KhaLOn36tA4ePFj150K03HDDDXr88cfL1k2fPl3Dw8Pq7+/Xhz70IfX09Ki3t1fLli1TXd3E3ka6u7uVSqXs7ZkzZyqRSCgej5etGxwctLd/85vf6Jvf/Kb+8Y9/yHVdFQoFnT17VsPDw2pubh73efbs2SNjjObPn1+2Pp/Pq7W1VZL0z3/+U8uXLy/7/ZIlS7Rjx44JtSlMCIUKbdq0SYVCQbNnz7brjDGqr6/XyZMnJ3WSLR6P26+fxoyOjpbdvueee3Ts2DH19/erq6tLjuNoyZIlGhkZmVxDgH9rbm7W3Llzz1s/ffp07du3Ty+++KJeeuklrVmzRt/+9rf1yiuvqL6+3vP2P3jfWCw27rpSqSRJOnDggHp7e3Xffffp61//uqZPn67f/e53uvfee8/bL/5TqVRSIpHQ7t27lUgkyn7X0tIiSeftZyAUKlIoFPTkk09qw4YN+uQnP1n2u0996lP68Y9/rKuuuko7d+7U6tWrx91GQ0ODisVi2brLL79cR48elTHGnrDbu3dv2X1effVVff/731dvb68kaWBgQMePH69Sy4DxJZNJ3Xrrrbr11lv1hS98QVdeeaX+9re/6SMf+ci4fbka/vKXv6hQKGjDhg32aGLbtm1l9xnvua+99loVi0UNDg7qYx/72Ljb/vCHP1x27kLSebejhlCowAsvvKCTJ0/q3nvvVSaTKfvdHXfcoU2bNmnjxo268cYbNWfOHN15550qFAravn27PUnW3d2t3/72t7rzzjvlOI7a2tp0/fXX69ixY1q/fr3uuOMO7dixQ9u3b1c6nbbbnzt3rp566iktWrRIruvqq1/9KkP/UBX5fF5Hjx4tW1dXV6cXXnhBxWJRixcvVlNTk5566iklk0l1dXVJGr8vV8OcOXNUKBT0ve99T8uWLdNrr72mH/zgB2X36e7u1qlTp7Rz505dffXVampq0vz583XXXXfp7rvv1oYNG3Tttdfq+PHj2rVrlxYuXKje3l498MADWrp0qdavX6/bb79dv/71ryP91ZEkTjRX4pZbbjG9vb3j/m737t1Gktm9e7f55S9/aa655hrT0NBg2trazIoVK+z9/vCHP5irrrrKOI5j/vPP8fjjj5tsNmuam5vN3XffbR599NGyE8179uwxixYtMo7jmHnz5pmf//zn553oEyeaMUGrVq0yks5bFixYYJ555hmzePFik06nTXNzs/noRz9qXnrpJfvY8fryeCeaPziIYtWqVea2224rW3fdddeZL33pS/b2Y489Zq644gqTTCbNTTfdZJ588kkjyZw8edLe57777jOtra1Gklm7dq0xxpiRkRHz8MMPm+7ublNfX2/a29vN8uXLzRtvvGEft2nTJtPR0WGSyaRZtmyZ+c53vhPpE80xY/hSDQBwDv+nAACwCAUAgDWlQiGfz2vdunXK5/NBl+KbKLRRik47qykqr1kU2lnLbZxS5xRc11Umk9HQ0FDZSJwwiUIbpei0s5qi8ppFoZ213MYpdaQAAPAXoQAAsCb9z2ulUkmHDx9WKpWy/3XrN9d1y36GURTaKF36dhpjlMvlNGvWrLI5diaL/u+fKLQziDZ63QcmfU7h0KFDymazky4QCMLAwIA6Ojoq3g79H1PVxfaBSR8pjM1yuO/N/WUzHobRe2erP59LrWlLJi5+pyksl8tp3rx5VeurY9vZvz/8/f+dUxeedC4sOlq8T+g3VeVyOc31sA94DoV8Pl82fGrsYi6pVKrmzp5X22h9+EMh3RTuUBgz2a96otz/h2LhD4V0KvyhMOZi+4DnL1f7+vqUyWTswqEzooT+j6jwfE7hg5+UXNdVNpvV4SNHQ/9J6cSZ8B8pXB7yIwXXddXe3j7pceEX6v9Hj4a//w/kwn+k0BmBIwXXdTXTwz7g+esjx3HkOE5VigOmGvo/oqLi6yk4Jw/KKbRUo5aaNbOlOvPC1zKjcP8N/Ro0Wv/eQdWPhvu162qeHnQJvjMm3EfKkiRT8nQ3/nkNAGARCgAAi1AAAFiEAgDAIhQAAFbFo4/ea+nQaCrc47QbE5dmwrMgMdhycqLQ/6fFR4IuwXcjJvyfj722MfyvBADAM0IBAGARCgAAi1AAAFiEAgDAqnj0UaohrnRDuLMlPnI66BJ8ZxJNQZcwJSVi55ZQKxaCrsB3DQlv8wJNZQ0x5j4CAEwQoQAAsAgFAIBFKAAALEIBAGBVPPoofjaneEM1Sqldo06457aRqtARIip9dlDp+jNBl+Gr95Mzgy7Bdw5zH1nhfyUAAJ4RCgAAi1AAAFiEAgDAIhQAAFbFg05ixihmTDVqqVn1oxGY+6iBuY8mw9Q1ydSH+7XLFHJBl+C7Ul0q6BJ8x9xHAIAJIxQAAJbnr4/y+bzy+by97bquLwUBtYj+j6jwfKTQ19enTCZjl2w262ddQE2h/yMqYsZ4O0s83ielbDarY//7ptLpcJ+kMYn6oEvwXdhPNLuuq5nt7RoaGlI6PfFpS6Lc/6Og1Bj+v6Hrupp5xayL7gOevz5yHEeO45y3vtTYEv4XNBb+Uy/hHj9Wefsu2P+TaZWS4Z4bKxaBKw96nRdoKmPuIwDAhBEKAACLUAAAWIQCAMAiFAAAFqEAALC4CqMXxttEUlNaBIbd+sKUQt8/TH1j0CX4rj4WdAX+q/e4i/NOAACwCAUAgEUoAAAsQgEAYBEKAACr4tFHR8+UNFwX7tEXHcVjQZfgu0L6iqBLmJL+lSuoRYWgy/DV3LrwX46z2NwadAk1gyMFAIBFKAAALEIBAGARCgAAi1AAAFgVjz5K1ceVagh3tgwWZwRdgu+mhfx6nCWf2tedblA63eDPxmuEO3pZ0CX4Lhny/i953wfC/W4OAJgQQgEAYBEKAACLUAAAWIQCAMCqePRRS+mMUqVwX8AtpXDP7SRJpVg66BJ8FffrylqlwrklxFoSQVfgPxOBKw963QfC/0oAADzz/BE/n88rn8/b267r+lIQUIvo/4gKz0cKfX19ymQydslms37WBdQU+j+iImaM8fR/buN9Uspmsxo88LbS6ZRvBdYEE4FzCo3hPqfguq7a29s1NDSkdHribb1Q/3/38KFJbQ+1xcTDfV5U8r4PeH4lHMeR4zhVKQ6Yauj/iIqK4zEXT0rxpmrUUrOK4T9Q0LSgC/CZX4OP8iahvAn38JyhfPh3gBnJ8Lcx5vEbD0YfAQAsQgEAYBEKAACLUAAAWIQCAMCqePTRcKGk+Gi4z9y3NwZdwaXA54PJOHamoDN14Z77qMsZDboE35lYuEdQSpI8zu/EOwEAwCIUAAAWoQAAsAgFAIBFKAAArIpHH80+fUjpREs1aqlZg/GuoEvw3fRIjLCqvq7CoNKF00GX4aujdbOCLsF3bUEXcAl4mg5bHCkAAP4DoQAAsAgFAIBFKAAALEIBAGBVPPrINDTKNCSrUUvNao2dCboE3xmFewQZJm9GfQTmPlK4r54neb/6IEcKAACLUAAAWIQCAMAiFAAAFqEAALAqHn30bt3lOl2frkYtNavNCf/IhMTwiaBL8FX8dM6X7R5vbFc+Ge7+P70uAv3fPRJ0Cb5L5LztAxwpAAAsz0cK+Xxe+Xze3nZd15eCgFpE/0dUeD5S6OvrUyaTsUs2m/WzLqCm0P8RFZ5D4aGHHtLQ0JBdBgYG/KwLqCn0f0SF56+PHMeR4zh+1gLULPo/oqLi0UcziieULoxUo5ba5c/AlZpSSs0MugRflYr1vmy31QwrbUI+XiN3NugKfFdMXxF0Cb4rqtnT/ULemwEAE0EoAAAsQgEAYBEKAACLUAAAWBWPPhpMtOpMXbjnfmlLhn/ul6F8KegSfJXzqX3vJ1pUTKR82XatyCQzQZfgu2NnikGX4LucxzZypAAAsAgFAIBFKAAALEIBAGARCgAAq+LRR63JhNIhH53z3tnwj0xobQz33zDu+PP5p7k+rpb6cH+2ej/kI9Mk6fKQv4dJkjPqrY3h7s0AgAkhFAAAFqEAALAIBQCARSgAACxCAQBgVTwkdbR0bgmzVicWdAm+iw+fCLoEX8VPR+Caqj5JNYT/s2P89MmgS/Bd/Iy3fSD8f20AgGeEAgDAIhQAABahAACwCAUAgFXx6COEw3t1lwVdgq9yCX8mPCuUjAol48u2a0VDIvyj707Whf+SozmPf0eOFAAAlucjhXw+r3w+b2+7rutLQUAtov8jKjwfKfT19SmTydglm836WRdQU+j/iIqYMcbTF6LjfVLKZrMaOHxU6XTatwJrQUMs5P+yLen9kaAr8FfOdTW/a5aGhoYm1V8v1P8PvnMk/P0/AucU3JHw7+M519W8zovvA56/PnIcR47jVKU4YKqh/yMqKh591BA/t4RasRB0Bb6b5jQEXYKv/Locp6NRORr1Zdu1IjYS/v6faWgKugTfxTy+UYf97RwAMAGEAgDAIhQAABahAACwCAUAgFX53EfFkXNLmCXCPTJHkuqOvRV0Cb6qy53yZ8Ox+LklxEwERubUHf+foEvwndd9INy9GQAwIYQCAMAiFAAAFqEAALAIBQCAxZXXIEn6V2NX0CX4Kjfqz/UPYqWCYqVwzw1k4uF/mzjU1Bl0Cb7LFb3tAxwpAAAsQgEAYBEKAACLUAAAWJM+gzR2Fc9cLle1YmpWBKa5yJ0K94ViTv27n3q8+uxFRan/m7qQT2MjKXcm3IMFJO/7wKRDYWxnmHvlf092E8All8vllMlkqrIdSZrzXwsr3hZwKV1sH4iZSX50KpVKOnz4sFKplGKxS3Nh77GLpQ8MDIT2YulRaKN06dtpjFEul9OsWbMUj1f+rSn93z9RaGcQbfS6D0z6SCEej6ujo2OyD69IOp0ObWcZE4U2Spe2ndU4QhhD//dfFNp5qdvoZR/gRDMAwCIUAADWlAoFx3G0du1aOY4TdCm+iUIbpei0s5qi8ppFoZ213MZJn2gGAITPlDpSAAD4i1AAAFiEAgDAIhQAABahAACwCAUAgEUoAAAsQgEAYP0fFL2r+ylNNmIAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, axs = plt.subplots(1, 2, figsize=(4, 2))\n", + "\n", + "for ax, l, h in zip(axs, ['Actual', 'Estimated'], [actual_forces, pred_forces]):\n", + " ax.matshow(h, vmin=-0.05, vmax=0.05, aspect='auto', cmap='RdBu')\n", + "\n", + " ax.set_xticklabels([])\n", + " ax.set_yticklabels([])\n", + " \n", + " ax.set_title(l, fontsize=10)\n", + "\n", + "fig.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "id": "46a0f2f8-f863-4de3-bd97-97ebd92676d4", + "metadata": {}, + "source": [ + "Get the mean Hessian" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "00a10907-667a-413c-851d-d47f0eff092b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 47.7 s, sys: 33.7 ms, total: 47.7 s\n", + "Wall time: 47.7 s\n" + ] + } + ], + "source": [ + "%%time\n", + "approx_hessian = model.mean_hessian(hess_model)" + ] + }, + { + "cell_type": "markdown", + "id": "f4de2e78-00c2-427f-b9bd-eb3ca881564b", + "metadata": {}, + "source": [ + "Compare to exact answer" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "d48893fd-df0d-4fa8-bfbe-0d04b71fbf1a", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[1.96495560e+01, 2.28518485e+01, 1.08009177e-03],\n", + " [2.28518485e+01, 8.36964299e+01, 3.94902961e-03],\n", + " [1.08009177e-03, 3.94902961e-03, 4.15881408e+00]])" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "exact_hess[:3, :3]" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "9b311dea-5744-4211-81cb-40aa1183301e", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[1.90039475e+01, 2.26533979e+01, 2.14746706e-02],\n", + " [2.26533979e+01, 8.32193023e+01, 1.32040896e-02],\n", + " [2.14746706e-02, 1.32040896e-02, 3.78143259e+00]])" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "approx_hessian[:3, :3]" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "addd7bef-854a-4b9f-96e9-2aa01b652495", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, axs = plt.subplots(1, 2, figsize=(4, 2))\n", + "\n", + "for ax, l, h in zip(axs, ['Exact', 'Approximate'], [exact_hess, approx_hessian]):\n", + " ax.matshow(h, vmin=-100, vmax=100, cmap='RdBu')\n", + "\n", + " ax.set_xticklabels([])\n", + " ax.set_yticklabels([])\n", + " \n", + " ax.set_title(l, fontsize=10)\n", + "\n", + "fig.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "id": "b516bb4e-d27d-4ad6-ad4b-b873c81670ff", + "metadata": {}, + "source": [ + "Get the zero point energy" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "abbbbfd6-7d17-4b93-880a-3352903b56c4", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "approx_vibs = VibrationsData.from_2d(data[0], approx_hessian)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "fdd80af3-8c18-40d8-b971-4a473bc91498", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "5.132369908274389" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "approx_vibs.get_zero_point_energy()" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "6b1af348-4bc9-4ced-9a12-44b3e49abe9c", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "5.5067174465850215" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "exact_zpe" + ] + }, + { + "cell_type": "markdown", + "id": "ab6a6645-bf0e-4ed7-874e-6a345063e0b5", + "metadata": {}, + "source": [ + "The two differ, but I'm not sure how important the difference is." + ] + }, + { + "cell_type": "markdown", + "id": "29a44b3d-cd3e-44af-9bc2-3e0164b22a38", + "metadata": {}, + "source": [ + "Save it to disk" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "40fd3d44-df72-4b9d-b7b0-f09fabe74c0d", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "with open(f'data/approx/{run_name}_d={step_size:.2e}_mbtr.json', 'w') as fp:\n", + " approx_vibs.write(fp)" + ] + }, + { + "cell_type": "markdown", + "id": "6489882c-acaf-4a07-bbe9-d643f7c5c882", + "metadata": {}, + "source": [ + "## Plot as a Function of Data\n", + "See what happens as we add more data to the training" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "bce41a81-6c88-4b0c-9d8d-0891d1832fd6", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Plotting at 16 steps: 5, 101, 198, 295, 392, ...\n" + ] + } + ], + "source": [ + "steps = np.linspace(5, len(data), 16, dtype=int)\n", + "print(f'Plotting at {len(steps)} steps: {\", \".join(map(str, steps[:5]))}, ...')" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "fe39ce86-1806-4367-8c86-e3ef58f81f84", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 16/16 [15:54<00:00, 59.66s/it]\n" + ] + } + ], + "source": [ + "zpes = []\n", + "for count in tqdm(steps):\n", + " with warnings.catch_warnings():\n", + " warnings.simplefilter(\"ignore\")\n", + " hess_model = model.train(data[:count])\n", + " \n", + " approx_hessian = model.mean_hessian(hess_model)\n", + " approx_vibs = VibrationsData.from_2d(data[0], approx_hessian)\n", + " zpes.append(approx_vibs.get_zero_point_energy())" + ] + }, + { + "cell_type": "markdown", + "id": "c179c3ae-695f-44ad-b548-10002c4ff30b", + "metadata": {}, + "source": [ + "Plot it" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "1c6706a9-a27f-448f-81d4-957939bb2ca8", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(3.5, 2))\n", + "\n", + "ax.plot(steps, zpes)\n", + "\n", + "ax.set_xlim([0, steps.max()])\n", + "ax.plot(ax.get_xlim(), [exact_zpe]*2, 'k--')\n", + "\n", + "ax.set_xlabel('Energies')\n", + "ax.set_ylabel('ZPE (eV)')\n", + "\n", + "fig.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "id": "e8788f74-c208-4939-aa9b-3bbdfd8310ee", + "metadata": {}, + "source": [ + "We consistently underestimate the ZPE. Is it because we have too few oscillators?" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "384af4b3-5eb3-4eac-b176-160f19944853", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.17" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index 5189da8..03604e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ authors = [ ] description = 'Faster Hessians through machine learning' readme = "README.md" -requires-python = ">=3.10" +requires-python = ">=3.9" license = { file = "LICENSE" } keywords = ["HPC", "AI", "Workflows", "Quantum Chemistry", "Chemical Engineering"] classifiers = [ diff --git a/tests/models/test_mbtr.py b/tests/models/test_mbtr.py new file mode 100644 index 0000000..c500156 --- /dev/null +++ b/tests/models/test_mbtr.py @@ -0,0 +1,31 @@ +"""Test for a MBTR-based energy model""" +import numpy as np + +from jitterbug.model.mbtr import MBTRCalculator, MBTREnergyModel + + +def test_model(train_set): + # Create then fit the model + calc = MBTRCalculator() + calc.train(train_set) + + # Predict the energy (we should be close!) + test_atoms = train_set[0].copy() + test_atoms.calc = calc + energy = test_atoms.get_potential_energy() + assert np.isclose(energy, train_set[0].get_potential_energy()) + + # See if force calculation works + forces = test_atoms.get_forces() + assert forces.shape == (3, 3) # At least make sure we get the right shape, values are iffy + + +def test_hessian(train_set): + """See if we can compute the Hessian""" + calc = MBTRCalculator() + model = MBTREnergyModel(calc, train_set[0]) + + # Run the fitting + hess_model = model.train(train_set) + hess = model.mean_hessian(hess_model) + assert hess.shape == (9, 9)