diff --git a/README.md b/README.md index bc69f0a..1c70f6e 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ from mattersim.datasets.utils.build import build_dataloader device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" print(f"Running MatterSim on {device}") -potential = Potential.load(load_path="pretrained_models/mattersim-v1.0.0-1M.pth", device=device) +potential = Potential.load() si = bulk("Si", "diamond", a=5.43) dataloader = build_dataloader([si], only_inference=True) @@ -53,7 +53,7 @@ print(predictions) We kindly request that users of MatterSim version 1.0.0 cite our preprint available on arXiv: ``` @article{yang2024mattersim, - title={MatterSim: A Deep Learning Atomistic Model Across Elements, Temperatures and Pressures}, + title={MatterSim: A Deep Learning Atomistic Model Across Elements, Temperatures and Pressures}, author={Han Yang and Chenxi Hu and Yichi Zhou and Xixian Liu and Yu Shi and Jielan Li and Guanzhi Li and Zekun Chen and Shuizhou Chen and Claudio Zeni and Matthew Horton and Robert Pinsler and Andrew Fowler and Daniel Zügner and Tian Xie and Jake Smith and Lixin Sun and Qian Wang and Lingyu Kong and Chang Liu and Hongxia Hao and Ziheng Lu}, year={2024}, eprint={2405.04967}, diff --git a/src/mattersim/applications/relax.py b/src/mattersim/applications/relax.py index 2c79ab2..f18ed27 100644 --- a/src/mattersim/applications/relax.py +++ b/src/mattersim/applications/relax.py @@ -7,6 +7,7 @@ from ase.filters import ExpCellFilter, FrechetCellFilter from ase.optimize import BFGS, FIRE from ase.optimize.optimize import Optimizer +from ase.units import GPa class Relaxer(object): @@ -53,7 +54,7 @@ def relax( steps: int = 500, fmax: float = 0.01, params_filter: dict = {}, - **kwargs + **kwargs, ) -> Atoms: """ Relax the atoms object. @@ -115,7 +116,7 @@ def relax_structures( constrain_symmetry: bool = False, fix_axis: Union[bool, Iterable[bool]] = False, pressure_in_GPa: Union[float, None] = None, - **kwargs + **kwargs, ) -> Union[Tuple[bool, Atoms], Tuple[List[bool], List[Atoms]]]: """ Args: @@ -138,11 +139,15 @@ def relax_structures( pass elif filter is None and pressure_in_GPa is not None: filter = "ExpCellFilter" - params_filter["scalar_pressure"] = pressure_in_GPa / 160.21766208 + params_filter["scalar_pressure"] = ( + pressure_in_GPa * GPa + ) # GPa = 1 / 160.21766208 elif filter is not None and pressure_in_GPa is None: params_filter["scalar_pressure"] = 0.0 else: - params_filter["scalar_pressure"] = pressure_in_GPa / 160.21766208 + params_filter["scalar_pressure"] = ( + pressure_in_GPa * GPa + ) # GPa = / 160.21766208 relaxer = Relaxer( optimizer=optimizer, diff --git a/src/mattersim/forcefield/m3gnet/m3gnet_multi_head.py b/src/mattersim/forcefield/m3gnet/m3gnet_multi_head.py deleted file mode 100644 index 6b04635..0000000 --- a/src/mattersim/forcefield/m3gnet/m3gnet_multi_head.py +++ /dev/null @@ -1,201 +0,0 @@ -# -*- coding: utf-8 -*- -from typing import Dict - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch_runstats.scatter import scatter - -from .modules import ( # noqa: E501 - MLP, - GatedMLP, - MainBlock, - SmoothBesselBasis, - SphericalBasisLayer, -) -from .scaling import AtomScaling - - -class M3Gnet_multi_head(nn.Module): - """ - M3Gnet with no massage passing - """ - - def __init__( - self, - normalizer_list: list[AtomScaling], - num_blocks: int = 4, - units: int = 128, - max_l: int = 4, - max_n: int = 4, - cutoff: float = 5.0, - device: str = "cuda", - max_z: int = 94, - threebody_cutoff: float = 4.0, - **kwargs, - ): - super().__init__() - self.rbf = SmoothBesselBasis(r_max=cutoff, max_n=max_n) - self.sbf = SphericalBasisLayer(max_n=max_n, max_l=max_l, cutoff=cutoff) - self.edge_encoder = MLP( - in_dim=max_n, out_dims=[units], activation="swish", use_bias=False - ) - module_list = [ - MainBlock(max_n, max_l, cutoff, units, max_n, threebody_cutoff) - for i in range(num_blocks) - ] - self.graph_conv = nn.ModuleList(module_list) - if isinstance(normalizer_list, list): - self.normalizer_list = nn.ModuleList(normalizer_list) - elif isinstance(normalizer_list, nn.ModuleList): - self.normalizer_list = normalizer_list - else: - raise NotImplementedError - self.final_layer_list = nn.ModuleList( - [ - GatedMLP( - in_dim=units, - out_dims=[units, units, 1], - activation=["swish", "swish", None], - ) - for _ in range(len(normalizer_list)) - ] - ) - self.apply(self.init_weights) - self.max_z = max_z - self.device = device - self.atom_embedding = MLP( - in_dim=max_z + 1, out_dims=[units], activation=None, use_bias=False - ) - self.atom_embedding.apply(self.init_weights_uniform) - self.model_args = { - "num_blocks": num_blocks, - "units": units, - "max_l": max_l, - "max_n": max_n, - "cutoff": cutoff, - "normalizer_list": self.normalizer_list, - "max_z": max_z, - "threebody_cutoff": threebody_cutoff, - } - print("This model is specifically designed for multi tasks") - - def forward( - self, - input: Dict[str, torch.Tensor], - dataset_idx: int = -1, - ): - # Exact data from input_dictionary - pos = input["atom_pos"] - cell = input["cell"] - pbc_offsets = input["pbc_offsets"] - atom_attr = input["atom_attr"] - edge_index = input["edge_index"] - three_body_indices = input["three_body_indices"] - num_three_body = input["num_three_body"] - num_bonds = input["num_bonds"] - num_triple_ij = input["num_triple_ij"] - num_atoms = input["num_atoms"] - num_graphs = input["num_graphs"] - batch = input["batch"] - - cumsum = torch.cumsum(num_bonds, dim=0) - num_bonds - index_bias = torch.repeat_interleave( # noqa: E501 - cumsum, num_three_body, dim=0 - ).unsqueeze(-1) - three_body_indices = three_body_indices + index_bias - - # === Refer to the implementation of M3GNet, === - # === we should re-compute the following attributes === - # edge_length, edge_vector(optional), triple_edge_length, theta_jik - atoms_batch = torch.repeat_interleave(repeats=num_atoms) - edge_batch = atoms_batch[edge_index[0]] - edge_vector = pos[edge_index[0]] - ( - pos[edge_index[1]] - + torch.einsum("bi, bij->bj", pbc_offsets, cell[edge_batch]) - ) - edge_length = torch.linalg.norm(edge_vector, dim=1) - vij = edge_vector[three_body_indices[:, 0].clone()] - vik = edge_vector[three_body_indices[:, 1].clone()] - rij = edge_length[three_body_indices[:, 0].clone()] - rik = edge_length[three_body_indices[:, 1].clone()] - cos_jik = torch.sum(vij * vik, dim=1) / (rij * rik) - # eps = 1e-7 avoid nan in torch.acos function - cos_jik = torch.clamp(cos_jik, min=-1.0 + 1e-7, max=1.0 - 1e-7) - triple_edge_length = rik.view(-1) - edge_length = edge_length.unsqueeze(-1) - atomic_numbers = atom_attr.squeeze(1).long() - - # featurize - atom_attr = self.atom_embedding(self.one_hot_atoms(atomic_numbers)) - edge_attr = self.rbf(edge_length.view(-1)) - edge_attr_zero = edge_attr # e_ij^0 - edge_attr = self.edge_encoder(edge_attr) - three_basis = self.sbf(triple_edge_length, torch.acos(cos_jik)) - - # feature_after_first_layer = None - - # Main Loop - for idx, conv in enumerate(self.graph_conv): - atom_attr, edge_attr = conv( - atom_attr, - edge_attr, - edge_attr_zero, - edge_index, - three_basis, - three_body_indices, - edge_length, - num_bonds, - num_triple_ij, - num_atoms, - ) - # if idx == 0: - # feature_after_first_layer = atom_attr.detach() - - # feature_before_branching_out = atom_attr.detach() - energies_i = self.final_layer_list[dataset_idx](atom_attr).view(-1) - if self.normalizer_list[dataset_idx] is not None: - energies_i = self.normalizer_list[dataset_idx]( - energies_i, atomic_numbers.view(-1) - ) - energies = scatter(energies_i, batch, dim=0, dim_size=num_graphs) - # return energies, - # feature_after_first_layer, - # feature_before_branching_out - return energies - - def init_weights(self, m): - if isinstance(m, nn.Linear): - torch.nn.init.xavier_uniform_(m.weight) - - def init_weights_uniform(self, m): - if isinstance(m, nn.Linear): - torch.nn.init.uniform_(m.weight, a=-0.05, b=0.05) - - def one_hot_atoms(self, species): - # one_hots = [] - # for i in range(species.shape[0]): - # one_hots.append( - # F.one_hot(species[i], - # num_classes=self.max_z+1 - # ).float().to(species.device) - # ) - # return torch.cat(one_hots, dim=0) - return F.one_hot(species, num_classes=self.max_z + 1).float() - - def print(self): - from prettytable import PrettyTable - - table = PrettyTable(["Modules", "Parameters"]) - total_params = 0 - for name, parameter in self.model.named_parameters(): - if not parameter.requires_grad: - continue - params = parameter.numel() - table.add_row([name, params]) - total_params += params - print(table) - print(f"Total Trainable Params: {total_params}") - - def get_model_args(self): - return self.model_args diff --git a/src/mattersim/forcefield/potential.py b/src/mattersim/forcefield/potential.py index dfe4bf0..b6edce9 100644 --- a/src/mattersim/forcefield/potential.py +++ b/src/mattersim/forcefield/potential.py @@ -2,6 +2,7 @@ """ Potential """ +import logging import os import pickle import random @@ -16,6 +17,7 @@ from ase import Atoms from ase.calculators.calculator import Calculator from ase.constraints import full_3x3_to_voigt_6_stress +from ase.units import GPa from torch.optim import Adam from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR from torch_ema import ExponentialMovingAverage @@ -24,9 +26,18 @@ from mattersim.datasets.utils.build import build_dataloader from mattersim.forcefield.m3gnet.m3gnet import M3Gnet -from mattersim.forcefield.m3gnet.m3gnet_multi_head import M3Gnet_multi_head from mattersim.jit_compile_tools.jit import compile_mode +rank = int(os.getenv("RANK", 0)) + +if rank == 0: + logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" + ) +else: + logging.basicConfig(level=logging.CRITICAL) +logger = logging.getLogger(__name__) + @compile_mode("script") class Potential(nn.Module): @@ -91,7 +102,7 @@ def __init__( self.ema = ema self.model_name = kwargs.get("model_name", "m3gnet") self.validation_metrics = kwargs.get( - "validation_metrics", {"loss": 10000.0} # noqa: E501 + "validation_metrics", {"loss": 10000000.0} # noqa: E501 ) self.last_epoch = kwargs.get("last_epoch", -1) self.description = kwargs.get("description", "") @@ -110,11 +121,11 @@ def freeze_reset_model( Freeze the model in the fine-tuning process """ if finetune_layers == -1: - print("fine-tuning all layers") + logger.info("fine-tuning all layers") elif finetune_layers >= 0 and finetune_layers < len( self.model.node_head.unified_encoder_layers ): - print(f"fine-tuning the last {finetune_layers} layers") + logger.info(f"fine-tuning the last {finetune_layers} layers") for name, param in self.model.named_parameters(): param.requires_grad = False @@ -165,11 +176,11 @@ def finetune_mode( reset_head_for_finetune: whether to reset the original head """ if self.model_name not in ["graphormer", "geomformer"]: - print("Only graphormer and geomformer support freezing layers") + logger.warning("Only graphormer and geomformer support freezing layers") return self.model.finetune_mode = True if finetune_head is None: - print("No finetune head is provided, using the original energy head") + logger.info("No finetune head is provided, using the original energy head") self.model.finetune_head = finetune_head self.model.finetune_task_mean = finetune_task_mean self.model.finetune_task_std = finetune_task_std @@ -193,9 +204,6 @@ def train_model( save_checkpoint: bool = False, save_path: str = "./results/", ckpt_interval: int = 10, - multi_head: bool = False, - dataset_name_list: List[str] = None, - sampler=None, is_distributed: bool = False, need_to_load_data: bool = False, **kwargs, @@ -230,64 +238,45 @@ def train_model( ) if is_distributed: self.rank = torch.distributed.get_rank() - print( + logger.info( f"Number of trainable parameters: {sum(p.numel() for p in self.model.parameters() if p.requires_grad):,}" # noqa: E501 ) for epoch in range(self.last_epoch + 1, epochs): - print(f"Epoch: {epoch} / {epochs}") - if not multi_head: - if need_to_load_data: - assert isinstance(dataloader, list) - random.Random(kwargs.get("seed", 42) + epoch).shuffle( # noqa: E501 - dataloader + logger.info(f"Epoch: {epoch} / {epochs}") + if need_to_load_data: + assert isinstance(dataloader, list) + random.Random(kwargs.get("seed", 42) + epoch).shuffle( # noqa: E501 + dataloader + ) + for idx, data_path in enumerate(dataloader): + with open(data_path, "rb") as f: + start = time.time() + train_data = pickle.load(f) + logger.info( + f"TRAIN: loading {data_path.split('/')[-2]}" + f"/{data_path.split('/')[-1]} dataset with " + f"{len(train_data)} data points, " + f"{len(train_data)} data points in total, " + f"time: {time.time() - start}" # noqa: E501 ) - for idx, data_path in enumerate(dataloader): - with open(data_path, "rb") as f: - start = time.time() - train_data = pickle.load(f) - print( - f"TRAIN: loading {data_path.split('/')[-2]}" - f"/{data_path.split('/')[-1]} dataset with " - f"{len(train_data)} data points, " - f"{len(train_data)} data points in total, " - f"time: {time.time() - start}" # noqa: E501 - ) - # Distributed Sampling - atoms_train_sampler = ( - torch.utils.data.distributed.DistributedSampler( - train_data, - seed=kwargs.get("seed", 42) - + idx * 131 - + epoch, # noqa: E501 - ) - ) - train_dataloader = DataLoader( + # Distributed Sampling + atoms_train_sampler = ( + torch.utils.data.distributed.DistributedSampler( train_data, - batch_size=kwargs.get("batch_size", 32), - shuffle=(atoms_train_sampler is None), - num_workers=0, - sampler=atoms_train_sampler, - ) - self.train_one_epoch( - train_dataloader, - epoch, - loss, - include_energy, - include_forces, - include_stresses, - force_loss_ratio, - stress_loss_ratio, - wandb, - is_distributed, - mode="train", - **kwargs, + seed=kwargs.get("seed", 42) + + idx * 131 + + epoch, # noqa: E501 ) - del train_dataloader - del train_data - torch.cuda.empty_cache() - else: + ) + train_dataloader = DataLoader( + train_data, + batch_size=kwargs.get("batch_size", 32), + shuffle=(atoms_train_sampler is None), + num_workers=0, + sampler=atoms_train_sampler, + ) self.train_one_epoch( - dataloader, + train_dataloader, epoch, loss, include_energy, @@ -300,28 +289,12 @@ def train_model( mode="train", **kwargs, ) - metric = self.train_one_epoch( - val_dataloader, - epoch, - loss, - include_energy, - include_forces, - include_stresses, - force_loss_ratio, - stress_loss_ratio, - wandb, - is_distributed, - mode="val", - **kwargs, - ) + del train_dataloader + del train_data + torch.cuda.empty_cache() else: - assert dataset_name_list is not None - assert ( - need_to_load_data is False - ), "load_training_data is not supported for multi-head training" # noqa: E501 - self.train_one_epoch_multi_head( + self.train_one_epoch( dataloader, - dataset_name_list, epoch, loss, include_energy, @@ -330,23 +303,24 @@ def train_model( force_loss_ratio, stress_loss_ratio, wandb, + is_distributed, mode="train", **kwargs, ) - metric = self.train_one_epoch_multi_head( - val_dataloader, - dataset_name_list, - epoch, - loss, - include_energy, - include_forces, - include_stresses, - force_loss_ratio, - stress_loss_ratio, - wandb, - mode="val", - **kwargs, - ) + metric = self.train_one_epoch( + val_dataloader, + epoch, + loss, + include_energy, + include_forces, + include_stresses, + force_loss_ratio, + stress_loss_ratio, + wandb, + is_distributed, + mode="val", + **kwargs, + ) if isinstance(self.scheduler, ReduceLROnPlateau): self.scheduler.step(metric) @@ -362,7 +336,7 @@ def train_model( "MAE_stress": metric[3], } if is_distributed: - # TODO 添加distributed训练早停 + # TODO add distributed early stopping if self.save_model_ddp( epoch, early_stop_patience, @@ -374,7 +348,6 @@ def train_model( ): break else: - # return True时为早停 if self.save_model( epoch, early_stop_patience, @@ -421,7 +394,7 @@ def save_model( ): self.save(os.path.join(save_path, "best_model.pth")) if epoch > best_model["last_epoch"] + early_stop_patience: - print("Early stopping") + logger.info("Early stopping") return True del best_model except BaseException: @@ -479,39 +452,24 @@ def test_model( include_forces: bool = False, include_stresses: bool = False, wandb=None, - multi_head: bool = False, **kwargs, ): """ Test model performance on a given dataset """ - if not multi_head: - return self.train_one_epoch( - val_dataloader, - 1, - loss, - include_energy, - include_forces, - include_stresses, - 1.0, - 0.1, - wandb=wandb, - mode="val", - ) - else: - return self.train_one_epoch_multi_head( - val_dataloader, - kwargs["dataset_name_list"], - 1, - loss, - include_energy, - include_forces, - include_stresses, - 1.0, - 0.1, - wandb=wandb, - mode="val", - ) + return self.train_one_epoch( + val_dataloader, + 1, + loss, + include_energy, + include_forces, + include_stresses, + 1.0, + 0.1, + wandb=wandb, + mode="val", + **kwargs, + ) def predict_properties( self, @@ -527,6 +485,9 @@ def predict_properties( - results[1] (list[np.ndarray]): a list of atomic forces - results[2] (list[np.ndarray]): a list of stresses """ + logger.warning( + "The unit of stress is GPa when using the predict_properties function." + ) self.model.eval() energies = [] forces = [] @@ -657,7 +618,7 @@ def train_one_epoch( s_mae = 0 if log: - print( + logger.info( "%s: Loss: %.4f, MAE(e): %.4f, MAE(f): %.4f, MAE(s): %.4f, Time: %.2fs, lr: %.8f\n" # noqa: E501 % ( mode, @@ -668,7 +629,6 @@ def train_one_epoch( time.time() - start_time, self.scheduler.get_last_lr()[0], ), - end="", ) if wandb and ((not is_distributed) or self.rank == 0): @@ -687,153 +647,6 @@ def train_one_epoch( if mode == "val": return (loss_avg_, e_mae, f_mae, s_mae) - def train_one_epoch_multi_head( - self, - dataloader_list, - dataset_name_list, - epoch, - loss, - include_energy=True, - include_forces=False, - include_stresses=False, - loss_f=1.0, - loss_s=0.1, - wandb=None, - mode="train", - **kwargs, - ): - start_time = time.time() - - metrics = {} - for dataset_name in dataset_name_list: - metrics_ = {} - metrics_["loss_avg"] = MeanMetric().to(self.device) - metrics_["train_e_mae"] = MeanMetric().to(self.device) - metrics_["train_f_mae"] = MeanMetric().to(self.device) - metrics_["train_s_mae"] = MeanMetric().to(self.device) - metrics[dataset_name] = metrics_ - - dataloader_iter = [ - dataloader.__iter__() for dataloader in dataloader_list # noqa: E501 - ] - if mode == "train": - self.model.train() - elif mode == "val": - self.model.eval() - - dataloader_len = [len(dataloader) for dataloader in dataloader_list] - for i in range(1, len(dataloader_len)): - dataloader_len[i] += dataloader_len[i - 1] - idx_list = list(range(dataloader_len[-1])) - random.shuffle(idx_list) - - for idx in idx_list: - for dataset_idx, bound in enumerate(dataloader_len): - if idx < bound: - break - - graph_batch = dataloader_iter[dataset_idx].__next__() - graph_batch.to(self.device) - input = batch_to_dict(graph_batch) - dataset_name = dataset_name_list[dataset_idx] - - if mode == "train": - result = self.forward( - input, - include_forces=include_forces, - include_stresses=include_stresses, - dataset_idx=dataset_idx, - ) - elif mode == "val": - with self.ema.average_parameters(): - result = self.forward( - input, - include_forces=include_forces, - include_stresses=include_stresses, - dataset_idx=dataset_idx, - ) - - loss_, e_mae, f_mae, s_mae = self.loss_calc( - graph_batch, - result, - loss, - include_energy, - include_forces, - include_stresses, - loss_f, - loss_s, - ) - - # loss backward - if mode == "train": - self.optimizer.zero_grad() - loss_.backward() - nn.utils.clip_grad_norm_( - self.model.parameters(), 1.0, norm_type=2 # noqa: E501 - ) - self.optimizer.step() - self.ema.update() - - metrics[dataset_name]["loss_avg"].update(loss_.detach()) - if include_energy: - metrics[dataset_name]["train_e_mae"].update(e_mae.detach()) - if include_forces: - metrics[dataset_name]["train_f_mae"].update(f_mae.detach()) - if include_stresses: - metrics[dataset_name]["train_s_mae"].update(s_mae.detach()) - - loss_all = 0 - e_mae = 0 - f_mae = 0 - s_mae = 0 - for dataset_name in dataset_name_list: - train_f_mae = train_s_mae = 0 - loss_avg = metrics[dataset_name]["loss_avg"].compute().item() - loss_all += loss_avg - if include_energy: - train_e_mae = metrics[dataset_name]["train_e_mae"].compute().item() - e_mae += train_e_mae - if include_forces and (dataset_name != "QM9"): - train_f_mae = ( - metrics[dataset_name]["train_f_mae"].compute().item() - ) # noqa: E501 - f_mae += train_f_mae - if include_stresses: - train_s_mae = ( - metrics[dataset_name]["train_s_mae"].compute().item() - ) # noqa: E501 - s_mae += train_s_mae - - print( - "%s %s: Loss: %.4f, MAE(e): %.4f, MAE(f): %.4f, MAE(s): %.4f, Time: %.2fs" # noqa: E501 - % ( - dataset_name, - mode, - loss_avg, - train_e_mae, - train_f_mae, - train_s_mae, - time.time() - start_time, - ) - ) - - if wandb: - wandb.log( - { - f"{dataset_name}/{mode}_loss": loss_avg, - f"{dataset_name}/{mode}_mae_e": train_e_mae, - f"{dataset_name}/{mode}_mae_f": train_f_mae, - f"{dataset_name}/{mode}_mae_s": train_s_mae, - }, - step=epoch, - ) - - if wandb: - wandb.log({"lr": self.scheduler.get_last_lr()[0]}, step=epoch) - - if mode == "val": - return (loss_all, e_mae, f_mae, s_mae) - def loss_calc( self, graph_batch, @@ -1008,7 +821,9 @@ def forward( output["forces"] = forces if stress_grad is not None: - stresses = 1 / volume[:, None, None] * stress_grad * 160.21766208 + stresses = ( + 1 / volume[:, None, None] * stress_grad / GPa + ) # 1/GPa = 160.21766208 output["stresses"] = stresses return output @@ -1017,7 +832,6 @@ def save(self, save_path): dir_name = os.path.dirname(save_path) if not os.path.exists(dir_name): os.makedirs(dir_name) - # 保存为单卡可加载的模型,多卡加载时需要先加载后放入DDP中 checkpoint = { "model_name": self.model_name, "model": self.model.module.state_dict() @@ -1037,40 +851,42 @@ def save(self, save_path): @staticmethod def load( - model_name: str = "m3gnet", load_path: str = None, + *, + model_name: str = "m3gnet", device: str = "cuda" if torch.cuda.is_available() else "cpu", args: Dict = None, load_training_state: bool = True, **kwargs, ): - if load_path is None: - if model_name == "m3gnet": - print("Loading the pre-trained M3GNet model") - current_dir = os.path.dirname(__file__) - load_path = os.path.join( - current_dir, "m3gnet/pretrained/mpf/best_model.pth" - ) - elif model_name == "graphormer" or model_name == "geomformer": - raise NotImplementedError - else: - raise NotImplementedError + if model_name.lower() != "m3gnet": + raise NotImplementedError + + current_dir = os.path.dirname(__file__) + if ( + load_path is None + or load_path.lower() == "mattersim-v1.0.0-1m.pth" + or load_path.lower() == "mattersim-v1.0.0-1m" + ): + load_path = os.path.join( + current_dir, "..", "pretrained_models/mattersim-v1.0.0-1M.pth" + ) + logger.info(f"Loading the pre-trained {os.path.basename(load_path)} model") + elif ( + load_path.lower() == "mattersim-v1.0.0-5m.pth" + or load_path.lower() == "mattersim-v1.0.0-5m" + ): + load_path = os.path.join( + current_dir, "..", "pretrained_models/mattersim-v1.0.0-5M.pth" + ) else: - print("Loading the model from %s" % load_path) + logger.info("Loading the model from %s" % load_path) + assert os.path.exists(load_path), f"Model file {load_path} not found" checkpoint = torch.load(load_path, map_location=device) assert checkpoint["model_name"] == model_name - if model_name == "m3gnet": - model = M3Gnet(device=device, **checkpoint["model_args"]).to(device) - elif model_name == "m3gnet_multi_head": - model = M3Gnet_multi_head(device=device, **checkpoint["model_args"]).to( - device - ) - elif model_name == "graphormer" or model_name == "geomformer": - raise NotImplementedError - else: - raise NotImplementedError + model = M3Gnet(device=device, **checkpoint["model_args"]).to(device) model.load_state_dict(checkpoint["model"], strict=False) if load_training_state: @@ -1128,90 +944,6 @@ def load( **kwargs, ) - @staticmethod - def load_from_multi_head_model( - model_name: str = "m3gnet", - head_index: int = -1, - load_path: str = None, - device: str = "cuda" if torch.cuda.is_available() else "cpu", - **kwargs, - ): - """ - Load one head of the multi-head model. - Args: - head_index: - -1: reset the head (final layer and - energy normalization module) - """ - if load_path is None: - if model_name == "m3gnet": - print("Loading the pre-trained multi-head M3GNet model") - current_dir = os.path.dirname(__file__) - load_path = os.path.join( - current_dir, - "m3gnet/pretrained/Transition1x-MD17-MPF21-QM9-HME21-OC20/" - "best_model.pth", - ) - else: - raise NotImplementedError - else: - print("Loading the model from %s" % load_path) - if head_index == -1: - print("Reset the final layer and normalization module") - checkpoint = torch.load(load_path, map_location=device) - if model_name == "m3gnet": - model = M3Gnet(device=device, **checkpoint["model_args"]).to( - device - ) # noqa: E501 - ori_ckpt = checkpoint["model"].copy() - for key in ori_ckpt: - if "final_layer_list" in key: - if "final_layer_list.%d" % head_index in key: - checkpoint["model"][ - key.replace("_layer_list.%d" % head_index, "") - ] = ori_ckpt[key] - del checkpoint["model"][key] - if "normalizer_list" in key: - if "normalizer_list.%d" % head_index in key: - checkpoint["model"][ - key.replace("_list.%d" % head_index, "") - ] = ori_ckpt[key] - del checkpoint["model"][key] - if "sph_2" in key: - del checkpoint["model"][key] - model.load_state_dict(checkpoint["model"], strict=True) - else: - raise NotImplementedError - description = checkpoint["description"] - model.eval() - - del checkpoint - - return Potential( - model, - device=device, - model_name=model_name, - description=description, - **kwargs, - ) - - def load_model(self, **kwargs): - warnings.warn( - "The interface of loading M3GNet model has been deprecated. " - "Please use Potential.load() instead.", - DeprecationWarning, - ) - warnings.warn( - "It only supports loading the pre-trained M3GNet model. " - "For other models, please use Potential.load() instead." - ) - current_dir = os.path.dirname(__file__) - load_path = os.path.join( - current_dir, "m3gnet/pretrained/mpf/best_model.pth" # noqa: E501 - ) - checkpoint = torch.load(load_path) - self.model.load_state_dict(checkpoint["model"]) - def set_description(self, description): self.description = description @@ -1270,7 +1002,7 @@ def __init__( potential: Potential, args_dict: dict = {}, compute_stress: bool = True, - stress_weight: float = 1.0, + stress_weight: float = GPa, device: str = "cuda" if torch.cuda.is_available() else "cpu", **kwargs, ): diff --git a/pretrained_models/mattersim-v1.0.0-1M.pth b/src/mattersim/pretrained_models/mattersim-v1.0.0-1M.pth similarity index 100% rename from pretrained_models/mattersim-v1.0.0-1M.pth rename to src/mattersim/pretrained_models/mattersim-v1.0.0-1M.pth