diff --git a/source/tests/pt/test_multitask.py b/source/tests/pt/test_multitask.py index c7a2784367..0c1abf1f44 100644 --- a/source/tests/pt/test_multitask.py +++ b/source/tests/pt/test_multitask.py @@ -33,20 +33,16 @@ def test_multitask_train(self): trainer.run() # check model keys self.assertEqual(len(trainer.wrapper.model), 2) - self.assertTrue("model_1" in trainer.wrapper.model) - self.assertTrue("model_2" in trainer.wrapper.model) + self.assertIn("model_1", trainer.wrapper.model) + self.assertIn("model_2", trainer.wrapper.model) # check shared parameters multi_state_dict = trainer.wrapper.model.state_dict() for state_key in multi_state_dict: if "model_1" in state_key: - self.assertTrue( - state_key.replace("model_1", "model_2") in multi_state_dict - ) + self.assertIn(state_key.replace("model_1", "model_2"), multi_state_dict) if "model_2" in state_key: - self.assertTrue( - state_key.replace("model_2", "model_1") in multi_state_dict - ) + self.assertIn(state_key.replace("model_2", "model_1"), multi_state_dict) if "model_1.descriptor" in state_key: torch.testing.assert_allclose( multi_state_dict[state_key],