Skip to content

Commit

Permalink
Add function which returns the MVN (#37)
Browse files Browse the repository at this point in the history
  • Loading branch information
WardLT authored Dec 13, 2023
1 parent cdc5087 commit 35a23ac
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 26 deletions.
46 changes: 31 additions & 15 deletions jitterbug/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pathlib import Path
import os

import ase
import numpy as np
from scipy import stats
from ase import Atoms
Expand All @@ -15,6 +16,9 @@
class EnergyModel:
"""Base class for functions which predict energy given molecular structure"""

reference: ase.Atoms
"""Structure for which we will be computing Hessians"""

def train(self, data: list[Atoms]) -> object:
"""Produce an energy model given observations of energies
Expand Down Expand Up @@ -43,6 +47,31 @@ def sample_hessians(self, model: object, num_samples: int) -> list[np.ndarray]:
Returns:
A list of 2D hessians
"""
n_params = len(self.reference) * 3
dist = self.get_hessian_distribution(model)
diag_ind = np.diag_indices(n_params)
triu_ind = np.triu_indices(n_params)
output = []
for sample in dist.rvs(size=num_samples):
# Fill in a 2D version
sample_hessian = np.zeros((n_params, n_params))
sample_hessian[triu_ind] = sample

# Make it symmetric
sample_hessian += sample_hessian.T
sample_hessian[diag_ind] /= 2

output.append(sample_hessian)
return output

def get_hessian_distribution(self, model: object) -> stats.multivariate_normal:
"""Get a multi-variate normal distribution of the independent parameters for the hessian
Args:
model: Model trained by this class
Returns:
A MVN distribution for the upper triangle of the Hessian matrix (in row-major order)
"""
raise NotImplementedError()


Expand Down Expand Up @@ -105,7 +134,7 @@ def mean_hessian(self, models: list[Calculator]) -> np.ndarray:
# Return the mean
return np.mean(hessians, axis=0)

def sample_hessians(self, models: list[Calculator], num_samples: int) -> list[np.ndarray]:
def get_hessian_distribution(self, models: list[Calculator]) -> list[np.ndarray]:
# Run all calculators
hessians = [self.compute_hessian(self.reference, calc) for calc in models]

Expand All @@ -116,17 +145,4 @@ def sample_hessians(self, models: list[Calculator], num_samples: int) -> list[np
hessian_covar = np.cov(hessians_flat, rowvar=False)

# Generate samples
hessian_mvn = stats.multivariate_normal(hessian_mean, hessian_covar, allow_singular=True)
diag_indices = np.diag_indices(hessians[0].shape[0])
output = []
for sample in hessian_mvn.rvs(size=num_samples):
# Fill in a 2D version
sample_hessian = np.zeros_like(hessians[0])
sample_hessian[triu_ind] = sample

# Make it symmetric
sample_hessian += sample_hessian.T
sample_hessian[diag_indices] /= 2

output.append(sample_hessian)
return output
return stats.multivariate_normal(hessian_mean, hessian_covar, allow_singular=True)
15 changes: 4 additions & 11 deletions jitterbug/model/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
"""
import numpy as np
from ase import Atoms
from scipy.stats import multivariate_normal
from sklearn.linear_model import ARDRegression
from sklearn.linear_model._base import LinearModel

Expand Down Expand Up @@ -83,7 +84,7 @@ def train(self, data: list[Atoms]) -> LinearModel:
def mean_hessian(self, model: LinearModel) -> np.ndarray:
return self._params_to_hessian(model.coef_)

def sample_hessians(self, model: LinearModel, num_samples: int) -> list[np.ndarray]:
def get_hessian_distribution(self, model: LinearModel) -> multivariate_normal:
# Get the covariance matrix
if not hasattr(model, 'sigma_'): # pragma: no-coverage
raise ValueError(f'Sampling only possible with Bayesian regressors. You trained a {type(model)}')
Expand All @@ -101,14 +102,8 @@ def sample_hessians(self, model: LinearModel, num_samples: int) -> list[np.ndarr
sigma = model.sigma_

# Sample the model parameters
params = np.random.multivariate_normal(model.coef_, sigma, size=num_samples)

# Assemble them into Hessians
output = []
for param in params:
hessian = self._params_to_hessian(param)
output.append(hessian)
return output
n_coords = len(self.reference) * 3
return multivariate_normal(model.coef_[n_coords:], sigma[n_coords:, n_coords:], allow_singular=True)

def _params_to_hessian(self, param: np.ndarray) -> np.ndarray:
"""Convert the parameters for the linear model into a Hessian
Expand All @@ -128,6 +123,4 @@ def _params_to_hessian(self, param: np.ndarray) -> np.ndarray:
hessian[triu_inds] = param[n_coords:] # The first n_coords terms are the linear part
hessian[off_diag_triu_inds] /= 2
hessian.T[triu_inds] = hessian[triu_inds]
# v = np.sqrt(self.reference.get_masses()).repeat(3).reshape(-1, 1)
# hessian /= np.dot(v, v.T)
return hessian
12 changes: 12 additions & 0 deletions jitterbug/model/linear_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from ase import io as aseio
from geometric.molecule import Molecule
from geometric.internal import DelocalizedInternalCoordinates as DIC
from scipy import stats
from sklearn.linear_model import ARDRegression
from sklearn.linear_model._base import LinearModel
from .base import EnergyModel
Expand Down Expand Up @@ -175,6 +176,17 @@ def sample_hessians(self, model: LinearModel, num_samples: int) -> list[np.ndarr
output.append(hessian)
return output

def get_hessian_distribution(self, model: LinearModel) -> stats.multivariate_normal:
# I'll go back and figure out the correct math for transforming the covariance matrix of the internal hessian
# to that of the cartesian Hessian. We'll just refit
ind = np.triu_indices(len(self.reference) * 3)
samples = [x[ind] for x in self.sample_hessians(model, 128)]
return stats.multivariate_normal(
np.mean(samples, axis=0),
np.cov(samples, rowvar=False),
allow_singular=True
)

def _params_to_hessian(self, param: np.ndarray) -> np.ndarray:
"""Convert the parameters for the linear model into a Hessian
Expand Down
9 changes: 9 additions & 0 deletions tests/models/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ def test_linear_model(train_set, model_type, num_params):
assert len(hessians) == 32
assert np.isclose(hessians[0], hessians[0].T).all()

# Make sure the underlying distribution has the same mean
dist = model.get_hessian_distribution(hessian_model)
ind = np.triu_indices(9)
assert np.isclose(
np.mean(hessians, axis=0)[ind],
dist.mean,
atol=10 # Does not have to be close. Our sampling size is pretty small
).all()

# Only test accuracy with IC harmonic. Other one's trash
if isinstance(model, ICHarmonicModel):
vib_data = VibrationsData.from_2d(reference, hessians[0])
Expand Down

0 comments on commit 35a23ac

Please sign in to comment.