Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
rayg1234 committed Aug 13, 2024
1 parent b4288af commit b4dfc93
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 112 deletions.
1 change: 0 additions & 1 deletion src/fairchem/core/models/finetune_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import errno
import logging
import os
from abc import ABC, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING

Expand Down
224 changes: 113 additions & 111 deletions tests/core/e2e/test_e2e_finetune_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@ def tutorial_val_src(tutorial_dataset_path):
return tutorial_dataset_path / "s2ef/val_20"


def make_checkpoint(data_source, seed) -> str:
def make_checkpoint(tempdir: str, data_source: Path, seed: int) -> str:
# first train a tiny eqv2 model to get a checkpoint
eqv2_yml = Path("tests/core/models/test_configs/test_equiformerv2_hydra.yml")
tempdir = tempfile.TemporaryDirectory()
ck_path = os.path.join(tempdir.name, "checkpoint.pt")
ck_path = os.path.join(tempdir, "checkpoint.pt")
_run_main(
tempdir.name,
tempdir,
eqv2_yml,
update_dict_with={
"optim": {
Expand All @@ -43,17 +42,17 @@ def make_checkpoint(data_source, seed) -> str:
world_size=0,
)
assert os.path.isfile(ck_path)
return ck_path, tempdir
return ck_path


def run_main_with_ft_hydra(tempdir: tempfile.TemporaryDirectory,
def run_main_with_ft_hydra(tempdir: str,
yaml: str,
data_src: str,
run_args: dict,
ft_config: str,
output_checkpoint: str):
_run_main(
tempdir.name,
tempdir,
yaml,
update_dict_with={
"optim": {
Expand Down Expand Up @@ -81,114 +80,117 @@ def run_main_with_ft_hydra(tempdir: tempfile.TemporaryDirectory,


def test_finetune_hydra_retain_backbone(tutorial_val_src):
starting_ckpt, original_tmpdir = make_checkpoint(tutorial_val_src, 0)
old_state_dict = torch.load(starting_ckpt)["state_dict"]
# now finetune a the model with the checkpoint from the first job
tempdir2 = tempfile.TemporaryDirectory()
ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml")
ck_ft_path = os.path.join(tempdir2.name, "checkpoint_ft.pt")
ft_config = {
"mode": FineTuneMode.RETAIN_BACKBONE_ONLY.name,
"starting_checkpoint": starting_ckpt,
"heads": {
"energy": {
"module": "equiformer_v2_energy_head"
},
"forces": {
"module": "equiformer_v2_force_head"
with tempfile.TemporaryDirectory() as orig_ckpt_dir:
starting_ckpt = make_checkpoint(orig_ckpt_dir, tutorial_val_src, 0)
old_state_dict = torch.load(starting_ckpt)["state_dict"]
# now finetune a the model with the checkpoint from the first job
with tempfile.TemporaryDirectory() as ft_temp_dir:
ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml")
ck_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft.pt")
ft_config = {
"mode": FineTuneMode.RETAIN_BACKBONE_ONLY.name,
"starting_checkpoint": starting_ckpt,
"heads": {
"energy": {
"module": "equiformer_v2_energy_head"
},
"forces": {
"module": "equiformer_v2_force_head"
}
}
}
}
}
run_main_with_ft_hydra(tempdir = tempdir2,
yaml = ft_yml,
data_src = tutorial_val_src,
run_args = {"seed": 1000},
ft_config = ft_config,
output_checkpoint = ck_ft_path)
assert os.path.isfile(ck_ft_path)
ft_ckpt = torch.load(ck_ft_path)
assert "config" in ft_ckpt
assert ft_ckpt["config"]["model"]["name"] == FTHYDRA_NAME
# check that the backbone weights are the same, and other weights are not the same
new_state_dict = ft_ckpt["state_dict"]
for key in new_state_dict:
if key.startswith("backbone"):
# backbone should be identical
assert torch.allclose(new_state_dict[key], old_state_dict[key])
elif key.startswith("output_heads") and key.endswith("weight"):
# heads weight should be different because the seeds are different
assert not torch.allclose(new_state_dict[key], old_state_dict[key])
run_main_with_ft_hydra(tempdir = ft_temp_dir,
yaml = ft_yml,
data_src = tutorial_val_src,
run_args = {"seed": 1000},
ft_config = ft_config,
output_checkpoint = ck_ft_path)
assert os.path.isfile(ck_ft_path)
ft_ckpt = torch.load(ck_ft_path)
assert "config" in ft_ckpt
assert ft_ckpt["config"]["model"]["name"] == FTHYDRA_NAME
# check that the backbone weights are the same, and other weights are not the same
new_state_dict = ft_ckpt["state_dict"]
for key in new_state_dict:
if key.startswith("backbone"):
# backbone should be identical
assert torch.allclose(new_state_dict[key], old_state_dict[key])
elif key.startswith("output_heads") and key.endswith("weight"):
# heads weight should be different because the seeds are different
assert not torch.allclose(new_state_dict[key], old_state_dict[key])


def test_finetune_hydra_data_only(tutorial_val_src):
starting_ckpt, original_tmpdir = make_checkpoint(tutorial_val_src, 0)
old_state_dict = torch.load(starting_ckpt)["state_dict"]
# now finetune a the model with the checkpoint from the first job
tempdir2 = tempfile.TemporaryDirectory()
ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml")
ck_ft_path = os.path.join(tempdir2.name, "checkpoint_ft.pt")
ft_config = {
"mode": FineTuneMode.DATA_ONLY.name,
"starting_checkpoint": starting_ckpt,
}
run_main_with_ft_hydra(tempdir = tempdir2,
yaml = ft_yml,
data_src = tutorial_val_src,
run_args = {"seed": 1000},
ft_config = ft_config,
output_checkpoint = ck_ft_path)
assert os.path.isfile(ck_ft_path)
ft_ckpt = torch.load(ck_ft_path)
assert "config" in ft_ckpt
config_model = ft_ckpt["config"]["model"]
assert config_model["name"] == FTHYDRA_NAME
# check that the entire model weights are the same
new_state_dict = ft_ckpt["state_dict"]
assert len(new_state_dict) == len(old_state_dict)
for key in new_state_dict:
assert torch.allclose(new_state_dict[key], old_state_dict[key])
# check the new checkpoint contains a hydra model
assert FTConfig.STARTING_MODEL in config_model[FTConfig.FT_CONFIG_NAME]
with tempfile.TemporaryDirectory() as orig_ckpt_dir:
starting_ckpt = make_checkpoint(orig_ckpt_dir, tutorial_val_src, 0)
old_state_dict = torch.load(starting_ckpt)["state_dict"]
# now finetune a the model with the checkpoint from the first job
with tempfile.TemporaryDirectory() as ft_temp_dir:
ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml")
ck_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft.pt")
ft_config = {
"mode": FineTuneMode.DATA_ONLY.name,
"starting_checkpoint": starting_ckpt,
}
run_main_with_ft_hydra(tempdir = ft_temp_dir,
yaml = ft_yml,
data_src = tutorial_val_src,
run_args = {"seed": 1000},
ft_config = ft_config,
output_checkpoint = ck_ft_path)
assert os.path.isfile(ck_ft_path)
ft_ckpt = torch.load(ck_ft_path)
assert "config" in ft_ckpt
config_model = ft_ckpt["config"]["model"]
assert config_model["name"] == FTHYDRA_NAME
# check that the entire model weights are the same
new_state_dict = ft_ckpt["state_dict"]
assert len(new_state_dict) == len(old_state_dict)
for key in new_state_dict:
assert torch.allclose(new_state_dict[key], old_state_dict[key])
# check the new checkpoint contains a hydra model
assert FTConfig.STARTING_MODEL in config_model[FTConfig.FT_CONFIG_NAME]


def test_finetune_from_finetunehydra(tutorial_val_src):
starting_ckpt, original_tmpdir = make_checkpoint(tutorial_val_src, 0)
# now finetune a the model with the checkpoint from the first job
finetune_run1 = tempfile.TemporaryDirectory()
ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml")
ck_ft_path = os.path.join(finetune_run1.name, "checkpoint_ft.pt")
ft_config_1 = {
"mode": FineTuneMode.DATA_ONLY.name,
"starting_checkpoint": starting_ckpt,
}
run_main_with_ft_hydra(tempdir = finetune_run1,
yaml = ft_yml,
data_src = tutorial_val_src,
run_args = {"seed": 1000},
ft_config = ft_config_1,
output_checkpoint = ck_ft_path)
assert os.path.isfile(ck_ft_path)

# now that we have a second checkpoint, try finetuning again from this checkpoint
########################################################################################
finetune_run2 = tempfile.TemporaryDirectory()
ck_ft2_path = os.path.join(finetune_run2.name, "checkpoint_ft.pt")
ft_config_2 = {
"mode": FineTuneMode.DATA_ONLY.name,
"starting_checkpoint": ck_ft_path,
}
run_main_with_ft_hydra(tempdir = finetune_run2,
yaml = ft_yml,
data_src = tutorial_val_src,
run_args = {"seed": 1000},
ft_config = ft_config_2,
output_checkpoint = ck_ft2_path)
ft_ckpt2 = torch.load(ck_ft2_path)
assert "config" in ft_ckpt2
config_model = ft_ckpt2["config"]["model"]
assert config_model["name"] == FTHYDRA_NAME
old_state_dict = torch.load(ck_ft_path)["state_dict"]
new_state_dict = ft_ckpt2["state_dict"]
# the state dicts should still be identical because we made the LR = 0.0
for key in new_state_dict:
assert torch.allclose(new_state_dict[key], old_state_dict[key])
with tempfile.TemporaryDirectory() as orig_ckpt_dir:
starting_ckpt = make_checkpoint(orig_ckpt_dir, tutorial_val_src, 0)
# now finetune a the model with the checkpoint from the first job
with tempfile.TemporaryDirectory() as finetune_run1_dir:
ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml")
ck_ft_path = os.path.join(finetune_run1_dir, "checkpoint_ft.pt")
ft_config_1 = {
"mode": FineTuneMode.DATA_ONLY.name,
"starting_checkpoint": starting_ckpt,
}
run_main_with_ft_hydra(tempdir = finetune_run1_dir,
yaml = ft_yml,
data_src = tutorial_val_src,
run_args = {"seed": 1000},
ft_config = ft_config_1,
output_checkpoint = ck_ft_path)
assert os.path.isfile(ck_ft_path)

# now that we have a second checkpoint, try finetuning again from this checkpoint
########################################################################################
with tempfile.TemporaryDirectory() as finetune_run2_dir:
ck_ft2_path = os.path.join(finetune_run2_dir, "checkpoint_ft.pt")
ft_config_2 = {
"mode": FineTuneMode.DATA_ONLY.name,
"starting_checkpoint": ck_ft_path,
}
run_main_with_ft_hydra(tempdir = finetune_run2_dir,
yaml = ft_yml,
data_src = tutorial_val_src,
run_args = {"seed": 1000},
ft_config = ft_config_2,
output_checkpoint = ck_ft2_path)
ft_ckpt2 = torch.load(ck_ft2_path)
assert "config" in ft_ckpt2
config_model = ft_ckpt2["config"]["model"]
assert config_model["name"] == FTHYDRA_NAME
old_state_dict = torch.load(ck_ft_path)["state_dict"]
new_state_dict = ft_ckpt2["state_dict"]
# the state dicts should still be identical because we made the LR = 0.0
for key in new_state_dict:
assert torch.allclose(new_state_dict[key], old_state_dict[key])

0 comments on commit b4dfc93

Please sign in to comment.