diff --git a/jitterbug/model/dscribe/local.py b/jitterbug/model/dscribe/local.py index 70f4c79..fb786e5 100644 --- a/jitterbug/model/dscribe/local.py +++ b/jitterbug/model/dscribe/local.py @@ -1,6 +1,7 @@ """Create a PyTorch-based model which uses features for each atom""" -from typing import Union, Optional, Callable +from typing import Union, Optional, Callable, Sequence +import ase from ase.calculators.calculator import Calculator, all_changes from dscribe.descriptors.descriptorlocal import DescriptorLocal from torch.utils.data import TensorDataset, DataLoader @@ -22,15 +23,21 @@ class InducedKernelGPR(torch.nn.Module): inducing_x: Starting points for the reference points of the kernel use_ard: Whether to employ a different length scale parameter for each descriptor, a technique known as Automatic Relevance Detection (ARD) + initial_lengthscale: Initial value for the lengthscale parmaeter """ - def __init__(self, inducing_x: torch.Tensor, use_ard: bool): + def __init__(self, inducing_x: torch.Tensor, use_ard: bool, initial_lengthscale: float = 10): super().__init__() n_points, n_desc = inducing_x.shape self.inducing_x = torch.nn.Parameter(inducing_x.clone()) self.alpha = torch.nn.Parameter(torch.empty((n_points,), dtype=inducing_x.dtype)) torch.nn.init.normal_(self.alpha) - self.lengthscales = torch.nn.Parameter(-torch.ones((n_desc,), dtype=inducing_x.dtype) if use_ard else -torch.ones((1,), dtype=inducing_x.dtype)) + + # Initial value + ls = np.log(initial_lengthscale) + self.lengthscales = torch.nn.Parameter( + -torch.ones((n_desc,), dtype=inducing_x.dtype) * ls if use_ard else + -torch.ones((1,), dtype=inducing_x.dtype) * ls) def forward(self, x) -> torch.Tensor: # Compute an RBF kernel @@ -40,42 +47,130 @@ def forward(self, x) -> torch.Tensor: esd = torch.exp(-diff) # Return the sum - return torch.tensordot(self.alpha, esd, dims=([0], [0])) + return torch.tensordot(self.alpha, esd, dims=([0], [0]))[:, None] -def make_gpr_model(train_descriptors: np.ndarray, num_inducing_points: int, use_ard_kernel: bool = False) -> InducedKernelGPR: - """Make the GPR model for a certain atomic system +class PerElementModule(torch.nn.Module): + """Fit a different model for each element + + Args: + models: Map of atomic number to element to use + """ + + def __init__(self, models: dict[int, torch.nn.Module]): + super().__init__() + self.models = torch.nn.ModuleDict( + dict((str(k), v) for k, v in models.items()) + ) + + def forward(self, element: torch.IntTensor, desc: torch.Tensor) -> torch.Tensor: + output = torch.empty_like(desc[:, 0]) + for elem, model in self.models.items(): + elem_id = int(elem) + mask = element == elem_id + output[mask] = model(desc[mask, :])[:, 0] + return output + + +def make_nn_model( + elements: np.ndarray, + descriptors: np.ndarray, + hidden_layers: Sequence[int] = (), + activation: torch.nn.Module = torch.nn.Sigmoid() +) -> PerElementModule: + """Make a neural network model for a certain atomic system Assumes that the descriptors have already been scaled Args: - train_descriptors: 3D array of all training points (num_configurations, num_atoms, num_descriptors) - num_inducing_points: Number of inducing points to use in the kernel. More points, more complex model + elements: Element for each atom in a structure (num_atoms,) + descriptors: 3D array of all training points (num_configurations, num_atoms, num_descriptors) + hidden_layers: Number of units in the hidden layers + activation: Activation function used for the hidden layers + """ + + # Detect the dtype + dtype = torch.from_numpy(descriptors[0, 0, :1]).dtype + + # Make a model for each element type + models: dict[int, torch.nn.Sequential] = {} + element_types = np.unique(elements) + + for element in element_types: + # Make the neural network + nn_layers = [] + input_size = descriptors.shape[2] + for hidden_size in hidden_layers: + nn_layers.extend([ + torch.nn.Linear(input_size, hidden_size, dtype=dtype), + activation + ]) + input_size = hidden_size + + # Make the last layer + nn_layers.append(torch.nn.Linear(input_size, 1, dtype=dtype)) + models[element] = torch.nn.Sequential(*nn_layers) + + return PerElementModule(models) + + +def make_gpr_model(elements: np.ndarray, + descriptors: np.ndarray, + num_inducing_points: int, + fix_inducing_points: bool = True, + use_ard_kernel: bool = False, + **kwargs) -> PerElementModule: + """Make the GPR model for a certain atomic system + + Assumes that the descriptors have already been scaled. + + Passes additional kwargs to the :class:`InducedKernelGPR` constructor. + + Args: + elements: Element for each atom in a structure (num_atoms,) + descriptors: 3D array of all training points (num_configurations, num_atoms, num_descriptors) + num_inducing_points: Number of inducing points to use in the kernel for each model. More points, more complex model + fix_inducing_points: Whether to fix the inducing points or allow them to be learned use_ard_kernel: Whether to use a different length scale parameter for each descriptor Returns: Model which can predict energy given descriptors for a single configuration """ - # Select a set of inducing points randomly - descriptors = np.reshape(train_descriptors, (-1, train_descriptors.shape[-1])) - num_inducing_points = min(num_inducing_points, descriptors.shape[0]) - inducing_inds = np.random.choice(descriptors.shape[0], size=(num_inducing_points,), replace=False) - inducing_points = descriptors[inducing_inds, :] + # Make a model for each element type + models: dict[int, InducedKernelGPR] = {} + element_types = np.unique(elements) + + for element in element_types: + # Select a set of inducing points from records of each atom + # TODO (wardlt): Use a method which ensures diversity, like KMeans + mask = elements == element + masked_descriptors = descriptors[:, mask, :] + masked_descriptors = np.reshape(masked_descriptors, (-1, masked_descriptors.shape[-1])) + num_inducing_points = min(num_inducing_points, masked_descriptors.shape[0]) + inducing_inds = np.random.choice(masked_descriptors.shape[0], size=(num_inducing_points,), replace=False) + inducing_points = masked_descriptors[inducing_inds, :] + + # Make the model + model = InducedKernelGPR( + inducing_x=torch.from_numpy(inducing_points), + use_ard=use_ard_kernel, + **kwargs + ) + model.inducing_x.requires_grad = not fix_inducing_points + models[element] = model - # Make the model - return InducedKernelGPR( - inducing_x=torch.from_numpy(inducing_points), - use_ard=use_ard_kernel, - ) + return PerElementModule(models) -def train_model(model: torch.nn.Module, +def train_model(model: PerElementModule, + train_e: np.ndarray, train_x: np.ndarray, train_y: np.ndarray, steps: int, batch_size: int = 4, learning_rate: float = 0.01, device: Union[str, torch.device] = 'cpu', + patience: Optional[int] = None, verbose: bool = False) -> pd.DataFrame: """Train the kernel model over many iterations @@ -83,21 +178,27 @@ def train_model(model: torch.nn.Module, Args: model: Model to be trained + train_e: Elements for each atom in a configuration (num_atoms,) train_x: 3D array of all training points (num_configurations, num_atoms, num_descriptors) train_y: Energies for each training point steps: Number of interactions over all training points batch_size: Number of conformers per batch learning_rate: Learning rate used for the optimizer device: Which device to use for training + patience: If provided, stop learning if train loss fails to improve after these many iterations verbose: Whether to display a progress bar Returns: Mean loss over each iteration """ - # Convert the data to Torches + # Convert the data to Tensors n_conf, n_atoms = train_x.shape[:2] train_x = torch.from_numpy(train_x) train_y = torch.from_numpy(train_y) + train_e = torch.from_numpy(train_e).to(device) + + # Duplicate the elements per batch size + train_e = train_e.repeat(batch_size) # Make the data loader dataset = TensorDataset(train_x, train_y) @@ -115,7 +216,10 @@ def train_model(model: torch.nn.Module, # Iterate over the data multiple times losses = [] - for _ in tqdm(range(steps), disable=not verbose, leave=False): + iterator = tqdm(range(steps), disable=not verbose, leave=False) + no_improvement = 0 # Number of epochs w/o improvement + best_loss = np.inf + for _ in iterator: epoch_loss = 0 for batch_x, batch_y in loader: # Prepare at the beginning of each batch @@ -125,7 +229,7 @@ def train_model(model: torch.nn.Module, # Predict on all configurations batch_x = torch.reshape(batch_x, (-1, batch_x.shape[-1])) # Flatten from (n_confs, n_atoms, n_desc) -> (n_confs * n_atoms, n_desc) - pred_y_per_atoms_flat = model(batch_x) + pred_y_per_atoms_flat = model(train_e, batch_x) # Get the mean sum for each atom pred_y_per_atoms = torch.reshape(pred_y_per_atoms_flat, (batch_size, n_atoms)) @@ -140,8 +244,16 @@ def train_model(model: torch.nn.Module, opt.step() epoch_loss += batch_loss.item() + # Update the best loss + no_improvement = 0 if epoch_loss < best_loss else no_improvement + 1 + best_loss = min(best_loss, epoch_loss) + iterator.set_description(f'Loss: {epoch_loss:.2e} - Patience: {no_improvement}') losses.append(epoch_loss) + # Break if no improvement + if patience is not None and no_improvement > patience: + break + # Pull the model back off the GPU model.to('cpu') return pd.DataFrame({'loss': losses}) @@ -151,7 +263,7 @@ class DScribeLocalCalculator(Calculator): """Calculator which uses descriptors for each atom and PyTorch to compute energy Keyword Args: - model (torch.nn.Module): A machine learning model which takes descriptors as inputs + model (PerElementModule): A machine learning model which takes atomic numbers and descriptors as inputs desc (DescriptorLocal): Tool used to compute the descriptors desc_scaling (tuple[np.ndarray, np.ndarray]): A offset and factor with which to adjust the energy per atom predictions, which are typically he mean and standard deviation of energy per atom across the training set. @@ -160,6 +272,8 @@ class DScribeLocalCalculator(Calculator): device (str | torch.device): Device to use for inference """ + # TODO (wardlt): Have scaling for descriptors and energies be per-element + implemented_properties = ['energy', 'forces', 'energies'] default_parameters = { 'model': None, @@ -169,32 +283,38 @@ class DScribeLocalCalculator(Calculator): 'device': 'cpu' } - def calculate(self, atoms=None, properties=('energy', 'forces', 'energies'), + def calculate(self, atoms: ase.Atoms = None, properties=('energy', 'forces', 'energies'), system_changes=all_changes): # Compute the descriptors for the atoms d_desc_d_pos, desc = self.parameters['desc'].derivatives(atoms, attach=True) # Scale the descriptors - offset, scale = self.parameters['desc_scaling'] - desc = (desc - offset) / scale - d_desc_d_pos /= scale + desc_offset, desc_scale = self.parameters['desc_scaling'] + desc = (desc - desc_offset) / desc_scale + d_desc_d_pos /= desc_scale # Convert to pytorch - desc = torch.from_numpy(desc.astype(np.float32)) + # TODO (wardlt): Make it possible to convert to float32 or lower + desc = torch.from_numpy(desc) desc.requires_grad = True - d_desc_d_pos = torch.from_numpy(d_desc_d_pos.astype(np.float32)) + d_desc_d_pos = torch.from_numpy(d_desc_d_pos) - # Move the model to device if need be - model: torch.nn.Module = self.parameters['model'] + # Make sure the model has all the required elements + model: PerElementModule = self.parameters['model'] + missing_elems = set(map(str, atoms.get_atomic_numbers())).difference(model.models.keys()) + if len(missing_elems) > 0: + raise ValueError(f'Model lacks parameters for elements: {", ".join(missing_elems)}') device = self.parameters['device'] model.to(device) # Run inference - offset, scale = self.parameters['energy_scaling'] + eng_offset, eng_scale = self.parameters['energy_scaling'] + elements = torch.from_numpy(atoms.get_atomic_numbers()) model.eval() # Ensure we're in eval mode + elements = elements.to(device) desc = desc.to(device) - pred_energies_dist = model(desc) - pred_energies = pred_energies_dist * scale + offset + pred_energies_dist = model(elements, desc) + pred_energies = pred_energies_dist * eng_scale + eng_offset pred_energy = torch.sum(pred_energies) self.results['energy'] = pred_energy.item() self.results['energies'] = pred_energies.detach().cpu().numpy() @@ -205,14 +325,19 @@ def calculate(self, atoms=None, properties=('energy', 'forces', 'energies'), # Derivatives for the descriptors are for each center (which is the input to the model) with respect to each atomic coordinate changing. # Energy is summed over the contributions from each center. # The total force is therefore a sum over the effect of an atom moving on all centers + # Note: Forces are scaled because pred_energy was scaled d_energy_d_desc = torch.autograd.grad( outputs=pred_energy, inputs=desc, grad_outputs=torch.ones_like(pred_energy), )[0] # Derivative of the energy with respect to the descriptors for each center d_desc_d_pos = d_desc_d_pos.to(device) - d_energy_d_center_d_pos = torch.einsum('ijkl,il->ijk', d_desc_d_pos, d_energy_d_desc) # Derivative for each center with respect to each atom - pred_forces = -d_energy_d_center_d_pos.sum(dim=0) * scale # Total effect on each center from each atom + + # Einsum is log-form for: dE_d(center:i from atom:j moving direction:k) + # = sum_(descriptors:l) d(descriptor:l)/d(i,j,k) * dE(center:i)/d(l) + # "Use the chain rule to get the change in energy for each center + d_energy_d_center_d_pos = torch.einsum('ijkl,il->ijk', d_desc_d_pos, d_energy_d_desc) + pred_forces = -d_energy_d_center_d_pos.sum(dim=0) # Total effect on each center from each atom # Store the results self.results['forces'] = pred_forces.detach().cpu().numpy() @@ -238,7 +363,7 @@ class DScribeLocalEnergyModel(ASEEnergyModel): def __init__(self, reference: Atoms, descriptors: DescriptorLocal, - model_fn: Callable[[np.ndarray], torch.nn.Module], + model_fn: Callable[[np.ndarray], PerElementModule], num_calculators: int, device: str = 'cpu', train_options: Optional[dict] = None): @@ -249,26 +374,29 @@ def __init__(self, self.train_options = train_options or {'steps': 4} def train_calculator(self, data: list[Atoms]) -> Calculator: - # Train it using the user-provided options + # Get the elements + elements = data[0].get_atomic_numbers() + + # Prepare the training set, scaling the input train_x = self.descriptors.create(data) offset_x = train_x.mean(axis=(0, 1)) - scale_x = np.clip(train_x.std(axis=(0, 1)), a_min=1e-6, a_max=None) train_x -= offset_x + scale_x = np.clip(train_x.std(axis=(0, 1)), a_min=1e-6, a_max=None) train_x /= scale_x - train_y = np.array([a.get_potential_energy() for a in data]) - scale_y, offset_y = np.std(train_y), np.mean(train_y) - train_y = (train_y - offset_y) / scale_y + train_y_per_atom = np.array([a.get_potential_energy() / len(a) for a in data]) + scale, offset = train_y_per_atom.std(), train_y_per_atom.mean() + train_y = np.array([(a.get_potential_energy() - len(a) * offset) / scale for a in data]) - # Make then train the model + # Make the model and train it model = self.model_fn(train_x) - train_model(model, train_x, train_y, device=self.device, **self.train_options) + train_model(model, elements, train_x, train_y, device=self.device, **self.train_options) - # Make the calculator + # Return the model return DScribeLocalCalculator( model=model, desc=self.descriptors, desc_scaling=(offset_x, scale_x), - energy_scaling=(offset_y, scale_y), + energy_scaling=(offset, scale), device=self.device ) diff --git a/notebooks/2_testing-fitting-strategies/2_fit-forcefield-using-soap.ipynb b/notebooks/2_testing-fitting-strategies/2_fit-forcefield-using-soap.ipynb index 3f9011e..fd8c273 100644 --- a/notebooks/2_testing-fitting-strategies/2_fit-forcefield-using-soap.ipynb +++ b/notebooks/2_testing-fitting-strategies/2_fit-forcefield-using-soap.ipynb @@ -56,13 +56,19 @@ }, "outputs": [], "source": [ - "db_path = '../1_explore-sampling-methods/data/along-axes/caffeine_pm7_None_d=5.00e-03-N=2.db'\n", - "device = 'cuda'" + "db_path = '../1_explore-sampling-methods/data/simple-uniform/caffeine_pm7_None_at_pm7_None_d=5.00e-03.db'\n", + "device = 'cuda'\n", + "overwrite = True\n", + "inducing_points = 256\n", + "l_max = 3\n", + "n_max = 1\n", + "cutoff = 6\n", + "initial_lengthscale = 100" ] }, { "cell_type": "markdown", - "id": "8505d400-8427-45b9-b626-3f9ca557d0c8", + "id": "5271562d-729e-4de2-9ed9-2037e6e77b8e", "metadata": {}, "source": [ "Derived" @@ -71,482 +77,225 @@ { "cell_type": "code", "execution_count": null, - "id": "a8be3c37-bf1f-4ba4-ba8f-afff6d6bed7d", + "id": "89895ee3-ba4c-4ff6-8868-a13fc49b61ff", "metadata": { "tags": [] }, "outputs": [], "source": [ "run_name, sampling_options = Path(db_path).name[:-3].rsplit(\"_\", 1)\n", - "exact_path = Path('../data/exact/') / f'{run_name}-ase.json'\n", + "exact_path = Path('../0_create-test-set/data/exact/') / f'{run_name}_d=0.01-ase.json'\n", "sampling_name = Path(db_path).parent.name\n", - "out_name = '_'.join([run_name, sampling_name, sampling_options])" + "out_name = '_'.join([run_name, sampling_name, sampling_options])\n", + "out_dir = Path('data/soap/')" ] }, { "cell_type": "markdown", - "id": "de1f6aac-b93e-45a7-98e6-ffd5205916a6", + "id": "c6cdc7a6-a421-4b1c-acb1-ce65ae2fcef6", "metadata": {}, "source": [ - "## Read in the Data\n", - "Get all computations for the desired calculation and the exact solution" + "Skip if done" ] }, { "cell_type": "code", "execution_count": null, - "id": "797b96d8-050c-4bdf-9815-586cfb5bc311", + "id": "8ad5420a-ba73-4266-abfb-8cccf5f322fd", "metadata": { "tags": [] }, "outputs": [], "source": [ - "with connect(db_path) as db:\n", - " data = [a.toatoms() for a in db.select('')]\n", - "print(f'Loaded {len(data)} structures')" + "if (out_dir / f'{out_name}-full.json').exists() and not overwrite:\n", + " raise ValueError('Already done!')" ] }, { "cell_type": "markdown", - "id": "3fa7d5d6-f9ee-431f-b16b-dcc556cdeb49", + "id": "de1f6aac-b93e-45a7-98e6-ffd5205916a6", "metadata": {}, "source": [ - "Read in the exact Hessian" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7389208d-9323-492c-8fc5-d05a372206c6", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "with open(exact_path) as fp:\n", - " exact_vibs = VibrationsData.read(fp)" + "## Read in the Data\n", + "Get all computations for the desired calculation and the exact solution" ] }, { "cell_type": "code", "execution_count": null, - "id": "a9965595-532c-4067-ba24-7620bd977007", + "id": "797b96d8-050c-4bdf-9815-586cfb5bc311", "metadata": { "tags": [] }, "outputs": [], "source": [ - "exact_hess = exact_vibs.get_hessian_2d()\n", - "exact_zpe = exact_vibs.get_zero_point_energy()" + "with connect(db_path) as db:\n", + " data = [a.toatoms() for a in db.select('')]\n", + "print(f'Loaded {len(data)} structures')" ] }, { "cell_type": "markdown", - "id": "52d04ec1-6ecc-458a-a580-79de2c327c09", + "id": "3fa7d5d6-f9ee-431f-b16b-dcc556cdeb49", "metadata": {}, "source": [ - "## Start by Adjusting Hyperparameters\n", - "There are many layers of things we can adjust with SOAP, including\n", - "- The descriptors which are used. SOAP has at least 3 main adjustable parameters\n", - "- The complexity of the GPR model, which is mainly varied by the number of inducing points (more points -> more complexity)\n", - "- How the model is trained: E.g., batch size, learning rate\n", - "\n", - "Here, we adjust them for our particular problem and start with the descriptors. \n", - "\n", - "We'll start from a reasonable guess for all then tweak each" + "Read in the exact Hessian" ] }, { "cell_type": "code", "execution_count": null, - "id": "b3abeb98-ad43-4411-9b70-b86d28dcf0f4", + "id": "7389208d-9323-492c-8fc5-d05a372206c6", "metadata": { "tags": [] }, "outputs": [], "source": [ - "train_data, test_data = train_test_split(data, test_size=0.1)" - ] - }, - { - "cell_type": "markdown", - "id": "bbe4599a-3928-4420-9156-a4ee66adfc5b", - "metadata": {}, - "source": [ - "Get a baseline score" + "with open(exact_path) as fp:\n", + " exact_vibs = VibrationsData.read(fp)" ] }, { "cell_type": "code", "execution_count": null, - "id": "26c0f2c0-58fe-4ad8-8f99-22e29e2ef9a2", + "id": "a9965595-532c-4067-ba24-7620bd977007", "metadata": { "tags": [] }, "outputs": [], "source": [ - "test_y = np.array([a.get_potential_energy() for a in test_data])\n", - "baseline_y = np.abs(test_y - test_y.mean()).mean()\n", - "print(f'Baseline score: {baseline_y*1000:.2e} meV')" + "exact_hess = exact_vibs.get_hessian_2d()\n", + "exact_zpe = exact_vibs.get_zero_point_energy()" ] }, { "cell_type": "markdown", - "id": "a4989393-60cc-4cd1-b97d-1291a4cd6083", + "id": "04c60da8-4a1d-4ae3-b45d-b77e71fd598f", "metadata": {}, "source": [ - "Make a model testing function" + "## Fit a Hessian with All Data\n", + "Fit a model with the parameters tuned above" ] }, { "cell_type": "code", "execution_count": null, - "id": "da82d49d-62ad-4219-bbbc-e37ec9c0fba4", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "def test_soap_model(train_data: list[ase.Atoms],\n", - " test_data: list[ase.Atoms],\n", - " soap: SOAP,\n", - " num_inducing_points: int,\n", - " train_steps: int,\n", - " batch_size: int,\n", - " learning_rate: float,\n", - " fit_inducing_points: bool = False,\n", - " device: str = 'cpu',\n", - " verbose: bool = False):\n", - " \"\"\"Train a model then evaluate on a test set\n", - " \n", - " Args:\n", - " train_data: Training data\n", - " test_data: Test data\n", - " soap: SOAP descriptor computer\n", - " train_steps: Number of training steps\n", - " batch_size: Batch size to use for training\n", - " learning_rate: Learning rate to use for training\n", - " fit_inducing_points: Whether to fit inducing points during training\n", - " device: Device used for training\n", - " verbose: Whether to display progress bar\n", - " Returns:\n", - " - Training curve\n", - " - Predictions on each entry in the test set\n", - " - MAE on the test set\n", - " \"\"\"\n", - " \n", - " # Prepare the training set, scaling the input\n", - " train_x = soap.create(train_data).astype(np.float32)\n", - " offset_x = train_x.mean(axis=(0, 1))\n", - " train_x -= offset_x\n", - " scale_x = np.clip(train_x.std(axis=(0, 1)), a_min=1e-6, a_max=None)\n", - " train_x /= scale_x\n", - " \n", - " train_y_per_atom = np.array([a.get_potential_energy() / len(a) for a in train_data])\n", - " scale, offset = train_y_per_atom.std(), train_y_per_atom.mean()\n", - " train_y = np.array([(a.get_potential_energy() - len(a) * offset) / scale for a in train_data]).astype(np.float32)\n", - " \n", - " # Make the model and train it\n", - " model = make_gpr_model(train_x, num_inducing_points=num_inducing_points, use_ard_kernel=True)\n", - " model.inducing_x.requires_grad = fit_inducing_points\n", - " log = train_model(model, train_x, train_y, steps=train_steps, batch_size=batch_size, verbose=verbose, learning_rate=learning_rate, device=device)\n", - " \n", - " # Run it on the test set\n", - " calc = DScribeLocalCalculator(model=model, desc=soap, desc_scaling=(offset_x, scale_x), energy_scaling=(offset, scale), device=device)\n", - " test_preds = {'true': [], 'pred': []}\n", - " for atoms in test_data:\n", - " test_preds['true'].append(atoms.get_potential_energy())\n", - " atoms = atoms.copy()\n", - " test_preds['pred'].append(calc.get_potential_energy(atoms))\n", - " \n", - " # Get the MAE\n", - " preds = pd.DataFrame(test_preds)\n", - " mae = (preds['true'] - preds['pred']).abs().mean()\n", - " return log, preds, mae" - ] - }, - { - "cell_type": "markdown", - "id": "a72ba9ca-5776-478f-87a0-797cb3289cf6", + "id": "6015e696-d316-480e-b669-3b429fa146a4", "metadata": {}, - "source": [ - "Determine a good cutoff" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "61ac83ed-585c-409c-8ca9-a8368dd81fa9", - "metadata": { - "tags": [] - }, "outputs": [], "source": [ - "species = ['C', 'O', 'N', 'H']\n", - "n_max = 4\n", - "l_max = 4\n", - "cutoffs = np.arange(3., 6.01, 1)\n", - "inducing_points = 64\n", - "train_steps = 8\n", - "test_scores = []\n", - "for cutoff in tqdm(cutoffs):\n", - " soap = SOAP(\n", - " species=species,\n", - " n_max=n_max,\n", - " l_max=l_max,\n", - " periodic=False,\n", - " r_cut=cutoff\n", - " )\n", - " log, preds, mae = test_soap_model(train_data, test_data, soap, inducing_points, train_steps=train_steps, batch_size=2, learning_rate=0.01, device=device)\n", - " test_scores.append(mae)" + "soap = SOAP(\n", + " species=list(set(data[0].get_chemical_symbols())),\n", + " n_max=n_max,\n", + " l_max=l_max,\n", + " periodic=False,\n", + " r_cut=cutoff\n", + ")" ] }, { "cell_type": "code", "execution_count": null, - "id": "61393e22-7343-4064-bb3e-d92988fdfd31", + "id": "a29c67ad-dc76-4bfb-94f0-d567a3544a9f", "metadata": { "tags": [] }, "outputs": [], "source": [ - "cutoff = cutoffs[np.argmin(test_scores)]\n", - "print(f'Selected a maximum distance of {cutoff:.2f} A. Best score: {min(test_scores)*1000:.2e} meV')" + "model = DScribeLocalEnergyModel(\n", + " reference=data[0],\n", + " model_fn=lambda x: make_gpr_model(data[0].get_atomic_numbers(), x, \n", + " num_inducing_points=inducing_points,\n", + " fix_inducing_points=True,\n", + " use_ard_kernel=True,\n", + " initial_lengthscale=initial_lengthscale),\n", + " descriptors=soap,\n", + " num_calculators=1,\n", + " device=device,\n", + " train_options=dict(steps=1024, batch_size=128, learning_rate=0.01, patience=128, verbose=True),\n", + ")" ] }, { "cell_type": "markdown", - "id": "9633efde-b487-4976-a3a2-5a33a10127ce", + "id": "503240dd-b52c-4111-a024-ec44766940e5", "metadata": {}, "source": [ - "Determine a good descriptor complexity. We are going to optimize $n$ and $l$ together for simplicty, but they do describe very different types of orbitals" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "f8f3b5aa-012c-4a90-ae4f-1e84edfa17c2", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "nl_maxes = range(1, 7)\n", - "test_scores = []\n", - "for nl_max in tqdm(nl_maxes):\n", - " soap = SOAP(\n", - " species=species,\n", - " n_max=nl_max,\n", - " l_max=nl_max,\n", - " periodic=False,\n", - " r_cut=cutoff\n", - " )\n", - " log, preds, mae = test_soap_model(train_data, test_data, soap, inducing_points, train_steps=train_steps, batch_size=2, learning_rate=0.01, device=device)\n", - " test_scores.append(mae)" + "Plot the model performance" ] }, { "cell_type": "code", "execution_count": null, - "id": "80ff45c3-3dfb-4f9d-ba10-03a9b1593a71", + "id": "5749c977-51bf-46e1-a4fc-d22159daf2e5", "metadata": { "tags": [] }, "outputs": [], "source": [ - "nl_max = nl_maxes[np.argmin(test_scores)]\n", - "print(f'Selected a complexity of {nl_max}. Best score: {min(test_scores)*1000:.2e} meV')" + "%%time\n", + "with warnings.catch_warnings():\n", + " warnings.simplefilter(\"ignore\")\n", + " hess_models = model.train(data)" ] }, { "cell_type": "markdown", - "id": "5101fa9c-08f8-4ac5-a488-d9e47a3e14f4", + "id": "ffa60734-b8ae-4d6d-9219-0782b1308b69", "metadata": {}, "source": [ - "Determine a good model complexity, increaseing the number of steps to allow more complex models to train effectively" + "Compare the energies" ] }, { "cell_type": "code", "execution_count": null, - "id": "ff6fb4e5-e77b-419d-993a-0b18aa5d2fe3", - "metadata": { - "tags": [] - }, + "id": "adf93c3e-8d96-4f47-8b72-9718e07ab2a9", + "metadata": {}, "outputs": [], "source": [ - "inducing_pointss = [32, 64, 128, 256, 512]\n", - "train_steps *= 2\n", - "test_scores = []\n", - "for inducing_points in tqdm(inducing_pointss):\n", - " soap = SOAP(\n", - " species=species,\n", - " n_max=nl_max,\n", - " l_max=nl_max,\n", - " periodic=False,\n", - " r_cut=cutoff\n", - " )\n", - " log, preds, mae = test_soap_model(train_data, test_data, soap, inducing_points, train_steps=train_steps, batch_size=2, learning_rate=0.01, device=device)\n", - " test_scores.append(mae)" + "true_e = np.array([a.get_potential_energy() for a in data])" ] }, { "cell_type": "code", "execution_count": null, - "id": "244c59e1-799f-4d91-ba4f-5a4fa69b94dc", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "fig, axs = plt.subplots(1, 2, figsize=(4.5, 2.))\n", - "\n", - "ax = axs[0]\n", - "ax.semilogx(inducing_pointss, np.multiply(test_scores, 1000), '--o')\n", - "ax.set_xlabel('Inducing Points')\n", - "ax.set_ylabel('MAE (meV)')\n", - "\n", - "ax = axs[1]\n", - "ax.semilogy(log['loss'])\n", - "ax.set_xlabel('Epoch')\n", - "ax.set_ylabel('Loss')\n", - "\n", - "fig.tight_layout()" - ] - }, - { - "cell_type": "markdown", - "id": "401d908c-386c-4237-b150-79c90c4bcd01", + "id": "f4f9dc2a-1b55-426c-9a68-db849976ff34", "metadata": {}, - "source": [ - "At least 512 is fine, let's just increase the number of training steps" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8049f3e3-c51d-4b00-9d1d-0c61c6264bf0", - "metadata": { - "tags": [] - }, "outputs": [], "source": [ - "inducing_points = 512\n", - "train_steps = 128" + "pred_e = np.array([hess_models[0].get_potential_energy(a) for a in data])" ] }, { "cell_type": "code", "execution_count": null, - "id": "533109bb-d210-49f5-a97d-8598b4ea7cbc", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "%%time\n", - "log, preds, mae = test_soap_model(train_data, test_data, soap, inducing_points, train_steps=train_steps, batch_size=2, learning_rate=0.01, device=device, verbose=True)" - ] - }, - { - "cell_type": "markdown", - "id": "2a4ea177-b258-412e-8405-08f3e372f345", + "id": "b72187eb-a8a7-49dc-bc2f-3e216e459cf4", "metadata": {}, - "source": [ - "Plot the learning curve of the final model" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3043cc87-30ef-4ebb-98d0-5d2deedf0a25", - "metadata": { - "tags": [] - }, "outputs": [], "source": [ - "print(f'Final MAE: {mae*1000:.2e} meV')" + "mae = np.abs(true_e - pred_e).mean()\n", + "print(f'MAE: {mae * 1000:.2e} meV')" ] }, { "cell_type": "code", "execution_count": null, - "id": "2edfad92-6cc6-48a7-8232-d8741a987363", - "metadata": { - "tags": [] - }, + "id": "4891c836-4655-42b3-9fbe-b424dc4941da", + "metadata": {}, "outputs": [], "source": [ - "fig, axs = plt.subplots(1, 2, figsize=(4.5, 2.))\n", + "fig, ax = plt.subplots(figsize=(3, 3))\n", "\n", - "ax = axs[0]\n", - "ax.semilogy(log['loss'])\n", - "ax.set_xlabel('Epoch')\n", - "ax.set_ylabel('Loss')\n", + "ax.scatter(1000 * (pred_e - true_e.min()), 1000 * (true_e - true_e.min()), s=5, alpha=0.8)\n", "\n", - "\n", - "ax = axs[1]\n", - "ax.scatter((preds['pred'] - preds['true'].min()) * 1000,\n", - " (preds['true'] - preds['true'].min()) * 1000, s=5)\n", - "ax.set_xlabel('Pred (eV)')\n", - "ax.set_ylabel('True (eV)')\n", "ax.set_xlim(ax.get_ylim())\n", "ax.set_ylim(ax.get_ylim())\n", "\n", - "ax.plot(ax.get_xlim(), ax.get_xlim(), 'k--', lw=1)\n", + "ax.plot(ax.get_xlim(), ax.get_xlim(), 'k--')\n", "\n", - "fig.tight_layout()" - ] - }, - { - "cell_type": "markdown", - "id": "04c60da8-4a1d-4ae3-b45d-b77e71fd598f", - "metadata": {}, - "source": [ - "## Fit a Hessian with All Data\n", - "Fit a model with the parameters tuned above" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a29c67ad-dc76-4bfb-94f0-d567a3544a9f", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "model = DScribeLocalEnergyModel(\n", - " reference=data[0],\n", - " model_fn=lambda x: make_gpr_model(x, num_inducing_points=512, use_ard_kernel=True),\n", - " descriptors=soap,\n", - " num_calculators=1,\n", - " device=device,\n", - " train_options=dict(steps=128, batch_size=2, learning_rate=0.01, verbose=True),\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "503240dd-b52c-4111-a024-ec44766940e5", - "metadata": {}, - "source": [ - "Plot the model performance" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "5749c977-51bf-46e1-a4fc-d22159daf2e5", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "%%time\n", - "with warnings.catch_warnings():\n", - " warnings.simplefilter(\"ignore\")\n", - " hess_models = model.train(data)" + "ax.set_xlabel('E, ML (meV)')\n", + "ax.set_ylabel('E, True (meV)')" ] }, { @@ -757,7 +506,7 @@ "source": [ "out_dir = Path('data/soap')\n", "out_dir.mkdir(exist_ok=True, parents=True)\n", - "with open(f'data/soap/{out_name}.json', 'w') as fp:\n", + "with open(f'data/soap/{out_name}-full.json', 'w') as fp:\n", " approx_vibs.write(fp)" ] }, @@ -805,35 +554,20 @@ "outputs": [], "source": [ "zpes = []\n", - "vib_data = []\n", - "for count in tqdm(steps):\n", - " with warnings.catch_warnings():\n", - " warnings.simplefilter(\"ignore\")\n", - " hess_model = model.train(data[:count])\n", - " \n", - " # Compute the approximate Hessian\n", - " approx_hessian = model.mean_hessian(hess_model)\n", - " approx_vibs = VibrationsData.from_2d(data[0], approx_hessian)\n", - " \n", - " # Save a ZPE and the JSON as a summary\n", - " \n", - " zpes.append(approx_vibs.get_zero_point_energy())\n", - " fp = StringIO()\n", - " approx_vibs.write(fp)\n", - " vib_data.append({'count': int(count), 'vib_data': fp.getvalue()})" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4b4563c7-f35c-458b-bb4f-36c466d59cd5", - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "with (out_dir / f'{out_name}-incremental.json').open('w') as fp:\n", - " json.dump(vib_data, fp)" + "with open(out_dir / f'{out_name}-increment.json', 'w') as fp:\n", + " for count in tqdm(steps):\n", + " with warnings.catch_warnings():\n", + " warnings.simplefilter(\"ignore\")\n", + " hess_model = model.train(data[:count])\n", + "\n", + " approx_hessian = model.mean_hessian(hess_model)\n", + " \n", + " # Save the incremental\n", + " print(json.dumps({'count': int(count), 'hessian': approx_hessian.tolist()}), file=fp)\n", + " \n", + " # Compute the ZPE\n", + " approx_vibs = VibrationsData.from_2d(data[0], approx_hessian)\n", + " zpes.append(approx_vibs.get_zero_point_energy())" ] }, { diff --git a/tests/models/test_soap.py b/tests/models/test_soap.py index d920137..338e202 100644 --- a/tests/models/test_soap.py +++ b/tests/models/test_soap.py @@ -1,9 +1,9 @@ import numpy as np import torch from dscribe.descriptors.soap import SOAP -from pytest import mark, fixture +from pytest import fixture -from jitterbug.model.dscribe.local import make_gpr_model, train_model, DScribeLocalCalculator, DScribeLocalEnergyModel +from jitterbug.model.dscribe.local import make_gpr_model, train_model, DScribeLocalCalculator, DScribeLocalEnergyModel, make_nn_model @fixture @@ -22,44 +22,66 @@ def descriptors(train_set, soap): return soap.create(train_set) -@mark.parametrize('use_adr', [True, False]) -def test_make_model(use_adr, descriptors, train_set): - model = make_gpr_model(descriptors, 4, use_ard_kernel=use_adr) +@fixture +def elements(train_set): + return train_set[0].get_atomic_numbers() + + +@fixture(params=['gpr-ard', 'gpr', 'nn']) +def model(elements, descriptors, request): + if request.param == 'gpr': + return make_gpr_model(elements, descriptors, 4, use_ard_kernel=False) + elif request.param == 'gpr-ard': + return make_gpr_model(elements, descriptors, 4, use_ard_kernel=True) + elif request.param == 'nn': + return make_nn_model(elements, descriptors, (16, 16)) + else: + raise NotImplementedError() + +def test_make_model(model, elements, descriptors, train_set): # Evaluate on a single point model.eval() - pred_y = model(torch.from_numpy(descriptors[0, :, :])) + pred_y = model( + torch.from_numpy(elements), + torch.from_numpy(descriptors[0, :, :]) + ) assert pred_y.shape == (3,) # 3 Atoms -@mark.parametrize('use_adr', [True, False]) -def test_train(descriptors, train_set, use_adr): +def test_train(model, elements, descriptors, train_set): # Make the model and the training set train_y = np.array([a.get_potential_energy() for a in train_set]) train_y -= train_y.min() - model = make_gpr_model(descriptors, 4, use_ard_kernel=use_adr) - model.inducing_x.requires_grad = False # Evaluate the untrained model model.eval() - pred_y = model(torch.from_numpy(descriptors.reshape((-1, descriptors.shape[-1])))) + pred_y = model( + torch.from_numpy(np.repeat(elements, descriptors.shape[0])), + torch.from_numpy(descriptors.reshape((-1, descriptors.shape[-1]))) + ) assert pred_y.dtype == torch.float64 + pred_y = torch.reshape(pred_y, [-1, elements.shape[0]]) error_y = pred_y.sum(axis=-1).detach().numpy() - train_y mae_untrained = np.abs(error_y).mean() # Train - losses = train_model(model, descriptors, train_y, 64) - assert len(losses) == 64 + losses = train_model(model, elements, descriptors, train_y, learning_rate=0.001, steps=8) + assert len(losses) == 8 # Run the evaluation model.eval() - pred_y = model(torch.from_numpy(descriptors.reshape((-1, descriptors.shape[-1])))) + pred_y = model( + torch.from_numpy(np.repeat(elements, descriptors.shape[0])), + torch.from_numpy(descriptors.reshape((-1, descriptors.shape[-1]))) + ) + pred_y = torch.reshape(pred_y, [-1, elements.shape[0]]) error_y = pred_y.sum(axis=-1).detach().numpy() - train_y mae_trained = np.abs(error_y).mean() - assert mae_trained < mae_untrained + assert mae_trained < mae_untrained * 1.1 -def test_calculator(descriptors, soap, train_set): +def test_calculator(elements, descriptors, soap, train_set): # Scale the input and outputs train_y = np.array([a.get_potential_energy() for a in train_set]) train_y -= train_y.mean() @@ -69,8 +91,8 @@ def test_calculator(descriptors, soap, train_set): descriptors = (descriptors - offset_x) / scale_x # Assemble and train for a few instances so that we get nonzero forces - model = make_gpr_model(descriptors, 32) - train_model(model, descriptors, train_y, 32) + model = make_gpr_model(elements, descriptors, 32) + train_model(model, elements, descriptors, train_y, 32) # Make the model calc = DScribeLocalCalculator( @@ -84,7 +106,8 @@ def test_calculator(descriptors, soap, train_set): forces = atoms.get_forces() energies.append(atoms.get_potential_energy()) numerical_forces = calc.calculate_numerical_forces(atoms, d=1e-4) - assert np.isclose(forces[:, :2], numerical_forces[:, :2], rtol=5e-1).all() # Make them agree w/i 50% (PES is not smooth) + force_mask = np.abs(numerical_forces) > 0.1 + assert np.isclose(forces[force_mask], numerical_forces[force_mask], rtol=0.1).all() # Agree w/i 10% assert np.std(energies) > 1e-6 @@ -93,13 +116,21 @@ def test_model(soap, train_set): model = DScribeLocalEnergyModel( reference=train_set[0], descriptors=soap, - model_fn=lambda x: make_gpr_model(x, num_inducing_points=32), + model_fn=lambda x: make_gpr_model(train_set[0].get_atomic_numbers(), x, num_inducing_points=32), num_calculators=4, ) # Run the fitting calcs = model.train(train_set) + # Make sure the energy is reasonable + eng = calcs[0].get_potential_energy(train_set[0]) + assert np.isclose(eng, train_set[0].get_potential_energy(), atol=1e-2) + + # Make sure they differ between entries + pred_e = [calcs[0].get_potential_energy(a) for a in train_set] + assert np.std(pred_e) > 1e-3 + # Test the mean hessian function mean_hess = model.mean_hessian(calcs) assert mean_hess.shape == (9, 9), 'Wrong shape'