From 50fe59125a06c4f9bcb98311d163b2004f9f93e3 Mon Sep 17 00:00:00 2001 From: brian Date: Wed, 25 Dec 2024 14:37:54 +0800 Subject: [PATCH] feat: Add finetune method for MatterSim --- pyproject.toml | 68 ++++--- script/finetune_mattersim.py | 248 ++++++++++++++++++++++++++ script/vasprun_to_xyz.py | 63 +++++++ src/mattersim/forcefield/potential.py | 50 +++--- src/mattersim/utils/logger_utils.py | 30 ++++ 5 files changed, 402 insertions(+), 57 deletions(-) create mode 100644 script/finetune_mattersim.py create mode 100644 script/vasprun_to_xyz.py create mode 100644 src/mattersim/utils/logger_utils.py diff --git a/pyproject.toml b/pyproject.toml index f113677..e24d5bf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,11 +1,5 @@ [build-system] -requires = [ - "Cython>=0.29.32", - "numpy<2", - "setuptools>=45", - "setuptools_scm", - "wheel", -] +requires = ["setuptools>=45", "wheel", "Cython>=0.29.32", "numpy<2", "setuptools_scm"] build-backend = "setuptools.build_meta" [project] @@ -13,68 +7,70 @@ name = "mattersim" dynamic = ["version"] description = "MatterSim: A Deep Learning Atomistic Model Across Elements, Temperatures and Pressures." authors = [ - { name = "Han Yang", email = "hanyang@microsoft.com" }, - { name = "Hongxia Hao", email = "hongxiahao@microsoft.com" }, - { name = "Jielan Li", email = "jielanli@microsoft.com" }, - { name = "Ziheng Lu", email = "zihenglu@microsoft.com" }, + {name = "Han Yang", email = "hanyang@microsoft.com"}, + {name = "Jielan Li", email = "jielanli@microsoft.com"}, + {name = "Hongxia Hao", email = "hongxiahao@microsoft.com"}, + {name = "Ziheng Lu", email = "zihenglu@microsoft.com"} ] readme = "README.md" requires-python = ">=3.9" classifiers = [ + "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", - "Programming Language :: Python :: 3", ] dependencies = [ "ase>=3.23.0", - "azure-identity", - "azure-storage-blob", - "deprecated", - "e3nn>=0.5.0", - "emmet-core>=0.84", - "loguru", - "mp-api", + "e3nn==0.5.0", + "seekpath", "numpy<2", - "opt_einsum_fx", - "pydantic>=2.9.2", "pymatgen", - "seekpath", - "torch-ema>=0.3", - "torch>=2.2.0", - "torch_geometric>=2.5.3", - "torch_runstats>=0.2.0", - "torchaudio>=2.2.0", + "loguru", + "torch==2.2.0", + "torchvision==0.17.0", + "torchaudio==2.2.0", + "torch_runstats==0.2.0", + "torch_geometric==2.5.3", "torchmetrics>=0.10.0", - "torchvision>=0.17.0", + "torch-ema==0.3", + "opt_einsum_fx", + "azure-storage-blob", + "azure-identity", + "mp-api", + "emmet-core<0.84", + "pydantic==2.9.2", + "deprecated", + "wandb" ] [project.optional-dependencies] dev = [ - "ipykernel", - "ipython", - "pre-commit", "pytest", "pytest-cov", "pytest-testmon", + "pre-commit", + "ipython", + "ipykernel" ] docs = [ - "nbconvert", - "nbsphinx", - "recommonmark", "sphinx", "sphinx-autodoc-typehints", - "sphinx-copybutton", "sphinx_book_theme", + "sphinx-copybutton", + "recommonmark", + "nbsphinx", + "nbconvert", ] + [project.urls] "Homepage" = "https://github.com/microsoft/mattersim" "Bug Tracker" = "https://github.com/microsoft/mattersim/issues" [tool.setuptools] -package-dir = { "" = "src" } +package-dir = {"" = "src"} [tool.setuptools.packages.find] where = ["src"] diff --git a/script/finetune_mattersim.py b/script/finetune_mattersim.py new file mode 100644 index 0000000..09e8a7d --- /dev/null +++ b/script/finetune_mattersim.py @@ -0,0 +1,248 @@ +# -*- coding: utf-8 -*- +import argparse +import os +import pickle as pkl +import random + +import numpy as np +import torch +import torch.distributed +import wandb +from ase.units import GPa + +from mattersim.datasets.utils.build import build_dataloader +from mattersim.forcefield.m3gnet.scaling import AtomScaling +from mattersim.forcefield.potential import Potential +from mattersim.utils.atoms_utils import AtomsAdaptor +from mattersim.utils.logger_utils import get_logger + +logger = get_logger() +torch.distributed.init_process_group(backend="nccl") +local_rank = int(os.environ["LOCAL_RANK"]) + + +def main(args): + args_dict = vars(args) + if args.wandb and local_rank == 0: + wandb_api_key = ( + args.wandb_api_key + if args.wandb_api_key is not None + else os.getenv("WANDB_API_KEY") + ) + wandb.login(key=wandb_api_key) + wandb.init( + project=args.wandb_project, + name=args.run_name, + config=args, + # id=args.run_name, + # resume="allow", + ) + + if args.wandb: + args_dict["wandb"] = wandb + + torch.distributed.barrier() + + # set random seed + random.seed(args.seed) + np.random.seed(args.seed) + torch.manual_seed(args.seed) + + torch.cuda.set_device(local_rank) + + if args.train_data_path.endswith(".pkl"): + with open(args.train_data_path, "rb") as f: + atoms_train = pkl.load(f) + else: + atoms_train = AtomsAdaptor.from_file(filename=args.train_data_path) + energies = [] + forces = [] if args.include_forces else None + stresses = [] if args.include_stresses else None + logger.info("Processing training datasets...") + for atoms in atoms_train: + energies.append(atoms.get_potential_energy()) + if args.include_forces: + forces.append(atoms.get_forces()) + if args.include_stresses: + stresses.append(atoms.get_stress(voigt=False) / GPa) # convert to GPa + + dataloader = build_dataloader( + atoms_train, + energies, + forces, + stresses, + shuffle=True, + pin_memory=True, + is_distributed=True, + **args_dict, + ) + + device = "cuda" if torch.cuda.is_available() else "cpu" + # build energy normalization module + if args.re_normalize: + scale = AtomScaling( + atoms=atoms_train, + total_energy=energies, + forces=forces, + verbose=True, + **args_dict, + ).to(device) + + if args.valid_data_path is not None: + if args.valid_data_path.endswith(".pkl"): + with open(args.valid_data_path, "rb") as f: + atoms_val = pkl.load(f) + else: + atoms_val = AtomsAdaptor.from_file(filename=args.train_data_path) + energies = [] + forces = [] if args.include_forces else None + stresses = [] if args.include_stresses else None + logger.info("Processing validation datasets...") + for atoms in atoms_val: + energies.append(atoms.get_potential_energy()) + if args.include_forces: + forces.append(atoms.get_forces()) + if args.include_stresses: + stresses.append(atoms.get_stress(voigt=False) / GPa) # convert to GPa + val_dataloader = build_dataloader( + atoms_val, + energies, + forces, + stresses, + pin_memory=True, + is_distributed=True, + **args_dict, + ) + else: + val_dataloader = None + + potential = Potential.from_checkpoint( + load_path=args.load_model_path, + load_training_state=False, + **args_dict, + ) + + if args.re_normalize: + potential.model.set_normalizer(scale) + + potential.model = torch.nn.parallel.DistributedDataParallel(potential.model) + torch.distributed.barrier() + + potential.train_model( + dataloader, + val_dataloader, + loss=torch.nn.HuberLoss(delta=0.01), + is_distributed=True, + **args_dict, + ) + + if local_rank == 0 and args.save_checkpoint: + wandb.save(os.path.join(args.save_path, "best_model.pth")) + + +if __name__ == "__main__": + # Some important arguments + parser = argparse.ArgumentParser() + + # path parameters + parser.add_argument( + "--run_name", type=str, default="example", help="name of the run" + ) + parser.add_argument( + "--train_data_path", type=str, default="./sample.xyz", help="train data path" + ) + parser.add_argument( + "--valid_data_path", type=str, default=None, help="valid data path" + ) + parser.add_argument( + "--load_model_path", + type=str, + default="mattersim-v1.0.0-1m", + help="path to load the model", + ) + parser.add_argument( + "--save_path", type=str, default="./results", help="path to save the model" + ) + parser.add_argument( + "--save_checkpoint", + type=bool, + default=False, + action=argparse.BooleanOptionalAction, + ) + parser.add_argument( + "--ckpt_interval", + type=int, + default=10, + help="save checkpoint every ckpt_interval epochs", + ) + parser.add_argument("--device", type=str, default="cuda") + + # model parameters + parser.add_argument("--cutoff", type=float, default=5.0, help="cutoff radius") + parser.add_argument( + "--threebody_cutoff", + type=float, + default=4.0, + help="cutoff radius for three-body term, which should be smaller than cutoff (two-body)", # noqa: E501 + ) + + # training parameters + parser.add_argument("--epochs", type=int, default=1000, help="number of epochs") + parser.add_argument("--batch_size", type=int, default=16) + parser.add_argument("--lr", type=float, default=2e-4) + parser.add_argument( + "--step_size", + type=int, + default=10, + help="step epoch for learning rate scheduler", + ) + parser.add_argument( + "--include_forces", + type=bool, + default=True, + action=argparse.BooleanOptionalAction, + ) + parser.add_argument( + "--include_stresses", + type=bool, + default=False, + action=argparse.BooleanOptionalAction, + ) + parser.add_argument("--force_loss_ratio", type=float, default=1.0) + parser.add_argument("--stress_loss_ratio", type=float, default=0.1) + parser.add_argument("--early_stop_patience", type=int, default=10) + parser.add_argument("--seed", type=int, default=42) + + # scaling parameters + parser.add_argument( + "--re_normalize", + type=bool, + default=False, + action=argparse.BooleanOptionalAction, + help="re-normalize the energy and forces according to the new data", + ) + parser.add_argument("--scale_key", type=str, default="per_species_forces_rms") + parser.add_argument( + "--shift_key", type=str, default="per_species_energy_mean_linear_reg" + ) + parser.add_argument("--init_scale", type=float, default=None) + parser.add_argument("--init_shift", type=float, default=None) + parser.add_argument( + "--trainable_scale", + type=bool, + default=False, + action=argparse.BooleanOptionalAction, + ) + parser.add_argument( + "--trainable_shift", + type=bool, + default=False, + action=argparse.BooleanOptionalAction, + ) + + # wandb parameters + parser.add_argument("--wandb", action="store_true") + parser.add_argument("--wandb_api_key", type=str, default=None) + parser.add_argument("--wandb_project", type=str, default="wandb_test") + args = parser.parse_args() + main(args) diff --git a/script/vasprun_to_xyz.py b/script/vasprun_to_xyz.py new file mode 100644 index 0000000..39a1e11 --- /dev/null +++ b/script/vasprun_to_xyz.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +import os +import random + +from ase.io import write + +from mattersim.utils.atoms_utils import AtomsAdaptor + +vasp_files = [ + "work/data/H/vasp/vasprun.xml", + "work/data/H/vasp_2/vasprun.xml", + "work/data/H/vasp_3/vasprun.xml", + "work/data/H/vasp_4/vasprun.xml", + "work/data/H/vasp_5/vasprun.xml", + "work/data/H/vasp_6/vasprun.xml", + "work/data/H/vasp_7/vasprun.xml", + "work/data/H/vasp_8/vasprun.xml", + "work/data/H/vasp_9/vasprun.xml", + "work/data/H/vasp_10/vasprun.xml", +] +train_ratio = 0.8 +validation_ratio = 0.1 +test_ratio = 0.1 + +save_dir = "./xyz_files" +os.makedirs(save_dir, exist_ok=True) + + +def main(): + atoms_train = [] + atoms_validation = [] + atoms_test = [] + + random.seed(42) + + for vasp_file in vasp_files: + atoms_list = AtomsAdaptor.from_file(filename=vasp_file) + random.shuffle(atoms_list) + num_atoms = len(atoms_list) + num_train = int(num_atoms * train_ratio) + num_validation = int(num_atoms * validation_ratio) + + atoms_train.extend(atoms_list[:num_train]) + atoms_validation.extend(atoms_list[num_train : num_train + num_validation]) + atoms_test.extend(atoms_list[num_train + num_validation :]) + + print( + f"Total number of atoms: {len(atoms_train) + len(atoms_validation) + len(atoms_test)}" # noqa: E501 + ) + + print(f"Number of atoms in the training set: {len(atoms_train)}") + print(f"Number of atoms in the validation set: {len(atoms_validation)}") + print(f"Number of atoms in the test set: {len(atoms_test)}") + + # Save the training, validation, and test datasets to xyz files + + write(f"{save_dir}/train.xyz", atoms_train) + write(f"{save_dir}/valid.xyz", atoms_validation) + write(f"{save_dir}/test.xyz", atoms_test) + + +if __name__ == "__main__": + main() diff --git a/src/mattersim/forcefield/potential.py b/src/mattersim/forcefield/potential.py index eb44261..8ae5ba8 100644 --- a/src/mattersim/forcefield/potential.py +++ b/src/mattersim/forcefield/potential.py @@ -18,7 +18,6 @@ from ase.constraints import full_3x3_to_voigt_6_stress from ase.units import GPa from deprecated import deprecated -from loguru import logger from torch.optim import Adam from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR from torch_ema import ExponentialMovingAverage @@ -29,8 +28,10 @@ from mattersim.forcefield.m3gnet.m3gnet import M3Gnet from mattersim.jit_compile_tools.jit import compile_mode from mattersim.utils.download_utils import download_checkpoint +from mattersim.utils.logger_utils import get_logger rank = int(os.getenv("RANK", 0)) +logger = get_logger() @compile_mode("script") @@ -101,7 +102,8 @@ def __init__( self.last_epoch = kwargs.get("last_epoch", -1) self.description = kwargs.get("description", "") self.saved_name = ["loss", "MAE_energy", "MAE_force", "MAE_stress"] - self.best_metric = 10 + self.best_metric = 10000 + self.best_metric_epoch = 0 self.rank = None self.use_finetune_label_loss = kwargs.get("use_finetune_label_loss", False) @@ -269,7 +271,7 @@ def train_model( num_workers=0, sampler=atoms_train_sampler, ) - self.train_one_epoch( + metric = self.train_one_epoch( train_dataloader, epoch, loss, @@ -287,7 +289,7 @@ def train_model( del train_data torch.cuda.empty_cache() else: - self.train_one_epoch( + metric = self.train_one_epoch( dataloader, epoch, loss, @@ -301,20 +303,21 @@ def train_model( mode="train", **kwargs, ) - metric = self.train_one_epoch( - val_dataloader, - epoch, - loss, - include_energy, - include_forces, - include_stresses, - force_loss_ratio, - stress_loss_ratio, - wandb, - is_distributed, - mode="val", - **kwargs, - ) + if val_dataloader is not None: + metric = self.train_one_epoch( + val_dataloader, + epoch, + loss, + include_energy, + include_forces, + include_stresses, + force_loss_ratio, + stress_loss_ratio, + wandb, + is_distributed, + mode="val", + **kwargs, + ) if isinstance(self.scheduler, ReduceLROnPlateau): self.scheduler.step(metric) @@ -330,7 +333,6 @@ def train_model( "MAE_stress": metric[3], } if is_distributed: - # TODO add distributed early stopping if self.save_model_ddp( epoch, early_stop_patience, @@ -426,9 +428,13 @@ def save_model_ddp( # so this operation should not be performed. # Only save the model on GPU 0, # the model on each GPU should be exactly the same. + if epoch > self.best_metric_epoch + early_stop_patience: + logger.info("Early stopping") + return True if metric[self.idx] < self.best_metric: self.best_metric = metric[self.idx] + self.best_metric_epoch = epoch if save_checkpoint and self.rank == 0: self.save(os.path.join(save_path, "best_model.pth")) if self.rank == 0 and save_checkpoint: @@ -638,8 +644,7 @@ def train_one_epoch( step=epoch, ) - if mode == "val": - return (loss_avg_, e_mae, f_mae, s_mae) + return (loss_avg_, e_mae, f_mae, s_mae) def loss_calc( self, @@ -894,6 +899,7 @@ def from_checkpoint( checkpoint = torch.load(load_path, map_location=device) assert checkpoint["model_name"] == model_name + checkpoint["model_args"].update(kwargs) model = M3Gnet(device=device, **checkpoint["model_args"]).to(device) model.load_state_dict(checkpoint["model"], strict=False) @@ -998,11 +1004,13 @@ def load( logger.info(f"Loading the pre-trained {os.path.basename(load_path)} model") else: logger.info("Loading the model from %s" % load_path) + assert os.path.exists(load_path), f"Model file {load_path} not found" checkpoint = torch.load(load_path, map_location=device) assert checkpoint["model_name"] == model_name + checkpoint["model_args"].update(kwargs) model = M3Gnet(device=device, **checkpoint["model_args"]).to(device) model.load_state_dict(checkpoint["model"], strict=False) diff --git a/src/mattersim/utils/logger_utils.py b/src/mattersim/utils/logger_utils.py new file mode 100644 index 0000000..3688e88 --- /dev/null +++ b/src/mattersim/utils/logger_utils.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +import os +import sys + +from loguru import logger + +handlers = {} + + +def get_logger(): + if not handlers: + logger.remove() + handlers["console"] = logger.add( + sys.stdout, + colorize=True, + filter=log_filter, + enqueue=True, + ) + + return logger + + +def log_filter(record): + if record["level"].name != "INFO": + return True + + if "RANK" not in os.environ or int(os.environ["RANK"]) == 0: + return True + else: + return False