Skip to content

Commit

Permalink
WIP: use mace as a target (#57)
Browse files Browse the repository at this point in the history
* WIP: use mace as a target

* io changes for MACE

* canonicalize atoms before saving to db

* convert all results to float64 for consistency
  • Loading branch information
miketynes authored Nov 8, 2024
1 parent d92140b commit cf6fbd5
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 10 deletions.
22 changes: 14 additions & 8 deletions 2_proxima/0_run-serial-proxima.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from ase.db import connect
from ase.md import MDLogger, VelocityVerlet
from chgnet.model import CHGNet
from mace.calculators import mace_mp
from gitinfo import get_git_info

from cascade.learning.chgnet import CHGNetInterface
Expand Down Expand Up @@ -113,7 +114,7 @@
for initial_data in args.initial_data:
main_logger.info(f'Adding data from {initial_data} to database')
for frame in io.iread(initial_data):
db.write(frame)
db.write(canonicalize(frame))

with connect(db_path) as db:
main_logger.info(f'Training database has {len(db)} entries.')
Expand Down Expand Up @@ -171,7 +172,10 @@
# Set up the proxima calculator
calc_dir = Path('cp2k-run') / params_hash
calc_dir.mkdir(parents=True, exist_ok=True)
target_calc = make_calculator(args.calculator, directory=str(calc_dir))
if args.calculator != 'mace_mp':
target_calc = make_calculator(args.calculator, directory=str(calc_dir))
else:
target_calc = mace_mp('small', device=args.training_device)

learning_calc = SerialLearningCalculator(
target_calc=target_calc,
Expand Down Expand Up @@ -212,24 +216,26 @@ def _log_proxima():
start_time = perf_counter()
with open(run_dir / 'proxima-log.json', 'a') as fp:
last_uncer, last_error = learning_calc.error_history[-1]
print(json.dumps({
d = {
'step_time': step_time,
'energy': float(atoms.get_potential_energy()),
'maximum_force': float(np.linalg.norm(atoms.get_forces(), axis=1).max()),
'stress': atoms.get_stress().astype(float).tolist(),
'temperature': atoms.get_temperature(),
'volume': atoms.get_volume(),
'temperature': float(atoms.get_temperature()),
'volume': float(atoms.get_volume()),
'used_surrogate': bool(learning_calc.used_surrogate),
'proxima_alpha': learning_calc.alpha,
'proxima_threshold': learning_calc.threshold,
'proxima_alpha': float(learning_calc.alpha) if learning_calc.alpha is not None else float(np.nan),
'proxima_threshold': float(learning_calc.threshold) if learning_calc.threshold is not None else float(np.nan),
'proxima_blending_step': int(learning_calc.blending_step),
'proxima_lambda_target': float(learning_calc.lambda_target),
'last_uncer': float(last_uncer),
'last_error': float(last_error),
'total_invocations': learning_calc.total_invocations,
'target_invocations': learning_calc.target_invocations,
'model_version': learning_calc.model_version
}), file=fp)
}
#print({k: type(v) for k,v in d.items()})
print(json.dumps(d), file=fp)

def _write_to_traj():
with Trajectory(traj_path, mode='a') as traj:
Expand Down
4 changes: 2 additions & 2 deletions cascade/proxima/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from cascade.learning.base import BaseLearnableForcefield
from cascade.calculator import EnsembleCalculator
from cascade.utils import to_voigt
from cascade.utils import to_voigt, canonicalize

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -248,7 +248,7 @@ def calculate(
db_atoms = atoms.copy()
db_atoms.calc = target_calc
with connect(self.parameters['db_path']) as db:
db.write(db_atoms)
db.write(canonicalize(db_atoms))

# Reset the model if the training frequency has been reached
surrogate_forces = self.surrogate_calc.results['forces']
Expand Down
5 changes: 5 additions & 0 deletions cascade/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def canonicalize(atoms: Atoms) -> Atoms:
old_calc = atoms.calc
out_atoms.calc = SinglePointCalculator(atoms)
out_atoms.calc.results = old_calc.results.copy()

for k, v in out_atoms.calc.results.items():
if isinstance(v, np.ndarray):
out_atoms.calc.results[k] = v.astype(np.float64)

return out_atoms


Expand Down

0 comments on commit cf6fbd5

Please sign in to comment.