diff --git a/src/mattersim/__init__.py b/src/mattersim/__init__.py new file mode 100644 index 0000000..e60ed73 --- /dev/null +++ b/src/mattersim/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- +from .__version__ import __version__ # noqa: F401 diff --git a/src/mattersim/__version__.py b/src/mattersim/__version__.py new file mode 100644 index 0000000..20a4a15 --- /dev/null +++ b/src/mattersim/__version__.py @@ -0,0 +1,5 @@ +# -*- coding: utf-8 -*- +import pkg_resources + +# Get the version from setup.py +__version__ = pkg_resources.get_distribution("mattersim").version diff --git a/src/mattersim/applications/moldyn.py b/src/mattersim/applications/moldyn.py new file mode 100644 index 0000000..5d1efdf --- /dev/null +++ b/src/mattersim/applications/moldyn.py @@ -0,0 +1,145 @@ +# -*- coding: utf-8 -*- +from typing import Union + +from ase import Atoms, units +from ase.io import Trajectory +from ase.md.npt import NPT +from ase.md.nvtberendsen import NVTBerendsen +from ase.md.velocitydistribution import ( # noqa: E501 + MaxwellBoltzmannDistribution, + Stationary, +) + + +class MolecularDynamics: + """ + This class is used for Molecular Dynamics. + """ + + SUPPORTED_ENSEMBLE = ["NVT_BERENDSEN", "NVT_NOSE_HOOVER"] + + def __init__( + self, + atoms: Atoms, + ensemble: str = "nvt_nose_hoover", + temperature: float = 300, + timestep: float = 1.0, + taut: Union[float, None] = None, + trajectory: Union[str, Trajectory, None] = None, + logfile: Union[str, None] = "-", + loginterval: int = 10, + append_trajectory: bool = False, + ): + """ + Args: + atoms (Union[Atoms, Structure]): ASE atoms object contains + structure information and calculator. + ensemble (str, optional): Simulation ensemble choosen. Defaults + to nvt_nose_hoover' + temperature (float, optional): Simulation temperature, in Kelvin. + Defaults to 300 K. + timestep (float, optional): The simulation time step, in fs. Defaults + to 1 fs. + taut (float, optional): Characteristic timescale of the thermostat, + in fs. If is None, automatically set it to 1000 * timestep. + trajectory (Union[str, Trajectory], optional): Attach trajectory + object. If trajectory is a string a Trajectory will be constructed. + Defaults to None, which means for no trajectory. + logfile (str, optional): If logfile is a string, a file with that name + will be opened. Defaults to '-', which means output to stdout. + loginterval (int, optional): Only write a log line for every loginterval + time steps. Defaults to 10. + append_trajectory (bool, optional): If False the trajectory file to be + overwriten each time the dynamics is restarted from scratch. If True, + the new structures are appended to the trajectory file instead. + + """ + assert atoms.calc is not None, ( + "Molecular Dynamics simulation only accept " + "ase atoms with an attached calculator" + ) + if ensemble.upper() not in self.SUPPORTED_ENSEMBLE: + raise NotImplementedError( # noqa: E501 + f"Ensemble {ensemble} is not yet supported." + ) + + self.atoms = atoms + + self.ensemble = ensemble.upper() + self._temperature = temperature + self.timestep = timestep + + if taut is None: + taut = 1000 * timestep * units.fs + self.taut = taut + + self._trajectory = trajectory + self.logfile = logfile + self.loginterval = loginterval + self.append_trajectory = append_trajectory + + self._initialize_dynamics() + + def _initialize_dynamics(self): + """ + Initialize the Dynamic ensemble class. + """ + MaxwellBoltzmannDistribution( + self.atoms, temperature_K=self._temperature, force_temp=True + ) + Stationary(self.atoms) + + if self.ensemble == "NVT_BERENDSEN": # noqa: E501 + self.dyn = NVTBerendsen( + self.atoms, + timestep=self.timestep * units.fs, + temperature_K=self._temperature, + taut=self.taut, + trajectory=self._trajectory, + logfile=self.logfile, + loginterval=self.loginterval, + append_trajectory=self.append_trajectory, + ) + elif self.ensemble == "NVT_NOSE_HOOVER": + self.dyn = NPT( + self.atoms, + timestep=self.timestep * units.fs, + temperature_K=self._temperature, + ttime=self.taut, + pfactor=None, + trajectory=self._trajectory, + logfile=self.logfile, + loginterval=self.loginterval, + append_trajectory=self.append_trajectory, + ) + else: + raise NotImplementedError( # noqa: E501 + f"Ensemble {self.ensemble} is not yet supported." + ) + + def run(self, n_steps: int = 1): + """ + Run Molecular Dynamic simulation. + + Args: + n_steps (int, optional): Number of steps to simulations. Defaults to 1. + """ + self.dyn.run(n_steps) + + @property + def temperature(self): + return self._temperature + + @temperature.setter + def temperature(self, temperature: float): + self._temperature = temperature + self._initialize_dynamics() + + @property + def trajectory(self): + return self._trajectory + + @trajectory.setter + def trajectory(self, trajectory: Union[str, Trajectory, None]): + self._trajectory = trajectory + self._initialize_dynamics() diff --git a/src/mattersim/applications/phonon.py b/src/mattersim/applications/phonon.py new file mode 100644 index 0000000..494b8cb --- /dev/null +++ b/src/mattersim/applications/phonon.py @@ -0,0 +1,228 @@ +# -*- coding: utf-8 -*- +import datetime +import os +from typing import Iterable, Union + +import numpy as np +from ase import Atoms +from phonopy import Phonopy +from tqdm import tqdm + +from mattersim.utils.phonon_utils import ( + get_primitive_cell, + to_ase_atoms, + to_phonopy_atoms, +) +from mattersim.utils.supercell_utils import get_supercell_parameters + + +class PhononWorkflow(object): + """ + This class is used to calculate the phonon dispersion relationship of + material using phonopy + """ + + def __init__( + self, + atoms: Atoms, + find_prim: bool = False, + work_dir: str = None, + amplitude: float = 0.01, + supercell_matrix: np.ndarray = None, + qpoints_mesh: np.ndarray = None, + max_atoms: int = None, + ): + """_summary + + Args: + atoms (Atoms): ASE atoms object contains structure information and + calculator. + find_prim (bool, optional): If find the primitive cell and use it + to calculate phonon. Default to False. + work_dir (str, optional): workplace path to contain phonon result. + Defaults to data + chemical_symbols + 'phonon' + amplitude (float, optional): Magnitude of the finite difference to + displace in force constant calculation, in Angstrom. Defaults + to 0.01 Angstrom. + supercell_matrix (nd.array, optional): Supercell matrix for constr + -uct supercell, priority over than max_atoms. Defaults to None. + qpoints_mesh (nd.array, optional): Qpoint mesh for IBZ integral, + priority over than max_atoms. Defaults to None. + max_atoms (int, optional): Maximum atoms number limitation for the + supercell generation. If not set, will automatic generate super + -cell based on symmetry. Defaults to None. + """ + assert ( + atoms.calc is not None + ), "PhononWorkflow only accepts ase atoms with an attached calculator" + if find_prim: + self.atoms = get_primitive_cell(atoms) + self.atoms.calc = atoms.calc + else: + self.atoms = atoms + if work_dir is not None: + self.work_dir = work_dir + else: + current_datetime = datetime.datetime.now() + formatted_datetime = current_datetime.strftime("%Y-%m-%d-%H-%M") + self.work_dir = ( + f"{formatted_datetime}-{atoms.get_chemical_formula()}-phonon" + ) + self.amplitude = amplitude + if supercell_matrix is not None: + if supercell_matrix.shape == (3, 3): + self.supercell_matrix = supercell_matrix + elif supercell_matrix.shape == (3,): + self.supercell_matrix = np.diag(supercell_matrix) + else: + assert ( + False + ), "Supercell matrix must be an array (3,1) or a matrix (3,3)." + else: + self.supercell_matrix = supercell_matrix + + if qpoints_mesh is not None: + assert qpoints_mesh.shape == (3,), "Qpoints mesh must be an array (3,1)." + self.qpoints_mesh = qpoints_mesh + else: + self.qpoints_mesh = qpoints_mesh + + self.max_atoms = max_atoms + + def compute_force_constants(self, atoms: Atoms, nrep_second: np.ndarray): + """ + Calculate force constants + + Args: + atoms (Atoms): ASE atoms object to provide lattice informations. + nrep_second (np.ndarray): Supercell size used for 2nd force + constant calculations. + """ + print(f"Supercell matrix for 2nd force constants : \n{nrep_second}") + # Generate phonopy object + phonon = Phonopy( + to_phonopy_atoms(atoms), + supercell_matrix=nrep_second, + primitive_matrix="auto", + log_level=2, + ) + + # Generate displacements + phonon.generate_displacements(distance=self.amplitude) + + # Compute force constants + second_scs = phonon.supercells_with_displacements + second_force_sets = [] + print("\n") + print("Inferring forces for displaced atoms and computing fcs ...") + for disp_second in tqdm(second_scs): + pa_second = to_ase_atoms(disp_second) + pa_second.calc = self.atoms.calc + second_force_sets.append(pa_second.get_forces()) + + phonon.forces = np.array(second_force_sets) + phonon.produce_force_constants() + phonon.symmetrize_force_constants() + + return phonon + + @staticmethod + def compute_phonon_spectrum_dos( + atoms: Atoms, phonon: Phonopy, k_point_mesh: Union[int, Iterable[int]] + ): + """ + Calculate phonon spectrum and DOS based on force constant matrix in + phonon object + + Args: + atoms (Atoms): ASE atoms object to provide lattice information + phonon (Phonopy): Phonopy object which contains force constants matrix + k_point_mesh (Union[int, Iterable[int]]): The qpoints number in First + Brillouin Zone in three directions for DOS calculation. + """ + print(f"Qpoints mesh for Brillouin Zone integration : {k_point_mesh}") + phonon.run_mesh(k_point_mesh) + print( + "Dispersion relations using phonopy for " + + str(atoms.symbols) + + " ..." + + "\n" + ) + + # plot phonon spectrum + phonon.auto_band_structure(plot=True, write_yaml=True).savefig( + f"{str(atoms.symbols)}_phonon_band.png", dpi=300 + ) + phonon.auto_total_dos(plot=True, write_dat=True).savefig( + f"{str(atoms.symbols)}_phonon_dos.png", dpi=300 + ) + + # Save additional files + phonon.save(settings={"force_constants": True}) + + @staticmethod + def check_imaginary_freq(phonon: Phonopy): + """ + Check whether phonon has imaginary frequency. + + Args: + phonon (Phonopy): Phonopy object which contains phonon spectrum frequency. + """ + band_dict = phonon.get_band_structure_dict() + frequencies = np.concatenate( + [np.array(freq).flatten() for freq in band_dict["frequencies"]], axis=None + ) + has_imaginary = False + if np.all(np.array(frequencies) >= -0.299): + pass + else: + print("Warning! Imaginary frequencies found!") + has_imaginary = True + + return has_imaginary + + def run(self): + """ + The entrypoint to start the workflow. + """ + current_path = os.path.abspath(".") + try: + # check folder exists + if not os.path.exists(self.work_dir): + os.makedirs(self.work_dir) + + os.chdir(self.work_dir) + + try: + # Generate supercell parameters based on optimized structures + nrep_second, k_point_mesh = get_supercell_parameters( + self.atoms, self.supercell_matrix, self.qpoints_mesh, self.max_atoms + ) + except Exception as e: + print("Error whille generating supercell parameters:", e) + raise + + try: + # Calculate 2nd force constants + phonon = self.compute_force_constants(self.atoms, nrep_second) + except Exception as e: + print("Error while computing force constants:", e) + raise + + try: + # Calculate phonon spectrum + self.compute_phonon_spectrum_dos(self.atoms, phonon, k_point_mesh) + # check whether has imaginary frequency + has_imaginary = self.check_imaginary_freq(phonon) + except Exception as e: + print("Error while computing phonon spectrum and dos:", e) + raise + + except Exception as e: + print("An error occurred during the Phonon workflow:", e) + raise + + finally: + os.chdir(current_path) + + return has_imaginary, phonon diff --git a/src/mattersim/applications/relax.py b/src/mattersim/applications/relax.py new file mode 100644 index 0000000..2c79ab2 --- /dev/null +++ b/src/mattersim/applications/relax.py @@ -0,0 +1,162 @@ +# -*- coding: utf-8 -*- +import warnings +from typing import Iterable, List, Tuple, Union + +from ase import Atoms +from ase.constraints import Filter, FixSymmetry +from ase.filters import ExpCellFilter, FrechetCellFilter +from ase.optimize import BFGS, FIRE +from ase.optimize.optimize import Optimizer + + +class Relaxer(object): + """Relaxer is a class for structural relaxation with fixed volume.""" + + SUPPORTED_OPTIMIZERS = {"BFGS": BFGS, "FIRE": FIRE} + SUPPORTED_FILTERS = { + "EXPCELLFILTER": ExpCellFilter, + "FRECHETCELLFILTER": FrechetCellFilter, + } + + def __init__( + self, + optimizer: Union[Optimizer, str] = "FIRE", + filter: Union[Filter, str, None] = None, + constrain_symmetry: bool = True, + fix_axis: Union[bool, Iterable[bool]] = False, + ) -> None: + """ + Args: + optimizer (Union[Optimizer, str]): The optimizer to use. + filter (Union[Filter, str, None]): The filter to use. + constrain_symmetry (bool): Whether to constrain the symmetry. + fix_axis (Union[bool, Iterable[bool]]): Whether to fix the axis. + """ + self.optimizer = ( + self.SUPPORTED_OPTIMIZERS[optimizer.upper()] + if isinstance(optimizer, str) + else optimizer + ) + self.relax_cell = filter is not None + if filter is not None: + self.filter = ( + self.SUPPORTED_FILTERS[filter.upper()] + if isinstance(filter, str) + else filter + ) + self.constrain_symmetry = constrain_symmetry + self.fix_axis = fix_axis + + def relax( + self, + atoms: Atoms, + steps: int = 500, + fmax: float = 0.01, + params_filter: dict = {}, + **kwargs + ) -> Atoms: + """ + Relax the atoms object. + + Args: + atoms (Atoms): The atoms object to relax. + steps (int): The maximum number of steps to take. + fmax (float): The maximum force allowed. + params_filter (dict): The parameters for the filter. + kwargs: Additional keyword arguments for the optimizer. + """ + + if atoms.calc is None: + raise ValueError("Atoms object must have a calculator.") + + if self.constrain_symmetry: + atoms.set_constraint(FixSymmetry(atoms)) + + if self.relax_cell: + # Set the mask for the fixed axis + if isinstance(self.fix_axis, bool): + mask = [not self.fix_axis for i in range(6)] + else: + assert ( + len(self.fix_axis) == 6 + ), "The length of fix_axis list not equal 6." + mask = [not elem for elem in self.fix_axis] + + # check if the scalar_pressure is provided + if ( + "scalar_pressure" in params_filter + and params_filter["scalar_pressure"] > 1 + ): + warnings.warn( + "The scalar_pressure used in ExpCellFilter assumes " + "eV/A^3 unit and 1 eV/A^3 is already 160 GPa. " + "Please make sure you have converted your pressure " + "from GPa to eV/A^3 by dividing by 160.21766208." + ) + ecf = self.filter(atoms, mask=mask, **params_filter) + else: + ecf = atoms + optimizer = self.optimizer(ecf, **kwargs) + optimizer.run(fmax=fmax, steps=steps) + + converged = optimizer.get_number_of_steps() < steps + + if self.constrain_symmetry: + atoms.set_constraint(None) + + return converged, atoms + + @classmethod + def relax_structures( + cls, + atoms: Union[Atoms, Iterable[Atoms]], + optimizer: Union[Optimizer, str] = "FIRE", + filter: Union[Filter, str, None] = None, + constrain_symmetry: bool = False, + fix_axis: Union[bool, Iterable[bool]] = False, + pressure_in_GPa: Union[float, None] = None, + **kwargs + ) -> Union[Tuple[bool, Atoms], Tuple[List[bool], List[Atoms]]]: + """ + Args: + atoms: (Union[Atoms, Iterable[Atoms]]): + The Atoms object or an iterable of Atoms objetcs to relax. + optimizer (Union[Optimizer, str]): The optimizer to use. + filter (Union[Filter, str, None]): The filter to use. + constrain_symmetry (bool): Whether to constrain the symmetry. + fix_axis (Union[bool, Iterable[bool]]): Whether to fix the axis. + **kwargs: Additional keyword arguments for the relax method. + Returns: + converged (Union[bool, List[bool]]): + Whether the relaxation converged or a list of them + Atoms (Union[Atoms, List[Atoms]]): + The relaxed atoms object or a list of them + """ + params_filter = {} + + if filter is None and pressure_in_GPa is None: + pass + elif filter is None and pressure_in_GPa is not None: + filter = "ExpCellFilter" + params_filter["scalar_pressure"] = pressure_in_GPa / 160.21766208 + elif filter is not None and pressure_in_GPa is None: + params_filter["scalar_pressure"] = 0.0 + else: + params_filter["scalar_pressure"] = pressure_in_GPa / 160.21766208 + + relaxer = Relaxer( + optimizer=optimizer, + filter=filter, + constrain_symmetry=constrain_symmetry, + fix_axis=fix_axis, + ) + + if isinstance(atoms, (list, tuple)): + relaxed_results = relaxed_results = [ + relaxer.relax(atom, params_filter=params_filter, **kwargs) + for atom in atoms + ] + converged, relaxed_atoms = zip(*relaxed_results) + return list(converged), list(relaxed_atoms) + else: + return relaxer.relax(atoms, params_filter=params_filter, **kwargs) diff --git a/src/mattersim/datasets/dataset.py b/src/mattersim/datasets/dataset.py new file mode 100644 index 0000000..b003e02 --- /dev/null +++ b/src/mattersim/datasets/dataset.py @@ -0,0 +1,109 @@ +# -*- coding: utf-8 -*- +from functools import lru_cache + +import numpy as np +import torch +from ase import Atoms +from torch_geometric.data import Data + + +@torch.jit.script +def convert_to_single_emb(x, offset: int = 512): + feature_num = x.size(1) if len(x.size()) > 1 else 1 + feature_offset = 1 + torch.arange(0, feature_num * offset, offset, dtype=torch.long) + x = x + feature_offset + return x + + +class AtomCalDataset: + def __init__( + self, + atom_list: list[Atoms], + energies: list[float] = None, + forces: list[np.ndarray] = None, + stresses: list[np.ndarray] = None, + finetune_task_label: list = None, + ): + self.data = self._preprocess( + atom_list, + energies, + forces, + stresses, + finetune_task_label, + ) + + def _preprocess( + self, + atom_list, + energies: list[float] = None, + forces: list[np.ndarray] = None, + stresses: list[np.ndarray] = None, + finetune_task_label: list = None, + use_ase_energy: bool = False, + use_ase_force: bool = False, + use_ase_stress: bool = False, + ): + data_list = [] + for i, (atom, energy, force, stress) in enumerate( + zip(atom_list, energies, forces, stresses) + ): + item_dict = atom.todict() + item_dict["info"] = {} + if energy is None: + energy = 0 + if force is None: + force = np.zeros([len(atom), 3]) + if stress is None: + stress = np.zeros([3, 3]) + try: + energy = atom.get_total_energy() if use_ase_energy else energy + force = ( + atom.get_forces(apply_constraint=False) if use_ase_force else force + ) + stress = atom.get_stress(voigt=False) if use_ase_stress else stress + except Exception as e: + RuntimeError(f"Error in {i}th data: {e}") + + if finetune_task_label is not None: + item_dict["finetune_task_label"] = finetune_task_label[i] + else: + item_dict["finetune_task_label"] = 0 + + item_dict["info"]["energy"] = energy + item_dict["info"]["stress"] = stress # * 160.2176621 + item_dict["forces"] = force + data_list.append(item_dict) + + return data_list + + @lru_cache(maxsize=16) + def __getitem__(self, idx): + item = self.data[idx] + return preprocess_atom_item(item, idx) + + def __len__(self): + return len(self.data) + + +def preprocess_atom_item(item, idx): + # numbers = item.pop("numbers") + numbers = item["numbers"] + item["x"] = torch.tensor(numbers, dtype=torch.long).unsqueeze(-1) + # positions = item.pop("positions") + positions = item["positions"] + item["pos"] = torch.tensor(positions, dtype=torch.float64) + item["cell"] = torch.tensor(item["cell"], dtype=torch.float64) + item["pbc"] = torch.tensor(item["pbc"], dtype=torch.bool) + item["idx"] = idx + item["y"] = torch.tensor([item["finetune_task_label"]]) + item["total_energy"] = torch.tensor([item["info"]["energy"]], dtype=torch.float64) + item["stress"] = torch.tensor(item["info"]["stress"], dtype=torch.float64) + item["forces"] = torch.tensor(item["forces"], dtype=torch.float64) + + item = Data(**item) + + x = item.x + + item.x = convert_to_single_emb(x) + + return item diff --git a/src/mattersim/datasets/utils/build.py b/src/mattersim/datasets/utils/build.py new file mode 100644 index 0000000..a194be2 --- /dev/null +++ b/src/mattersim/datasets/utils/build.py @@ -0,0 +1,360 @@ +# -*- coding: utf-8 -*- +import time +import warnings + +import numpy as np +import torch +from ase import Atoms +from torch.utils.data import DataLoader as DataLoader_torch +from torch_geometric.loader import DataLoader as DataLoader_pyg + +from mattersim.datasets.dataset import AtomCalDataset +from mattersim.datasets.utils.convertor import GraphConvertor + + +def build_dataloader( + atoms: list[Atoms] = None, + energies: list[float] = None, + forces: list[np.ndarray] = None, + stresses: list[np.ndarray] = None, + cutoff: float = 5.0, + threebody_cutoff: float = 4.0, + batch_size: int = 64, + model_type: str = "m3gnet", + shuffle=False, + only_inference: bool = False, + num_workers: int = 0, + pin_memory: bool = False, + multiprocessing: int = 0, + multithreading: int = 0, + dataset=None, + finetune_task_label: list = None, + **kwargs, +): + """ + Build a dataloader given a list of atoms + - atoms : a list of atoms in ase format + - energies, forces and stresses are necessary for training + - energies : a list of energy (float) with unit eV + - forces : a list of nx3 force matrix (np.ndarray) with unit eV/Å, + where n is the number of atom in each structure. + - stresses : a list of 3x3 stress matrix (np.ndarray) with unit GPa + - only_inference : if True, energies, forces and stresses will be ignored + - num_workers : number of workers for dataloader + - pin_memory : if True, the datasets will be stored in GPU or CPU memory + - pin_memory_device : the device for pin_memory + - dataset : the dataset object for the dataloader + only used for graphormer and geomformer + """ + + convertor = GraphConvertor(model_type, cutoff, True, threebody_cutoff) + + preprocessed_data = [] + + if dataset is None: + if not only_inference: + assert ( + energies is not None + ), "energies must be provided if only_inference is False" + if stresses is not None: + assert np.array(stresses[0]).shape == ( + 3, + 3, + ), "stresses must be a list of 3x3 matrices" + + length = len(atoms) + if energies is None: + energies = [None] * length + if forces is None: + forces = [None] * length + if stresses is None: + stresses = [None] * length + + if model_type == "m3gnet": + if multiprocessing == 0 and multithreading == 0: + # start = time.time() + for graph, energy, force, stress in zip(atoms, energies, forces, stresses): + graph = convertor.convert(graph.copy(), energy, force, stress, **kwargs) + if graph is not None: + preprocessed_data.append(graph) + # print("Data preprocessing time: {:.2f} s".format(time.time() - start)) + elif multithreading > 0 and multiprocessing == 0: + from multiprocessing.pool import ThreadPool + + warnings.warn("multithreading is experimental") + warnings.warn("it may not be faster than single thread due to GIL.") + print("Using multithreading with {} threads".format(multithreading)) + start = time.time() + pool = ThreadPool(processes=multithreading) + preprocessed_data = pool.starmap( + convertor.convert, zip(atoms, energies, forces, stresses) + ) + pool.close() + print("Time elapsed: {:.2f} s".format(time.time() - start)) + elif multiprocessing > 0 and multithreading == 0: + import multiprocessing as mp + + warnings.warn("multiprocessing is experimental.") + print("Using multiprocessing with {} workers".format(multiprocessing)) + # torch.multiprocessing.set_sharing_strategy('file_system') + start = time.time() + pool = mp.Pool(multiprocessing) + results = [] + for i in range(multiprocessing): + left = int(i * length / multiprocessing) + right = int((i + 1) * length / multiprocessing) + results.append( + pool.apply_async(multiprocess_data, args=(atoms[left:right], 1)) + ) + pool.close() + pool.join() + for result in results: + graph = result.get() + if graph is not None: + preprocessed_data.extend(graph) + print("Time for multiprocessing: {:.2f} s".format(time.time() - start)) + else: + raise NotImplementedError + + return DataLoader_pyg( + preprocessed_data, + batch_size=batch_size, + shuffle=shuffle, + num_workers=num_workers, + pin_memory=pin_memory, + ) + + elif model_type == "graphormer" or model_type == "geomformer": + raise NotImplementedError + + +def multiprocess_data(atoms: list[Atoms], number): + convertor = GraphConvertor() + result = [] + for graph in atoms: + graph = convertor.convert( + graph, + graph.get_potential_energy(), + graph.get_forces(), + graph.get_stress(voigt=False) * 160.2, + ) + if graph is not None: + result.append(graph) + return result + + + +def pad_1d_unsqueeze(x, padlen): + x = x + 1 # pad id = 0 + xlen = x.size(0) + if xlen < padlen: + new_x = x.new_zeros([padlen], dtype=x.dtype) + new_x[:xlen] = x + x = new_x + return x.unsqueeze(0) + + +def pad_2d_unsqueeze(x, padlen): + x = x + 1 # pad id = 0 + xlen, xdim = x.size() + if xlen < padlen: + new_x = x.new_zeros([padlen, xdim], dtype=x.dtype) + new_x[:xlen, :] = x + x = new_x + return x.unsqueeze(0) + + +@torch.jit.script +def mask_after_k_persample(n_sample: int, n_len: int, persample_k: torch.Tensor): + assert persample_k.shape[0] == n_sample + assert persample_k.max() <= n_len + device = persample_k.device + mask = torch.zeros([n_sample, n_len + 1], device=device) + mask[torch.arange(n_sample, device=device), persample_k] = 1 + mask = mask.cumsum(dim=1)[:, :-1] + return mask.type(torch.bool) + + +def auto_cell(cell, cutoff=10.0): + # find max value in x, y, z direction + max_x = max(int(cutoff / torch.min(torch.abs(cell[:, 0, 0]))), 1) + max_y = max(int(cutoff / torch.min(torch.abs(cell[:, 1, 1]))), 1) + max_z = max(int(cutoff / torch.min(torch.abs(cell[:, 2, 2]))), 1) + # loop + cells = [] + for i in range(-max_x, max_x + 1): + for j in range(-max_y, max_y + 1): + for k in range(-max_z, max_z + 1): + if i == 0 and j == 0 and k == 0: + continue + cells.append([i, j, k]) + return cells + + +def cell_expand(pos, atoms, cell, cutoff=10.0): + batch_size, max_num_atoms = pos.size()[:2] + cells = auto_cell(cell, cutoff) + cell_tensor = ( + torch.tensor(cells, device=pos.device) + .to(cell.dtype) + .unsqueeze(0) + .expand(batch_size, -1, -1) + ) # batch_size, n_cell, 3 + offset = torch.bmm(cell_tensor, cell) # B x n_cell x 3 + expand_pos = pos.unsqueeze(1) + offset.unsqueeze(2) # B x n_cell x T x 3 + expand_pos = expand_pos.view(batch_size, -1, 3) # B x (n_cell x T) x 3 + expand_dist = torch.norm( + pos.unsqueeze(2) - expand_pos.unsqueeze(1), p=2, dim=-1 + ) # B x T x (8 x T) + expand_mask = expand_dist < cutoff # B x T x (8 x T) + expand_mask = torch.masked_fill(expand_mask, atoms.eq(0).unsqueeze(-1), False) + expand_mask = (torch.sum(expand_mask, dim=1) > 0) & ( + ~(atoms.eq(0).repeat(1, len(cells))) + ) # B x (8 x T) + expand_len = torch.sum(expand_mask, dim=-1) + max_expand_len = torch.max(expand_len) + outcell_index = torch.zeros( + [batch_size, max_expand_len], dtype=torch.long, device=pos.device + ) + expand_pos_compressed = torch.zeros( + [batch_size, max_expand_len, 3], dtype=pos.dtype, device=pos.device + ) + outcell_all_index = torch.arange( + max_num_atoms, dtype=torch.long, device=pos.device + ).repeat(len(cells)) + for i in range(batch_size): + outcell_index[i, : expand_len[i]] = outcell_all_index[expand_mask[i]] + expand_pos_compressed[i, : expand_len[i], :] = expand_pos[i, expand_mask[i], :] + return ( + expand_pos_compressed, + expand_len, + outcell_index, + mask_after_k_persample(batch_size, max_expand_len, expand_len), + ) + + +def pad_spatial_pos_unsqueeze(x, padlen): + x = x + 1 + xlen = x.size(0) + if xlen < padlen: + new_x = x.new_zeros([padlen, padlen], dtype=x.dtype) + new_x[:xlen, :xlen] = x + x = new_x + return x.unsqueeze(0) + + +@torch.jit.script +def convert_to_single_emb(x, offset: int = 512): + feature_num = x.size(1) if len(x.size()) > 1 else 1 + feature_offset = 1 + torch.arange(0, feature_num * offset, offset, dtype=torch.long) + x = x + feature_offset + return x + + +class BatchedDataDataset(torch.utils.data.Dataset): + # class BatchedDataDataset(torch.utils.data.IterableDataset): + def __init__( + self, + dataset, + max_node=512, + infer=False, + ): + super().__init__() + self.dataset = dataset + self.max_node = max_node + + self.infer = infer + + def __getitem__(self, index): + item = self.dataset[int(index)] + return item + + def __len__(self): + return len(self.dataset) + + def collate(self, samples): + return collator_ft( + samples, + max_node=self.max_node, + use_pbc=True, + ) + + +def pad_pos_unsqueeze(x, padlen): + xlen, xdim = x.size() + if xlen < padlen: + new_x = x.new_zeros([padlen, xdim], dtype=x.dtype) + new_x[:xlen, :] = x + x = new_x + return x.unsqueeze(0) + + +def collator_ft(items, max_node=512, use_pbc=True): + original_len = len(items) + items = [item for item in items if item is not None and item.x.size(0) <= max_node] + filtered_len = len(items) + if filtered_len < original_len: + pass + # print("warning: molecules with atoms more than %d are filtered" % max_node) + pos = None + max_node_num = max(item.x.size(0) for item in items if item is not None) + forces = None + stress = None + total_energy = None + + if hasattr(items[0], "pos") and items[0].pos is not None: + poses = [item.pos - item.pos.mean(dim=0, keepdim=True) for item in items] + # poses = [item.pos for item in items] + pos = torch.cat([pad_pos_unsqueeze(i, max_node_num) for i in poses]) + if hasattr(items[0], "forces") and items[0].forces is not None: + forcess = [item.forces for item in items] + forces = torch.cat([pad_pos_unsqueeze(i, max_node_num) for i in forcess]) + if hasattr(items[0], "stress") and items[0].stress is not None: + stress = torch.cat([item.stress.unsqueeze(0) for item in items], dim=0) + if hasattr(items[0], "total_energy") and items[0].cell is not None: + total_energy = torch.cat([item.total_energy for item in items]) + + items = [ + ( + item.idx, + item.x, + item.y, + (item.pbc if hasattr(item, "pbc") else torch.tensor([False, False, False])) + if use_pbc + else None, + (item.cell if hasattr(item, "cell") else torch.zeros([3, 3])) + if use_pbc + else None, + (int(item.num_atoms) if hasattr(item, "num_atoms") else item.x.size()[0]), + ) + for item in items + ] + ( + idxs, + xs, + ys, + pbcs, + cells, + natoms, + ) = zip(*items) + + y = torch.cat(ys) + x = torch.cat([pad_2d_unsqueeze(i, max_node_num) for i in xs]) + + pbc = torch.cat([i.unsqueeze(0) for i in pbcs], dim=0) if use_pbc else None + cell = torch.cat([i.unsqueeze(0) for i in cells], dim=0) if use_pbc else None + natoms = torch.tensor(natoms) if use_pbc else None + node_type_edge = None + return dict( + idx=torch.LongTensor(idxs), + x=x, + y=y, + pos=pos + 1e-5, + pbc=pbc, + cell=cell, + natoms=natoms, + total_energy=total_energy, + forces=forces, + stress=stress, + node_type_edge=node_type_edge, + ) diff --git a/src/mattersim/datasets/utils/convertor.py b/src/mattersim/datasets/utils/convertor.py new file mode 100644 index 0000000..ef85821 --- /dev/null +++ b/src/mattersim/datasets/utils/convertor.py @@ -0,0 +1,261 @@ +# -*- coding: utf-8 -*- +from typing import Optional, Tuple + +import ase +import numpy as np +import torch +from ase import Atoms +from pymatgen.optimization.neighbors import find_points_in_spheres +from torch_geometric.data import Data + +from .threebody_indices import compute_threebody as _compute_threebody + +""" +Supported Properties: + - "num_nodes"(set by default) ## int + - "num_edges"(set by default) ## int + - "num_atoms" ## int + - "num_bonds" ## int + - "atom_attr" ## tensor [num_atoms,atom_attr_dim=1] + - "atom_pos" ## tensor [num_atoms,3] + - "edge_length" ## tensor [num_edges,1] + - "edge_vector" ## tensor [num_edges,3] + - "edge_index" ## tensor [2,num_edges] + - "three_body_indices" ## tensor [num_three_body,2] + - "num_three_body" ## int + - "num_triple_ij" ## tensor [num_edges,1] + - "num_triple_i" ## tensor [num_atoms,1] + - "num_triple_s" ## tensor [1,1] + - "theta_jik" ## tensor [num_three_body,1] + - "triple_edge_length" ## tensor [num_three_body,1] + - "phi" ## tensor [num_three_body,1] + - "energy" ## float + - "forces" ## tensor [num_atoms,3] + - "stress" ## tensor [3,3] +""" + +""" +Computing various graph based operations (M3GNet) +""" + + +def compute_threebody_indices( + bond_atom_indices: np.array, + bond_length: np.array, + n_atoms: int, + atomic_number: np.array, + threebody_cutoff: Optional[float] = None, +): + """ + Given a graph without threebody indices, add the threebody indices + according to a threebody cutoff radius + Args: + bond_atom_indices: np.array, [n_atoms, 2] + bond_length: np.array, [n_atoms] + n_atoms: int + atomic_number: np.array, [n_atoms] + threebody_cutoff: float, threebody cutoff radius + + Returns: + triple_bond_indices, n_triple_ij, n_triple_i, n_triple_s + + """ + n_atoms = np.array(n_atoms).reshape(1) + atomic_number = atomic_number.reshape(-1, 1) + n_bond = bond_atom_indices.shape[0] + if n_bond > 0 and threebody_cutoff is not None: + valid_three_body = bond_length <= threebody_cutoff + ij_reverse_map = np.where(valid_three_body)[0] + original_index = np.arange(n_bond)[valid_three_body] + bond_atom_indices = bond_atom_indices[valid_three_body, :] + else: + ij_reverse_map = None + original_index = np.arange(n_bond) + + if bond_atom_indices.shape[0] > 0: + bond_indices, n_triple_ij, n_triple_i, n_triple_s = _compute_threebody( + np.ascontiguousarray(bond_atom_indices, dtype="int32"), + np.array(n_atoms, dtype="int32"), + ) + if ij_reverse_map is not None: + n_triple_ij_ = np.zeros(shape=(n_bond,), dtype="int32") + n_triple_ij_[ij_reverse_map] = n_triple_ij + n_triple_ij = n_triple_ij_ + bond_indices = original_index[bond_indices] + bond_indices = np.array(bond_indices, dtype="int32") + else: + bond_indices = np.reshape(np.array([], dtype="int32"), [-1, 2]) + if n_bond == 0: + n_triple_ij = np.array([], dtype="int32") + else: + n_triple_ij = np.array([0] * n_bond, dtype="int32") + n_triple_i = np.array([0] * len(atomic_number), dtype="int32") + n_triple_s = np.array([0], dtype="int32") + return bond_indices, n_triple_ij, n_triple_i, n_triple_s + + +def get_fixed_radius_bonding( + structure: ase.Atoms, + cutoff: float = 5.0, + numerical_tol: float = 1e-8, + pbc: bool = True, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """ + Get graph representations from structure within cutoff + Args: + structure (pymatgen Structure or molecule) + cutoff (float): cutoff radius + numerical_tol (float): numerical tolerance + + Returns: + center_indices, neighbor_indices, images, distances + """ + if isinstance(structure, Atoms): + pbc_ = np.array(structure.pbc, dtype=int) + if np.all(pbc_ < 0.1) or not pbc: + lattice_matrix = np.array( + [[1000.0, 0.0, 0.0], [0.0, 1000.0, 0.0], [0.0, 0.0, 1000.0]], + dtype=float, + ) + pbc_ = np.array([0, 0, 0], dtype=int) + else: + lattice_matrix = np.ascontiguousarray( + structure.cell[:], dtype=float + ) # noqa: E501 + + cart_coords = np.ascontiguousarray( + np.array(structure.positions), dtype=float + ) # noqa: E501 + else: + raise ValueError("structure type not supported") + r = float(cutoff) + + ( + center_indices, + neighbor_indices, + images, + distances, + ) = find_points_in_spheres( # noqa: E501 + cart_coords, + cart_coords, + r=r, + pbc=pbc_, + lattice=lattice_matrix, + tol=numerical_tol, + ) + center_indices = center_indices.astype(np.int64) + neighbor_indices = neighbor_indices.astype(np.int64) + images = images.astype(np.int64) + distances = distances.astype(float) + exclude_self = (center_indices != neighbor_indices) | ( + distances > numerical_tol + ) # noqa: E501 + return ( + center_indices[exclude_self], + neighbor_indices[exclude_self], + images[exclude_self], + distances[exclude_self], + ) + + +class GraphConvertor: + """ + Convert ase.Atoms to Graph + """ + + default_properties = ["num_nodes", "num_edges"] + + def __init__( + self, + model_type: str = "m3gnet", + twobody_cutoff: float = 5.0, + has_threebody: bool = True, + threebody_cutoff: float = 4.0, + ): + self.model_type = model_type + self.twobody_cutoff = twobody_cutoff + self.threebody_cutoff = threebody_cutoff + self.has_threebody = has_threebody + + def convert( + self, + atoms: Atoms, + energy=None, + forces=None, + stress=None, + pbc=True, + **kwargs, + ): + """ + Convert the structure into graph + Args: + pbc: bool, whether to use periodic boundary condition, default True + """ + # normalize the structure + scaled_pos = atoms.get_scaled_positions() + scaled_pos = np.mod(scaled_pos, 1) + atoms.set_scaled_positions(scaled_pos) + args = {} + if self.model_type == "m3gnet": + args["num_atoms"] = len(atoms) + args["num_nodes"] = len(atoms) + args["atom_attr"] = torch.FloatTensor( + atoms.get_atomic_numbers() + ).unsqueeze( # noqa: E501 + -1 + ) + args["atom_pos"] = torch.FloatTensor(atoms.get_positions()) + args["cell"] = torch.FloatTensor(np.array(atoms.cell)).unsqueeze(0) + ( + sent_index, + receive_index, + shift_vectors, + distances, + ) = get_fixed_radius_bonding(atoms, self.twobody_cutoff, pbc=pbc) + args["num_bonds"] = len(sent_index) + args["edge_index"] = torch.from_numpy( + np.array([sent_index, receive_index]) + ) # noqa: E501 + args["pbc_offsets"] = torch.FloatTensor(shift_vectors) + if self.has_threebody: + ( + triple_bond_index, + n_triple_ij, + n_triple_i, + n_triple_s, + ) = compute_threebody_indices( + bond_atom_indices=args["edge_index"] + .numpy() + .transpose(1, 0), # noqa: E501 + bond_length=distances, + n_atoms=atoms.positions.shape[0], + atomic_number=atoms.get_atomic_numbers(), + threebody_cutoff=self.threebody_cutoff, + ) + args["three_body_indices"] = torch.from_numpy( + triple_bond_index + ).to( # noqa: E501 + torch.long + ) # [num_three_body,2] + args["num_three_body"] = args["three_body_indices"].shape[0] + args["num_triple_ij"] = ( + torch.from_numpy(n_triple_ij).to(torch.long).unsqueeze(-1) + ) + else: + args["three_body_indices"] = None + args["num_three_body"] = None + args["num_triple_ij"] = None + if energy is not None: + args["energy"] = torch.FloatTensor([energy]) + if forces is not None: + args["forces"] = torch.FloatTensor(forces) + if stress is not None: + args["stress"] = torch.FloatTensor(stress).unsqueeze(0) + return Data(**args) + + elif self.model_type == "graphormer": + raise NotImplementedError + else: + raise NotImplementedError( + "model type {} not implemented".format(self.model_type) + ) diff --git a/src/mattersim/datasets/utils/regressor.py b/src/mattersim/datasets/utils/regressor.py new file mode 100644 index 0000000..af277a1 --- /dev/null +++ b/src/mattersim/datasets/utils/regressor.py @@ -0,0 +1,201 @@ +# -*- coding: utf-8 -*- +""" +Nequip +""" + +import logging +from typing import Optional + +import numpy as np +import torch +from sklearn.gaussian_process import GaussianProcessRegressor +from sklearn.gaussian_process.kernels import DotProduct, Hyperparameter, Kernel + + +def solver( + X, y, regressor: Optional[str] = "NormalizedGaussianProcess", **kwargs +): # noqa: E501 + if regressor == "GaussianProcess": + return gp(X, y, **kwargs) + elif regressor == "NormalizedGaussianProcess": + return normalized_gp(X, y, **kwargs) + else: + raise NotImplementedError(f"{regressor} is not implemented") + + +def normalized_gp(X, y, **kwargs): + feature_rms = 1.0 / np.sqrt(np.average(X**2, axis=0)) + feature_rms = np.nan_to_num(feature_rms, 1) + y_mean = torch.sum(y) / torch.sum(X) + mean, std = base_gp( + X, + y - (torch.sum(X, axis=1) * y_mean).reshape(y.shape), + NormalizedDotProduct, + {"diagonal_elements": feature_rms}, + **kwargs, + ) + return mean + y_mean, std + + +def gp(X, y, **kwargs): + return base_gp( + X, y, DotProduct, {"sigma_0": 0, "sigma_0_bounds": "fixed"}, **kwargs + ) + + +def base_gp( + X, + y, + kernel, + kernel_kwargs, + alpha: Optional[float] = 0.1, + max_iteration: int = 20, + stride: Optional[int] = 1, +): + if len(y.shape) == 1: + y = y.reshape([-1, 1]) + + if stride is not None: + X = X[::stride] + y = y[::stride] + + not_fit = True + iteration = 0 + mean = None + std = None + while not_fit: + print(f"GP fitting iteration {iteration} {alpha}") + try: + _kernel = kernel(**kernel_kwargs) + gpr = GaussianProcessRegressor( + kernel=_kernel, random_state=0, alpha=alpha + ) # noqa: E501 + gpr = gpr.fit(X, y) + + vec = torch.diag(torch.ones(X.shape[1])) + mean, std = gpr.predict(vec, return_std=True) + + mean = torch.as_tensor( + mean, dtype=torch.get_default_dtype() + ).reshape( # noqa: E501 + [-1] + ) # noqa: E501 + # ignore all the off-diagonal terms + std = torch.as_tensor( + std, dtype=torch.get_default_dtype() + ).reshape( # noqa: E501 + [-1] + ) # noqa: E501 + likelihood = gpr.log_marginal_likelihood() + + res = torch.sqrt( + torch.square(torch.matmul(X, mean.reshape([-1, 1])) - y).mean() + ) + + print( + f"GP fitting: alpha {alpha}:\n" + f" residue {res}\n" + f" mean {mean} std {std}\n" + f" log marginal likelihood {likelihood}" + ) + not_fit = False + + except Exception as e: + print(f"GP fitting failed for alpha={alpha} and {e.args}") + if alpha == 0 or alpha is None: + logging.info("try a non-zero alpha") + not_fit = False + raise ValueError( + f"Please set the {alpha} to non-zero value. \n" + "The dataset energy is rank deficient to be solved with GP" + ) + else: + alpha = alpha * 2 + iteration += 1 + logging.debug(f" increase alpha to {alpha}") + + if iteration >= max_iteration or not_fit is False: + raise ValueError( + "Please set the per species shift and scale " + "to zeros and ones. \nThe dataset energy is " + "to diverge to be solved with GP" + ) + + return mean, std + + +class NormalizedDotProduct(Kernel): + r"""Dot-Product kernel. + .. math:: + k(x_i, x_j) = x_i \cdot A \cdot x_j + """ + + def __init__(self, diagonal_elements): + # TODO: check shape + self.diagonal_elements = diagonal_elements + self.A = np.diag(diagonal_elements) + + def __call__(self, X, Y=None, eval_gradient=False): + """Return the kernel k(X, Y) and optionally its gradient. + Parameters + ---------- + X : ndarray of shape (n_samples_X, n_features) + Left argument of the returned kernel k(X, Y) + Y : ndarray of shape (n_samples_Y, n_features), default=None + Right argument of the returned kernel k(X, Y). If None, k(X, X) + if evaluated instead. + eval_gradient : bool, default=False + Determines whether the gradient with respect to the log of + the kernel hyperparameter is computed. + Only supported when Y is None. + Returns + ------- + K : ndarray of shape (n_samples_X, n_samples_Y) + Kernel k(X, Y) + K_gradient : ndarray of shape (n_samples_X, n_samples_X, n_dims),\ + optional + The gradient of the kernel k(X, X) with respect to the log of the + hyperparameter of the kernel. Only returned when `eval_gradient` + is True. + """ + X = np.atleast_2d(X) + if Y is None: + K = (X.dot(self.A)).dot(X.T) + else: + if eval_gradient: + raise ValueError( + "Gradient can only be evaluated when Y is None." # noqa: E501 + ) + K = (X.dot(self.A)).dot(Y.T) + + if eval_gradient: + return K, np.empty((X.shape[0], X.shape[0], 0)) + else: + return K + + def diag(self, X): + """Returns the diagonal of the kernel k(X, X). + The result of this method is identical to np.diag(self(X)); however, + it can be evaluated more efficiently since only the diagonal is + evaluated. + Parameters + ---------- + X : ndarray of shape (n_samples_X, n_features) + Left argument of the returned kernel k(X, Y). + Returns + ------- + K_diag : ndarray of shape (n_samples_X,) + Diagonal of kernel k(X, X). + """ + return np.einsum("ij,ij,jj->i", X, X, self.A) + + def __repr__(self): + return "" + + def is_stationary(self): + """Returns whether the kernel is stationary.""" + return False + + @property + def hyperparameter_diagonal_elements(self): + return Hyperparameter("diagonal_elements", "numeric", "fixed") diff --git a/src/mattersim/datasets/utils/setup.py b/src/mattersim/datasets/utils/setup.py new file mode 100644 index 0000000..b4895aa --- /dev/null +++ b/src/mattersim/datasets/utils/setup.py @@ -0,0 +1,15 @@ +# -*- coding: utf-8 -*- +from distutils.core import Extension, setup + +import numpy +from Cython.Build import cythonize + +package = Extension( + "threebody_indices", + ["threebody_indices.pyx"], + include_dirs=[numpy.get_include()], # noqa: E501 +) +setup(ext_modules=cythonize([package])) + +# usage: +# python setup.py build_ext --inplace diff --git a/src/mattersim/datasets/utils/threebody_indices.pyx b/src/mattersim/datasets/utils/threebody_indices.pyx new file mode 100644 index 0000000..f24d303 --- /dev/null +++ b/src/mattersim/datasets/utils/threebody_indices.pyx @@ -0,0 +1,91 @@ +# cython: boundscheck=False +# cython: wraparound=False +# cython: nonecheck=False +# cython: cdivision=True +# cython: profile=True +# cython: language_level=3 +# distutils: language = c +# distutils: define_macros=NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION + +cimport numpy as np +import numpy as np +from libc.stdlib cimport free, malloc +from libc.string cimport memset + + +def compute_threebody(const int[:, ::1] bond_atom_indices, + const int[::1] n_atoms): + """ + Calculate the three body indices from pair atom indices + Args: + bond_atom_indices (np.ndarray): pair atom indices + n_atoms (int): number of atoms + Returns: + triple_bond_indices (np.ndarray): bond indices that form three-body + py_n_triple_ij (np.ndarray): number of three-body angles for each bond + py_n_triple_i (np.ndarray): number of three-body angles each atom + py_n_triple_s (np.ndarray): number of three-body angles for each + structure + """ + cdef int i, j, k + cdef int n_bond = bond_atom_indices.shape[0] + cdef int n_atom = 0 + cdef int n_struct = n_atoms.shape[0] + for i in range(n_struct): + n_atom += n_atoms[i] + + cdef int* n_bond_per_atom = malloc(n_atom * sizeof(int)) + memset(n_bond_per_atom, 0, n_atom * sizeof(int)) + + for i in range(n_bond): + n_bond_per_atom[bond_atom_indices[i, 0]] += 1 + + cdef int* n_triple_i = malloc(n_atom * sizeof(int)) + cdef int* n_triple_ij = malloc(n_bond * sizeof(int)) + cdef int* n_triple_s = malloc(n_struct * sizeof(int)) + + memset(n_triple_s, 0, n_struct * sizeof(int)) + + cdef int n_triple = 0 + cdef int n_triple_temp + cdef int start = 0 + + for i in range(n_atom): + n_triple_temp = n_bond_per_atom[i] * (n_bond_per_atom[i] - 1) + for j in range(n_bond_per_atom[i]): + n_triple_ij[start + j] = n_bond_per_atom[i] - 1 + n_triple += n_triple_temp + n_triple_i[i] = n_triple_temp + start += n_bond_per_atom[i] + + cdef np.ndarray triple_bond_indices = np.empty(shape=(n_triple, 2), + dtype=np.int32) + + start = 0 + cdef int index = 0 + for i in range(n_atom): + for j in range(n_bond_per_atom[i]): + for k in range(n_bond_per_atom[i]): + if j != k: + triple_bond_indices[index, 0] = start + j + triple_bond_indices[index, 1] = start + k + index += 1 + start += n_bond_per_atom[i] + + start = 0 + cdef int end = start + cdef int n_atom_temp + for i in range(n_struct): + end += n_atoms[i] + for j in range(start, end): + n_triple_s[i] += n_triple_i[j] + start = end + py_n_triple_ij = np.array(n_triple_ij) + py_n_triple_i = np.array(n_triple_i) + py_n_triple_s = np.array(n_triple_s) + + free(n_triple_ij) + free(n_triple_i) + free(n_triple_s) + free(n_bond_per_atom) + return triple_bond_indices, py_n_triple_ij, py_n_triple_i, py_n_triple_s diff --git a/src/mattersim/forcefield/m3gnet/m3gnet.py b/src/mattersim/forcefield/m3gnet/m3gnet.py new file mode 100644 index 0000000..66a5f96 --- /dev/null +++ b/src/mattersim/forcefield/m3gnet/m3gnet.py @@ -0,0 +1,187 @@ +# -*- coding: utf-8 -*- +from typing import Dict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_runstats.scatter import scatter + +from mattersim.jit_compile_tools.jit import compile_mode + +from .modules import ( # noqa: F501 + MLP, + GatedMLP, + MainBlock, + SmoothBesselBasis, + SphericalBasisLayer, +) +from .scaling import AtomScaling + + +@compile_mode("script") +class M3Gnet(nn.Module): + """ + M3Gnet + """ + + def __init__( + self, + num_blocks: int = 4, + units: int = 128, + max_l: int = 4, + max_n: int = 4, + cutoff: float = 5.0, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + max_z: int = 94, + threebody_cutoff: float = 4.0, + **kwargs, + ): + super().__init__() + self.rbf = SmoothBesselBasis(r_max=cutoff, max_n=max_n) + self.sbf = SphericalBasisLayer(max_n=max_n, max_l=max_l, cutoff=cutoff) + self.edge_encoder = MLP( + in_dim=max_n, out_dims=[units], activation="swish", use_bias=False + ) + module_list = [ + MainBlock(max_n, max_l, cutoff, units, max_n, threebody_cutoff) + for i in range(num_blocks) + ] + self.graph_conv = nn.ModuleList(module_list) + self.final = GatedMLP( + in_dim=units, + out_dims=[units, units, 1], + activation=["swish", "swish", None], + ) + self.apply(self.init_weights) + self.atom_embedding = MLP( + in_dim=max_z + 1, out_dims=[units], activation=None, use_bias=False + ) + self.atom_embedding.apply(self.init_weights_uniform) + self.normalizer = AtomScaling(verbose=False, max_z=max_z) + self.max_z = max_z + self.device = device + self.model_args = { + "num_blocks": num_blocks, + "units": units, + "max_l": max_l, + "max_n": max_n, + "cutoff": cutoff, + "max_z": max_z, + "threebody_cutoff": threebody_cutoff, + } + + def forward( + self, + input: Dict[str, torch.Tensor], + dataset_idx: int = -1, + ) -> torch.Tensor: + # Exact data from input_dictionary + pos = input["atom_pos"] + cell = input["cell"] + pbc_offsets = input["pbc_offsets"].float() + atom_attr = input["atom_attr"] + edge_index = input["edge_index"].long() + three_body_indices = input["three_body_indices"].long() + num_three_body = input["num_three_body"] + num_bonds = input["num_bonds"] + num_triple_ij = input["num_triple_ij"] + num_atoms = input["num_atoms"] + num_graphs = input["num_graphs"] + batch = input["batch"] + + # -------------------------------------------------------------# + cumsum = torch.cumsum(num_bonds, dim=0) - num_bonds + index_bias = torch.repeat_interleave( # noqa: F501 + cumsum, num_three_body, dim=0 + ).unsqueeze(-1) + three_body_indices = three_body_indices + index_bias + + # === Refer to the implementation of M3GNet, === + # === we should re-compute the following attributes === + # edge_length, edge_vector(optional), triple_edge_length, theta_jik + atoms_batch = torch.repeat_interleave(repeats=num_atoms) + edge_batch = atoms_batch[edge_index[0]] + edge_vector = pos[edge_index[0]] - ( + pos[edge_index[1]] + + torch.einsum("bi, bij->bj", pbc_offsets, cell[edge_batch]) + ) + edge_length = torch.linalg.norm(edge_vector, dim=1) + vij = edge_vector[three_body_indices[:, 0].clone()] + vik = edge_vector[three_body_indices[:, 1].clone()] + rij = edge_length[three_body_indices[:, 0].clone()] + rik = edge_length[three_body_indices[:, 1].clone()] + cos_jik = torch.sum(vij * vik, dim=1) / (rij * rik) + # eps = 1e-7 avoid nan in torch.acos function + cos_jik = torch.clamp(cos_jik, min=-1.0 + 1e-7, max=1.0 - 1e-7) + triple_edge_length = rik.view(-1) + edge_length = edge_length.unsqueeze(-1) + atomic_numbers = atom_attr.squeeze(1).long() + + # featurize + atom_attr = self.atom_embedding(self.one_hot_atoms(atomic_numbers)) + edge_attr = self.rbf(edge_length.view(-1)) + edge_attr_zero = edge_attr # e_ij^0 + edge_attr = self.edge_encoder(edge_attr) + three_basis = self.sbf(triple_edge_length, torch.acos(cos_jik)) + + # Main Loop + for idx, conv in enumerate(self.graph_conv): + atom_attr, edge_attr = conv( + atom_attr, + edge_attr, + edge_attr_zero, + edge_index, + three_basis, + three_body_indices, + edge_length, + num_bonds, + num_triple_ij, + num_atoms, + ) + + energies_i = self.final(atom_attr).view(-1) # [batch_size*num_atoms] + energies_i = self.normalizer(energies_i, atomic_numbers) + energies = scatter(energies_i, batch, dim=0, dim_size=num_graphs) + + return energies # [batch_size] + + def init_weights(self, m): + if isinstance(m, nn.Linear): + torch.nn.init.xavier_uniform_(m.weight) + + def init_weights_uniform(self, m): + if isinstance(m, nn.Linear): + torch.nn.init.uniform_(m.weight, a=-0.05, b=0.05) + + @torch.jit.export + def one_hot_atoms(self, species): + # one_hots = [] + # for i in range(species.shape[0]): + # one_hots.append( + # F.one_hot( + # species[i], + # num_classes=self.max_z+1).float().to(species.device) + # ) + # return torch.cat(one_hots, dim=0) + return F.one_hot(species, num_classes=self.max_z + 1).float() + + def print(self): + from prettytable import PrettyTable + + table = PrettyTable(["Modules", "Parameters"]) + total_params = 0 + for name, parameter in self.named_parameters(): + if not parameter.requires_grad: + continue + params = parameter.numel() + table.add_row([name, params]) + total_params += params + print(table) + print(f"Total Trainable Params: {total_params}") + + @torch.jit.export + def set_normalizer(self, normalizer: AtomScaling): + self.normalizer = normalizer + + def get_model_args(self): + return self.model_args diff --git a/src/mattersim/forcefield/m3gnet/m3gnet_multi_head.py b/src/mattersim/forcefield/m3gnet/m3gnet_multi_head.py new file mode 100644 index 0000000..6b04635 --- /dev/null +++ b/src/mattersim/forcefield/m3gnet/m3gnet_multi_head.py @@ -0,0 +1,201 @@ +# -*- coding: utf-8 -*- +from typing import Dict + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_runstats.scatter import scatter + +from .modules import ( # noqa: E501 + MLP, + GatedMLP, + MainBlock, + SmoothBesselBasis, + SphericalBasisLayer, +) +from .scaling import AtomScaling + + +class M3Gnet_multi_head(nn.Module): + """ + M3Gnet with no massage passing + """ + + def __init__( + self, + normalizer_list: list[AtomScaling], + num_blocks: int = 4, + units: int = 128, + max_l: int = 4, + max_n: int = 4, + cutoff: float = 5.0, + device: str = "cuda", + max_z: int = 94, + threebody_cutoff: float = 4.0, + **kwargs, + ): + super().__init__() + self.rbf = SmoothBesselBasis(r_max=cutoff, max_n=max_n) + self.sbf = SphericalBasisLayer(max_n=max_n, max_l=max_l, cutoff=cutoff) + self.edge_encoder = MLP( + in_dim=max_n, out_dims=[units], activation="swish", use_bias=False + ) + module_list = [ + MainBlock(max_n, max_l, cutoff, units, max_n, threebody_cutoff) + for i in range(num_blocks) + ] + self.graph_conv = nn.ModuleList(module_list) + if isinstance(normalizer_list, list): + self.normalizer_list = nn.ModuleList(normalizer_list) + elif isinstance(normalizer_list, nn.ModuleList): + self.normalizer_list = normalizer_list + else: + raise NotImplementedError + self.final_layer_list = nn.ModuleList( + [ + GatedMLP( + in_dim=units, + out_dims=[units, units, 1], + activation=["swish", "swish", None], + ) + for _ in range(len(normalizer_list)) + ] + ) + self.apply(self.init_weights) + self.max_z = max_z + self.device = device + self.atom_embedding = MLP( + in_dim=max_z + 1, out_dims=[units], activation=None, use_bias=False + ) + self.atom_embedding.apply(self.init_weights_uniform) + self.model_args = { + "num_blocks": num_blocks, + "units": units, + "max_l": max_l, + "max_n": max_n, + "cutoff": cutoff, + "normalizer_list": self.normalizer_list, + "max_z": max_z, + "threebody_cutoff": threebody_cutoff, + } + print("This model is specifically designed for multi tasks") + + def forward( + self, + input: Dict[str, torch.Tensor], + dataset_idx: int = -1, + ): + # Exact data from input_dictionary + pos = input["atom_pos"] + cell = input["cell"] + pbc_offsets = input["pbc_offsets"] + atom_attr = input["atom_attr"] + edge_index = input["edge_index"] + three_body_indices = input["three_body_indices"] + num_three_body = input["num_three_body"] + num_bonds = input["num_bonds"] + num_triple_ij = input["num_triple_ij"] + num_atoms = input["num_atoms"] + num_graphs = input["num_graphs"] + batch = input["batch"] + + cumsum = torch.cumsum(num_bonds, dim=0) - num_bonds + index_bias = torch.repeat_interleave( # noqa: E501 + cumsum, num_three_body, dim=0 + ).unsqueeze(-1) + three_body_indices = three_body_indices + index_bias + + # === Refer to the implementation of M3GNet, === + # === we should re-compute the following attributes === + # edge_length, edge_vector(optional), triple_edge_length, theta_jik + atoms_batch = torch.repeat_interleave(repeats=num_atoms) + edge_batch = atoms_batch[edge_index[0]] + edge_vector = pos[edge_index[0]] - ( + pos[edge_index[1]] + + torch.einsum("bi, bij->bj", pbc_offsets, cell[edge_batch]) + ) + edge_length = torch.linalg.norm(edge_vector, dim=1) + vij = edge_vector[three_body_indices[:, 0].clone()] + vik = edge_vector[three_body_indices[:, 1].clone()] + rij = edge_length[three_body_indices[:, 0].clone()] + rik = edge_length[three_body_indices[:, 1].clone()] + cos_jik = torch.sum(vij * vik, dim=1) / (rij * rik) + # eps = 1e-7 avoid nan in torch.acos function + cos_jik = torch.clamp(cos_jik, min=-1.0 + 1e-7, max=1.0 - 1e-7) + triple_edge_length = rik.view(-1) + edge_length = edge_length.unsqueeze(-1) + atomic_numbers = atom_attr.squeeze(1).long() + + # featurize + atom_attr = self.atom_embedding(self.one_hot_atoms(atomic_numbers)) + edge_attr = self.rbf(edge_length.view(-1)) + edge_attr_zero = edge_attr # e_ij^0 + edge_attr = self.edge_encoder(edge_attr) + three_basis = self.sbf(triple_edge_length, torch.acos(cos_jik)) + + # feature_after_first_layer = None + + # Main Loop + for idx, conv in enumerate(self.graph_conv): + atom_attr, edge_attr = conv( + atom_attr, + edge_attr, + edge_attr_zero, + edge_index, + three_basis, + three_body_indices, + edge_length, + num_bonds, + num_triple_ij, + num_atoms, + ) + # if idx == 0: + # feature_after_first_layer = atom_attr.detach() + + # feature_before_branching_out = atom_attr.detach() + energies_i = self.final_layer_list[dataset_idx](atom_attr).view(-1) + if self.normalizer_list[dataset_idx] is not None: + energies_i = self.normalizer_list[dataset_idx]( + energies_i, atomic_numbers.view(-1) + ) + energies = scatter(energies_i, batch, dim=0, dim_size=num_graphs) + # return energies, + # feature_after_first_layer, + # feature_before_branching_out + return energies + + def init_weights(self, m): + if isinstance(m, nn.Linear): + torch.nn.init.xavier_uniform_(m.weight) + + def init_weights_uniform(self, m): + if isinstance(m, nn.Linear): + torch.nn.init.uniform_(m.weight, a=-0.05, b=0.05) + + def one_hot_atoms(self, species): + # one_hots = [] + # for i in range(species.shape[0]): + # one_hots.append( + # F.one_hot(species[i], + # num_classes=self.max_z+1 + # ).float().to(species.device) + # ) + # return torch.cat(one_hots, dim=0) + return F.one_hot(species, num_classes=self.max_z + 1).float() + + def print(self): + from prettytable import PrettyTable + + table = PrettyTable(["Modules", "Parameters"]) + total_params = 0 + for name, parameter in self.model.named_parameters(): + if not parameter.requires_grad: + continue + params = parameter.numel() + table.add_row([name, params]) + total_params += params + print(table) + print(f"Total Trainable Params: {total_params}") + + def get_model_args(self): + return self.model_args diff --git a/src/mattersim/forcefield/m3gnet/modules/__init__.py b/src/mattersim/forcefield/m3gnet/modules/__init__.py new file mode 100644 index 0000000..cfbf13a --- /dev/null +++ b/src/mattersim/forcefield/m3gnet/modules/__init__.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- +from .angle_encoding import SphericalBasisLayer +from .edge_encoding import SmoothBesselBasis +from .layers import MLP, GatedMLP, LinearLayer, SwishLayer +from .message_passing import AtomLayer, EdgeLayer, MainBlock + +__all__ = [ + "SphericalBasisLayer", + "SmoothBesselBasis", + "GatedMLP", + "MLP", + "LinearLayer", + "SwishLayer", + "AtomLayer", + "EdgeLayer", + "MainBlock", +] diff --git a/src/mattersim/forcefield/m3gnet/modules/angle_encoding.py b/src/mattersim/forcefield/m3gnet/modules/angle_encoding.py new file mode 100644 index 0000000..91fb08d --- /dev/null +++ b/src/mattersim/forcefield/m3gnet/modules/angle_encoding.py @@ -0,0 +1,247 @@ +# -*- coding: utf-8 -*- +# """ +# Ref: +# - https://github.com/akirasosa/pytorch-dimenet +# - https://arxiv.org/abs/2003.03123 +# """ +# +import math + +import torch +import torch.nn as nn + + +@torch.jit.script +def _spherical_harmonics(lmax: int, x: torch.Tensor) -> torch.Tensor: + sh_0_0 = torch.ones_like(x) * 0.5 * math.sqrt(1.0 / math.pi) + if lmax == 0: + return torch.stack( + [ + sh_0_0, + ], + dim=-1, + ) + + sh_1_1 = math.sqrt(3.0 / (4.0 * math.pi)) * x + if lmax == 1: + return torch.stack([sh_0_0, sh_1_1], dim=-1) + + sh_2_2 = math.sqrt(5.0 / (16.0 * math.pi)) * (3.0 * x**2 - 1.0) + if lmax == 2: + return torch.stack([sh_0_0, sh_1_1, sh_2_2], dim=-1) + + sh_3_3 = math.sqrt(7.0 / (16.0 * math.pi)) * x * (5.0 * x**2 - 3.0) + if lmax == 3: + return torch.stack([sh_0_0, sh_1_1, sh_2_2, sh_3_3], dim=-1) + + raise ValueError("lmax must be less than 8") + + +class SphericalBasisLayer(nn.Module): + def __init__(self, max_n, max_l, cutoff): + super(SphericalBasisLayer, self).__init__() + + assert max_l <= 4, "lmax must be less than 5" + assert max_n <= 4, "max_n must be less than 5" + + self.max_n = max_n + self.max_l = max_l + self.cutoff = cutoff + + # retrieve formulas + self.register_buffer( + "factor", torch.sqrt(torch.tensor(2.0 / (self.cutoff**3))) + ) + self.coef = torch.zeros(4, 9, 4) + self.coef[0, 0, :] = torch.tensor( + [ + 3.14159274101257, + 6.28318548202515, + 9.42477798461914, + 12.5663709640503, + ] # noqa: E501 + ) + self.coef[1, :4, :] = torch.tensor( + [ + [ + -1.02446483277785, + -1.00834335996107, + -1.00419641763893, + -1.00252381898662, + ], + [ + 4.49340963363647, + 7.7252516746521, + 10.9041213989258, + 14.0661935806274, + ], # noqa: E501 + [ + 0.22799275301076, + 0.130525632358311, + 0.092093290316619, + 0.0712718627992818, + ], + [ + 4.49340963363647, + 7.7252516746521, + 10.9041213989258, + 14.0661935806274, + ], # noqa: E501 + ] + ) + self.coef[2, :6, :] = torch.tensor( + [ + [ + -1.04807944170731, + -1.01861796359391, + -1.01002272174988, + -1.00628955560036, + ], + [ + 5.76345920562744, + 9.09501171112061, + 12.322940826416, + 15.5146026611328, + ], # noqa: E501 + [ + 0.545547077361439, + 0.335992298618515, + 0.245888396928293, + 0.194582402961821, + ], + [ + 5.76345920562744, + 9.09501171112061, + 12.322940826416, + 15.5146026611328, + ], # noqa: E501 + [ + 0.0946561878721665, + 0.0369424811413594, + 0.0199537107571916, + 0.0125418876146463, + ], + [ + 5.76345920562744, + 9.09501171112061, + 12.322940826416, + 15.5146026611328, + ], # noqa: E501 + ] + ) + self.coef[3, :8, :] = torch.tensor( + [ + [ + 1.06942831392075, + 1.0292173312802, + 1.01650804843248, + 1.01069656069999, + ], # noqa: E501 + [ + 6.9879322052002, + 10.4171180725098, + 13.6980228424072, + 16.9236221313477, + ], # noqa: E501 + [ + 0.918235852195231, + 0.592803493701152, + 0.445250264272671, + 0.358326327374518, + ], + [ + 6.9879322052002, + 10.4171180725098, + 13.6980228424072, + 16.9236221313477, + ], # noqa: E501 + [ + 0.328507713452024, + 0.142266673367543, + 0.0812617757677838, + 0.0529328657590962, + ], + [ + 6.9879322052002, + 10.4171180725098, + 13.6980228424072, + 16.9236221313477, + ], # noqa: E501 + [ + 0.0470107184508114, + 0.0136570088173405, + 0.0059323726279831, + 0.00312775039221944, + ], + [ + 6.9879322052002, + 10.4171180725098, + 13.6980228424072, + 16.9236221313477, + ], # noqa: E501 + ] + ) + + def forward(self, r, theta_val): + r = r / self.cutoff + # Denote empty lists for rbf and cbf + rbfs = [] + + for j in range(self.max_l): + rbfs.append(torch.sin(self.coef[0, 0, j] * r) / r) + + if self.max_n > 1: + for j in range(self.max_l): + rbfs.append( + ( + self.coef[1, 0, j] + * r + * torch.cos(self.coef[1, 1, j] * r) # noqa: E501 + + self.coef[1, 2, j] + * torch.sin(self.coef[1, 3, j] * r) # noqa: E501 + ) + / r**2 + ) + + if self.max_n > 2: + for j in range(self.max_l): + rbfs.append( + ( + self.coef[2, 0, j] + * (r**2) + * torch.sin(self.coef[2, 1, j] * r) + - self.coef[2, 2, j] + * r + * torch.cos(self.coef[2, 3, j] * r) # noqa: E501 + + self.coef[2, 4, j] + * torch.sin(self.coef[2, 5, j] * r) # noqa: E501 + ) + / (r**3) + ) + + if self.max_n > 3: + for j in range(self.max_l): + rbfs.append( + ( + self.coef[3, 0, j] + * (r**3) + * torch.cos(self.coef[3, 1, j] * r) + - self.coef[3, 2, j] + * (r**2) + * torch.sin(self.coef[3, 3, j] * r) + - self.coef[3, 4, j] + * r + * torch.cos(self.coef[3, 5, j] * r) + + self.coef[3, 6, j] + * torch.sin(self.coef[3, 7, j] * r) # noqa: E501 + ) + / r**4 + ) + + rbfs = torch.stack(rbfs, dim=-1) + rbfs = rbfs * self.factor + + cbfs = _spherical_harmonics(self.max_l - 1, torch.cos(theta_val)) + cbfs = cbfs.repeat_interleave(self.max_n, dim=1) + + return rbfs * cbfs diff --git a/src/mattersim/forcefield/m3gnet/modules/edge_encoding.py b/src/mattersim/forcefield/m3gnet/modules/edge_encoding.py new file mode 100644 index 0000000..1b04bcb --- /dev/null +++ b/src/mattersim/forcefield/m3gnet/modules/edge_encoding.py @@ -0,0 +1,194 @@ +# -*- coding: utf-8 -*- +""" +Ref: + - https://github.com/mir-group/nequip + - https://www.nature.com/articles/s41467-022-29939-5 +""" + +import math +from typing import Optional + +import torch +from e3nn.math import soft_one_hot_linspace +from torch import nn + +from mattersim.jit_compile_tools.jit import compile_mode + + +class e3nn_basias(nn.Module): + def __init__( + self, + r_max: float, + r_min: Optional[float] = None, + e3nn_basis_name: str = "gaussian", + num_basis: int = 8, + ): + super().__init__() + self.r_max = r_max + self.r_min = r_min if r_min is not None else 0.0 + self.e3nn_basis_name = e3nn_basis_name + self.num_basis = num_basis + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return soft_one_hot_linspace( + x, + start=self.r_min, + end=self.r_max, + number=self.num_basis, + basis=self.e3nn_basis_name, + cutoff=True, + ) + + def _make_tracing_inputs(self, n: int): + return [{"forward": (torch.randn(5, 1),)} for _ in range(n)] + + +class BesselBasis(nn.Module): + def __init__(self, r_max, num_basis=8, trainable=True): + r"""Radial Bessel Basis, as proposed in + DimeNet: https://arxiv.org/abs/2003.03123 + + Parameters + ---------- + r_max : float + Cutoff radius + + num_basis : int + Number of Bessel Basis functions + + trainable : bool + Train the :math:`n \pi` part or not. + """ + super(BesselBasis, self).__init__() + + self.trainable = trainable + self.num_basis = num_basis + + self.r_max = float(r_max) + self.prefactor = 2.0 / self.r_max + + bessel_weights = ( + torch.linspace(start=1.0, end=num_basis, steps=num_basis) * math.pi + ) + if self.trainable: + self.bessel_weights = nn.Parameter(bessel_weights) + else: + self.register_buffer("bessel_weights", bessel_weights) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Evaluate Bessel Basis for input x. + + Parameters + ---------- + x : torch.Tensor + Input + """ + numerator = torch.sin( + self.bessel_weights * x.unsqueeze(-1) / self.r_max # noqa: E501 + ) + + return self.prefactor * (numerator / x.unsqueeze(-1)) + + +@compile_mode("script") +class SmoothBesselBasis(nn.Module): + def __init__(self, r_max, max_n=10): + r"""Smooth Radial Bessel Basis, as proposed + in DimeNet: https://arxiv.org/abs/2003.03123 + This is an orthogonal basis with first + and second derivative at the cutoff + equals to zero. The function was derived from + the order 0 spherical Bessel function, + and was expanded by the different zero roots + Ref: + https://arxiv.org/pdf/1907.02374.pdf + Args: + r_max: torch.Tensor distance tensor + max_n: int, max number of basis, expanded by the zero roots + Returns: expanded spherical harmonics with + derivatives smooth at boundary + Parameters + ---------- + """ + super(SmoothBesselBasis, self).__init__() + self.max_n = max_n + n = torch.arange(0, max_n).float()[None, :] + PI = 3.1415926535897 + SQRT2 = 1.41421356237 + fnr = ( + (-1) ** n + * SQRT2 + * PI + / r_max**1.5 + * (n + 1) + * (n + 2) + / torch.sqrt(2 * n**2 + 6 * n + 5) + ) + en = n**2 * (n + 2) ** 2 / (4 * (n + 1) ** 4 + 1) + dn = [torch.tensor(1.0).float()] + for i in range(1, max_n): + dn.append(1 - en[0, i] / dn[-1]) + dn = torch.stack(dn) + self.register_buffer("dn", dn) + self.register_buffer("en", en) + self.register_buffer("fnr_weights", fnr) + self.register_buffer( + "n_1_pi_cutoff", + ((torch.arange(0, max_n).float() + 1) * PI / r_max).reshape(1, -1), + ) + self.register_buffer( + "n_2_pi_cutoff", + ((torch.arange(0, max_n).float() + 2) * PI / r_max).reshape(1, -1), + ) + self.register_buffer("r_max", torch.tensor(r_max)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Evaluate Smooth Bessel Basis for input x. + + Parameters + ---------- + x : torch.Tensor + Input + """ + x_1 = x.unsqueeze(-1) * self.n_1_pi_cutoff + x_2 = x.unsqueeze(-1) * self.n_2_pi_cutoff + fnr = self.fnr_weights * (torch.sin(x_1) / x_1 + torch.sin(x_2) / x_2) + gn = [fnr[:, 0]] + for i in range(1, self.max_n): + gn.append( + 1 + / torch.sqrt(self.dn[i]) + * ( + fnr[:, i] + + torch.sqrt(self.en[0, i] / self.dn[i - 1]) * gn[-1] # noqa: E501 + ) + ) + return torch.transpose(torch.stack(gn), 1, 0) + + +# class GaussianBasis(nn.Module): +# r_max: float + +# def __init__(self, r_max, r_min=0.0, num_basis=8, trainable=True): +# super().__init__() + +# self.trainable = trainable +# self.num_basis = num_basis + +# self.r_max = float(r_max) +# self.r_min = float(r_min) + +# means = torch.linsspace(self.r_min, self.r_max, self.num_basis) +# stds = torch.full(size=means.size, fill_value=means[1] - means[0]) +# if self.trainable: +# self.means = nn.Parameter(means) +# self.stds = nn.Parameter(stds) +# else: +# self.register_buffer("means", means) +# self.register_buffer("stds", stds) + +# def forward(self, x: torch.Tensor) -> torch.Tensor: +# x = (x[..., None] - self.means) / self.stds +# x = x.square().mul(-0.5).exp() / self.stds # sqrt(2 * pi) diff --git a/src/mattersim/forcefield/m3gnet/modules/layers.py b/src/mattersim/forcefield/m3gnet/modules/layers.py new file mode 100644 index 0000000..dfd127b --- /dev/null +++ b/src/mattersim/forcefield/m3gnet/modules/layers.py @@ -0,0 +1,175 @@ +# -*- coding: utf-8 -*- +from typing import Union + +import torch.nn as nn + + +class LinearLayer(nn.Module): + def __init__( + self, + in_dim, + out_dim, + bias=True, + ): + super().__init__() + self.linear = nn.Linear(in_dim, out_dim, bias=bias) + + def forward(self, x): + return self.linear(x) + + +class SigmoidLayer(nn.Module): + def __init__( + self, + in_dim, + out_dim, + bias=True, + ): + super().__init__() + self.linear = nn.Linear(in_dim, out_dim, bias=bias) + self.sigmoid = nn.Sigmoid() + + def forward( + self, + x, + ): + return self.sigmoid(self.linear(x)) + + +class SwishLayer(nn.Module): + def __init__( + self, + in_dim, + out_dim, + bias=True, + ): + super().__init__() + self.linear = nn.Linear(in_dim, out_dim, bias=bias) + self.sigmoid = nn.Sigmoid() + + def forward( + self, + x, + ): + x = self.linear(x) + return x * self.sigmoid(x) + + +class ReLULayer(nn.Module): + def __init__( + self, + in_dim, + out_dim, + bias=True, + ): + super().__init__() + self.linear = nn.Linear(in_dim, out_dim, bias=bias) + self.relu = nn.ReLU() + + def forward( + self, + x, + ): + return self.relu(self.linear(x)) + + +class GatedMLP(nn.Module): + def __init__( + self, + in_dim: int, + out_dims: list, + activation: Union[list[Union[str, None]], str] = "swish", + use_bias: bool = True, + ): + super().__init__() + input_dim = in_dim + if isinstance(activation, str) or activation is None: + activation = [activation] * len(out_dims) + else: + assert len(activation) == len( + out_dims + ), "activation and out_dims must have the same length" + module_list_g = [] + for i in range(len(out_dims)): + if activation[i] == "swish": + module_list_g.append( # noqa: E501 + SwishLayer(input_dim, out_dims[i], bias=use_bias) + ) + elif activation[i] == "sigmoid": + module_list_g.append( + SigmoidLayer(input_dim, out_dims[i], bias=use_bias) + ) + elif activation[i] is None: + module_list_g.append( # noqa: E501 + LinearLayer(input_dim, out_dims[i], bias=use_bias) + ) + input_dim = out_dims[i] + module_list_sigma = [] + activation[-1] = "sigmoid" + input_dim = in_dim + for i in range(len(out_dims)): + if activation[i] == "swish": + module_list_sigma.append( + SwishLayer(input_dim, out_dims[i], bias=use_bias) + ) + elif activation[i] == "sigmoid": + module_list_sigma.append( + SigmoidLayer(input_dim, out_dims[i], bias=use_bias) + ) + elif activation[i] is None: + module_list_sigma.append( + LinearLayer(input_dim, out_dims[i], bias=use_bias) + ) + else: + raise NotImplementedError + input_dim = out_dims[i] + self.g = nn.Sequential(*module_list_g) + self.sigma = nn.Sequential(*module_list_sigma) + + def forward( + self, + x, + ): + return self.g(x) * self.sigma(x) + + +class MLP(nn.Module): + def __init__( + self, + in_dim: int, + out_dims: list, + activation: Union[list[Union[str, None]], str, None] = "swish", + use_bias: bool = True, + ): + super().__init__() + input_dim = in_dim + if isinstance(activation, str) or activation is None: + activation = [activation] * len(out_dims) + else: + assert len(activation) == len( + out_dims + ), "activation and out_dims must have the same length" + module_list = [] + for i in range(len(out_dims)): + if activation[i] == "swish": + module_list.append( + SwishLayer(input_dim, out_dims[i], bias=use_bias) # noqa: E501 + ) + elif activation[i] == "sigmoid": + module_list.append( + SigmoidLayer(input_dim, out_dims[i], bias=use_bias) # noqa: E501 + ) + elif activation[i] is None: + module_list.append( + LinearLayer(input_dim, out_dims[i], bias=use_bias) # noqa: E501 + ) + else: + raise NotImplementedError + input_dim = out_dims[i] + self.mlp = nn.Sequential(*module_list) + + def forward( + self, + x, + ): + return self.mlp(x) diff --git a/src/mattersim/forcefield/m3gnet/modules/message_passing.py b/src/mattersim/forcefield/m3gnet/modules/message_passing.py new file mode 100644 index 0000000..e4c068e --- /dev/null +++ b/src/mattersim/forcefield/m3gnet/modules/message_passing.py @@ -0,0 +1,258 @@ +# -*- coding: utf-8 -*- +import torch +import torch.nn as nn +from torch_runstats.scatter import scatter + +from .layers import GatedMLP, LinearLayer, SigmoidLayer, SwishLayer + + +def polynomial(r: torch.Tensor, cutoff: float) -> torch.Tensor: + """ + Polynomial cutoff function + Args: + r (tf.Tensor): radius distance tensor + cutoff (float): cutoff distance + Returns: polynomial cutoff functions + """ + ratio = torch.div(r, cutoff) + result = ( + 1 + - 6 * torch.pow(ratio, 5) + + 15 * torch.pow(ratio, 4) + - 10 * torch.pow(ratio, 3) + ) + return torch.clamp(result, min=0.0) + + +class ThreeDInteraction(nn.Module): + def __init__( + self, + max_n, + max_l, + cutoff, + units, + spherecal_dim, + threebody_cutoff, + ): + super().__init__() + # self.sbf = SphericalBesselFunction( + # max_l=max_l, max_n=max_n, cutoff=cutoff, smooth=smooth) + # self.shf = SphericalHarmonicsFunction(max_l=max_l, use_phi=use_phi) + self.atom_mlp = SigmoidLayer(in_dim=units, out_dim=spherecal_dim) + # Linyu have modified the self.edge_gate_mlp + # by adding swish activation and use_bias=False + self.edge_gate_mlp = GatedMLP( + in_dim=spherecal_dim, + out_dims=[units], + activation="swish", + use_bias=False, # noqa: E501 + ) + self.cutoff = cutoff + self.threebody_cutoff = threebody_cutoff + + def forward( + self, + edge_attr, + three_basis, + atom_attr, + edge_index, + three_body_index, + edge_length, + num_edges, + num_triple_ij, + ): + atom_mask = ( + self.atom_mlp(atom_attr)[edge_index[0][three_body_index[:, 1]]] + * polynomial( + edge_length[three_body_index[:, 0]], self.threebody_cutoff # noqa: E501 + ) + * polynomial( + edge_length[three_body_index[:, 1]], self.threebody_cutoff # noqa: E501 + ) + ) + three_basis = three_basis * atom_mask + index_map = torch.arange(torch.sum(num_edges).item()).to( + edge_length.device + ) # noqa: E501 + index_map = torch.repeat_interleave(index_map, num_triple_ij).to( + edge_length.device + ) + e_ij_tuda = scatter( + three_basis, + index_map, + dim=0, + reduce="sum", + dim_size=torch.sum(num_edges).item(), + ) + edge_attr_prime = edge_attr + self.edge_gate_mlp(e_ij_tuda) + return edge_attr_prime + + +class AtomLayer(nn.Module): + """ + v_i'=v_i+sum(phi(v+i,v_j,e_ij',u)W*e_ij^0) + """ + + def __init__( + self, + atom_attr_dim, + edge_attr_dim, + spherecal_dim, + ): + super().__init__() + self.gated_mlp = GatedMLP( + in_dim=2 * atom_attr_dim + spherecal_dim, + out_dims=[128, 64, atom_attr_dim], # noqa: E501 + ) # [2*atom_attr_dim+edge_attr_prime_dim] -> [atom_attr_dim] + self.edge_layer = LinearLayer( + in_dim=edge_attr_dim, out_dim=1 + ) # [atom_attr_dim] -> [1] + + def forward( + self, + atom_attr, + edge_attr, + edge_index, + edge_attr_prime, # [sum(num_edges),edge_attr_dim] + num_atoms, # [batch_size] + ): + feat = torch.concat( + [ + atom_attr[edge_index[0]], + atom_attr[edge_index[1]], + edge_attr_prime, + ], # noqa: E501 + dim=1, + ) + atom_attr_prime = self.gated_mlp(feat) * self.edge_layer(edge_attr) + atom_attr_prime = scatter( + atom_attr_prime, + edge_index[1], + dim=0, + dim_size=torch.sum(num_atoms).item(), # noqa: E501 + ) + return atom_attr_prime + atom_attr + + +class EdgeLayer(nn.Module): + """e_ij'=e_ij+phi(v_i,v_j,e_ij,u)W*e_ij^0""" + + def init( + self, + atom_attr_dim, + edge_attr_dim, + spherecal_dim, + ): + super().__init__() + self.gated_mlp = GatedMLP( + in_dim=2 * atom_attr_dim + spherecal_dim, + out_dims=[128, 64, edge_attr_dim], # noqa: E501 + ) + self.edge_layer = LinearLayer(in_dim=edge_attr_dim, out_dim=1) + + def forward( + self, + atom_attr, + edge_attr, + edge_index, + edge_attr_prime, # [sum(num_edges),edge_attr_dim] + ): + feat = torch.concat( + [ + atom_attr[edge_index[0]], + atom_attr[edge_index[1]], + edge_attr_prime, + ], # noqa: E501 + dim=1, + ) + edge_attr_prime = self.gated_mlp(feat) * self.edge_layer(edge_attr) + return edge_attr_prime + edge_attr + + +class MainBlock(nn.Module): + """ + MainBlock for Message Passing in M3GNet + """ + + def __init__( + self, + max_n, + max_l, + cutoff, + units, + spherical_dim, + threebody_cutoff, + ): + super().__init__() + self.gated_mlp_atom = GatedMLP( + in_dim=2 * units + units, + out_dims=[units, units], + activation="swish", # noqa: E501 + ) # [2*atom_attr_dim+edge_attr_prime_dim] -> [units] + self.edge_layer_atom = SwishLayer( + in_dim=spherical_dim, out_dim=units, bias=False # noqa: E501 + ) # [spherecal_dim] -> [units] + self.gated_mlp_edge = GatedMLP( + in_dim=2 * units + units, + out_dims=[units, units], + activation="swish", # noqa: E501 + ) # [2*atom_attr_dim+edge_attr_prime_dim] -> [units] + self.edge_layer_edge = LinearLayer( + in_dim=spherical_dim, out_dim=units, bias=False + ) # [spherecal_dim] -> [units] + self.three_body = ThreeDInteraction( + max_n, max_l, cutoff, units, max_n * max_l, threebody_cutoff + ) + + def forward( + self, + atom_attr, + edge_attr, + edge_attr_zero, + edge_index, + three_basis, + three_body_index, + edge_length, + num_edges, + num_triple_ij, + num_atoms, + ): + # threebody interaction + edge_attr = self.three_body( + edge_attr, + three_basis, + atom_attr, + edge_index, + three_body_index, + edge_length, + num_edges, + num_triple_ij.view(-1), + ) + + # update bond feature + feat = torch.concat( + [atom_attr[edge_index[0]], atom_attr[edge_index[1]], edge_attr], + dim=1, # noqa: E501 + ) + edge_attr = edge_attr + self.gated_mlp_edge( + feat + ) * self.edge_layer_edge( # noqa: E501 + edge_attr_zero + ) + + # update atom feature + feat = torch.concat( + [atom_attr[edge_index[0]], atom_attr[edge_index[1]], edge_attr], + dim=1, # noqa: E501 + ) + atom_attr_prime = self.gated_mlp_atom(feat) * self.edge_layer_atom( + edge_attr_zero + ) + atom_attr = atom_attr + scatter( # noqa: E501 + atom_attr_prime, + edge_index[0], + dim=0, + dim_size=torch.sum(num_atoms).item(), # noqa: E501 + ) + + return atom_attr, edge_attr diff --git a/src/mattersim/forcefield/m3gnet/scaling.py b/src/mattersim/forcefield/m3gnet/scaling.py new file mode 100644 index 0000000..532de8a --- /dev/null +++ b/src/mattersim/forcefield/m3gnet/scaling.py @@ -0,0 +1,335 @@ +# -*- coding: utf-8 -*- +""" +Atomic scaling module. Used for predicting extensive properties. +""" + +from typing import Optional, Union + +import numpy as np +import torch +import torch.nn as nn +from ase import Atoms +from torch_runstats.scatter import scatter_mean + +from mattersim.datasets.utils.regressor import solver + +DATA_INDEX = { + "total_energy": 0, + "forces": 2, + "per_atom_energy": 1, + "per_species_energy": 0, +} + + +class AtomScaling(nn.Module): + """ + Atomic extensive property rescaling module + """ + + def __init__( + self, + atoms: list[Atoms] = None, + total_energy: list[float] = None, + forces: list[np.ndarray] = None, + atomic_numbers: list[np.ndarray] = None, + num_atoms: list[float] = None, + max_z: int = 94, + scale_key: str = None, + shift_key: str = None, + init_scale: Union[torch.Tensor, float] = None, + init_shift: Union[torch.Tensor, float] = None, + trainable_scale: bool = False, + trainable_shift: bool = False, + verbose: bool = False, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + **kwargs, + ): + """ + Args: + forces: a list of atomic forces (np.ndarray) in each graph + max_z: (int) maximum atomic number + - if scale_key or shift_key is specified, + max_z should be equal to the maximum atomic_number. + scale_key: valid options are: + - total_energy_std + - per_atom_energy_std + - per_species_energy_std + - forces_rms + - per_species_forces_rms (default) + shift_key: valid options are: + - total_energy_mean + - per_atom_energy_mean + - per_species_energy_mean : + default option is gaussian regression (NequIP) + - per_species_energy_mean_linear_reg : + an alternative choice is linear regression (M3GNet) + init_scale (torch.Tensor or float) + init_shift (torch.Tensor or float) + """ + super().__init__() + + self.max_z = max_z + self.device = device + + # === Data preprocessing === + if scale_key or shift_key: + total_energy = torch.from_numpy(np.array(total_energy)) + forces = ( + torch.from_numpy(np.concatenate(forces, axis=0)) + if forces is not None + else None + ) + if atomic_numbers is None: + atomic_numbers = [atom.get_atomic_numbers() for atom in atoms] + atomic_numbers = ( + torch.from_numpy(np.concatenate(atomic_numbers, axis=0)) + .squeeze(-1) + .long() + ) # (num_atoms,) + # assert max_z == atomic_numbers.max().item(), + # "max_z should be equal to the maximum atomic_number" + if num_atoms is None: + num_atoms = [ # noqa: E501 + atom.positions.shape[0] for atom in atoms + ] # (N_GRAPHS, ) + num_atoms = torch.from_numpy(np.array(num_atoms)) + per_atom_energy = total_energy / num_atoms + data_list = [total_energy, per_atom_energy, forces] + + assert ( + num_atoms.size()[0] == total_energy.size()[0] + ), "num_atoms and total_energy should have the same size, " + f"but got {num_atoms.size()[0]} and {total_energy.size()[0]}" + if forces is not None: + assert ( + forces.size()[0] == atomic_numbers.size()[0] + ), "forces and atomic_numbers should have the same length, " + f"but got {forces.size()[0]} and {atomic_numbers.size()[0]}" + + # === Calculate the scaling factors === + if ( + scale_key == "per_species_energy_std" + and shift_key == "per_species_energy_mean" + and init_shift is None + and init_scale is None + ): + # Using gaussian regression two times + # to get the shift and scale is potentially unstable + init_shift, init_scale = self.get_gaussian_statistics( + atomic_numbers, num_atoms, total_energy + ) + else: + if shift_key and init_shift is None: + init_shift = self.get_statistics( + shift_key, max_z, data_list, atomic_numbers, num_atoms + ) + if scale_key and init_scale is None: + init_scale = self.get_statistics( + scale_key, max_z, data_list, atomic_numbers, num_atoms + ) + + # === initial values are given === + if init_scale is None: + init_scale = torch.ones(max_z + 1) + elif isinstance(init_scale, float): + init_scale = torch.tensor(init_scale).repeat(max_z + 1) + else: + assert init_scale.size()[0] == max_z + 1 + + if init_shift is None: + init_shift = torch.zeros(max_z + 1) + elif isinstance(init_shift, float): + init_shift = torch.tensor(init_shift).repeat(max_z + 1) + else: + assert init_shift.size()[0] == max_z + 1 + + init_shift = init_shift.float() + init_scale = init_scale.float() + if trainable_scale is True: + self.scale = torch.nn.Parameter(init_scale) + else: + self.register_buffer("scale", init_scale) + + if trainable_shift is True: + self.shift = torch.nn.Parameter(init_shift) + else: + self.register_buffer("shift", init_shift) + + if verbose is True: + print("Current scale: ", init_scale) + print("Current shift: ", init_shift) + + self.to(device) + + def transform( + self, atomic_energies: torch.Tensor, atomic_numbers: torch.Tensor + ) -> torch.Tensor: + """ + Take the origin values from model and get the transformed values + """ + curr_shift = self.shift[atomic_numbers] + curr_scale = self.scale[atomic_numbers] + normalized_energies = curr_scale * atomic_energies + curr_shift + return normalized_energies + + def inverse_transform( + self, atomic_energies: torch.Tensor, atomic_numbers: torch.Tensor + ) -> torch.Tensor: + """ + Take the transformed values and get the original values + """ + curr_shift = self.shift[atomic_numbers] + curr_scale = self.scale[atomic_numbers] + unnormalized_energies = (atomic_energies - curr_shift) / curr_scale + return unnormalized_energies + + def forward( + self, atomic_energies: torch.Tensor, atomic_numbers: torch.Tensor + ) -> torch.Tensor: + """ + Atomic_energies and atomic_numbers should have the same size + """ + return self.transform(atomic_energies, atomic_numbers) + + def get_statistics( + self, key, max_z, data_list, atomic_numbers, num_atoms + ) -> torch.Tensor: + """ + Valid key: + scale_key: valid options are: + - total_energy_mean + - per_atom_energy_mean + - per_species_energy_mean + - per_species_energy_mean_linear_reg : + an alternative choice is linear regression + shift_key: valid options are: + - total_energy_std + - per_atom_energy_std + - per_species_energy_std + - forces_rms + - per_species_forces_rms + """ + data = None + for data_key in DATA_INDEX: + if data_key in key: + data = data_list[DATA_INDEX[data_key]] + assert data is not None + + statistics = None + if "mean" in key: + if "per_species" in key: + n_atoms = torch.repeat_interleave(repeats=num_atoms) + if "linear_reg" in key: + features = bincount( + atomic_numbers, n_atoms, minlength=self.max_z + 1 + ).numpy() + # print(features[0], features.shape) + data = data.numpy() + assert features.ndim == 2 # [batch, n_type] + features = features[ + (features > 0).any(axis=1) + ] # deal with non-contiguous batch indexes + statistics = np.linalg.pinv(features.T.dot(features)).dot( + features.T.dot(data) + ) + statistics = torch.from_numpy(statistics) + else: + N = bincount( + atomic_numbers, + num_atoms, + minlength=self.max_z + 1, # noqa: E501 + ) + assert N.ndim == 2 # [batch, n_type] + # deal with non-contiguous batch indexes + N = N[(N > 0).any(dim=1)] + N = N.type(torch.get_default_dtype()) + statistics, _ = solver( + N, data, regressor="NormalizedGaussianProcess" + ) + else: + statistics = torch.mean(data).item() + elif "std" in key: + if "per_species" in key: + print( + "Warning: calculating per_species_energy_std for " + "full periodic table systems is risky, please use " + "per_species_forces_rms instead." + ) + n_atoms = torch.repeat_interleave(repeats=num_atoms) + N = bincount(atomic_numbers, n_atoms, minlength=self.max_z + 1) + assert N.ndim == 2 # [batch, n_type] + # deal with non-contiguous batch indexes + N = N[(N > 0).any(dim=1)] + N = N.type(torch.get_default_dtype()) + _, statistics = solver( # noqa: E501 + N, data, regressor="NormalizedGaussianProcess" + ) + else: + statistics = torch.std(data).item() + elif "rms" in key: + if "per_species" in key: + square = scatter_mean( + data.square(), atomic_numbers, dim=0, dim_size=max_z + 1 + ) + statistics = square.mean(axis=-1) + else: + statistics = torch.sqrt(torch.mean(data.square())).item() + + if isinstance(statistics, torch.Tensor) is not True: + statistics = torch.tensor(statistics).repeat(max_z + 1) + + assert statistics.size()[0] == max_z + 1 + + return statistics + + def get_gaussian_statistics( + self, + atomic_numbers: torch.Tensor, + num_atoms: torch.Tensor, + total_energy: torch.Tensor, + ): + """ + Get the gaussian process mean and variance + """ + n_atoms = torch.repeat_interleave(repeats=num_atoms) + N = bincount(atomic_numbers, n_atoms, minlength=self.max_z + 1) + assert N.ndim == 2 # [batch, n_type] + N = N[(N > 0).any(dim=1)] # deal with non-contiguous batch indexes + N = N.type(torch.get_default_dtype()) + mean, std = solver( # noqa: E501 + N, total_energy, regressor="NormalizedGaussianProcess" + ) + assert mean.size()[0] == self.max_z + 1 + assert std.size()[0] == self.max_z + 1 + return mean, std + + +def bincount( + input: torch.Tensor, + batch: Optional[torch.Tensor] = None, + minlength: int = 0, # noqa: E501 +): + assert input.ndim == 1 + if batch is None: + return torch.bincount(input, minlength=minlength) + else: + assert batch.shape == input.shape + + length = input.max().item() + 1 + if minlength == 0: + minlength = length + if length > minlength: + raise ValueError( + f"minlength {minlength} too small for input " + f"with integers up to and including {length}" # noqa: E501 + ) + + # Flatten indexes + # Make each "class" in input into a per-input class. + input_ = input + batch * minlength + + num_batch = batch.max() + 1 + + return torch.bincount(input_, minlength=minlength * num_batch).reshape( + num_batch, minlength + ) diff --git a/src/mattersim/forcefield/potential.py b/src/mattersim/forcefield/potential.py new file mode 100644 index 0000000..dfe4bf0 --- /dev/null +++ b/src/mattersim/forcefield/potential.py @@ -0,0 +1,1358 @@ +# -*- coding: utf-8 -*- +""" +Potential +""" +import os +import pickle +import random +import time +import warnings +from typing import Dict, List, Optional + +import numpy as np +import torch +import torch.distributed +import torch.nn as nn +from ase import Atoms +from ase.calculators.calculator import Calculator +from ase.constraints import full_3x3_to_voigt_6_stress +from torch.optim import Adam +from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR +from torch_ema import ExponentialMovingAverage +from torch_geometric.loader import DataLoader +from torchmetrics import MeanMetric + +from mattersim.datasets.utils.build import build_dataloader +from mattersim.forcefield.m3gnet.m3gnet import M3Gnet +from mattersim.forcefield.m3gnet.m3gnet_multi_head import M3Gnet_multi_head +from mattersim.jit_compile_tools.jit import compile_mode + + +@compile_mode("script") +class Potential(nn.Module): + """ + A wrapper class for the force field model + """ + + def __init__( + self, + model, + optimizer=None, + scheduler: str = "StepLR", + ema=None, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + allow_tf32=False, + **kwargs, + ): + """ + Args: + potential : a force field model + lr : learning rate + scheduler : a torch scheduler + normalizer : an energy normalization module + """ + super().__init__() + self.model = model + if optimizer is None: + self.optimizer = Adam( + self.model.parameters(), lr=kwargs.get("lr", 1e-3), eps=1e-7 + ) + else: + self.optimizer = optimizer + if not isinstance(scheduler, str): + self.scheduler = scheduler + elif scheduler == "StepLR": + step_size = kwargs.get("step_size", 10) + gamma = kwargs.get("gamma", 0.95) + self.scheduler = StepLR( + self.optimizer, step_size=step_size, gamma=gamma # noqa: E501 + ) + elif scheduler == "ReduceLROnPlateau": + factor = kwargs.get("factor", 0.8) + patience = kwargs.get("patience", 50) + self.scheduler = ReduceLROnPlateau( + self.optimizer, + mode="min", + factor=factor, + patience=patience, + verbose=False, + ) + else: + raise NotImplementedError + torch.backends.cuda.matmul.allow_tf32 = allow_tf32 + self.device = device + self.to(device) + + if ema is None: + self.ema = ExponentialMovingAverage( + self.model.parameters(), decay=kwargs.get("ema_decay", 0.99) + ) + else: + self.ema = ema + self.model_name = kwargs.get("model_name", "m3gnet") + self.validation_metrics = kwargs.get( + "validation_metrics", {"loss": 10000.0} # noqa: E501 + ) + self.last_epoch = kwargs.get("last_epoch", -1) + self.description = kwargs.get("description", "") + self.saved_name = ["loss", "MAE_energy", "MAE_force", "MAE_stress"] + self.best_metric = 10 + self.rank = None + + self.use_finetune_label_loss = kwargs.get("use_finetune_label_loss", False) + + def freeze_reset_model( + self, + finetune_layers: int = -1, + reset_head_for_finetune: bool = False, + ): + """ + Freeze the model in the fine-tuning process + """ + if finetune_layers == -1: + print("fine-tuning all layers") + elif finetune_layers >= 0 and finetune_layers < len( + self.model.node_head.unified_encoder_layers + ): + print(f"fine-tuning the last {finetune_layers} layers") + for name, param in self.model.named_parameters(): + param.requires_grad = False + + # for energy head + if finetune_layers > 0: + for name, param in self.model.node_head.unified_encoder_layers[ + -finetune_layers: + ].named_parameters(): + param.requires_grad = True + for ( + name, + param, + ) in self.model.node_head.unified_final_invariant_ln.named_parameters(): + param.requires_grad = True + for ( + name, + param, + ) in self.model.node_head.unified_output_layer.named_parameters(): + param.requires_grad = True + for name, param in self.model.layer_norm.named_parameters(): + param.requires_grad = True + for name, param in self.model.lm_head_transform_weight.named_parameters(): + param.requires_grad = True + for name, param in self.model.energy_out.named_parameters(): + param.requires_grad = True + if reset_head_for_finetune: + self.model.lm_head_transform_weight.reset_parameters() + self.model.energy_out.reset_parameters() + else: + raise ValueError( + "finetune_layers should be -1 or a positive integer,and less than the number of layers" # noqa: E501 + ) + + def finetune_mode( + self, + finetune_layers: int = -1, + finetune_head: nn.Module = None, + reset_head_for_finetune: bool = False, + finetune_task_mean: float = 0.0, + finetune_task_std: float = 1.0, + use_finetune_label_loss: bool = False, + ): + """ + Set the model to fine-tuning mode + finetune_layers: the layer to finetune, former layers will be frozen + if -1, all layers will be finetuned + finetune_head: the head to finetune + reset_head_for_finetune: whether to reset the original head + """ + if self.model_name not in ["graphormer", "geomformer"]: + print("Only graphormer and geomformer support freezing layers") + return + self.model.finetune_mode = True + if finetune_head is None: + print("No finetune head is provided, using the original energy head") + self.model.finetune_head = finetune_head + self.model.finetune_task_mean = finetune_task_mean + self.model.finetune_task_std = finetune_task_std + self.freeze_reset_model(finetune_layers, reset_head_for_finetune) + self.use_finetune_label_loss = use_finetune_label_loss + + def train_model( + self, + dataloader: Optional[list], + val_dataloader, + loss: torch.nn.modules.loss = torch.nn.MSELoss(), + include_energy: bool = True, + include_forces: bool = False, + include_stresses: bool = False, + force_loss_ratio: float = 1.0, + stress_loss_ratio: float = 0.1, + epochs: int = 100, + early_stop_patience: int = 100, + metric_name: str = "val_loss", + wandb=None, + save_checkpoint: bool = False, + save_path: str = "./results/", + ckpt_interval: int = 10, + multi_head: bool = False, + dataset_name_list: List[str] = None, + sampler=None, + is_distributed: bool = False, + need_to_load_data: bool = False, + **kwargs, + ): + """ + Train model + Args: + dataloader: training data loader + val_dataloader: validation data loader + loss (torch.nn.modules.loss): loss object + include_energy (bool) : whether to use energy as + optimization targets + include_forces (bool) : whether to use forces as + optimization targets + include_stresses (bool) : whether to use stresses as + optimization targets + force_loss_ratio (float): the ratio of forces in loss + stress_loss_ratio (float): the ratio of stress in loss + ckpt_interval (int): the interval to save checkpoints + early_stop_patience (int): the patience for early stopping + metric_name (str): the metric used for saving `best` checkpoints + and early stopping supported metrics: + `val_loss`, `val_mae_e`, + `val_mae_f`, `val_mae_s` + sampler: used in distributed training + is_distributed: whether to use DistributedDataParallel + need_to_load_data: whether to load data from disk + + """ + self.idx = ["val_loss", "val_mae_e", "val_mae_f", "val_mae_s"].index( + metric_name + ) + if is_distributed: + self.rank = torch.distributed.get_rank() + print( + f"Number of trainable parameters: {sum(p.numel() for p in self.model.parameters() if p.requires_grad):,}" # noqa: E501 + ) + for epoch in range(self.last_epoch + 1, epochs): + print(f"Epoch: {epoch} / {epochs}") + if not multi_head: + if need_to_load_data: + assert isinstance(dataloader, list) + random.Random(kwargs.get("seed", 42) + epoch).shuffle( # noqa: E501 + dataloader + ) + for idx, data_path in enumerate(dataloader): + with open(data_path, "rb") as f: + start = time.time() + train_data = pickle.load(f) + print( + f"TRAIN: loading {data_path.split('/')[-2]}" + f"/{data_path.split('/')[-1]} dataset with " + f"{len(train_data)} data points, " + f"{len(train_data)} data points in total, " + f"time: {time.time() - start}" # noqa: E501 + ) + # Distributed Sampling + atoms_train_sampler = ( + torch.utils.data.distributed.DistributedSampler( + train_data, + seed=kwargs.get("seed", 42) + + idx * 131 + + epoch, # noqa: E501 + ) + ) + train_dataloader = DataLoader( + train_data, + batch_size=kwargs.get("batch_size", 32), + shuffle=(atoms_train_sampler is None), + num_workers=0, + sampler=atoms_train_sampler, + ) + self.train_one_epoch( + train_dataloader, + epoch, + loss, + include_energy, + include_forces, + include_stresses, + force_loss_ratio, + stress_loss_ratio, + wandb, + is_distributed, + mode="train", + **kwargs, + ) + del train_dataloader + del train_data + torch.cuda.empty_cache() + else: + self.train_one_epoch( + dataloader, + epoch, + loss, + include_energy, + include_forces, + include_stresses, + force_loss_ratio, + stress_loss_ratio, + wandb, + is_distributed, + mode="train", + **kwargs, + ) + metric = self.train_one_epoch( + val_dataloader, + epoch, + loss, + include_energy, + include_forces, + include_stresses, + force_loss_ratio, + stress_loss_ratio, + wandb, + is_distributed, + mode="val", + **kwargs, + ) + else: + assert dataset_name_list is not None + assert ( + need_to_load_data is False + ), "load_training_data is not supported for multi-head training" # noqa: E501 + self.train_one_epoch_multi_head( + dataloader, + dataset_name_list, + epoch, + loss, + include_energy, + include_forces, + include_stresses, + force_loss_ratio, + stress_loss_ratio, + wandb, + mode="train", + **kwargs, + ) + metric = self.train_one_epoch_multi_head( + val_dataloader, + dataset_name_list, + epoch, + loss, + include_energy, + include_forces, + include_stresses, + force_loss_ratio, + stress_loss_ratio, + wandb, + mode="val", + **kwargs, + ) + + if isinstance(self.scheduler, ReduceLROnPlateau): + self.scheduler.step(metric) + else: + self.scheduler.step() + + self.last_epoch = epoch + + self.validation_metrics = { + "loss": metric[0], + "MAE_energy": metric[1], + "MAE_force": metric[2], + "MAE_stress": metric[3], + } + if is_distributed: + # TODO 添加distributed训练早停 + if self.save_model_ddp( + epoch, + early_stop_patience, + save_path, + metric_name, + save_checkpoint, + metric, + ckpt_interval, + ): + break + else: + # return True时为早停 + if self.save_model( + epoch, + early_stop_patience, + save_path, + metric_name, + save_checkpoint, + metric, + ckpt_interval, + ): + break + + def save_model( + self, + epoch, + early_stop_patience, + save_path, + metric_name, + save_checkpoint, + metric, + ckpt_interval, + ): + with self.ema.average_parameters(): + try: + best_model = torch.load( + os.path.join(save_path, "best_model.pth") # noqa: E501 + ) + assert metric_name in [ + "val_loss", + "val_mae_e", + "val_mae_f", + "val_mae_s", + ], ( + f"`{metric_name}` metric name not supported. " + "supported metrics: `val_loss`, `val_mae_e`, " + "`val_mae_f`, `val_mae_s`" + ) + + if ( + save_checkpoint is True + and metric[self.idx] + < best_model["validation_metrics"][ + self.saved_name[self.idx] + ] # noqa: E501 + ): + self.save(os.path.join(save_path, "best_model.pth")) + if epoch > best_model["last_epoch"] + early_stop_patience: + print("Early stopping") + return True + del best_model + except BaseException: + if save_checkpoint is True: + self.save(os.path.join(save_path, "best_model.pth")) + + if save_checkpoint is True and epoch % ckpt_interval == 0: + self.save(os.path.join(save_path, f"ckpt_{epoch}.pth")) + if save_checkpoint is True: + self.save(os.path.join(save_path, "last_model.pth")) + return False + + def save_model_ddp( + self, + epoch, + early_stop_patience, + save_path, + metric_name, + save_checkpoint, + metric, + ckpt_interval, + ): + with self.ema.average_parameters(): + assert metric_name in [ + "val_loss", + "val_mae_e", + "val_mae_f", + "val_mae_s", + ], ( # noqa: E501 + f"`{metric_name}` metric name not supported. " + "supported metrics: `val_loss`, `val_mae_e`, " + "`val_mae_f`, `val_mae_s`" + ) + # Loading on multiple GPUs is too time consuming, + # so this operation should not be performed. + # Only save the model on GPU 0, + # the model on each GPU should be exactly the same. + + if metric[self.idx] < self.best_metric: + self.best_metric = metric[self.idx] + if save_checkpoint and self.rank == 0: + self.save(os.path.join(save_path, "best_model.pth")) + if self.rank == 0 and save_checkpoint: + if epoch % ckpt_interval == 0: + self.save(os.path.join(save_path, f"ckpt_{epoch}.pth")) + self.save(os.path.join(save_path, "last_model.pth")) + # torch.distributed.barrier() + return False + + def test_model( + self, + val_dataloader, + loss: torch.nn.modules.loss = torch.nn.MSELoss(), + include_energy: bool = True, + include_forces: bool = False, + include_stresses: bool = False, + wandb=None, + multi_head: bool = False, + **kwargs, + ): + """ + Test model performance on a given dataset + """ + if not multi_head: + return self.train_one_epoch( + val_dataloader, + 1, + loss, + include_energy, + include_forces, + include_stresses, + 1.0, + 0.1, + wandb=wandb, + mode="val", + ) + else: + return self.train_one_epoch_multi_head( + val_dataloader, + kwargs["dataset_name_list"], + 1, + loss, + include_energy, + include_forces, + include_stresses, + 1.0, + 0.1, + wandb=wandb, + mode="val", + ) + + def predict_properties( + self, + dataloader, + include_forces: bool = False, + include_stresses: bool = False, + **kwargs, + ): + """ + Predict properties (e.g., energies, forces) given a well-trained model + Return: results tuple + - results[0] (list[float]): a list of energies + - results[1] (list[np.ndarray]): a list of atomic forces + - results[2] (list[np.ndarray]): a list of stresses + """ + self.model.eval() + energies = [] + forces = [] + stresses = [] + for batch_idx, graph_batch in enumerate(dataloader): + if self.model_name == "graphormer" or self.model_name == "geomformer": + raise NotImplementedError + else: + graph_batch.to(self.device) + input = batch_to_dict(graph_batch) + result = self.forward( + input, + include_forces=include_forces, + include_stresses=include_stresses, # noqa: E501 + ) + if self.model_name == "graphormer" or self.model_name == "geomformer": + raise NotImplementedError + else: + energies.extend(result["total_energy"].cpu().tolist()) + if include_forces: + forces_tuple = torch.split( + result["forces"].cpu().detach(), + graph_batch.num_atoms.cpu().tolist(), + dim=0, + ) + for atomic_force in forces_tuple: + forces.append(np.array(atomic_force)) + if include_stresses: + stresses.extend(list(result["stresses"].cpu().detach().numpy())) + + return (energies, forces, stresses) + + # ============================ + + def train_one_epoch( + self, + dataloader, + epoch, + loss, + include_energy, + include_forces, + include_stresses, + loss_f, + loss_s, + wandb, + is_distributed=False, + mode="train", + log=True, + **kwargs, + ): + start_time = time.time() + loss_avg = MeanMetric().to(self.device) + train_e_mae = MeanMetric().to(self.device) + train_f_mae = MeanMetric().to(self.device) + train_s_mae = MeanMetric().to(self.device) + + # scaler = torch.cuda.amp.GradScaler() + + if mode == "train": + self.model.train() + elif mode == "val": + self.model.eval() + + for batch_idx, graph_batch in enumerate(dataloader): + if self.model_name == "graphormer" or self.model_name == "geomformer": + raise NotImplementedError + else: + graph_batch.to(self.device) + input = batch_to_dict(graph_batch) + if mode == "train": + result = self.forward( + input, + include_forces=include_forces, + include_stresses=include_stresses, + ) + elif mode == "val": + with self.ema.average_parameters(): + result = self.forward( + input, + include_forces=include_forces, + include_stresses=include_stresses, + ) + + loss_, e_mae, f_mae, s_mae = self.loss_calc( + graph_batch, + result, + loss, + include_energy, + include_forces, + include_stresses, + loss_f, + loss_s, + ) + + # loss backward + if mode == "train": + self.optimizer.zero_grad() + loss_.backward() + nn.utils.clip_grad_norm_( + self.model.parameters(), 1.0, norm_type=2 # noqa: E501 + ) + self.optimizer.step() + # scaler.scale(loss_).backward() + # scaler.step(self.optimizer) + # scaler.update() + self.ema.update() + + loss_avg.update(loss_.detach()) + if include_energy: + train_e_mae.update(e_mae.detach()) + if include_forces: + train_f_mae.update(f_mae.detach()) + if include_stresses: + train_s_mae.update(s_mae.detach()) + + loss_avg_ = loss_avg.compute().item() + if include_energy: + e_mae = train_e_mae.compute().item() + else: + e_mae = 0 + if include_forces: + f_mae = train_f_mae.compute().item() + else: + f_mae = 0 + if include_stresses: + s_mae = train_s_mae.compute().item() + else: + s_mae = 0 + + if log: + print( + "%s: Loss: %.4f, MAE(e): %.4f, MAE(f): %.4f, MAE(s): %.4f, Time: %.2fs, lr: %.8f\n" # noqa: E501 + % ( + mode, + loss_avg.compute().item(), + e_mae, + f_mae, + s_mae, + time.time() - start_time, + self.scheduler.get_last_lr()[0], + ), + end="", + ) + + if wandb and ((not is_distributed) or self.rank == 0): + wandb.log( + { + f"{mode}/loss": loss_avg_, + f"{mode}/mae_e": e_mae, + f"{mode}/mae_f": f_mae, + f"{mode}/mae_s": s_mae, + f"{mode}/lr": self.scheduler.get_last_lr()[0], + f"{mode}/mae_tot": e_mae + f_mae + s_mae, + }, + step=epoch, + ) + + if mode == "val": + return (loss_avg_, e_mae, f_mae, s_mae) + + def train_one_epoch_multi_head( + self, + dataloader_list, + dataset_name_list, + epoch, + loss, + include_energy=True, + include_forces=False, + include_stresses=False, + loss_f=1.0, + loss_s=0.1, + wandb=None, + mode="train", + **kwargs, + ): + start_time = time.time() + + metrics = {} + for dataset_name in dataset_name_list: + metrics_ = {} + metrics_["loss_avg"] = MeanMetric().to(self.device) + metrics_["train_e_mae"] = MeanMetric().to(self.device) + metrics_["train_f_mae"] = MeanMetric().to(self.device) + metrics_["train_s_mae"] = MeanMetric().to(self.device) + metrics[dataset_name] = metrics_ + + dataloader_iter = [ + dataloader.__iter__() for dataloader in dataloader_list # noqa: E501 + ] + if mode == "train": + self.model.train() + elif mode == "val": + self.model.eval() + + dataloader_len = [len(dataloader) for dataloader in dataloader_list] + for i in range(1, len(dataloader_len)): + dataloader_len[i] += dataloader_len[i - 1] + idx_list = list(range(dataloader_len[-1])) + random.shuffle(idx_list) + + for idx in idx_list: + for dataset_idx, bound in enumerate(dataloader_len): + if idx < bound: + break + + graph_batch = dataloader_iter[dataset_idx].__next__() + graph_batch.to(self.device) + input = batch_to_dict(graph_batch) + dataset_name = dataset_name_list[dataset_idx] + + if mode == "train": + result = self.forward( + input, + include_forces=include_forces, + include_stresses=include_stresses, + dataset_idx=dataset_idx, + ) + elif mode == "val": + with self.ema.average_parameters(): + result = self.forward( + input, + include_forces=include_forces, + include_stresses=include_stresses, + dataset_idx=dataset_idx, + ) + + loss_, e_mae, f_mae, s_mae = self.loss_calc( + graph_batch, + result, + loss, + include_energy, + include_forces, + include_stresses, + loss_f, + loss_s, + ) + + # loss backward + if mode == "train": + self.optimizer.zero_grad() + loss_.backward() + nn.utils.clip_grad_norm_( + self.model.parameters(), 1.0, norm_type=2 # noqa: E501 + ) + self.optimizer.step() + self.ema.update() + + metrics[dataset_name]["loss_avg"].update(loss_.detach()) + if include_energy: + metrics[dataset_name]["train_e_mae"].update(e_mae.detach()) + if include_forces: + metrics[dataset_name]["train_f_mae"].update(f_mae.detach()) + if include_stresses: + metrics[dataset_name]["train_s_mae"].update(s_mae.detach()) + + loss_all = 0 + e_mae = 0 + f_mae = 0 + s_mae = 0 + for dataset_name in dataset_name_list: + train_f_mae = train_s_mae = 0 + loss_avg = metrics[dataset_name]["loss_avg"].compute().item() + loss_all += loss_avg + if include_energy: + train_e_mae = metrics[dataset_name]["train_e_mae"].compute().item() + e_mae += train_e_mae + if include_forces and (dataset_name != "QM9"): + train_f_mae = ( + metrics[dataset_name]["train_f_mae"].compute().item() + ) # noqa: E501 + f_mae += train_f_mae + if include_stresses: + train_s_mae = ( + metrics[dataset_name]["train_s_mae"].compute().item() + ) # noqa: E501 + s_mae += train_s_mae + + print( + "%s %s: Loss: %.4f, MAE(e): %.4f, MAE(f): %.4f, MAE(s): %.4f, Time: %.2fs" # noqa: E501 + % ( + dataset_name, + mode, + loss_avg, + train_e_mae, + train_f_mae, + train_s_mae, + time.time() - start_time, + ) + ) + + if wandb: + wandb.log( + { + f"{dataset_name}/{mode}_loss": loss_avg, + f"{dataset_name}/{mode}_mae_e": train_e_mae, + f"{dataset_name}/{mode}_mae_f": train_f_mae, + f"{dataset_name}/{mode}_mae_s": train_s_mae, + }, + step=epoch, + ) + + if wandb: + wandb.log({"lr": self.scheduler.get_last_lr()[0]}, step=epoch) + + if mode == "val": + return (loss_all, e_mae, f_mae, s_mae) + + def loss_calc( + self, + graph_batch, + result, + loss, + include_energy, + include_forces, + include_stresses, + loss_f=1.0, + loss_s=0.1, + ): + e_mae = 0.0 + f_mae = 0.0 + s_mae = 0.0 + loss_ = torch.tensor(0.0, device=self.device, requires_grad=True) + + if self.model_name == "graphormer" or self.model_name == "geomformer": + raise NotImplementedError + else: + if include_energy: + e_gt = graph_batch.energy / graph_batch.num_atoms + e_pred = result["total_energy"] / graph_batch.num_atoms + loss_ = loss_ + loss(e_pred, e_gt) + e_mae = torch.nn.L1Loss()(e_pred, e_gt) + if include_forces: + f_gt = graph_batch.forces + f_pred = result["forces"] + loss_ = loss_ + loss(f_pred, f_gt) * loss_f + f_mae = torch.nn.L1Loss()(f_pred, f_gt) + # f_mae = torch.mean(torch.abs(f_pred - f_gt)).item() + if include_stresses: + s_gt = graph_batch.stress + s_pred = result["stresses"] + loss_ = loss_ + loss(s_pred, s_gt) * loss_s + s_mae = torch.nn.L1Loss()(s_pred, s_gt) + # s_mae = torch.mean(torch.abs((s_pred - s_gt))).item() + return loss_, e_mae, f_mae, s_mae + + def get_properties( + self, + graph_batch, + include_forces: bool = True, + include_stresses: bool = True, + **kwargs, + ): + """ + get energy, force and stress from a list of graph + Args: + graph_batch: + include_forces (bool): whether to include force + include_stresses (bool): whether to include stress + Returns: + results: a tuple, which consists of energies, forces and stress + """ + warnings.warn( + "This interface (get_properties) has been deprecated. " + "Please use Potential.forward(input, include_forces, " + "include_stresses) instead.", + DeprecationWarning, + ) + if self.model_name == "graphormer" or self.model_name == "geomformer": + raise NotImplementedError + else: + graph_batch.to(self.device) + input = batch_to_dict(graph_batch) + result = self.forward( + input, + include_forces=include_forces, + include_stresses=include_stresses, + **kwargs, + ) + # Warning: tuple + if not include_forces and not include_stresses: + return (result["total_energy"],) + elif include_forces and not include_stresses: + return (result["total_energy"], result["forces"]) + elif include_forces and include_stresses: + return (result["total_energy"], result["forces"], result["stresses"]) + + def forward( + self, + input: Dict[str, torch.Tensor], + include_forces: bool = True, + include_stresses: bool = True, + dataset_idx: int = -1, + ) -> Dict[str, torch.Tensor]: + """ + get energy, force and stress from a list of graph + Args: + input: a dictionary contains all necessary info. + The `batch_to_dict` method could convert a graph_batch from + pyg dataloader to the input dictionary. + include_forces (bool): whether to include force + include_stresses (bool): whether to include stress + dataset_idx (int): used for multi-head model, set to -1 by default + Returns: + results: a dictionary, which consists of energies, + forces and stresses + """ + output = {} + if self.model_name == "graphormer" or self.model_name == "geomformer": + raise NotImplementedError + else: + strain = torch.zeros_like(input["cell"], device=self.device) + volume = torch.linalg.det(input["cell"]) + if include_forces is True: + input["atom_pos"].requires_grad_(True) + if include_stresses is True: + strain.requires_grad_(True) + input["cell"] = torch.matmul( + input["cell"], + (torch.eye(3, device=self.device)[None, ...] + strain), + ) + strain_augment = torch.repeat_interleave( + strain, input["num_atoms"], dim=0 + ) + input["atom_pos"] = torch.einsum( + "bi, bij -> bj", + input["atom_pos"], + (torch.eye(3, device=self.device)[None, ...] + strain_augment), + ) + volume = torch.linalg.det(input["cell"]) + + energies = self.model.forward(input, dataset_idx) + output["total_energy"] = energies + + # Only take first derivative if only force is required + if include_forces is True and include_stresses is False: + grad_outputs: List[Optional[torch.Tensor]] = [ + torch.ones_like( + energies, + ) + ] + grad = torch.autograd.grad( + outputs=[ + energies, + ], + inputs=[input["atom_pos"]], + grad_outputs=grad_outputs, + create_graph=self.model.training, + ) + + # Dump out gradient for forces + force_grad = grad[0] + if force_grad is not None: + forces = torch.neg(force_grad) + output["forces"] = forces + + # Take derivatives up to second order + # if both forces and stresses are required + if include_forces is True and include_stresses is True: + grad_outputs: List[Optional[torch.Tensor]] = [ + torch.ones_like( + energies, + ) + ] + grad = torch.autograd.grad( + outputs=[ + energies, + ], + inputs=[input["atom_pos"], strain], + grad_outputs=grad_outputs, + create_graph=self.model.training, + ) + + # Dump out gradient for forces and stresses + force_grad = grad[0] + stress_grad = grad[1] + + if force_grad is not None: + forces = torch.neg(force_grad) + output["forces"] = forces + + if stress_grad is not None: + stresses = 1 / volume[:, None, None] * stress_grad * 160.21766208 + output["stresses"] = stresses + + return output + + def save(self, save_path): + dir_name = os.path.dirname(save_path) + if not os.path.exists(dir_name): + os.makedirs(dir_name) + # 保存为单卡可加载的模型,多卡加载时需要先加载后放入DDP中 + checkpoint = { + "model_name": self.model_name, + "model": self.model.module.state_dict() + if hasattr(self.model, "module") + else self.model.state_dict(), + "model_args": self.model.module.get_model_args() + if hasattr(self.model, "module") + else self.model.get_model_args(), + "optimizer": self.optimizer.state_dict(), + "ema": self.ema.state_dict(), + "scheduler": self.scheduler.state_dict(), + "last_epoch": self.last_epoch, + "validation_metrics": self.validation_metrics, + "description": self.description, + } + torch.save(checkpoint, save_path) + + @staticmethod + def load( + model_name: str = "m3gnet", + load_path: str = None, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + args: Dict = None, + load_training_state: bool = True, + **kwargs, + ): + if load_path is None: + if model_name == "m3gnet": + print("Loading the pre-trained M3GNet model") + current_dir = os.path.dirname(__file__) + load_path = os.path.join( + current_dir, "m3gnet/pretrained/mpf/best_model.pth" + ) + elif model_name == "graphormer" or model_name == "geomformer": + raise NotImplementedError + else: + raise NotImplementedError + else: + print("Loading the model from %s" % load_path) + + checkpoint = torch.load(load_path, map_location=device) + + assert checkpoint["model_name"] == model_name + if model_name == "m3gnet": + model = M3Gnet(device=device, **checkpoint["model_args"]).to(device) + elif model_name == "m3gnet_multi_head": + model = M3Gnet_multi_head(device=device, **checkpoint["model_args"]).to( + device + ) + elif model_name == "graphormer" or model_name == "geomformer": + raise NotImplementedError + else: + raise NotImplementedError + model.load_state_dict(checkpoint["model"], strict=False) + + if load_training_state: + optimizer = Adam(model.parameters()) + scheduler = StepLR(optimizer, step_size=10, gamma=0.95) + try: + optimizer.load_state_dict(checkpoint["optimizer"]) + except BaseException: + try: + optimizer.load_state_dict(checkpoint["optimizer"].state_dict()) + except BaseException: + optimizer = None + try: + scheduler.load_state_dict(checkpoint["scheduler"]) + except BaseException: + try: + scheduler.load_state_dict(checkpoint["scheduler"].state_dict()) + except BaseException: + scheduler = "StepLR" + try: + last_epoch = checkpoint["last_epoch"] + validation_metrics = checkpoint["validation_metrics"] + description = checkpoint["description"] + except BaseException: + last_epoch = -1 + validation_metrics = {"loss": 0.0} + description = "" + try: + ema = ExponentialMovingAverage(model.parameters(), decay=0.99) + ema.load_state_dict(checkpoint["ema"]) + except BaseException: + ema = None + else: + optimizer = None + scheduler = "StepLR" + last_epoch = -1 + validation_metrics = {"loss": 0.0} + description = "" + ema = None + + model.eval() + + del checkpoint + + return Potential( + model, + optimizer=optimizer, + ema=ema, + scheduler=scheduler, + device=device, + model_name=model_name, + last_epoch=last_epoch, + validation_metrics=validation_metrics, + description=description, + **kwargs, + ) + + @staticmethod + def load_from_multi_head_model( + model_name: str = "m3gnet", + head_index: int = -1, + load_path: str = None, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + **kwargs, + ): + """ + Load one head of the multi-head model. + Args: + head_index: + -1: reset the head (final layer and + energy normalization module) + """ + if load_path is None: + if model_name == "m3gnet": + print("Loading the pre-trained multi-head M3GNet model") + current_dir = os.path.dirname(__file__) + load_path = os.path.join( + current_dir, + "m3gnet/pretrained/Transition1x-MD17-MPF21-QM9-HME21-OC20/" + "best_model.pth", + ) + else: + raise NotImplementedError + else: + print("Loading the model from %s" % load_path) + if head_index == -1: + print("Reset the final layer and normalization module") + checkpoint = torch.load(load_path, map_location=device) + if model_name == "m3gnet": + model = M3Gnet(device=device, **checkpoint["model_args"]).to( + device + ) # noqa: E501 + ori_ckpt = checkpoint["model"].copy() + for key in ori_ckpt: + if "final_layer_list" in key: + if "final_layer_list.%d" % head_index in key: + checkpoint["model"][ + key.replace("_layer_list.%d" % head_index, "") + ] = ori_ckpt[key] + del checkpoint["model"][key] + if "normalizer_list" in key: + if "normalizer_list.%d" % head_index in key: + checkpoint["model"][ + key.replace("_list.%d" % head_index, "") + ] = ori_ckpt[key] + del checkpoint["model"][key] + if "sph_2" in key: + del checkpoint["model"][key] + model.load_state_dict(checkpoint["model"], strict=True) + else: + raise NotImplementedError + description = checkpoint["description"] + model.eval() + + del checkpoint + + return Potential( + model, + device=device, + model_name=model_name, + description=description, + **kwargs, + ) + + def load_model(self, **kwargs): + warnings.warn( + "The interface of loading M3GNet model has been deprecated. " + "Please use Potential.load() instead.", + DeprecationWarning, + ) + warnings.warn( + "It only supports loading the pre-trained M3GNet model. " + "For other models, please use Potential.load() instead." + ) + current_dir = os.path.dirname(__file__) + load_path = os.path.join( + current_dir, "m3gnet/pretrained/mpf/best_model.pth" # noqa: E501 + ) + checkpoint = torch.load(load_path) + self.model.load_state_dict(checkpoint["model"]) + + def set_description(self, description): + self.description = description + + def get_description(self): + return self.description + + +def batch_to_dict(graph_batch, model_type="m3gnet", device="cuda"): + if model_type == "m3gnet": + # TODO: key_list + atom_pos = graph_batch.atom_pos + cell = graph_batch.cell + pbc_offsets = graph_batch.pbc_offsets + atom_attr = graph_batch.atom_attr + edge_index = graph_batch.edge_index + three_body_indices = graph_batch.three_body_indices + num_three_body = graph_batch.num_three_body + num_bonds = graph_batch.num_bonds + num_triple_ij = graph_batch.num_triple_ij + num_atoms = graph_batch.num_atoms + num_graphs = graph_batch.num_graphs + num_graphs = torch.tensor(num_graphs) + batch = graph_batch.batch + + # Resemble input dictionary + input = {} + input["atom_pos"] = atom_pos + input["cell"] = cell + input["pbc_offsets"] = pbc_offsets + input["atom_attr"] = atom_attr + input["edge_index"] = edge_index + input["three_body_indices"] = three_body_indices + input["num_three_body"] = num_three_body + input["num_bonds"] = num_bonds + input["num_triple_ij"] = num_triple_ij + input["num_atoms"] = num_atoms + input["num_graphs"] = num_graphs + input["batch"] = batch + elif model_type == "graphormer" or model_type == "geomformer": + raise NotImplementedError + else: + raise NotImplementedError + + return input + + +class DeepCalculator(Calculator): + """ + Deep calculator based on ase Calculator + """ + + implemented_properties = ["energy", "free_energy", "forces", "stress"] + + def __init__( + self, + potential: Potential, + args_dict: dict = {}, + compute_stress: bool = True, + stress_weight: float = 1.0, + device: str = "cuda" if torch.cuda.is_available() else "cpu", + **kwargs, + ): + """ + Args: + potential (Potential): m3gnet.models.Potential + compute_stress (bool): whether to calculate the stress + stress_weight (float): the stress weight. + **kwargs: + """ + super().__init__(**kwargs) + self.potential = potential + self.compute_stress = compute_stress + self.stress_weight = stress_weight + self.args_dict = args_dict + self.device = device + + def calculate( + self, + atoms: Optional[Atoms] = None, + properties: Optional[list] = None, + system_changes: Optional[list] = None, + ): + """ + Args: + atoms (ase.Atoms): ase Atoms object + properties (list): list of properties to calculate + system_changes (list): monitor which properties of atoms were + changed for new calculation. If not, the previous calculation + results will be loaded. + Returns: + """ + + all_changes = [ + "positions", + "numbers", + "cell", + "pbc", + "initial_charges", + "initial_magmoms", + ] + + properties = properties or ["energy"] + system_changes = system_changes or all_changes + super().calculate( + atoms=atoms, properties=properties, system_changes=system_changes + ) + + self.args_dict["batch_size"] = 1 + self.args_dict["only_inference"] = 1 + dataloader = build_dataloader( + [atoms], model_type=self.potential.model_name, **self.args_dict + ) + for graph_batch in dataloader: + # Resemble input dictionary + if ( + self.potential.model_name == "graphormer" + or self.potential.model_name == "geomformer" + ): + raise NotImplementedError + else: + graph_batch = graph_batch.to(self.device) + input = batch_to_dict(graph_batch) + + result = self.potential.forward( + input, include_forces=True, include_stresses=self.compute_stress + ) + if ( + self.potential.model_name == "graphormer" + or self.potential.model_name == "geomformer" + ): + raise NotImplementedError + else: + self.results.update( + energy=result["total_energy"].detach().cpu().numpy()[0], + free_energy=result["total_energy"].detach().cpu().numpy()[0], + forces=result["forces"].detach().cpu().numpy(), + ) + if self.compute_stress: + self.results.update( + stress=self.stress_weight + * full_3x3_to_voigt_6_stress( + result["stresses"].detach().cpu().numpy()[0] + ) + ) diff --git a/src/mattersim/jit_compile_tools/jit.py b/src/mattersim/jit_compile_tools/jit.py new file mode 100644 index 0000000..e989214 --- /dev/null +++ b/src/mattersim/jit_compile_tools/jit.py @@ -0,0 +1,337 @@ +# -*- coding: utf-8 -*- +"""jit.py is used to compile model with jit and modified from + https://github.com/e3nn/e3nn/blob/main/e3nn/util/jit.py +""" + +import copy +import inspect +import warnings +from typing import Optional + +import torch +from opt_einsum_fx import jitable +from torch import fx + +_RL4CSP_COMPILE_MODE = "__rl4csp_compile_mode__" +_VALID_MODES = ("trace", "script", "unsupported", None) +_MAKE_TRACING_INPUTS = "_make_tracing_inputs" + + +def compile_mode(mode: str): + """Decorator to set the compile mode of a module. + + Parameters + ---------- + mode : str + 'script', 'trace', or None + """ + if mode not in _VALID_MODES: + raise ValueError("Invalid compile mode") + + def decorator(obj): + if not (inspect.isclass(obj) and issubclass(obj, torch.nn.Module)): + raise TypeError( + "@rl4csp.mattersim.forcefield.jit.compile_mode can only " + "decorate classes derived from torch.nn.Module" + ) + setattr(obj, _RL4CSP_COMPILE_MODE, mode) + return obj + + return decorator + + +def get_compile_mode(mod: torch.nn.Module) -> str: + """Get the compilation mode of a module. + + Parameters + ---------- + mod : torch.nn.Module + + Returns + ------- + 'script', 'trace', or None if the module was not decorated with + @compile_mode + """ + if hasattr(mod, _RL4CSP_COMPILE_MODE): + mode = getattr(mod, _RL4CSP_COMPILE_MODE) + else: + mode = getattr(type(mod), _RL4CSP_COMPILE_MODE, None) + if mode is None and isinstance(mod, fx.GraphModule): + mode = "script" + assert mode in _VALID_MODES, "Invalid compile mode `%r`" % mode + return mode + + +def compile( + mod: torch.nn.Module, + n_trace_checks: int = 1, + script_options: dict = None, + trace_options: dict = None, + in_place: bool = True, +): + """Recursively compile a module and all submodules according + to their decorators. + + (Sub)modules without decorators will be unaffected. + + Parameters + ---------- + mod : torch.nn.Module + The module to compile. The module will have its submodules + compiled replaced in-place. + n_trace_checks : int, default = 1 + How many random example inputs to generate when tracing a module. + Must be at least one in order to have a tracing input. + Extra example inputs will be pased to ``torch.jit.trace`` + to confirm that the traced copmute graph doesn't change. + script_options : dict, default = {} + Extra kwargs for ``torch.jit.script``. + trace_options : dict, default = {} + Extra kwargs for ``torch.jit.trace``. + + Returns + ------- + Returns the compiled module. + :param trace_options: + :param script_options: + :param n_trace_checks: + :param mod: + :param in_place: + """ + script_options = script_options or {} + trace_options = trace_options or {} + + mode = get_compile_mode(mod) + if mode == "unsupported": + raise NotImplementedError( + f"{type(mod).__name__} does not support TorchScript compilation" + ) + + if not in_place: + mod = copy.deepcopy(mod) + # TODO: debug logging + assert n_trace_checks >= 1 + # == recurse to children == + # This allows us to trace compile submodules of modules we are going to + # script + for submod_name, submod in mod.named_children(): + setattr( + mod, + submod_name, + compile( + submod, + n_trace_checks=n_trace_checks, + script_options=script_options, + trace_options=trace_options, + in_place=True, + # since we deepcopied the module above, we can do inplace + ), + ) + # == Compile this module now == + if mode == "script": + if isinstance(mod, fx.GraphModule): + mod = jitable(mod) + mod = torch.jit.script(mod, **script_options) + elif mode == "trace": + # These are always modules, so we're always using trace_module + # We need tracing inputs: + check_inputs = get_tracing_inputs( + mod, + n_trace_checks, + ) + assert len(check_inputs) >= 1, "Must have at least one tracing input." + # Do the actual trace + mod = torch.jit.trace_module( + mod, + inputs=check_inputs[0], + check_inputs=check_inputs, + **trace_options, # noqa: E501 + ) + return mod + + +def get_tracing_inputs( + mod: torch.nn.Module, + n: int = 1, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, +): + """Get random tracing inputs for ``mod``. + + First checks if ``mod`` has a ``_make_tracing_inputs`` method. + If so, calls it with ``n`` as the single argument and returns its results. + + Otherwise, attempts to infer the input signature of the module using + ``e3nn.util._argtools._get_io_irreps``. + + Parameters + ---------- + mod : torch.nn.Module + n : int, default = 1 + A hint for how many inputs are wanted. Usually n will be returned, + but modules don't necessarily have to. + device : torch.device + The device to do tracing on. If `None` (default), will be guessed. + dtype : torch.dtype + The dtype to trace with. If `None` (default), will be guessed. + + Returns + ------- + list of dict + Tracing inputs in the format of ``torch.jit.trace_module``: + dicts mapping method names like ``'forward'`` to tuples of arguments. + """ + # Avoid circular imports + from ._argtools import ( + _get_device, + _get_floating_dtype, + _get_io_irreps, + _rand_args, + _to_device_dtype, + ) + + # - Get inputs - + if hasattr(mod, _MAKE_TRACING_INPUTS): + # This returns a trace_module style dict of method names to test inputs + trace_inputs = mod._make_tracing_inputs(n) + assert isinstance(trace_inputs, list) + for d in trace_inputs: + assert isinstance( + d, dict + ), "_make_tracing_inputs must return a list of dict[str, tuple]" + assert all( + isinstance(k, str) and isinstance(v, tuple) + for k, v in d.items() # noqa: E501 + ), "_make_tracing_inputs must return a list of dict[str, tuple]" + else: + # Try to infer. This will throw if it can't. + irreps_in, _ = _get_io_irreps( + mod, irreps_out=[None] + ) # we're only trying to infer inputs + trace_inputs = [{"forward": _rand_args(irreps_in)} for _ in range(n)] + # - Put them on the right device - + if device is None: + device = _get_device(mod) + if dtype is None: + dtype = _get_floating_dtype(mod) + # Move them + trace_inputs = _to_device_dtype(trace_inputs, device, dtype) + return trace_inputs + + +def trace_module( + mod: torch.nn.Module, + inputs: dict = None, + check_inputs: list = None, + in_place: bool = True, +): + """Trace a module. + + Identical signature to ``torch.jit.trace_module``, but first recursively + compiles ``mod`` using ``compile``. + + Parameters + ---------- + mod : torch.nn.Module + inputs : dict + check_inputs : list of dict + Returns + ------- + Traced module. + """ + check_inputs = check_inputs or [] + + # Set the compile mode for mod, temporarily + old_mode = getattr(mod, _RL4CSP_COMPILE_MODE, None) + if old_mode is not None and old_mode != "trace": + warnings.warn( + f"Trying to trace a module of type {type(mod).__name__} marked " + "with @compile_mode != 'trace', expect errors!" + ) + setattr(mod, _RL4CSP_COMPILE_MODE, "trace") + + # If inputs are provided, set make_tracing_input temporarily + old_make_tracing_input = None + if inputs is not None: + old_make_tracing_input = getattr(mod, _MAKE_TRACING_INPUTS, None) + setattr( + mod, + _MAKE_TRACING_INPUTS, + lambda num: ([inputs] + check_inputs), # noqa: E501 + ) + + # Compile + out = compile(mod, in_place=in_place) + + # Restore old values, if we had them + if old_mode is not None: + setattr(mod, _RL4CSP_COMPILE_MODE, old_mode) + if old_make_tracing_input is not None: + setattr(mod, _MAKE_TRACING_INPUTS, old_make_tracing_input) + return out + + +def trace( + mod: torch.nn.Module, + example_inputs: tuple = None, + check_inputs: list = None, + in_place: bool = True, +): + """Trace a module. + + Identical signature to ``torch.jit.trace``, but first recursively compiles + ``mod`` using :func:``compile``. + + Parameters + ---------- + mod : torch.nn.Module + example_inputs : tuple + check_inputs : list of tuple + Returns + ------- + Traced module. + """ + check_inputs = check_inputs or [] + + return trace_module( + mod=mod, + inputs=( + {"forward": example_inputs} + if example_inputs is not None + else None # noqa: E501 + ), + check_inputs=[{"forward": c} for c in check_inputs], + in_place=in_place, + ) + + +def script(mod, in_place: bool = True): + """Script a module. + + Like ``torch.jit.script``, but first recursively compiles ``mod`` + using :func:``compile``. + + Parameters + ---------- + mod : torch.nn.Module + Returns + ------- + Scripted module. + """ + # Set the compile mode for mod, temporarily + old_mode = getattr(mod, _RL4CSP_COMPILE_MODE, None) + if old_mode is not None and old_mode != "script": + warnings.warn( + f"Trying to script a module of type {type(mod).__name__} marked " + "with @compile_mode != 'script', expect errors! " + ) + setattr(mod, _RL4CSP_COMPILE_MODE, "script") + + # Compile + out = compile(mod, in_place=in_place) + + # Restore old values, if we had them + if old_mode is not None: + setattr(mod, _RL4CSP_COMPILE_MODE, old_mode) + + return out diff --git a/src/mattersim/jit_compile_tools/jit_compile.py b/src/mattersim/jit_compile_tools/jit_compile.py new file mode 100644 index 0000000..add39fc --- /dev/null +++ b/src/mattersim/jit_compile_tools/jit_compile.py @@ -0,0 +1,205 @@ +# -*- coding: utf-8 -*- +import logging +import pathlib +import sys +from typing import Dict, Tuple, Union + +if sys.version_info[1] >= 8: + from typing import Final +else: + from typing_extensions import Final + +# This is a weird hack to avoid Intel MKL issues on the cluster +import ase.data + +# when this is called as a subprocess of a process that has itself initialized +# PyTorch. Since numpy gets imported later anyway for dataset stuff, +# this shouldn't affect performance. +import numpy as np # noqa: F401 +import torch + +from .jit import script + +# Denote meta_data_keys +TWO_BODY_CUTOFF: Final[str] = "two_body_cutoff" +HAS_THREE_BODY: Final[str] = "has_three_body" +THREE_BODY_CUTOFF: Final[str] = "three_body_cutoff" +N_SPECIES_KEY: Final[str] = "n_species" +TYPE_NAMES_KEY: Final[str] = "type_names" +JIT_BAILOUT_KEY: Final[str] = "_jit_bailout_depth" +JIT_FUSION_STRATEGY: Final[str] = "_jit_fusion_strategy" +TF32_KEY: Final[str] = "allow_tf32" + +_ALL_METADATA_KEYS = [ + TWO_BODY_CUTOFF, + HAS_THREE_BODY, + THREE_BODY_CUTOFF, + N_SPECIES_KEY, + TYPE_NAMES_KEY, + JIT_BAILOUT_KEY, + JIT_FUSION_STRATEGY, + TF32_KEY, +] + + +def _compile_for_deploy(model): + model.eval() + + if not isinstance(model, torch.jit.ScriptModule): + print("Non TorchScript model detected,JIT compiling the model ....") + model = script(model) + else: + print( + "Model provided is already a TorchScript model, " + "return as it is." # noqa: E501 + ) + return model + + +def load_deployed_model( + model_path: Union[pathlib.Path, str], + device: Union[str, torch.device] = "cpu", + freeze: bool = True, +) -> Tuple[torch.jit.ScriptModule, Dict[str, str]]: + r"""Load a deployed model. + Args: + model_path: the path to the deployed model's ``.pth`` file. + Returns: + model, metadata dictionary + """ + metadata = {k: "" for k in _ALL_METADATA_KEYS} + try: + model = torch.jit.load( + model_path, map_location=device, _extra_files=metadata + ) # noqa: E501 + except RuntimeError as e: + raise ValueError( + f"{model_path} does not seem to be a deployed RL4CSP model file. " + f"Did you forget to deploy it? \n\n(Underlying error: {e})" + ) + + # Confirm its TorchScript + assert isinstance(model, torch.jit.ScriptModule) + + # Make sure we're in eval mode + model.eval() + # Freeze on load: + if freeze and hasattr(model, "training"): + # hasattr is how torch checks whether model is unfrozen + # only freeze if already unfrozen + model = torch.jit.freeze(model) + + # Everything we store right now is ASCII, so decode for printing + metadata = {k: v.decode("ascii") for k, v in metadata.items()} + + # JIT strategy + strategy = metadata.get(JIT_FUSION_STRATEGY, "") + + if strategy != "": + strategy = [e.split(",") for e in strategy.split(";")] + strategy = [(e[0], int(e[1])) for e in strategy] + else: + print( + "Missing information: JIT strategy, " + "loading deployed model fails !" # noqa: E501 + ) + exit() + + # JIT bailout + jit_bailout: int = metadata.get(JIT_BAILOUT_KEY, "") + if jit_bailout == "": + print( + "Missing information: JIT_BAILOUT_KEY, " + "loading deployed model fails !" # noqa: E501 + ) + exit() + + # JIT allow_tf32 + jit_allow_tf32: int = metadata.get(TF32_KEY, "") + if jit_allow_tf32 == "": + print("Missing information: TF32_KEY, loading deployed model fails !") + exit() + + return model, metadata + + +def deploy( + model, + is_m3gnet_pretrained=False, + is_m3gnet_multi_head_pretrained=False, + metadata=None, + deployed_model_name="deployed.pth", + device="cpu", +): + # Compile model + complied_model = _compile_for_deploy(model) + + # Use default metadata dictionary for pretrained models + if is_m3gnet_pretrained: + metadata = {} + + # Do set differences get atomic numbers + full_atomic_numbers = set(np.arange(1, 95, 1)) + discard_atomic_numbers = set(np.arange(84, 89, 1)) + covered_atomic_numbers = list( + full_atomic_numbers.difference(discard_atomic_numbers) + ) + type_names = [] + for atomic_num in covered_atomic_numbers: + type_names.append(ase.data.chemical_symbols[atomic_num]) + metadata[TWO_BODY_CUTOFF] = str(5.0) + metadata[HAS_THREE_BODY] = str(True) + metadata[THREE_BODY_CUTOFF] = str(4.0) + metadata[N_SPECIES_KEY] = str(89) + metadata[TYPE_NAMES_KEY] = " ".join(type_names) + metadata[JIT_BAILOUT_KEY] = str(2) + metadata[JIT_FUSION_STRATEGY] = ";".join( + "%s,%i" % e for e in [("DYNAMIC", 3)] # noqa: E501 + ) + metadata[TF32_KEY] = str(int(0)) + + # TODO: Add default meta keys for m3gent_multi_head models + # elif is_m3gnet_multi_head_pretrained: + + else: + # Missing fields in meta data triggers failing compilation + metadata_keys = metadata.keys + for _ALL_METADATA_KEY in _ALL_METADATA_KEYS: + if _ALL_METADATA_KEY not in metadata_keys: + logging.info( + "Miss metadata key: " + + _ALL_METADATA_KEY + + " model deploying fails!" + ) + exit() + # Missing metadata values, other than JIT compile information, + # triggers failing compilation + for i in range(len(metadata_keys) - 3): + if metadata[metadata_keys[i]].empty(): + logging.info( + "metadata with key " + + metadata_keys + + "not set, model deploying fails!" + ) + exit() + # Set default JIT compile information is values are not set + if ( + metadata["JIT_BAILOUT_KEY"].empty() + or metadata[JIT_FUSION_STRATEGY].empty() + or metadata[TF32_KEY].empty() + ): + metadata[JIT_BAILOUT_KEY] = str(2) + metadata[JIT_FUSION_STRATEGY] = ";".join( + "%s,%i" % e for e in [("DYNAMIC", 3)] + ) + metadata[TF32_KEY] = str(int(0)) + + # Deploy model with full information + # Confirm its TorchScript + assert isinstance(complied_model, torch.jit.ScriptModule) + if device != "cuda": + complied_model = complied_model.cpu() + + torch.jit.save(complied_model, deployed_model_name, _extra_files=metadata) + + return complied_model, metadata diff --git a/src/mattersim/utils/atoms_utils.py b/src/mattersim/utils/atoms_utils.py new file mode 100644 index 0000000..ef67ec9 --- /dev/null +++ b/src/mattersim/utils/atoms_utils.py @@ -0,0 +1,84 @@ +# -*- coding: utf-8 -*- +import os + +from ase import Atoms +from ase.io import read as ase_read +from mp_api.client import MPRester +from pymatgen.core.structure import Structure +from pymatgen.io.ase import AseAtomsAdaptor + + +class AtomsAdaptor(object): + """ + This class is used to read different structures type + and transform it to ASE Atoms object. + """ + + def __init__(self): + pass + + @classmethod + def from_ase_atoms(cls, atoms: Atoms): + """ + Get Atoms from Atoms. + + Args: + atoms (Atoms): ASE Atoms object. + """ + if not isinstance(atoms, Atoms): + raise TypeError("Input must be ASE Atoms object.") + return atoms + + @classmethod + def from_pymatgen_structure(cls, structure: Structure): + """ + Get Atoms from Structure. + + Args: + structure (Structure): pymatgen Structure object. + """ + if not isinstance(structure, Structure): + raise TypeError("Input must be pymatgen Structure object") + return AseAtomsAdaptor.get_atoms(structure, msonable=False) + + @classmethod + def from_mp_id(cls, mp_id: str, api_key: str = None): + """ + Get Atoms from mp-id. + + mp_id (str): mp_id for materials. + api_key (str, optional): api_key to access Material Projects. + If not provided, try to extract it from environment variables. + """ + mp_api_key = api_key or os.getenv("MP_API_KEY") + if not mp_api_key: + raise ValueError( + "An MP API key is required to fetch data from" + " Materials Project, but was not found in the" + " environment variables or provided." + ) + with MPRester(mp_api_key) as m: + structure = m.get_structure_by_material_id(mp_id) + return AseAtomsAdaptor.get_atoms(structure, msonable=False) + + @classmethod + def from_file(cls, filename: str, format: str = None): + """ + Get Atoms from file. + + filename (str): file name which contains structures. + format (str, optional): file format. If None, will automately + guess. + """ + if not os.path.exists(filename): + raise FileNotFoundError(f"File {filename} not found.") + + if format: + atoms_list = ase_read(filename, index=":", format=format) + else: + try: + atoms_list = ase_read(filename, index=":") + except Exception as e: + raise ValueError(f"Can not automately guess the file format: {e}") + + return atoms_list diff --git a/src/mattersim/utils/phonon_utils.py b/src/mattersim/utils/phonon_utils.py new file mode 100644 index 0000000..3e56a21 --- /dev/null +++ b/src/mattersim/utils/phonon_utils.py @@ -0,0 +1,49 @@ +# -*- coding: utf-8 -*- +from ase import Atoms +from phonopy import Phonopy +from phonopy.structure.atoms import PhonopyAtoms + + +def get_primitive_cell(atoms: Atoms): + """ + Get primitive cell from ASE atoms object + Args: + atoms (Atoms): ASE atoms object to provide lattice information + """ + phonopy_atoms = Phonopy( + to_phonopy_atoms(atoms), primitive_matrix="auto", log_level=2 + ) + primitive = phonopy_atoms.primitive + atoms = to_ase_atoms(primitive) + return atoms + + +def to_phonopy_atoms(atoms: Atoms): + """ + Transform ASE atoms object to Phonopy object + Args: + atoms (Atoms): ASE atoms object to provide lattice informations. + """ + phonopy_atoms = PhonopyAtoms( + symbols=atoms.get_chemical_symbols(), + cell=atoms.get_cell(), + masses=atoms.get_masses(), + positions=atoms.get_positions(), + ) + return phonopy_atoms + + +def to_ase_atoms(phonopy_atoms): + """ + Transform Phonopy object to ASE atoms object + Args: + phonopy_atoms (Phonopy): Phonopy object to provide lattice informations. + """ + atoms = Atoms( + symbols=phonopy_atoms.symbols, + cell=phonopy_atoms.cell, + masses=phonopy_atoms.masses, + positions=phonopy_atoms.positions, + pbc=True, + ) + return atoms diff --git a/src/mattersim/utils/supercell_utils.py b/src/mattersim/utils/supercell_utils.py new file mode 100644 index 0000000..6a2589a --- /dev/null +++ b/src/mattersim/utils/supercell_utils.py @@ -0,0 +1,201 @@ +# -*- coding: utf-8 -*- +import numpy as np +from ase import Atoms +from ase.spacegroup.symmetrize import check_symmetry + + +def auto_grid_detection( + atom: Atoms, + max_atoms: int, + ratio_tolerance: float = 1.1, + is_santity_check: bool = True, + is_verbose: bool = True, +): + """ + This function automates the detection of grid for a given atomic structure + and max_atoms. If lattice vectors lenght in three direction is same or the + difference is smaller than 0.1, the supercell vector element will has the + same value in three direction, the vaule is (max_atoms/atoms)^(1/3). Else + the supercell vector element will be set proportionally to make the three + supercell lattice vector length as same as possible. + + Args: + atom (Atoms): ASE atoms object to provide lattice informations. + max_atoms: (int): Maximum atom number limitation for supercell. + ratio_tolerance (float, optional): The tolerance for the ratio of the + lengths of the lattice vectors. Defaults to 1.1. + is_santity_check (bool, optional): If True, performs a sanity check to + ensure symmetry is preserved after replications. Defaults to True. + is_verbose (bool, optional): If True, prints detailed information about + the atomic structure and the replication process. Defaults to True. + """ + # Get the cell length + lattice_vector_lengths = atom.cell.cellpar()[:3] + + # Base case, absolute the same length + if ( + lattice_vector_lengths[0] + == lattice_vector_lengths[1] + == lattice_vector_lengths[2] + ): + number_of_replicas = int(np.round(max_atoms / (len(atom))) ** (1 / 3)) + number_of_replicas = max(number_of_replicas, 1) + max_replication = ( + number_of_replicas, + number_of_replicas, + number_of_replicas, + ) + + # Case 1: Non-symmetry lengths within tolerance + else: + lattice_vector_lengths_argsort_indices = np.argsort(lattice_vector_lengths)[ + ::-1 + ] + sorted_lattice_vector_lengths = lattice_vector_lengths[ + lattice_vector_lengths_argsort_indices + ] + ratios = [ + sorted_lattice_vector_lengths[0] / sorted_lattice_vector_lengths[1], + sorted_lattice_vector_lengths[0] / sorted_lattice_vector_lengths[2], + ] + + # Variation in cell length with in the tolerance we still consider it + # as an N-N-N replications + if ratios[0] <= ratio_tolerance and ratios[1] <= ratio_tolerance: + number_of_replicas = int(np.round(max_atoms / (len(atom))) ** (1 / 3)) + number_of_replicas = max(number_of_replicas, 1) + max_replication = ( + number_of_replicas, + number_of_replicas, + number_of_replicas, + ) + + # Case 2: Non-symmetry lengths beyond tolerance + else: + # Compute the replica along the most asymmetric direction + asymmetric_replica = int( + (max_atoms / len(atom) / np.prod(ratios)) ** (1 / 3) + ) + asymmetric_replica = max(asymmetric_replica, 1) + + # Recover replica on the other two direction based on ratios + replica_r0 = max(int(np.round(asymmetric_replica * ratios[0])), 1) + replica_r1 = max(int(np.round(asymmetric_replica * ratios[1])), 1) + indices_to_recover_lattice_vector_order = np.argsort( + lattice_vector_lengths_argsort_indices + ) + max_replication_arr = np.array( + [asymmetric_replica, replica_r0, replica_r1] + )[indices_to_recover_lattice_vector_order] + max_replication = tuple(max_replication_arr) + + # Broad cast unit cell infomraiton + if is_verbose: + print("System:", atom) + print("Number of atoms in the unit cell: ", len(atom)) + print("Lattice vector and angles: ", atom.cell.cellpar()) + print( + "Space group: ", + check_symmetry(atom, 1e-3, verbose=False)["international"], + ) + + if is_santity_check: + symmetry_of_unit_cell = check_symmetry(atom, 1e-3, verbose=False)[ + "international" + ] + symmetry_of_replicated_supercell = check_symmetry( + atom.copy().repeat(max_replication), 1e-3, verbose=False + )["international"] + if symmetry_of_unit_cell == symmetry_of_replicated_supercell: + print( + "symmetry is preserved after replications, safely return " + "replication combination !\n" + ) + return max_replication + else: + print( + "Symmetry is lose after replications. No possible replications" + " can be found !\n" + ) + return (1, 1, 1) + + # Check if max_replication is still the initial value + if max_replication == (1, 1, 1): + print("No possible replications. Returning unit cell.") + return (1, 1, 1) + else: + return max_replication + + +def get_supercell_parameters( + atom: Atoms, + supercell_matrix: np.ndarray = None, + qpoints_mesh: np.ndarray = None, + max_atoms: int = None, +): + """ + Based on symmetry to get supercell setting parameters. + First, setting the maximum atoms number limitation for supercell. If max_atoms + is None, will automatic setting it, else use user assigned. If the lattice + parameters in three direction is same or approximately same, the max_atoms will + be set small, e.g. 216 or 300; else some direction needed expand larger than + others, so the max_atoms also need more, e.g. 450. Then, based on max_atoms, + call auto_grid_dection function to obtain supercell matrix diagonal elements. + Finally setting the k_point_mesh used to integrate Brillouin Zone. Cause kpoints + in inverse space, smaller real space means the inverse space is larger, will + need more kpoints. + + Args: + atom (Atoms): ASE atoms object to provide lattice information. + supercell_matrix (nd.ndarray, optional): Supercell matrix for construct + supercell, prior than max_atoms. + qpoints_mesh (np.ndarray, optional): Qpoints mesh for IBZ integral, prio + over than max_atoms. + max_atoms (int, optional): If not None, will use user setting maximum + atoms number limitation for generate supercell, else automatic set. + Defaults to None. + """ + if supercell_matrix is not None: + nrep_second = np.diag(supercell_matrix) + if nrep_second[0] == nrep_second[1] == nrep_second[2]: + k_point_mesh = 6 * np.array(nrep_second) + else: + k_point_mesh = 3 * np.array(nrep_second) + + if qpoints_mesh is not None: + k_point_mesh = qpoints_mesh + + return supercell_matrix, k_point_mesh + + lattice_vector_lengths = atom.cell.cellpar()[:3] + lattice_vector_lengths_argsort_indices = np.argsort(lattice_vector_lengths)[::-1] + sorted_lattice_vector_lengths = lattice_vector_lengths[ + lattice_vector_lengths_argsort_indices + ] + ratios = [ + sorted_lattice_vector_lengths[0] / sorted_lattice_vector_lengths[1], + sorted_lattice_vector_lengths[0] / sorted_lattice_vector_lengths[2], + ] + if max_atoms: + pass + elif ( + check_symmetry(atom, 1e-3, verbose=False)["international"] == "Fd-3m" + or check_symmetry(atom, 1e-3, verbose=False)["international"] == "Fm-3m" + or check_symmetry(atom, 1e-3, verbose=False)["international"] == "F-43m" + ): + max_atoms = 216 + elif check_symmetry(atom, 1e-3, verbose=False)["international"] == "P6_3mc": + max_atoms = 450 + elif ratios[0] <= 1.1 and ratios[1] <= 1.1: + max_atoms = 300 + else: + max_atoms = 300 + + nrep_second = auto_grid_detection(atom, max_atoms, is_verbose=False) + + if nrep_second[0] == nrep_second[1] == nrep_second[2]: + k_point_mesh = 6 * np.array(nrep_second) + else: + k_point_mesh = 3 * np.array(nrep_second) + + return nrep_second, k_point_mesh diff --git a/tests/applications/test_phonon.py b/tests/applications/test_phonon.py new file mode 100644 index 0000000..146903e --- /dev/null +++ b/tests/applications/test_phonon.py @@ -0,0 +1,78 @@ +# -*- coding: utf-8 -*- +import unittest + +import numpy as np +from ase import Atoms +from ase.calculators.emt import EMT +from phonopy import Phonopy + +from mattersim.applications.phonon import PhononWorkflow + + +class PhononTestCase(unittest.TestCase): + def setUp(self): + # Create an example structure for testing + a = 1.786854996 + positions = [ + (1.78685500, 1.78685500, 1.78685500), + (2.68028249, 2.68028249, 2.68028249), + ] + cell = [(0, a, a), (a, 0, a), (a, a, 0)] + + self.atoms = Atoms("C2", positions=positions, cell=cell, pbc=True) + + # Create an conventional cell for testing + a2 = a * 2 + positions2 = [ + (0, 0, 0), + (0, a2 / 2, a2 / 2), + (a2 / 2, 0, a2 / 2), + (a2 / 2, a2 / 2, 0), + (a2 / 4, a2 / 4, a2 / 4), + (a2 / 4, 3 * a2 / 4, 3 * a2 / 4), + (3 * a2 / 4, a2 / 4, 3 * a2 / 4), + (3 * a2 / 4, 3 * a2 / 4, a2 / 4), + ] + + cell2 = [(a2, 0, 0), (0, a2, 0), (0, 0, a2)] + + self.atoms_conv = Atoms("C8", positions=positions2, cell=cell2, pbc=True) + + self.calculator = EMT() + self.atoms.calc = self.calculator + self.atoms_conv.calc = self.calculator + + def test_phonon(self): + phononworkflow = PhononWorkflow(self.atoms, work_dir="/tmp/diamond") + has_imaginary, phonon = phononworkflow.run() + + self.assertTrue(has_imaginary) + self.assertIsInstance(phonon, Phonopy) + + def test_phonon_supercell(self): + supercell_matrix = np.array([[4, 0, 0], [0, 4, 0], [0, 0, 4]]) + qpoints_mesh = np.array([12, 12, 12]) + phononworkflow = PhononWorkflow( + self.atoms, + work_dir="/tmp/diamond", + supercell_matrix=supercell_matrix, + qpoints_mesh=qpoints_mesh, + ) + has_imaginary, phonon = phononworkflow.run() + + self.assertTrue(has_imaginary) + self.assertIsInstance(phonon, Phonopy) + + def test_phonon_prim(self): + phononworkflow = PhononWorkflow( + self.atoms_conv, work_dir="/tmp/diamond_conv", find_prim=True + ) + has_imaginary, phonon = phononworkflow.run() + has_imaginary, phonon = phononworkflow.run() + + self.assertTrue(has_imaginary) + self.assertIsInstance(phonon, Phonopy) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/applications/test_relax.py b/tests/applications/test_relax.py new file mode 100644 index 0000000..df508f9 --- /dev/null +++ b/tests/applications/test_relax.py @@ -0,0 +1,160 @@ +# -*- coding: utf-8 -*- +import unittest + +from ase import Atoms +from ase.calculators.emt import EMT +from pymatgen.io.ase import AseAtomsAdaptor +from pymatgen.symmetry.analyzer import SpacegroupAnalyzer + +from mattersim.applications.relax import Relaxer + + +class RelaxerTestCase(unittest.TestCase): + def setUp(self): + # Create an example structure with displaced atoms for testing + a = 1.786854996 # Angstroms + positions = [ + (0, 0, 0), + (a / 4, a / 4, a / 4), + (a / 2, a / 2, 0), + (a / 2, 0, a / 2), + (0, a / 2, a / 2), + (a / 4, 3 * a / 4, 3 * a / 4.01), # displaced + (3 * a / 4, a / 4.01, 3 * a / 4), # displaced + (3 * a / 4, 3 * a / 4, a / 4), + ] + cell = [(a, 0, 0), (0, a, 0), (0, 0, a)] + self.atoms_displaced = Atoms( + "C8", positions=positions, cell=cell, pbc=True # noqa: E501 + ) + + # Create an example structure with expanded cell for testing + a = 1.786854996 * 1.2 + positions = [ + (0, 0, 0), + (a / 4, a / 4, a / 4), + (a / 2, a / 2, 0), + (a / 2, 0, a / 2), + (0, a / 2, a / 2), + (a / 4, 3 * a / 4, 3 * a / 4), + (3 * a / 4, a / 4, 3 * a / 4), + (3 * a / 4, 3 * a / 4, a / 4), + ] + cell = [(a, 0, 0), (0, a, 0), (0, 0, a)] + self.atoms_expanded = Atoms( + "C8", positions=positions, cell=cell, pbc=True # noqa: E501 + ) + + self.calculator = EMT() + + def test_default_relaxer(self): + relaxer = Relaxer() + atoms_displaced = self.atoms_displaced.copy() + atoms_displaced.set_calculator(self.calculator) + converged, relaxed_atoms = relaxer.relax( + atoms_displaced, fmax=0.1, steps=500 + ) # noqa: E501 + self.assertTrue(converged) + self.assertIsInstance(relaxed_atoms, Atoms) + + def test_relax_structures(self): + atoms_list = [ + self.atoms_displaced.copy(), + self.atoms_displaced.copy(), + self.atoms_displaced.copy(), + ] + for atoms in atoms_list: + atoms.set_calculator(self.calculator) + + converged_list, relaxed_atoms_list = Relaxer.relax_structures( + atoms_list, fmax=0.1, steps=500 + ) + self.assertIsInstance(converged_list, list) + for converged in converged_list: + self.assertTrue(converged) + + def test_relax_structures_under_pressure(self): + atoms_displaced = self.atoms_displaced.copy() + atoms_displaced.set_calculator(self.calculator) + init_volume = atoms_displaced.get_volume() + print(f"Initial volume: {init_volume}") + + # First, relax under 0 pressure + converged, relaxed_atoms = Relaxer.relax_structures( + atoms_displaced, + steps=500, + fmax=0.1, + filter="FrechetCellFilter", + pressure_in_GPa=0.0, + ) + intermediate_volume = relaxed_atoms.get_volume() + print(f"Intermediate volume: {intermediate_volume}") + self.assertTrue(converged) + + # Second, relax under 100 GPa + converged, relaxed_atoms = Relaxer.relax_structures( + relaxed_atoms, + steps=500, + fmax=0.1, + filter="FrechetCellFilter", + pressure_in_GPa=100.0, + ) + final_volume = relaxed_atoms.get_volume() + print(f"Final volume: {final_volume}") + self.assertTrue(converged) + self.assertLess(final_volume, intermediate_volume) + print(f"Final cell: {relaxed_atoms.cell}") + + def test_relax_with_filter_and_constrained_symmetry(self): + atoms_expanded = self.atoms_expanded.copy() + atoms_expanded.set_calculator(self.calculator) + init_volume = atoms_expanded.get_volume() + print(f"Initial volume: {init_volume}") + + init_analyzer = SpacegroupAnalyzer( + AseAtomsAdaptor.get_structure(self.atoms_expanded) + ) + init_spacegroup = init_analyzer.get_space_group_number() + + # First, relax under 0 pressure + converged, relaxed_atoms = Relaxer.relax_structures( + atoms_expanded, + steps=500, + fmax=0.1, + filter="FrechetCellFilter", + pressure_in_GPa=0.0, + constrain_symmetry=True, + ) + intermediate_volume = relaxed_atoms.get_volume() + print(f"Intermediate volume: {intermediate_volume}") + self.assertTrue(converged) + + # Second, relax under 100 GPa + converged, relaxed_atoms = Relaxer.relax_structures( + relaxed_atoms, + steps=500, + fmax=0.1, + filter="FrechetCellFilter", + pressure_in_GPa=100.0, + constrain_symmetry=True, + ) + final_volume = relaxed_atoms.get_volume() + print(f"Final volume: {final_volume}") + self.assertTrue(converged) + self.assertLess(final_volume, intermediate_volume) + + final_analyzer = SpacegroupAnalyzer( + AseAtomsAdaptor.get_structure(relaxed_atoms) + ) + final_spacegroup = final_analyzer.get_space_group_number() + self.assertEqual(init_spacegroup, final_spacegroup) + print(f"Final cell: {relaxed_atoms.cell}") + cell_a = relaxed_atoms.cell[0, 0] + cell_b = relaxed_atoms.cell[1, 1] + cell_c = relaxed_atoms.cell[2, 2] + self.assertAlmostEqual(cell_a, cell_b) + self.assertAlmostEqual(cell_a, cell_c) + + +if __name__ == "__main__": + unittest.main()