Skip to content

Commit

Permalink
fix tearDown
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Jan 26, 2024
1 parent 228fec3 commit 2079866
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
9 changes: 9 additions & 0 deletions source/tests/pt/test_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def setUp(self):
self.config["training"]["numb_steps"] = 10
self.config["training"]["save_freq"] = 10

def tearDown(self):
JITTest.tearDown(self)


@unittest.skip("hybrid not supported at the moment")
class TestEnergyModelHybrid(unittest.TestCase, JITTest):
Expand All @@ -111,6 +114,9 @@ def setUp(self):
self.config["training"]["numb_steps"] = 10
self.config["training"]["save_freq"] = 10

def tearDown(self):
JITTest.tearDown(self)


@unittest.skip("hybrid not supported at the moment")
class TestEnergyModelHybrid2(unittest.TestCase, JITTest):
Expand All @@ -126,6 +132,9 @@ def setUp(self):
self.config["training"]["numb_steps"] = 10
self.config["training"]["save_freq"] = 10

def tearDown(self):
JITTest.tearDown(self)


if __name__ == "__main__":
unittest.main()
12 changes: 12 additions & 0 deletions source/tests/pt/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def setUp(self):
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1

def tearDown(self) -> None:
DPTrainTest.tearDown(self)


class TestEnergyModelDPA1(unittest.TestCase, DPTrainTest):
def setUp(self):
Expand All @@ -63,6 +66,9 @@ def setUp(self):
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1

def tearDown(self) -> None:
DPTrainTest.tearDown(self)


class TestEnergyModelDPA2(unittest.TestCase, DPTrainTest):
def setUp(self):
Expand All @@ -85,6 +91,9 @@ def setUp(self):
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1

def tearDown(self) -> None:
DPTrainTest.tearDown(self)


@unittest.skip("hybrid not supported at the moment")
class TestEnergyModelHybrid(unittest.TestCase, DPTrainTest):
Expand All @@ -99,6 +108,9 @@ def setUp(self):
self.config["training"]["numb_steps"] = 1
self.config["training"]["save_freq"] = 1

def tearDown(self) -> None:
DPTrainTest.tearDown(self)


if __name__ == "__main__":
unittest.main()

0 comments on commit 2079866

Please sign in to comment.