Skip to content

Commit

Permalink
add amp test
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Sep 3, 2024
1 parent 94e4a7f commit 2e206ec
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
16 changes: 13 additions & 3 deletions tests/core/e2e/test_e2e_commons.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def _run_main(
save_checkpoint_to=None,
save_predictions_to=None,
world_size=0,
amp=False,
):
config_yaml = Path(rundir) / "train_and_val_on_val.yml"
update_yaml_with_dict(input_yaml, config_yaml, update_dict_with)
Expand All @@ -124,9 +125,18 @@ def _run_main(

# run
parser = flags.get_parser()
args, override_args = parser.parse_known_args(
["--mode", "train", "--seed", "100", "--config-yml", "config.yml", "--cpu"]
)
command_line_args = [
"--mode",
"train",
"--seed",
"100",
"--config-yml",
"config.yml",
"--cpu",
]
if amp:
command_line_args.append("--amp")
args, override_args = parser.parse_known_args(command_line_args)
for arg_name, arg_value in run_args.items():
setattr(args, arg_name, arg_value)
config = build_config(args, override_args)
Expand Down
18 changes: 17 additions & 1 deletion tests/core/e2e/test_s2ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,13 @@

class TestSmoke:
def smoke_test_train(
self, input_yaml, tutorial_val_src, world_size, num_workers, otf_norms=False
self,
input_yaml,
tutorial_val_src,
world_size,
num_workers,
otf_norms=False,
amp=False,
):
with tempfile.TemporaryDirectory() as tempdirname:
# first train a very simple model, checkpoint
Expand Down Expand Up @@ -60,6 +66,7 @@ def smoke_test_train(
save_checkpoint_to=checkpoint_path,
save_predictions_to=training_predictions_filename,
world_size=world_size,
amp=amp,
)
assert "train/energy_mae" in acc.Tags()["scalars"]
assert "val/energy_mae" in acc.Tags()["scalars"]
Expand Down Expand Up @@ -197,6 +204,15 @@ def test_train_and_predict(
configs,
tutorial_val_src,
):
# test with amp
self.smoke_test_train(
input_yaml=configs[model_name],
tutorial_val_src=tutorial_val_src,
otf_norms=otf_norms,
world_size=0,
num_workers=1,
amp=True,
)
# test without ddp
self.smoke_test_train(
input_yaml=configs[model_name],
Expand Down

0 comments on commit 2e206ec

Please sign in to comment.