Skip to content

Commit

Permalink
remove ema
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiyil1230 committed Sep 7, 2024
1 parent d4b0bc5 commit a80d4d4
Show file tree
Hide file tree
Showing 21 changed files with 130 additions and 420 deletions.
7 changes: 2 additions & 5 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from orb_models.finetune_utilities import experiment, optim
from orb_models.dataset import data_loaders
from orb_models.finetune_utilities import steps
from orb_models import utils


logging.basicConfig(
Expand All @@ -24,7 +23,7 @@ def run(args):
Args:
config (DictConfig): Config for training loop.
"""
device = utils.init_device()
device = experiment.init_device()
experiment.seed_everything(args.random_seed)

# Make sure to use this flag for matmuls on A100 and H100 GPUs.
Expand All @@ -39,7 +38,7 @@ def run(args):

# Move model to correct device.
model.to(device=device)
optimizer, lr_scheduler, ema = optim.get_optim(args.lr, args.max_epochs, model)
optimizer, lr_scheduler = optim.get_optim(args.lr, args.max_epochs, model)

wandb_run = None
# Logger instantiation/configuration
Expand Down Expand Up @@ -80,7 +79,6 @@ def run(args):
model=model,
optimizer=optimizer,
dataloader=train_loader,
ema=ema,
lr_scheduler=lr_scheduler,
clip_grad=args.gradient_clip_val,
device=device,
Expand Down Expand Up @@ -110,7 +108,6 @@ def run(args):
"lr_scheduler_state_dict": lr_scheduler.state_dict()
if lr_scheduler
else None,
"ema_state_dict": ema.state_dict() if ema else None,
}
torch.save(
checkpoint,
Expand Down
20 changes: 5 additions & 15 deletions orb_models/dataset/ase_dataset.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
from pathlib import Path
from typing import Dict, Literal, Optional, Tuple, Union
from typing import Dict, Optional, Tuple, Union

import ase
import ase.db
import ase.db.row
import numpy as np
import torch
from ase.stress import voigt_6_to_full_3x3_stress
import numpy as np
from e3nn import o3


from orb_models.forcefield import (
atomic_system,
property_definitions,
)
from torch.utils.data import Dataset

from orb_models.forcefield import atomic_system, property_definitions
from orb_models.forcefield.base import AtomGraphs


Expand All @@ -31,11 +27,7 @@ class AseSqliteDataset(Dataset):
of the dataset.
system_config: A config for controlling how an atomic system is represented.
target_config: A config for regression/classification targets.
evaluation: Three modes: "eval_with_noise", "eval_no_noise", "train".
augmentation: If random rotation augmentation is used.
limit_size: Limit the size of the dataset to this many samples. Useful for debugging.
masking_args: Arguments for masking function.
filter_indices_path: Path to a file containing a list of indices to include in the dataset.
Returns:
An AseSqliteDataset.
Expand All @@ -44,7 +36,7 @@ class AseSqliteDataset(Dataset):
def __init__(
self,
name: str,
path: str,
path: Union[str, Path],
system_config: Optional[atomic_system.SystemConfig] = None,
target_config: Optional[atomic_system.PropertyConfig] = None,
augmentation: Optional[bool] = True,
Expand Down Expand Up @@ -205,15 +197,13 @@ def get_dataset(
name: str,
system_config: atomic_system.SystemConfig,
target_config: atomic_system.PropertyConfig,
evaluation: Literal["eval_with_noise", "eval_no_noise", "train"] = "train",
) -> AseSqliteDataset:
"""Dataset factory function."""
return AseSqliteDataset(
path=path,
name=name,
system_config=system_config,
target_config=target_config,
evaluation=evaluation,
)


Expand Down
20 changes: 8 additions & 12 deletions orb_models/dataset/data_loaders.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
import random
import logging
from typing import Dict, List, Optional
import random
from typing import Any, Optional

import numpy as np
import torch
from torch.utils.data import (
BatchSampler,
DataLoader,
RandomSampler,
)
from torch.utils.data import BatchSampler, DataLoader, RandomSampler

from orb_models.forcefield import base
from orb_models.dataset.ase_dataset import AseSqliteDataset
from orb_models.forcefield import base
from orb_models.forcefield.atomic_system import make_property_definitions_from_config

HAVE_PRINTED_WORKER_INFO = False
Expand Down Expand Up @@ -47,19 +43,19 @@ def build_train_loader(
path: str,
num_workers: int,
batch_size: int,
augmentation: Optional[List[str]] = None,
target_config: Optional[Dict] = None,
augmentation: Optional[bool] = None,
target_config: Optional[Any] = None,
**kwargs,
) -> DataLoader:
"""Builds the train dataloader from a config file.
Args:
dataset: The dataset name.
path: Dataset path.
num_workers: The number of workers for each dataset.
batch_size: The batch_size config for each dataset.
temperature: The temperature for temperature sampling.
Default is None for using random sampler.
augmentation: If rotation augmentation is used.
target_config: The target config.
Returns:
The train Dataloader.
Expand Down
221 changes: 0 additions & 221 deletions orb_models/finetune_utilities/ema.py

This file was deleted.

Loading

0 comments on commit a80d4d4

Please sign in to comment.