Skip to content

Commit

Permalink
some review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiyil1230 committed Sep 10, 2024
1 parent 959b152 commit 7720c1a
Show file tree
Hide file tree
Showing 19 changed files with 253 additions and 498 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ For more information on the models, please see the [MODELS.md](MODELS.md) file.

import ase
from ase.build import bulk
from orb_models.forcefield import pretrained
from orb_models.forcefield import atomic_system

from orb_models.forcefield import atomic_system, pretrained
from orb_models.forcefield.base import batch_graphs

orbff = pretrained.orb_v1()
Expand All @@ -65,10 +65,10 @@ atoms = atomic_system.atom_graphs_to_ase_atoms(
```python
import ase
from ase.build import bulk

from orb_models.forcefield import pretrained
from orb_models.forcefield.calculator import ORBCalculator


device="cpu" # or device="cuda"
orbff = pretrained.orb_v1(device=device) # or choose another model using ORB_PRETRAINED_MODELS[model_name]()
calc = ORBCalculator(orbff, device=device)
Expand Down
170 changes: 145 additions & 25 deletions finetune.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,160 @@
"""Finetuning loop."""

import os
import logging
import argparse
import time
import logging
import os
from typing import Optional, Union, cast

import torch
from orb_models.forcefield import pretrained
from orb_models.finetune_utilities import experiment, optim
from orb_models.dataset import data_loaders
from orb_models.finetune_utilities import steps
import tqdm
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader

import wandb
from orb_models.dataset import data_loaders
from orb_models import utils
from orb_models.forcefield import pretrained
from wandb import wandb_run

logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)


def finetune(
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
dataloader: DataLoader,
lr_scheduler: Optional[_LRScheduler] = None,
num_steps: Optional[int] = None,
clip_grad: Optional[float] = None,
log_freq: float = 10,
device: torch.device = torch.device("cpu"),
epoch: int = 0,
):
"""Train for a fixed number of steps.
Args:
model: The model to optimize.
optimizer: The optimizer for the model.
dataloader: A Pytorch Dataloader, which may be infinite if num_steps is passed.
lr_scheduler: Optional, a Learning rate scheduler for modifying the learning rate.
num_steps: The number of training steps to take. This is required for distributed training,
because controlling parallism is easier if all processes take exactly the same number of steps (
this particularly applies when using dynamic batching).
clip_grad: Optional, the gradient clipping threshold.
log_freq: The logging frequency for step metrics.
device: The device to use for training.
epoch: The number of epochs the model has been fintuned.
Returns
A dictionary of metrics.
"""
run: Optional[wandb_run.Run] = cast(Optional[wandb_run.Run], wandb.run)

if clip_grad is not None:
hook_handles = utils.gradient_clipping(model, clip_grad)

metrics = utils.ScalarMetricTracker()

# Set the model to "train" mode.
model.train()

# Get tqdm for the training batches
batch_generator = iter(dataloader)
num_training_batches: Union[int, float]
if num_steps is not None:
num_training_batches = num_steps
else:
try:
num_training_batches = len(dataloader)
except TypeError:
raise ValueError("Dataloader has no length, you must specify num_steps.")

batch_generator_tqdm = tqdm.tqdm(batch_generator, total=num_training_batches)

i = 0
batch_iterator = iter(batch_generator_tqdm)
while True:
if num_steps and i == num_steps:
break

optimizer.zero_grad(set_to_none=True)

step_metrics = {
"batch_size": 0.0,
"batch_num_edges": 0.0,
"batch_num_nodes": 0.0,
}

# Reset metrics so that it reports raw values for each step but still do averages on
# the gradient accumulation.
if i % log_freq == 0:
metrics.reset()

batch = next(batch_iterator)
batch = batch.to(device)
step_metrics["batch_size"] += len(batch.n_node)
step_metrics["batch_num_edges"] += batch.n_edge.sum()
step_metrics["batch_num_nodes"] += batch.n_node.sum()

with torch.cuda.amp.autocast(enabled=False):
batch_outputs = model.loss(batch)
loss = batch_outputs.loss
metrics.update(batch_outputs.log)
if torch.isnan(loss):
raise ValueError("nan loss encountered")
loss.backward()

optimizer.step()

if lr_scheduler is not None:
lr_scheduler.step()

metrics.update(step_metrics)

if i != 0 and i % log_freq == 0:
metrics_dict = metrics.get_metrics()
if run is not None:
global_step = (epoch * num_training_batches) + i
if run.sweep_id is not None:
run.log(
{"loss": metrics_dict["loss"]},
commit=False,
)
run.log(
{"global_step": global_step},
commit=False,
)
run.log(
utils.prefix_keys(metrics_dict, "train_step"), commit=False
)
# Log learning rates.
run.log(
{
f"pg_{idx}": group["lr"]
for idx, group in enumerate(optimizer.param_groups)
},
)

# Finished a single full step!
i += 1

if clip_grad is not None:
for h in hook_handles:
h.remove()

return metrics.get_metrics()


def run(args):
"""Training Loop.
Args:
config (DictConfig): Config for training loop.
"""
device = experiment.init_device()
experiment.seed_everything(args.random_seed)
device = utils.init_device()
utils.seed_everything(args.random_seed)

# Make sure to use this flag for matmuls on A100 and H100 GPUs.
torch.set_float32_matmul_precision("high")
Expand All @@ -39,22 +169,21 @@ def run(args):
# Move model to correct device.
model.to(device=device)
total_steps = args.max_epochs * args.num_steps
optimizer, lr_scheduler = optim.get_optim(args.lr, total_steps, model)
optimizer, lr_scheduler = utils.get_optim(args.lr, total_steps, model)

wandb_run = None
# Logger instantiation/configuration
if args.wandb:
import wandb

logging.info("Instantiating WandbLogger.")
wandb_run = experiment.init_wandb_from_config(args, job_type="finetuning")
wandb_run = utils.init_wandb_from_config(job_type="finetuning")

wandb.define_metric("global_step")
wandb.define_metric("epochs")
wandb.define_metric("train_step/*", step_metric="global_step")
wandb.define_metric("learning_rates/*", step_metric="global_step")
wandb.define_metric("finetune/*", step_metric="epochs")
wandb.define_metric("key-metrics/*", step_metric="epochs")

loader_args = dict(
dataset=args.dataset,
Expand All @@ -65,7 +194,7 @@ def run(args):
)
train_loader = data_loaders.build_train_loader(
**loader_args,
augmentation=getattr(args, "augmentation", True),
augmentation=True,
)
logging.info("Starting training!")

Expand All @@ -75,8 +204,7 @@ def run(args):

for epoch in range(start_epoch, args.max_epochs):
print(f"Start epoch: {epoch} training...")
t1 = time.time()
avg_train_metrics = steps.fintune(
avg_train_metrics = finetune(
model=model,
optimizer=optimizer,
dataloader=train_loader,
Expand All @@ -86,18 +214,10 @@ def run(args):
num_steps=num_steps,
epoch=epoch,
)
t2 = time.time()
train_times = {}
train_times["avg_time_per_step"] = (t2 - t1) / num_steps
train_times["total_time"] = t2 - t1

if args.wandb:
wandb.run.log(
experiment.prefix_keys(avg_train_metrics, "finetune"), commit=False
)
wandb.run.log(
experiment.prefix_keys(train_times, "finetune", sep="-"),
commit=False,
utils.prefix_keys(avg_train_metrics, "finetune"), commit=False
)
wandb.run.log({"epoch": epoch}, commit=True)

Expand Down Expand Up @@ -147,7 +267,7 @@ def main():
"--num_workers", default=8, type=int, help="Number of workers for data loader."
)
parser.add_argument(
"--batch_size", default=100, type=int, help="Batch size for finetuning."
"--batch_size", default=10, type=int, help="Batch size for finetuning."
)
parser.add_argument(
"--gradient_clip_val", default=0.5, type=float, help="Gradient clip value."
Expand Down
11 changes: 5 additions & 6 deletions internal/check.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
"""Integration tests to check compatibility of outputs with internal OM models."""

import torch
import ase
import argparse

from orb_models.forcefield import pretrained
from orb_models.forcefield import atomic_system
from core.models import load
import ase
import torch
from core.dataset import atomic_system as core_atomic_system
from core.models import load

import argparse
from orb_models.forcefield import atomic_system, pretrained


def main(model: str, core_model: str):
Expand Down
Loading

0 comments on commit 7720c1a

Please sign in to comment.