Skip to content

Commit

Permalink
add hydra freeze backbone option
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Oct 28, 2024
1 parent 7c82af9 commit 10b2ac3
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 4 deletions.
5 changes: 5 additions & 0 deletions src/fairchem/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def __init__(
finetune_config: dict | None = None,
otf_graph: bool = True,
pass_through_head_outputs: bool = False,
freeze_backbone: bool = False,
):
super().__init__()
self.device = None
Expand Down Expand Up @@ -282,6 +283,10 @@ def __init__(
"Backbone not specified and not found in the starting checkpoint"
)

if freeze_backbone:
for param in self.backbone.parameters():
param.requires_grad = False

if heads is not None:
heads = copy.deepcopy(heads)
# Iterate through outputs_cfg and create heads
Expand Down
119 changes: 115 additions & 4 deletions tests/core/e2e/test_e2e_finetune_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@
from pathlib import Path

import pytest
from fairchem.core.scripts.convert_hydra_to_release import convert_fine_tune_checkpoint
import torch
from test_e2e_commons import _run_main, oc20_lmdb_train_and_val_from_paths

from fairchem.core.scripts.convert_hydra_to_release import convert_fine_tune_checkpoint


@pytest.fixture()
def tutorial_val_src(tutorial_dataset_path):
Expand Down Expand Up @@ -104,12 +105,122 @@ def verify_release_checkpoint(release_yaml_fn, release_checkpoint_fn, ft_state_d
assert os.path.isfile(ck_release_ft_afterload_path)
ft_after_state_dict = torch.load(ck_release_ft_afterload_path)["state_dict"]
for key in ft_after_state_dict:
if key.startswith("module.backbone"):
assert torch.allclose(ft_after_state_dict[key], ft_state_dict[key])
elif key.startswith("module.output_heads") and key.endswith("weight"):
if (
key.startswith("module.backbone")
or key.startswith("module.output_heads")
and key.endswith("weight")
):
assert torch.allclose(ft_after_state_dict[key], ft_state_dict[key])


def test_finetune_hydra_freeze_backbone(tutorial_val_src):
with tempfile.TemporaryDirectory() as orig_ckpt_dir:
starting_ckpt = make_checkpoint(orig_ckpt_dir, tutorial_val_src, 0)
old_state_dict = torch.load(starting_ckpt)["state_dict"]

# Test to make sure without freeze the backbone weights change
with tempfile.TemporaryDirectory() as ft_temp_dir:
ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml")
ck_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft.pt")
model_config = {
"name": "hydra",
"finetune_config": {"starting_checkpoint": starting_ckpt},
"heads": {
"energy": {"module": "equiformer_v2_energy_head"},
"forces": {"module": "equiformer_v2_force_head"},
},
}

_run_main(
ft_temp_dir,
ft_yml,
update_dict_with={
"optim": {
"max_epochs": 1,
"eval_every": 8,
"batch_size": 1,
"num_workers": 0,
"lr_initial": 10.0,
},
"dataset": oc20_lmdb_train_and_val_from_paths(
train_src=str(tutorial_val_src),
val_src=str(tutorial_val_src),
test_src=str(tutorial_val_src),
otf_norms=False,
),
"model": model_config,
},
update_run_args_with={"seed": 1000},
save_checkpoint_to=ck_ft_path,
world_size=1,
)

assert os.path.isfile(ck_ft_path)
ft_ckpt = torch.load(ck_ft_path)
assert "config" in ft_ckpt
assert ft_ckpt["config"]["model"]["name"] == "hydra"
# check that the backbone weights are different, and other weights are not the same
ft_state_dict = ft_ckpt["state_dict"]
for key in ft_state_dict:
if key.startswith("module.backbone") and ".weight" in key:
# backbone should be different
assert not torch.allclose(ft_state_dict[key], old_state_dict[key])
elif key.startswith("module.output_heads") and key.endswith("weight"):
# heads weight should be different because the seeds are different
assert not torch.allclose(ft_state_dict[key], old_state_dict[key])

# Test to make sure with freeze the backbone weights are unchanged
with tempfile.TemporaryDirectory() as ft_temp_dir:
ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml")
ck_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft.pt")
model_config = {
"name": "hydra",
"finetune_config": {"starting_checkpoint": starting_ckpt},
"heads": {
"energy": {"module": "equiformer_v2_energy_head"},
"forces": {"module": "equiformer_v2_force_head"},
},
"freeze_backbone": True,
}

_run_main(
ft_temp_dir,
ft_yml,
update_dict_with={
"optim": {
"max_epochs": 1,
"eval_every": 8,
"batch_size": 1,
"num_workers": 0,
},
"dataset": oc20_lmdb_train_and_val_from_paths(
train_src=str(tutorial_val_src),
val_src=str(tutorial_val_src),
test_src=str(tutorial_val_src),
otf_norms=False,
),
"model": model_config,
},
update_run_args_with={"seed": 1000},
save_checkpoint_to=ck_ft_path,
world_size=1,
)

assert os.path.isfile(ck_ft_path)
ft_ckpt = torch.load(ck_ft_path)
assert "config" in ft_ckpt
assert ft_ckpt["config"]["model"]["name"] == "hydra"
# check that the backbone weights are different, and other weights are not the same
ft_state_dict = ft_ckpt["state_dict"]
for key in ft_state_dict:
if key.startswith("module.backbone"):
# backbone should be different
assert torch.allclose(ft_state_dict[key], old_state_dict[key])
elif key.startswith("module.output_heads") and key.endswith("weight"):
# heads weight should be different because the seeds are different
assert not torch.allclose(ft_state_dict[key], old_state_dict[key])


def test_finetune_hydra_retain_backbone(tutorial_val_src):
with tempfile.TemporaryDirectory() as orig_ckpt_dir:
starting_ckpt = make_checkpoint(orig_ckpt_dir, tutorial_val_src, 0)
Expand Down

0 comments on commit 10b2ac3

Please sign in to comment.