Skip to content

Commit

Permalink
Finetune Hydra (#797)
Browse files Browse the repository at this point in the history
* add basic function

* support retain backbone mode

* fix config

* add hydra interface

* run linter

* run ruff

* fix test

* update main, add configs

* add tests

* test double finetune

* format

* fix few comments

* remove finetuneinterface

* update tests

* remove configs
  • Loading branch information
rayg1234 authored Aug 13, 2024
1 parent 8143ccb commit 8fb16d6
Show file tree
Hide file tree
Showing 9 changed files with 660 additions and 183 deletions.
29 changes: 29 additions & 0 deletions src/fairchem/core/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,35 @@ def _report_incompat_keys(
return missing_keys, unexpected_keys


def match_state_dict(
model_state_dict: Mapping[str, torch.Tensor],
checkpoint_state_dict: Mapping[str, torch.Tensor],
) -> dict:
# match the model's state dict with the checkpoint state and return a new dict
# that's compatible with the models

# Match the "module." count in the keys of model and checkpoint state_dict
# DataParallel model has 1 "module.", DistributedDataParallel has 2 "module."
# Not using either of the above two would have no "module."

ckpt_key_count = next(iter(checkpoint_state_dict)).count("module")
mod_key_count = next(iter(model_state_dict)).count("module")
key_count_diff = mod_key_count - ckpt_key_count

if key_count_diff > 0:
new_dict = {
key_count_diff * "module." + k: v for k, v in checkpoint_state_dict.items()
}
elif key_count_diff < 0:
new_dict = {
k[len("module.") * abs(key_count_diff) :]: v
for k, v in checkpoint_state_dict.items()
}
else:
new_dict = checkpoint_state_dict
return new_dict


def load_state_dict(
module: nn.Module,
state_dict: Mapping[str, torch.Tensor],
Expand Down
25 changes: 23 additions & 2 deletions src/fairchem/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@

from __future__ import annotations

import copy
import logging
from abc import ABCMeta, abstractmethod
from abc import ABC, ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING

Expand Down Expand Up @@ -227,8 +228,19 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]:
return


class HydraInterface(ABC):
# a hydra has a backbone and heads
@abstractmethod
def get_backbone(self) -> BackboneInterface:
raise not NotImplementedError

@abstractmethod
def get_heads(self) -> dict[str, HeadInterface]:
raise not NotImplementedError


@registry.register_model("hydra")
class HydraModel(nn.Module, GraphModelMixin):
class HydraModel(nn.Module, GraphModelMixin, HydraInterface):
def __init__(
self,
backbone: dict,
Expand All @@ -237,6 +249,9 @@ def __init__(
):
super().__init__()
self.otf_graph = otf_graph
# make a copy so we don't modify the original config
backbone = copy.deepcopy(backbone)
heads = copy.deepcopy(heads)

backbone_model_name = backbone.pop("model")
self.backbone: BackboneInterface = registry.get_model_class(
Expand Down Expand Up @@ -272,3 +287,9 @@ def forward(self, data: Batch):
out.update(self.output_heads[k](data, emb))

return out

def get_backbone(self) -> BackboneInterface:
return self.backbone

def get_heads(self) -> dict[str, HeadInterface]:
return self.output_heads
177 changes: 177 additions & 0 deletions src/fairchem/core/models/finetune_hydra.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
from __future__ import annotations

import copy
import errno
import logging
import os
from enum import Enum
from typing import TYPE_CHECKING

import torch
from torch import nn

from fairchem.core.common.registry import registry
from fairchem.core.common.utils import load_state_dict, match_state_dict
from fairchem.core.models.base import BackboneInterface, HeadInterface, HydraInterface

if TYPE_CHECKING:
from torch_geometric.data import Batch

FTHYDRA_NAME = "finetune_hydra"

class FineTuneMode(Enum):
# in DATA_ONLY, we load the entire model and only finetune on new data
DATA_ONLY = 1
# in this mode, we only load the Backbone and feed the output of the backbone
# to new heads that are specified
RETAIN_BACKBONE_ONLY = 2


def get_model_config_from_checkpoint(checkpoint_path: str) -> dict:
if not os.path.isfile(checkpoint_path):
raise FileNotFoundError(
errno.ENOENT, "Checkpoint file not found", checkpoint_path
)
checkpoint = torch.load(checkpoint_path)
return checkpoint["config"]["model"]


def load_hydra_model(checkpoint_path: str) -> HydraInterface:
if not os.path.isfile(checkpoint_path):
raise FileNotFoundError(
errno.ENOENT, "Checkpoint file not found", checkpoint_path
)
logging.info(f"Loading checkpoint from: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path)
config = checkpoint["config"]["model"]
name = config.pop("name")
hydra_model = registry.get_model_class(name)(**config)
assert isinstance(
hydra_model, HydraInterface
), "Can only load models with the HydraInterface"
matched_dict = match_state_dict(hydra_model.state_dict(), checkpoint["state_dict"])
load_state_dict(hydra_model, matched_dict, strict=True)
return hydra_model


class FTConfig:
FT_CONFIG_NAME = "finetune_config"
STARTING_CHECKPOINT = "starting_checkpoint"
STARTING_MODEL = "starting_model"
MODE = "mode"
HEADS = "heads"

def __init__(self, config: dict):
self.config = config
self._mode = FineTuneMode[self.config[FTConfig.MODE]]
assert (
(FTConfig.STARTING_CHECKPOINT in self.config)
or (FTConfig.STARTING_MODEL in self.config)
), "Either a starting checkpoint or a starting model must be provided!"
assert FTConfig.MODE in self.config
if self._mode == FineTuneMode.RETAIN_BACKBONE_ONLY:
# in this mode, we keep the backbone but attach new output heads specified in head config
assert (
FTConfig.HEADS in self.config
), "heads cannot be empty when using RETAIN_BACKBONE_ONLY mode!"

def load_model(self) -> nn.Module:
# if provided a hydra config to start, build from the starting hydra model
# this assumes the weights are loaded from the state_dict in the checkpoint.pt file instead
# so no actual weights are loaded here
if FTConfig.STARTING_MODEL in self.config:
# register model from hydra_config
config_copy = copy.deepcopy(self.config[FTConfig.STARTING_MODEL])
name = config_copy.pop("name")
hydra_model = registry.get_model_class(name)(**config_copy)
# if provided a checkpoint to start then load the model and weights from the given checkpoint
# this happens used in the beginning of a finetuning run
elif FTConfig.STARTING_CHECKPOINT in self.config:
hydra_model: HydraInterface = load_hydra_model(
self.config[FTConfig.STARTING_CHECKPOINT]
)
assert isinstance(hydra_model, HydraInterface)

num_params = sum(p.numel() for p in hydra_model.parameters())
logging.info(f"Loaded Original hydra model with {num_params} params")
return hydra_model

def get_standalone_config(self) -> dict:
# replace a config with a checkpoint with one that has the model config only
# this is required for standalone prediction (so we don't need to ship the original checkpoint),
# multi-round finetuning, and better robustness
standalone_config = {
"name": FTHYDRA_NAME,
FTConfig.FT_CONFIG_NAME: self.config,
}
if FTConfig.STARTING_CHECKPOINT in self.config:
# modify the config to store the original model config inside model attrs
# so we dont need the checkpoint again when loading from checkpoint
new_config = copy.deepcopy(self.config)
new_config[FTConfig.STARTING_MODEL] = (
get_model_config_from_checkpoint(
self.config[FTConfig.STARTING_CHECKPOINT]
)
)
standalone_config[FTConfig.FT_CONFIG_NAME] = new_config
return standalone_config

@property
def mode(self) -> FineTuneMode:
return self._mode

@property
def head_config(self) -> dict:
return copy.deepcopy(self.config[FTConfig.HEADS])


@registry.register_model(FTHYDRA_NAME)
class FineTuneHydra(nn.Module, HydraInterface):
def __init__(self, finetune_config: dict):
super().__init__()
ft_config = FTConfig(finetune_config)
logging.info(f"Initializing FineTuneHydra model in {ft_config.mode} mode")
hydra_model: HydraInterface = ft_config.load_model()
self.backbone: BackboneInterface = hydra_model.get_backbone()

if ft_config.mode == FineTuneMode.DATA_ONLY:
# in this mode, we just use the model as is and train on it with new data
self.output_heads: dict[str, HeadInterface] = hydra_model.get_heads()
elif ft_config.mode == FineTuneMode.RETAIN_BACKBONE_ONLY:
# in this mode, we keep the backbone but attach new output heads specified in head config
self.output_heads: dict[str, HeadInterface] = {}
heads_config = ft_config.head_config
head_names_sorted = sorted(heads_config.keys())
for head_name in head_names_sorted:
head_config = heads_config[head_name]
if "module" not in head_config:
raise ValueError(
f"{head_name} head does not specify module to use for the head"
)

module_name = head_config.pop("module")
self.output_heads[head_name] = registry.get_model_class(module_name)(
self.backbone,
**head_config,
)
num_params = sum(
p.numel() for p in self.output_heads[head_name].parameters()
)
logging.info(
f"Attaching new output head: {module_name} with {num_params} params"
)
self.output_heads = torch.nn.ModuleDict(self.output_heads)


def forward(self, data: Batch):
emb = self.backbone(data)
out = {}
for k in self.output_heads:
out.update(self.output_heads[k](data, emb))
return out

def get_backbone(self) -> BackboneInterface:
return self.backbone

def get_heads(self) -> dict[str, HeadInterface]:
return self.output_heads
57 changes: 24 additions & 33 deletions src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@
get_commit_hash,
get_loss_module,
load_state_dict,
match_state_dict,
save_checkpoint,
update_config,
)
from fairchem.core.datasets.base_dataset import create_dataset
from fairchem.core.models.finetune_hydra import FineTuneHydra, FTConfig
from fairchem.core.modules.evaluator import Evaluator
from fairchem.core.modules.exponential_moving_average import ExponentialMovingAverage
from fairchem.core.modules.loss import DDPLoss
Expand Down Expand Up @@ -115,8 +117,7 @@ def __init__(
self.config = {
"task": task,
"trainer": name,
"model": aii(model.pop("name"), str),
"model_attributes": model,
"model": model,
"outputs": outputs,
"optim": optimizer,
"loss_functions": loss_functions,
Expand Down Expand Up @@ -297,9 +298,7 @@ def get_dataloader(self, dataset, sampler) -> DataLoader:
)

def load_datasets(self) -> None:
self.ocp_collater = OCPCollater(
self.config["model_attributes"].get("otf_graph", False)
)
self.ocp_collater = OCPCollater(self.config["model"].get("otf_graph", False))
self.train_loader = None
self.val_loader = None
self.test_loader = None
Expand Down Expand Up @@ -511,16 +510,20 @@ def load_task(self):
def load_model(self) -> None:
# Build model
if distutils.is_master():
logging.info(f"Loading model: {self.config['model']}")
logging.info(f"Loading model: {self.config['model']['name']}")

self.model = registry.get_model_class(self.config["model"])(
**self.config["model_attributes"],
model_config_copy = copy.deepcopy(self.config["model"])
model_name = model_config_copy.pop("name")
self.model = registry.get_model_class(model_name)(
**model_config_copy,
).to(self.device)

num_params = sum(p.numel() for p in self.model.parameters())

if distutils.is_master():
logging.info(
f"Loaded {self.model.__class__.__name__} with "
f"{self.model.num_params} parameters."
f"{num_params} parameters."
)

if self.logger is not None:
Expand All @@ -530,11 +533,12 @@ def load_model(self) -> None:
self.logger.watch(
self.model, log_freq=int(self.config["logger"]["watch"])
)
self.logger.log_summary({"num_params": self.model.num_params})
self.logger.log_summary({"num_params": num_params})

if distutils.initialized() and not self.config["noddp"]:
self.model = DistributedDataParallel(
self.model, device_ids=None if self.cpu else [self.device]
self.model,
device_ids=None if self.cpu else [self.device],
)

@property
Expand All @@ -561,28 +565,8 @@ def load_checkpoint(
self.best_val_metric = checkpoint.get("best_val_metric", None)
self.primary_metric = checkpoint.get("primary_metric", None)

# Match the "module." count in the keys of model and checkpoint state_dict
# DataParallel model has 1 "module.", DistributedDataParallel has 2 "module."
# Not using either of the above two would have no "module."

ckpt_key_count = next(iter(checkpoint["state_dict"])).count("module")
mod_key_count = next(iter(self.model.state_dict())).count("module")
key_count_diff = mod_key_count - ckpt_key_count

if key_count_diff > 0:
new_dict = {
key_count_diff * "module." + k: v
for k, v in checkpoint["state_dict"].items()
}
elif key_count_diff < 0:
new_dict = {
k[len("module.") * abs(key_count_diff) :]: v
for k, v in checkpoint["state_dict"].items()
}
else:
new_dict = checkpoint["state_dict"]

strict = self.config["task"].get("strict_load", True)
new_dict = match_state_dict(self.model.state_dict(), checkpoint["state_dict"])
strict = self.config.get("task", {}).get("strict_load", True)
load_state_dict(self.model, new_dict, strict=strict)

if "optimizer" in checkpoint:
Expand Down Expand Up @@ -723,6 +707,13 @@ def save(
training_state: bool = True,
) -> str | None:
if not self.is_debug and distutils.is_master():
# if we are using a FineTune-able model, then we need to modify the config to remove
# the original starting checkpoint so it can be loaded standalone, can move this to save function
if isinstance(self.model, FineTuneHydra):
self.config["model"] = FTConfig(
self.config["model"][FTConfig.FT_CONFIG_NAME]
).get_standalone_config()

state = {
"state_dict": self.model.state_dict(),
"normalizers": {
Expand Down
2 changes: 1 addition & 1 deletion tests/core/datasets/test_create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_dataloader(self, *args, **kwargs):
return None

config = {
"model_attributes": {},
"model": {},
"optim": {"batch_size": 0},
"dataset": {
"format": "ase_db",
Expand Down
Loading

0 comments on commit 8fb16d6

Please sign in to comment.