Skip to content

Commit

Permalink
update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiyil1230 committed Sep 11, 2024
1 parent 843a81c commit 6194b8c
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 21 deletions.
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ print("Optimized Energy:", atoms.get_potential_energy())
```


### Finetuning
You can finetune the model using your custom dataset.
```python
python finetune.py --dataset=<dataset_name> --data_path=<your_data_path>
```
After the model is finetuned, checkpoints will be saved. You can us the new model and load the checkpoint by modifying the `weights_path` in `pretrained.py`.

### Citing

We are currently preparing a preprint for publication.
Expand Down
12 changes: 2 additions & 10 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,20 +266,12 @@ def run(args):
wandb.run.log({"epoch": epoch}, commit=True)

# Save checkpoint from last epoch
if epoch == args.max_epochs - 1:
checkpoint = {
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"lr_scheduler_state_dict": (
lr_scheduler.state_dict() if lr_scheduler else None
),
}
if epoch == 1:
# cerate ckpts folder if it does not exist
if not os.path.exists(args.checkpoint_path):
os.makedirs(args.checkpoint_path)
torch.save(
checkpoint,
model.state_dict(),
os.path.join(args.checkpoint_path, f"checkpoint_epoch{epoch}.ckpt"),
)
logging.info(f"Checkpoint saved to {args.checkpoint_path}")
Expand Down
4 changes: 3 additions & 1 deletion orb_models/forcefield/atomic_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def ase_atoms_to_atom_graphs(
),
system_id: Optional[int] = None,
brute_force_knn: Optional[bool] = None,
device: Optional[torch.device] = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
device: Optional[torch.device] = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
),
) -> AtomGraphs:
"""Generate AtomGraphs from an ase.Atoms object.
Expand Down
3 changes: 2 additions & 1 deletion orb_models/forcefield/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

from collections import defaultdict
from copy import deepcopy
from typing import Any, Dict, List, Mapping, NamedTuple, Optional, Sequence, Union
from typing import (Any, Dict, List, Mapping, NamedTuple, Optional, Sequence,
Union)

import torch
import tree
Expand Down
3 changes: 2 additions & 1 deletion orb_models/forcefield/calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import torch
from ase.calculators.calculator import Calculator, all_changes

from orb_models.forcefield.atomic_system import SystemConfig, ase_atoms_to_atom_graphs
from orb_models.forcefield.atomic_system import (SystemConfig,
ase_atoms_to_atom_graphs)
from orb_models.forcefield.graph_regressor import GraphRegressor


Expand Down
1 change: 0 additions & 1 deletion orb_models/forcefield/featurization_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import numpy as np
import torch

# TODO(Mark): Make pynanoflann optional
from pynanoflann import KDTree as NanoKDTree
from scipy.spatial import KDTree as SciKDTree
Expand Down
3 changes: 2 additions & 1 deletion orb_models/forcefield/graph_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from orb_models.forcefield import base, segment_ops
from orb_models.forcefield.gns import _KEY, MoleculeGNS
from orb_models.forcefield.nn_util import build_mlp
from orb_models.forcefield.property_definitions import PROPERTIES, PropertyDefinition
from orb_models.forcefield.property_definitions import (PROPERTIES,
PropertyDefinition)
from orb_models.forcefield.reference_energies import REFERENCE_ENERGIES

global HAS_WARNED_FOR_TF32_MATMUL
Expand Down
8 changes: 2 additions & 6 deletions orb_models/forcefield/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,8 @@

from orb_models.forcefield.featurization_utilities import get_device
from orb_models.forcefield.gns import MoleculeGNS
from orb_models.forcefield.graph_regressor import (
EnergyHead,
GraphHead,
GraphRegressor,
NodeHead,
)
from orb_models.forcefield.graph_regressor import (EnergyHead, GraphHead,
GraphRegressor, NodeHead)
from orb_models.forcefield.rbf import ExpNormalSmearing

global HAS_MESSAGED_FOR_TF32_MATMUL
Expand Down

0 comments on commit 6194b8c

Please sign in to comment.