From 1442177972a3aecd3d5e4ea824d15871fe7b0055 Mon Sep 17 00:00:00 2001 From: Muhammed Shuaibi <45150244+mshuaibii@users.noreply.github.com> Date: Fri, 12 Jul 2024 14:32:17 -0700 Subject: [PATCH] Relax config (#758) * update example config per PR714 * no need to support this * y->energy --- configs/ocp_example.yml | 14 ++-- .../s2ef/2M/dimenet_plus_plus/dpp_relax.yml | 84 ------------------- .../core/common/relaxation/ase_utils.py | 2 +- 3 files changed, 8 insertions(+), 92 deletions(-) delete mode 100755 configs/s2ef/2M/dimenet_plus_plus/dpp_relax.yml diff --git a/configs/ocp_example.yml b/configs/ocp_example.yml index bbab42b51..b979a7a32 100644 --- a/configs/ocp_example.yml +++ b/configs/ocp_example.yml @@ -71,6 +71,13 @@ dataset: test: # Directory containing test set LMDBs src: data/s2ef/all/test_id/ + relax: + # Path to initial structures to run relaxations on. Same as the IS2RE set. + src: data/is2re/all/test_id/data.lmdb + # To shard a dataset into smaller subsets, define the total_shards desired + # and the shard a particular process to see. + total_shards: 1 # int (optional) + shard: 0 # int (optional) task: # This is an argument used for checkpoint loading. By default it is True and loads @@ -92,13 +99,6 @@ task: relaxation_fmax: 0.02 # Whether to save out the positions. write_pos: True # True or False - # Path to initial structures to run relaxations on. Same as the IS2RE set. - relax_dataset: - src: data/is2re/all/test_id/data.lmdb - # To shard a dataset into smaller subsets, define the total_shards desired - # and the shard a particular process to see. - total_shards: 1 # int (optional) - shard: 0 # int (optional) relax_opt: name: lbfgs maxstep: 0.2 diff --git a/configs/s2ef/2M/dimenet_plus_plus/dpp_relax.yml b/configs/s2ef/2M/dimenet_plus_plus/dpp_relax.yml deleted file mode 100755 index 9d37dd759..000000000 --- a/configs/s2ef/2M/dimenet_plus_plus/dpp_relax.yml +++ /dev/null @@ -1,84 +0,0 @@ -trainer: ocp - -loss_functions: - - energy: - fn: mae - coefficient: 1 - - forces: - fn: l2mae - coefficient: 50 - -dataset: - train: - format: lmdb - src: data/s2ef/2M/train/ - key_mapping: - y: energy - force: forces - transforms: - normalizer: - energy: - mean: -0.7554450631141663 - stdev: 2.887317180633545 - forces: - mean: 0 - stdev: 2.887317180633545 - val: - src: data/s2ef/all/val_id/ - -logger: wandb - -task: - dataset: trajectory_lmdb - description: "Regressing to energies and forces for DFT trajectories from OCP" - type: regression - metric: mae - labels: - - potential energy - grad_input: atomic forces - train_on_free_atoms: True - eval_on_free_atoms: True - relax_dataset: - src: data/is2re/all/test_id/data.lmdb - write_pos: True - relaxation_steps: 200 - relax_opt: - maxstep: 0.04 - memory: 50 - damping: 1.0 - alpha: 70.0 - traj_dir: "ml-relaxations/dpp-2M-test-id" - -model: - name: dimenetplusplus - hidden_channels: 192 - out_emb_channels: 192 - num_blocks: 3 - cutoff: 6.0 - num_radial: 6 - num_spherical: 7 - num_before_skip: 1 - num_after_skip: 2 - num_output_layers: 3 - regress_forces: True - use_pbc: True - -# *** Important note *** -# The total number of gpus used for this run was 32. -# If the global batch size (num_gpus * batch_size) is modified -# the lr_milestones and warmup_steps need to be adjusted accordingly. - -optim: - batch_size: 12 - eval_batch_size: 12 - eval_every: 10000 - num_workers: 8 - lr_initial: 0.0001 - lr_gamma: 0.1 - lr_milestones: # steps at which lr_initial <- lr_initial * lr_gamma - - 20833 - - 31250 - - 41666 - warmup_steps: 10416 - warmup_factor: 0.2 - max_epochs: 15 diff --git a/src/fairchem/core/common/relaxation/ase_utils.py b/src/fairchem/core/common/relaxation/ase_utils.py index 55a132e94..08efe08b6 100644 --- a/src/fairchem/core/common/relaxation/ase_utils.py +++ b/src/fairchem/core/common/relaxation/ase_utils.py @@ -43,7 +43,7 @@ def batch_to_atoms(batch): positions = torch.split(batch.pos, natoms) tags = torch.split(batch.tags, natoms) cells = batch.cell - energies = batch.y.view(-1).tolist() + energies = batch.energy.view(-1).tolist() atoms_objects = [] for idx in range(n_systems):