Skip to content

Commit

Permalink
run tested fine wandb ok
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiyil1230 committed Sep 7, 2024
1 parent 08c8e2c commit d4b0bc5
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 37 deletions.
40 changes: 26 additions & 14 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def run(args):
config (DictConfig): Config for training loop.
"""
device = utils.init_device()
experiment.seed_everything(args.random_seed, utils.get_local_rank())
experiment.seed_everything(args.random_seed)

# Make sure to use this flag for matmuls on A100 and H100 GPUs.
torch.set_float32_matmul_precision("high")
Expand All @@ -39,7 +39,7 @@ def run(args):

# Move model to correct device.
model.to(device=device)
optimizer, lr_scheduler, ema = optim.get_optim(args.lr, model)
optimizer, lr_scheduler, ema = optim.get_optim(args.lr, args.max_epochs, model)

wandb_run = None
# Logger instantiation/configuration
Expand Down Expand Up @@ -69,11 +69,12 @@ def run(args):
)
logging.info("Starting training!")

num_steps = len(train_loader)
num_steps = args.num_steps

start_epoch = 0

for epoch in range(start_epoch, args.max_epochs):
print(f"Start epoch: {epoch} training...")
t1 = time.time()
avg_train_metrics = steps.fintune(
model=model,
Expand All @@ -91,25 +92,30 @@ def run(args):
train_times["avg_time_per_step"] = (t2 - t1) / num_steps
train_times["total_time"] = t2 - t1

if wandb.run is not None:
if args.wandb:
wandb.run.log(
experiment.prefix_keys(avg_train_metrics, "train"), commit=False
experiment.prefix_keys(avg_train_metrics, "finetune"), commit=False
)
wandb.run.log(
experiment.prefix_keys(train_times, "train", sep="-"),
experiment.prefix_keys(train_times, "finetune", sep="-"),
commit=False,
)
wandb.run.log({"epoch": epoch}, commit=True)

if epoch == args.max_epochs - 1:
checkpoint = {
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'lr_scheduler_state_dict': lr_scheduler.state_dict() if lr_scheduler else None,
'ema_state_dict': ema.state_dict() if ema else None,
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"lr_scheduler_state_dict": lr_scheduler.state_dict()
if lr_scheduler
else None,
"ema_state_dict": ema.state_dict() if ema else None,
}
torch.save(checkpoint, args.checkpoint_path)
torch.save(
checkpoint,
os.path.join(args.checkpoint_path, f"checkpoint_epoch{epoch}.ckpt"),
)
logging.info(f"Checkpoint saved to {args.checkpoint_path}")

if wandb_run is not None:
Expand All @@ -127,7 +133,7 @@ def main():
)
parser.add_argument(
"--wandb",
default=False,
default=True,
action="store_true",
help="If the run is logged to wandb.",
)
Expand All @@ -153,9 +159,15 @@ def main():
type=int,
help="Maximum number of epochs to finetune.",
)
parser.add_argument(
"--num_steps",
default=100,
type=int,
help="Num steps of in each epoch.",
)
parser.add_argument(
"--checkpoint_path",
default=os.path.join(os.getcwd(), "checkpoints"),
default=os.getcwd(),
type=str,
help="Path to save the model checkpoint.",
)
Expand Down
20 changes: 11 additions & 9 deletions orb_models/finetune_utilities/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def prefix_keys(
return {f"{prefix}{sep}{k}": v for k, v in dict_to_prefix.items()}


def seed_everything(seed: int, rank: int) -> None:
def seed_everything(seed: int, rank: int = 0) -> None:
"""Set the seed for all pseudo random number generators."""
random.seed(seed + rank)
numpy.random.seed(seed + rank)
Expand All @@ -39,21 +39,23 @@ def seed_everything(seed: int, rank: int) -> None:

def init_wandb_from_config(args, job_type: str) -> wandb_run.Run:
"""Initialise wandb from config."""
run_name = args.get("name")
project = args.get("project")
if not run_name:
if not hasattr(args, "wandb_name"):
run_name = f"{job_type}-test"
if not project:
else:
run_name = args.name
if not hasattr(args, "wandb_project"):
project = "orb-experiment"
else:
project = args.project

wandb.init( # type: ignore
job_type=job_type,
dir=os.path.join(os.getcwd(), "wandb"),
name=run_name,
project=project,
entity=args.entity,
mode=args.mode,
group=args.get("group"),
sync_tensorboard=True,
entity="orbitalmaterials",
mode="online",
sync_tensorboard=False,
)
assert wandb.run is not None
return wandb.run
18 changes: 9 additions & 9 deletions orb_models/finetune_utilities/optim.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union

import hydra
import logging
import omegaconf
import torch
from orb_models.finetune_utilities.ema import ExponentialMovingAverage as EMA
Expand Down Expand Up @@ -92,35 +92,35 @@ def make_parameter_groups(
parameter_group_names[-1].add(name)

# log the remaining parameter groups
hydra.utils.log.info("Constructed parameter groups:")
logging.info("Constructed parameter groups:")
for k in range(len(parameter_groups)):
group_options = {
key: val for key, val in parameter_groups[k].items() if key != "params"
}
hydra.utils.log.info("Group %s, options: %s", k, group_options)
logging.info("Group %s, options: %s", k, group_options)
if verbose:
hydra.utils.log.info("Parameters: ")
logging.info("Parameters: ")
for p in list(parameter_group_names[k]):
hydra.utils.log.info(p)
logging.info(p)

# check for unused regex
for regex, count in regex_use_counts.items():
if count == 0:
hydra.utils.log.warning(
logging.warning(
"Parameter group regex %s does not match any parameter name.",
regex,
)
return parameter_groups


def get_optim(
lr: float, model: torch.nn.Module
lr: float, max_epoch: int, model: torch.nn.Module
) -> Tuple[
torch.optim.Optimizer,
Optional[torch.optim.lr_scheduler._LRScheduler],
Optional[EMA],
]:
"""Configure optimizers, LR schedulers and EMA from a Hydra config."""
"""Configure optimizers, LR schedulers and EMA."""
parameter_groups = [
{
"filter_string": "(.*bias|.*layer_norm.*|.*batch_norm.*)",
Expand All @@ -130,7 +130,7 @@ def get_optim(
params = make_parameter_groups(model, parameter_groups)
opt = torch.optim.Adam(params, lr=lr)

scheduler = torch.optim.lr_scheduler.CosineAnnealingLR
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=max_epoch)
ema_decay = 0.999
ema = EMA(model.parameters(), ema_decay)

Expand Down
7 changes: 4 additions & 3 deletions orb_models/finetune_utilities/steps.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import List, Optional, Union, cast

import torch
import tqdm
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader

Expand Down Expand Up @@ -52,7 +53,7 @@ def fintune(
lr_scheduler: Optional[_LRScheduler] = None,
num_steps: Optional[int] = None,
clip_grad: Optional[float] = None,
log_freq: float = 100,
log_freq: float = 10,
device: torch.device = torch.device("cpu"),
epoch: int = 0,
):
Expand Down Expand Up @@ -96,7 +97,7 @@ def fintune(
except TypeError:
raise ValueError("Dataloader has no length, you must specify num_steps.")

batch_generator_tqdm = batch_generator
batch_generator_tqdm = tqdm.tqdm(batch_generator, total=num_training_batches)

i = 0
batch_iterator = iter(batch_generator_tqdm)
Expand Down Expand Up @@ -173,4 +174,4 @@ def fintune(
for h in hook_handles:
h.remove()

return metrics.get_metrics(sync_dist=True)
return metrics.get_metrics()
2 changes: 1 addition & 1 deletion orb_models/forcefield/atomic_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def ase_fix_atoms_to_tensor(atoms: ase.Atoms) -> Optional[torch.Tensor]:
def make_property_definitions_from_config(
config: Optional[Dict] = None,
) -> PropertyConfig:
"""Get PropertyConfig object from hydra config."""
"""Get PropertyConfig object from config."""
if config is None:
return PropertyConfig()
assert all(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ classifiers = [
dependencies = [
"cached_path>=1.6.2",
"ase>=3.23.0",
"numpy<2.0.0",
"numpy==1.26.4",
"scipy>=1.13.1",
"torch==2.2.0",
"dm-tree>=0.1.8",
Expand Down

0 comments on commit d4b0bc5

Please sign in to comment.