From f9bcc7168e7f2e279276a8d2e5c9194f7b04939f Mon Sep 17 00:00:00 2001 From: Logan Ward Date: Fri, 8 Nov 2024 07:20:20 -0800 Subject: [PATCH] Shuffle data read from db files --- 1_ml-potential/mace/0_train-mace.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/1_ml-potential/mace/0_train-mace.py b/1_ml-potential/mace/0_train-mace.py index 8ec1daf..3dc8292 100644 --- a/1_ml-potential/mace/0_train-mace.py +++ b/1_ml-potential/mace/0_train-mace.py @@ -42,8 +42,15 @@ train_atoms: list[Atoms] = [] valid_atoms: list[Atoms] = [] train_hasher = sha256() + rng = np.random.RandomState(1) for file in args.train_files: my_atoms = read(file, slice(None)) + + # Shuffle the data if they are not from a traj file + if not file.endswith('.traj'): + rng.shuffle(my_atoms) + + # Hash dataset for reproducibility for atoms in my_atoms: train_hasher.update(atoms.positions.tobytes()) train_hasher.update(atoms.cell.tobytes())