From 2e206ec2cc98fc26e26ff010c4c1f570071a8931 Mon Sep 17 00:00:00 2001 From: Misko Date: Tue, 3 Sep 2024 21:27:18 +0000 Subject: [PATCH] add amp test --- tests/core/e2e/test_e2e_commons.py | 16 +++++++++++++--- tests/core/e2e/test_s2ef.py | 18 +++++++++++++++++- 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/tests/core/e2e/test_e2e_commons.py b/tests/core/e2e/test_e2e_commons.py index ef2b860bf..88fe82b1c 100644 --- a/tests/core/e2e/test_e2e_commons.py +++ b/tests/core/e2e/test_e2e_commons.py @@ -111,6 +111,7 @@ def _run_main( save_checkpoint_to=None, save_predictions_to=None, world_size=0, + amp=False, ): config_yaml = Path(rundir) / "train_and_val_on_val.yml" update_yaml_with_dict(input_yaml, config_yaml, update_dict_with) @@ -124,9 +125,18 @@ def _run_main( # run parser = flags.get_parser() - args, override_args = parser.parse_known_args( - ["--mode", "train", "--seed", "100", "--config-yml", "config.yml", "--cpu"] - ) + command_line_args = [ + "--mode", + "train", + "--seed", + "100", + "--config-yml", + "config.yml", + "--cpu", + ] + if amp: + command_line_args.append("--amp") + args, override_args = parser.parse_known_args(command_line_args) for arg_name, arg_value in run_args.items(): setattr(args, arg_name, arg_value) config = build_config(args, override_args) diff --git a/tests/core/e2e/test_s2ef.py b/tests/core/e2e/test_s2ef.py index 10e3203c9..12c1b74d7 100644 --- a/tests/core/e2e/test_s2ef.py +++ b/tests/core/e2e/test_s2ef.py @@ -32,7 +32,13 @@ class TestSmoke: def smoke_test_train( - self, input_yaml, tutorial_val_src, world_size, num_workers, otf_norms=False + self, + input_yaml, + tutorial_val_src, + world_size, + num_workers, + otf_norms=False, + amp=False, ): with tempfile.TemporaryDirectory() as tempdirname: # first train a very simple model, checkpoint @@ -60,6 +66,7 @@ def smoke_test_train( save_checkpoint_to=checkpoint_path, save_predictions_to=training_predictions_filename, world_size=world_size, + amp=amp, ) assert "train/energy_mae" in acc.Tags()["scalars"] assert "val/energy_mae" in acc.Tags()["scalars"] @@ -197,6 +204,15 @@ def test_train_and_predict( configs, tutorial_val_src, ): + # test with amp + self.smoke_test_train( + input_yaml=configs[model_name], + tutorial_val_src=tutorial_val_src, + otf_norms=otf_norms, + world_size=0, + num_workers=1, + amp=True, + ) # test without ddp self.smoke_test_train( input_yaml=configs[model_name],