From cf6fbd58d7c0e8184eaadad6b38a6528a60d14eb Mon Sep 17 00:00:00 2001 From: Michael Tynes Date: Fri, 8 Nov 2024 15:30:26 -0600 Subject: [PATCH] WIP: use mace as a target (#57) * WIP: use mace as a target * io changes for MACE * canonicalize atoms before saving to db * convert all results to float64 for consistency --- 2_proxima/0_run-serial-proxima.py | 22 ++++++++++++++-------- cascade/proxima/__init__.py | 4 ++-- cascade/utils.py | 5 +++++ 3 files changed, 21 insertions(+), 10 deletions(-) diff --git a/2_proxima/0_run-serial-proxima.py b/2_proxima/0_run-serial-proxima.py index e4f8dd7..88a2171 100644 --- a/2_proxima/0_run-serial-proxima.py +++ b/2_proxima/0_run-serial-proxima.py @@ -18,6 +18,7 @@ from ase.db import connect from ase.md import MDLogger, VelocityVerlet from chgnet.model import CHGNet +from mace.calculators import mace_mp from gitinfo import get_git_info from cascade.learning.chgnet import CHGNetInterface @@ -113,7 +114,7 @@ for initial_data in args.initial_data: main_logger.info(f'Adding data from {initial_data} to database') for frame in io.iread(initial_data): - db.write(frame) + db.write(canonicalize(frame)) with connect(db_path) as db: main_logger.info(f'Training database has {len(db)} entries.') @@ -171,7 +172,10 @@ # Set up the proxima calculator calc_dir = Path('cp2k-run') / params_hash calc_dir.mkdir(parents=True, exist_ok=True) - target_calc = make_calculator(args.calculator, directory=str(calc_dir)) + if args.calculator != 'mace_mp': + target_calc = make_calculator(args.calculator, directory=str(calc_dir)) + else: + target_calc = mace_mp('small', device=args.training_device) learning_calc = SerialLearningCalculator( target_calc=target_calc, @@ -212,16 +216,16 @@ def _log_proxima(): start_time = perf_counter() with open(run_dir / 'proxima-log.json', 'a') as fp: last_uncer, last_error = learning_calc.error_history[-1] - print(json.dumps({ + d = { 'step_time': step_time, 'energy': float(atoms.get_potential_energy()), 'maximum_force': float(np.linalg.norm(atoms.get_forces(), axis=1).max()), 'stress': atoms.get_stress().astype(float).tolist(), - 'temperature': atoms.get_temperature(), - 'volume': atoms.get_volume(), + 'temperature': float(atoms.get_temperature()), + 'volume': float(atoms.get_volume()), 'used_surrogate': bool(learning_calc.used_surrogate), - 'proxima_alpha': learning_calc.alpha, - 'proxima_threshold': learning_calc.threshold, + 'proxima_alpha': float(learning_calc.alpha) if learning_calc.alpha is not None else float(np.nan), + 'proxima_threshold': float(learning_calc.threshold) if learning_calc.threshold is not None else float(np.nan), 'proxima_blending_step': int(learning_calc.blending_step), 'proxima_lambda_target': float(learning_calc.lambda_target), 'last_uncer': float(last_uncer), @@ -229,7 +233,9 @@ def _log_proxima(): 'total_invocations': learning_calc.total_invocations, 'target_invocations': learning_calc.target_invocations, 'model_version': learning_calc.model_version - }), file=fp) + } + #print({k: type(v) for k,v in d.items()}) + print(json.dumps(d), file=fp) def _write_to_traj(): with Trajectory(traj_path, mode='a') as traj: diff --git a/cascade/proxima/__init__.py b/cascade/proxima/__init__.py index a9615aa..a8566ca 100644 --- a/cascade/proxima/__init__.py +++ b/cascade/proxima/__init__.py @@ -13,7 +13,7 @@ from cascade.learning.base import BaseLearnableForcefield from cascade.calculator import EnsembleCalculator -from cascade.utils import to_voigt +from cascade.utils import to_voigt, canonicalize logger = logging.getLogger(__name__) @@ -248,7 +248,7 @@ def calculate( db_atoms = atoms.copy() db_atoms.calc = target_calc with connect(self.parameters['db_path']) as db: - db.write(db_atoms) + db.write(canonicalize(db_atoms)) # Reset the model if the training frequency has been reached surrogate_forces = self.surrogate_calc.results['forces'] diff --git a/cascade/utils.py b/cascade/utils.py index 0db04aa..07ab732 100644 --- a/cascade/utils.py +++ b/cascade/utils.py @@ -60,6 +60,11 @@ def canonicalize(atoms: Atoms) -> Atoms: old_calc = atoms.calc out_atoms.calc = SinglePointCalculator(atoms) out_atoms.calc.results = old_calc.results.copy() + + for k, v in out_atoms.calc.results.items(): + if isinstance(v, np.ndarray): + out_atoms.calc.results[k] = v.astype(np.float64) + return out_atoms