Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix gemnet scaling factors fit.py and add a test #819

Merged
merged 3 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 28 additions & 36 deletions src/fairchem/core/modules/scaling/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,18 +37,9 @@ def _train_batch(trainer: BaseTrainer, batch) -> None:
del out, loss


def main(*, num_batches: int = 16) -> None:
# region args/config setup
setup_logging()

parser = flags.get_parser()
args, override_args = parser.parse_known_args()
_config = build_config(args, override_args)
_config["logger"] = "wandb"
# endregion
def compute_scaling_factors(config, num_batches: int = 16) -> None:

assert not args.distributed, "This doesn't work with DDP"
with new_trainer_context(args=args, config=_config) as ctx:
with new_trainer_context(config=config) as ctx:
config = ctx.config
trainer = ctx.trainer

Expand All @@ -61,8 +52,8 @@ def main(*, num_batches: int = 16) -> None:
logging.info(f"Input checkpoint path: {ckpt_file}, {ckpt_file.exists()=}")

model: nn.Module = trainer.model
val_loader = trainer.val_loader
assert val_loader is not None, "Val dataset is required for making predictions"
data_loader = trainer.train_loader
assert data_loader is not None, "Train set required to load batches"

if ckpt_file.exists():
trainer.load_checkpoint(checkpoint_path=str(ckpt_file))
Expand Down Expand Up @@ -122,15 +113,8 @@ def main(*, num_batches: int = 16) -> None:
sys.exit(-1)
# endregion

# region get the output path
out_path = Path(
_prefilled_input(
"Enter output path for fitted scale factors: ",
prefill=str(ckpt_file),
)
)
if out_path.exists():
logging.warning(f"Already found existing file: {out_path}")
if ckpt_file.exists():
logging.warning(f"Already found existing file: {ckpt_file}")
flag = input(
"Do you want to continue and overwrite existing file (1), "
"or exit (2)? "
Expand All @@ -142,7 +126,7 @@ def main(*, num_batches: int = 16) -> None:
sys.exit()

logging.info(
f"Output path for fitted scale factors: {out_path}, {out_path.exists()=}"
f"Output path for fitted scale factors: {ckpt_file}, {ckpt_file.exists()=}"
)
# endregion

Expand Down Expand Up @@ -175,7 +159,7 @@ def index_fn(name: str = name) -> None:
module.initialize_(index_fn=index_fn)

# single pass through network
_train_batch(trainer, next(iter(val_loader)))
_train_batch(trainer, next(iter(data_loader)))

# sort the scale factors by their computation order
sorted_factors = sorted(
Expand All @@ -200,7 +184,7 @@ def index_fn(name: str = name) -> None:

logging.info(f"Fitting {name}...")
with module.fit_context_():
for batch in islice(val_loader, num_batches):
for batch in islice(data_loader, num_batches):
_train_batch(trainer, batch)
stats, ratio, value = module.fit_()

Expand All @@ -216,19 +200,27 @@ def index_fn(name: str = name) -> None:
assert module.fitted, f"{name} is not fitted"

# region save the scale factors to the checkpoint file
trainer.config["cmd"]["checkpoint_dir"] = out_path.parent
trainer.config["cmd"]["checkpoint_dir"] = ckpt_file.parent
trainer.is_debug = False
out_file = trainer.save(
metrics=None,
checkpoint_file=out_path.name,
training_state=False,

torch.save(
{
x[0].replace(".scale_factor", ""): x[1]
for x in trainer.model.to("cpu").named_parameters()
if ".scale_" in x[0]
},
str(ckpt_file),
)
assert out_file is not None, "Failed to save checkpoint"
out_file = Path(out_file)
assert out_file.exists(), f"Failed to save checkpoint to {out_file}"
# endregion
logging.info(f"Saved results to: {out_file}")
logging.info(f"Saved results to: {ckpt_file}")


if __name__ == "__main__":
main()
# region args/config setup
setup_logging()

parser = flags.get_parser()
args, override_args = parser.parse_known_args()
assert not args.distributed, "This doesn't work with DDP"
config = build_config(args, override_args)

compute_scaling_factors(config)
19 changes: 11 additions & 8 deletions tests/core/e2e/test_e2e_commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,16 @@ def merge_dictionary(d, u):
return d


def update_yaml_with_dict(input_yaml, output_yaml, update_dict_with):
with open(input_yaml) as yaml_file:
yaml_config = yaml.safe_load(yaml_file)
if update_dict_with is not None:
yaml_config = merge_dictionary(yaml_config, update_dict_with)
yaml_config["backend"] = "gloo"
with open(str(output_yaml), "w") as yaml_file:
yaml.dump(yaml_config, yaml_file)


def _run_main(
rundir,
input_yaml,
Expand All @@ -103,14 +113,7 @@ def _run_main(
world_size=0,
):
config_yaml = Path(rundir) / "train_and_val_on_val.yml"

with open(input_yaml) as yaml_file:
yaml_config = yaml.safe_load(yaml_file)
if update_dict_with is not None:
yaml_config = merge_dictionary(yaml_config, update_dict_with)
yaml_config["backend"] = "gloo"
with open(str(config_yaml), "w") as yaml_file:
yaml.dump(yaml_config, yaml_file)
update_yaml_with_dict(input_yaml, config_yaml, update_dict_with)
run_args = {
"run_dir": rundir,
"logdir": f"{rundir}/logs",
Expand Down
67 changes: 65 additions & 2 deletions tests/core/e2e/test_s2ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,19 @@
import numpy as np
import numpy.testing as npt
import pytest
from test_e2e_commons import _run_main, oc20_lmdb_train_and_val_from_paths
from fairchem.core._cli import Runner
from fairchem.core.modules.scaling.fit import compute_scaling_factors
from test_e2e_commons import (
_run_main,
oc20_lmdb_train_and_val_from_paths,
update_yaml_with_dict,
)

from fairchem.core.common.utils import setup_logging
from fairchem.core.common.utils import build_config, setup_logging
from fairchem.core.scripts.make_lmdb_sizes import get_lmdb_sizes_parser, make_lmdb_sizes

from fairchem.core.common.flags import flags

setup_logging()


Expand Down Expand Up @@ -98,6 +106,61 @@ def smoke_test_train(
energy_from_train, energy_from_checkpoint, rtol=1e-6, atol=1e-6
)

def test_gemnet_fit_scaling(self, configs, tutorial_val_src):

with tempfile.TemporaryDirectory() as tempdirname:
# (1) generate scaling factors for gemnet config
config_yaml = f"{tempdirname}/train_and_val_on_val.yml"
scaling_pt = f"{tempdirname}/scaling.pt"
# run
parser = flags.get_parser()
args, override_args = parser.parse_known_args(
[
"--mode",
"train",
"--seed",
"100",
"--config-yml",
config_yaml,
"--cpu",
"--checkpoint",
scaling_pt,
]
)
update_yaml_with_dict(
configs["gemnet_oc"],
config_yaml,
update_dict_with={
"dataset": oc20_lmdb_train_and_val_from_paths(
train_src=str(tutorial_val_src),
val_src=str(tutorial_val_src),
test_src=str(tutorial_val_src),
),
},
)
config = build_config(args, override_args)

# (2) if existing scaling factors are present remove them
if "scale_file" in config["model"]:
config["model"].pop("scale_file")

compute_scaling_factors(config)

# (3) try to run the config with the newly generated scaling factors
_ = _run_main(
rundir=tempdirname,
update_dict_with={
"optim": {"max_epochs": 1},
"model": {"use_pbc_single": True, "scale_file": scaling_pt},
"dataset": oc20_lmdb_train_and_val_from_paths(
train_src=str(tutorial_val_src),
val_src=str(tutorial_val_src),
test_src=str(tutorial_val_src),
),
},
input_yaml=configs["gemnet_oc"],
)

# not all models are tested with otf normalization estimation
# only gemnet_oc, escn, equiformer, and their hydra versions
@pytest.mark.parametrize(
Expand Down
Loading