diff --git a/src/fairchem/core/models/finetune_hydra.py b/src/fairchem/core/models/finetune_hydra.py index 00f02e7ba..6c271e24e 100644 --- a/src/fairchem/core/models/finetune_hydra.py +++ b/src/fairchem/core/models/finetune_hydra.py @@ -4,7 +4,6 @@ import errno import logging import os -from abc import ABC, abstractmethod from enum import Enum from typing import TYPE_CHECKING diff --git a/tests/core/e2e/test_e2e_finetune_hydra.py b/tests/core/e2e/test_e2e_finetune_hydra.py index 5f243c072..91f2abd49 100644 --- a/tests/core/e2e/test_e2e_finetune_hydra.py +++ b/tests/core/e2e/test_e2e_finetune_hydra.py @@ -16,13 +16,12 @@ def tutorial_val_src(tutorial_dataset_path): return tutorial_dataset_path / "s2ef/val_20" -def make_checkpoint(data_source, seed) -> str: +def make_checkpoint(tempdir: str, data_source: Path, seed: int) -> str: # first train a tiny eqv2 model to get a checkpoint eqv2_yml = Path("tests/core/models/test_configs/test_equiformerv2_hydra.yml") - tempdir = tempfile.TemporaryDirectory() - ck_path = os.path.join(tempdir.name, "checkpoint.pt") + ck_path = os.path.join(tempdir, "checkpoint.pt") _run_main( - tempdir.name, + tempdir, eqv2_yml, update_dict_with={ "optim": { @@ -43,17 +42,17 @@ def make_checkpoint(data_source, seed) -> str: world_size=0, ) assert os.path.isfile(ck_path) - return ck_path, tempdir + return ck_path -def run_main_with_ft_hydra(tempdir: tempfile.TemporaryDirectory, +def run_main_with_ft_hydra(tempdir: str, yaml: str, data_src: str, run_args: dict, ft_config: str, output_checkpoint: str): _run_main( - tempdir.name, + tempdir, yaml, update_dict_with={ "optim": { @@ -81,114 +80,117 @@ def run_main_with_ft_hydra(tempdir: tempfile.TemporaryDirectory, def test_finetune_hydra_retain_backbone(tutorial_val_src): - starting_ckpt, original_tmpdir = make_checkpoint(tutorial_val_src, 0) - old_state_dict = torch.load(starting_ckpt)["state_dict"] - # now finetune a the model with the checkpoint from the first job - tempdir2 = tempfile.TemporaryDirectory() - ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml") - ck_ft_path = os.path.join(tempdir2.name, "checkpoint_ft.pt") - ft_config = { - "mode": FineTuneMode.RETAIN_BACKBONE_ONLY.name, - "starting_checkpoint": starting_ckpt, - "heads": { - "energy": { - "module": "equiformer_v2_energy_head" - }, - "forces": { - "module": "equiformer_v2_force_head" + with tempfile.TemporaryDirectory() as orig_ckpt_dir: + starting_ckpt = make_checkpoint(orig_ckpt_dir, tutorial_val_src, 0) + old_state_dict = torch.load(starting_ckpt)["state_dict"] + # now finetune a the model with the checkpoint from the first job + with tempfile.TemporaryDirectory() as ft_temp_dir: + ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml") + ck_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft.pt") + ft_config = { + "mode": FineTuneMode.RETAIN_BACKBONE_ONLY.name, + "starting_checkpoint": starting_ckpt, + "heads": { + "energy": { + "module": "equiformer_v2_energy_head" + }, + "forces": { + "module": "equiformer_v2_force_head" + } + } } - } - } - run_main_with_ft_hydra(tempdir = tempdir2, - yaml = ft_yml, - data_src = tutorial_val_src, - run_args = {"seed": 1000}, - ft_config = ft_config, - output_checkpoint = ck_ft_path) - assert os.path.isfile(ck_ft_path) - ft_ckpt = torch.load(ck_ft_path) - assert "config" in ft_ckpt - assert ft_ckpt["config"]["model"]["name"] == FTHYDRA_NAME - # check that the backbone weights are the same, and other weights are not the same - new_state_dict = ft_ckpt["state_dict"] - for key in new_state_dict: - if key.startswith("backbone"): - # backbone should be identical - assert torch.allclose(new_state_dict[key], old_state_dict[key]) - elif key.startswith("output_heads") and key.endswith("weight"): - # heads weight should be different because the seeds are different - assert not torch.allclose(new_state_dict[key], old_state_dict[key]) + run_main_with_ft_hydra(tempdir = ft_temp_dir, + yaml = ft_yml, + data_src = tutorial_val_src, + run_args = {"seed": 1000}, + ft_config = ft_config, + output_checkpoint = ck_ft_path) + assert os.path.isfile(ck_ft_path) + ft_ckpt = torch.load(ck_ft_path) + assert "config" in ft_ckpt + assert ft_ckpt["config"]["model"]["name"] == FTHYDRA_NAME + # check that the backbone weights are the same, and other weights are not the same + new_state_dict = ft_ckpt["state_dict"] + for key in new_state_dict: + if key.startswith("backbone"): + # backbone should be identical + assert torch.allclose(new_state_dict[key], old_state_dict[key]) + elif key.startswith("output_heads") and key.endswith("weight"): + # heads weight should be different because the seeds are different + assert not torch.allclose(new_state_dict[key], old_state_dict[key]) def test_finetune_hydra_data_only(tutorial_val_src): - starting_ckpt, original_tmpdir = make_checkpoint(tutorial_val_src, 0) - old_state_dict = torch.load(starting_ckpt)["state_dict"] - # now finetune a the model with the checkpoint from the first job - tempdir2 = tempfile.TemporaryDirectory() - ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml") - ck_ft_path = os.path.join(tempdir2.name, "checkpoint_ft.pt") - ft_config = { - "mode": FineTuneMode.DATA_ONLY.name, - "starting_checkpoint": starting_ckpt, - } - run_main_with_ft_hydra(tempdir = tempdir2, - yaml = ft_yml, - data_src = tutorial_val_src, - run_args = {"seed": 1000}, - ft_config = ft_config, - output_checkpoint = ck_ft_path) - assert os.path.isfile(ck_ft_path) - ft_ckpt = torch.load(ck_ft_path) - assert "config" in ft_ckpt - config_model = ft_ckpt["config"]["model"] - assert config_model["name"] == FTHYDRA_NAME - # check that the entire model weights are the same - new_state_dict = ft_ckpt["state_dict"] - assert len(new_state_dict) == len(old_state_dict) - for key in new_state_dict: - assert torch.allclose(new_state_dict[key], old_state_dict[key]) - # check the new checkpoint contains a hydra model - assert FTConfig.STARTING_MODEL in config_model[FTConfig.FT_CONFIG_NAME] + with tempfile.TemporaryDirectory() as orig_ckpt_dir: + starting_ckpt = make_checkpoint(orig_ckpt_dir, tutorial_val_src, 0) + old_state_dict = torch.load(starting_ckpt)["state_dict"] + # now finetune a the model with the checkpoint from the first job + with tempfile.TemporaryDirectory() as ft_temp_dir: + ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml") + ck_ft_path = os.path.join(ft_temp_dir, "checkpoint_ft.pt") + ft_config = { + "mode": FineTuneMode.DATA_ONLY.name, + "starting_checkpoint": starting_ckpt, + } + run_main_with_ft_hydra(tempdir = ft_temp_dir, + yaml = ft_yml, + data_src = tutorial_val_src, + run_args = {"seed": 1000}, + ft_config = ft_config, + output_checkpoint = ck_ft_path) + assert os.path.isfile(ck_ft_path) + ft_ckpt = torch.load(ck_ft_path) + assert "config" in ft_ckpt + config_model = ft_ckpt["config"]["model"] + assert config_model["name"] == FTHYDRA_NAME + # check that the entire model weights are the same + new_state_dict = ft_ckpt["state_dict"] + assert len(new_state_dict) == len(old_state_dict) + for key in new_state_dict: + assert torch.allclose(new_state_dict[key], old_state_dict[key]) + # check the new checkpoint contains a hydra model + assert FTConfig.STARTING_MODEL in config_model[FTConfig.FT_CONFIG_NAME] def test_finetune_from_finetunehydra(tutorial_val_src): - starting_ckpt, original_tmpdir = make_checkpoint(tutorial_val_src, 0) - # now finetune a the model with the checkpoint from the first job - finetune_run1 = tempfile.TemporaryDirectory() - ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml") - ck_ft_path = os.path.join(finetune_run1.name, "checkpoint_ft.pt") - ft_config_1 = { - "mode": FineTuneMode.DATA_ONLY.name, - "starting_checkpoint": starting_ckpt, - } - run_main_with_ft_hydra(tempdir = finetune_run1, - yaml = ft_yml, - data_src = tutorial_val_src, - run_args = {"seed": 1000}, - ft_config = ft_config_1, - output_checkpoint = ck_ft_path) - assert os.path.isfile(ck_ft_path) - - # now that we have a second checkpoint, try finetuning again from this checkpoint - ######################################################################################## - finetune_run2 = tempfile.TemporaryDirectory() - ck_ft2_path = os.path.join(finetune_run2.name, "checkpoint_ft.pt") - ft_config_2 = { - "mode": FineTuneMode.DATA_ONLY.name, - "starting_checkpoint": ck_ft_path, - } - run_main_with_ft_hydra(tempdir = finetune_run2, - yaml = ft_yml, - data_src = tutorial_val_src, - run_args = {"seed": 1000}, - ft_config = ft_config_2, - output_checkpoint = ck_ft2_path) - ft_ckpt2 = torch.load(ck_ft2_path) - assert "config" in ft_ckpt2 - config_model = ft_ckpt2["config"]["model"] - assert config_model["name"] == FTHYDRA_NAME - old_state_dict = torch.load(ck_ft_path)["state_dict"] - new_state_dict = ft_ckpt2["state_dict"] - # the state dicts should still be identical because we made the LR = 0.0 - for key in new_state_dict: - assert torch.allclose(new_state_dict[key], old_state_dict[key]) + with tempfile.TemporaryDirectory() as orig_ckpt_dir: + starting_ckpt = make_checkpoint(orig_ckpt_dir, tutorial_val_src, 0) + # now finetune a the model with the checkpoint from the first job + with tempfile.TemporaryDirectory() as finetune_run1_dir: + ft_yml = Path("tests/core/models/test_configs/test_finetune_hydra.yml") + ck_ft_path = os.path.join(finetune_run1_dir, "checkpoint_ft.pt") + ft_config_1 = { + "mode": FineTuneMode.DATA_ONLY.name, + "starting_checkpoint": starting_ckpt, + } + run_main_with_ft_hydra(tempdir = finetune_run1_dir, + yaml = ft_yml, + data_src = tutorial_val_src, + run_args = {"seed": 1000}, + ft_config = ft_config_1, + output_checkpoint = ck_ft_path) + assert os.path.isfile(ck_ft_path) + + # now that we have a second checkpoint, try finetuning again from this checkpoint + ######################################################################################## + with tempfile.TemporaryDirectory() as finetune_run2_dir: + ck_ft2_path = os.path.join(finetune_run2_dir, "checkpoint_ft.pt") + ft_config_2 = { + "mode": FineTuneMode.DATA_ONLY.name, + "starting_checkpoint": ck_ft_path, + } + run_main_with_ft_hydra(tempdir = finetune_run2_dir, + yaml = ft_yml, + data_src = tutorial_val_src, + run_args = {"seed": 1000}, + ft_config = ft_config_2, + output_checkpoint = ck_ft2_path) + ft_ckpt2 = torch.load(ck_ft2_path) + assert "config" in ft_ckpt2 + config_model = ft_ckpt2["config"]["model"] + assert config_model["name"] == FTHYDRA_NAME + old_state_dict = torch.load(ck_ft_path)["state_dict"] + new_state_dict = ft_ckpt2["state_dict"] + # the state dicts should still be identical because we made the LR = 0.0 + for key in new_state_dict: + assert torch.allclose(new_state_dict[key], old_state_dict[key])