diff --git a/src/mattersim/forcefield/potential.py b/src/mattersim/forcefield/potential.py index 6647216..7d61d6e 100644 --- a/src/mattersim/forcefield/potential.py +++ b/src/mattersim/forcefield/potential.py @@ -2,6 +2,7 @@ """ Potential """ +import logging import os import pickle import random @@ -9,7 +10,6 @@ import warnings from typing import Dict, List, Optional -import logging import numpy as np import torch import torch.distributed @@ -17,6 +17,7 @@ from ase import Atoms from ase.calculators.calculator import Calculator from ase.constraints import full_3x3_to_voigt_6_stress +from ase.units import GPa from torch.optim import Adam from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR from torch_ema import ExponentialMovingAverage @@ -37,6 +38,7 @@ logging.basicConfig(level=logging.CRITICAL) logger = logging.getLogger(__name__) + @compile_mode("script") class Potential(nn.Module): """ @@ -178,9 +180,7 @@ def finetune_mode( return self.model.finetune_mode = True if finetune_head is None: - logger.info( - "No finetune head is provided, using the original energy head" - ) + logger.info("No finetune head is provided, using the original energy head") self.model.finetune_head = finetune_head self.model.finetune_task_mean = finetune_task_mean self.model.finetune_task_std = finetune_task_std @@ -1000,7 +1000,7 @@ def __init__( potential: Potential, args_dict: dict = {}, compute_stress: bool = True, - stress_weight: float = 1/160.21766208, + stress_weight: float = GPa, device: str = "cuda" if torch.cuda.is_available() else "cpu", **kwargs, ):