From 17c568aef32c24d1ec82a7d6e865cfdf78e8dd6f Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sat, 2 Mar 2024 18:44:00 -0500 Subject: [PATCH] pt: ban torch.testing.assert_allclose See https://github.com/pytorch/pytorch/issues/61844 Signed-off-by: Jinzhe Zeng --- pyproject.toml | 4 ++++ source/tests/pt/test_multitask.py | 2 +- source/tests/pt/test_training.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b1d110ff0a..36851b1401 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -237,6 +237,7 @@ select = [ "C4", # flake8-comprehensions "RUF", # ruff "NPY", # numpy + "TID251", # banned-api "TID253", # banned-module-level-imports ] @@ -272,6 +273,9 @@ banned-module-level-imports = [ "torch", ] +[tool.ruff.lint.flake8-tidy-imports.banned-api] +"torch.testing.assert_allclose".msg = "Use `torch.testing.assert_close()` instead, see https://github.com/pytorch/pytorch/issues/61844." + [tool.ruff.lint.extend-per-file-ignores] # Also ignore `E402` in all `__init__.py` files. "deepmd/tf/**" = ["TID253"] diff --git a/source/tests/pt/test_multitask.py b/source/tests/pt/test_multitask.py index 3c0240dbdc..d06733b016 100644 --- a/source/tests/pt/test_multitask.py +++ b/source/tests/pt/test_multitask.py @@ -47,7 +47,7 @@ def test_multitask_train(self): if "model_2" in state_key: self.assertIn(state_key.replace("model_2", "model_1"), multi_state_dict) if "model_1.descriptor" in state_key: - torch.testing.assert_allclose( + torch.testing.assert_close( multi_state_dict[state_key], multi_state_dict[state_key.replace("model_1", "model_2")], ) diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index b9a42385dc..13e47a953b 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -39,7 +39,7 @@ def test_trainable(self): trainer_fix.run() model_dict_after_training = deepcopy(trainer_fix.model.state_dict()) for key in model_dict_before_training: - torch.testing.assert_allclose( + torch.testing.assert_close( model_dict_before_training[key], model_dict_after_training[key] ) self.tearDown()