Skip to content

Commit

Permalink
Support MACE models in Proxima script
Browse files Browse the repository at this point in the history
  • Loading branch information
WardLT committed Sep 13, 2024
1 parent 5a515be commit e32e981
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions 2_proxima/0_run-serial-proxima.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
from ase.db import connect
from ase.md import MDLogger, VelocityVerlet
from chgnet.model import CHGNet
from mace.calculators import mace_mp
from gitinfo import get_git_info

from cascade.learning.chgnet import CHGNetInterface
from cascade.proxima import SerialLearningCalculator
from cascade.calculator import make_calculator
from cascade.learning.torchani import TorchANI
from cascade.learning.torchani.build import make_output_nets, make_aev_computer
from cascade.learning.mace import MACEInterface
from cascade.proxima import SerialLearningCalculator
from cascade.calculator import make_calculator
from cascade.utils import canonicalize
import cascade

Expand All @@ -50,7 +52,7 @@
group.add_argument('--seed', type=int, default=1, help='Random seed used to start dynamics')

group = parser.add_argument_group(title="Learner Details", description="Configure the surrogate model")
group.add_argument('--model-type', choices=['ani', 'chgnet'], help='Which type of machine learning model to train.')
group.add_argument('--model-type', choices=['ani', 'chgnet', 'mace'], help='Which type of machine learning model to train.')
group.add_argument('--initial-model', help='Path to initial model in message format. Code will generate a network with default settings if none provided')
group.add_argument('--initial-data', nargs='*', default=(), help='Path to data files (e.g., ASE .traj and .db) containing initial training data')
group.add_argument('--ensemble-size', type=int, default=2, help='Number of models to train on different data segments')
Expand Down Expand Up @@ -120,6 +122,7 @@
learner = {
'ani': TorchANI(),
'chgnet': CHGNetInterface(),
'mace': MACEInterface()
}[args.model_type]
main_logger.info(f'Ready to train a {args.model_type} model')

Expand All @@ -138,6 +141,9 @@
elif args.model_type == 'chgnet':
models = [CHGNet.load()] * args.ensemble_size
logger.info('Loaded the pretrained weights for CHGNet')
elif args.model_type == 'mace':
models = [mace_mp('small').models[0]] * args.ensemble_size
logger.info('Loaded the MACE-MP small model.')
else:
raise NotImplemented(f'Default models not implemented for {args.model_type}')

Expand Down

0 comments on commit e32e981

Please sign in to comment.