Skip to content

Commit

Permalink
Take advice from QL scan
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd committed Feb 28, 2024
1 parent 2618d98 commit f1585b2
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions source/tests/pt/test_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,16 @@ def test_multitask_train(self):
trainer.run()
# check model keys
self.assertEqual(len(trainer.wrapper.model), 2)
self.assertTrue("model_1" in trainer.wrapper.model)
self.assertTrue("model_2" in trainer.wrapper.model)
self.assertIn("model_1", trainer.wrapper.model)
self.assertIn("model_2", trainer.wrapper.model)

# check shared parameters
multi_state_dict = trainer.wrapper.model.state_dict()
for state_key in multi_state_dict:
if "model_1" in state_key:
self.assertTrue(
state_key.replace("model_1", "model_2") in multi_state_dict
)
self.assertIn(state_key.replace("model_1", "model_2"), multi_state_dict)
if "model_2" in state_key:
self.assertTrue(
state_key.replace("model_2", "model_1") in multi_state_dict
)
self.assertIn(state_key.replace("model_2", "model_1"), multi_state_dict)
if "model_1.descriptor" in state_key:
torch.testing.assert_allclose(
multi_state_dict[state_key],
Expand Down

0 comments on commit f1585b2

Please sign in to comment.