Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
bachvudinh committed Jul 18, 2024
1 parent db599fe commit 26832a1
Showing 1 changed file with 12 additions and 18 deletions.
30 changes: 12 additions & 18 deletions tests/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,16 @@ def setUpClass(cls):
S3Helper()

@test_name("Tokenizer Loading Test")
@patch('s3helper.S3HelperAutoTokenizer.from_pretrained')
def test_tokenizer_loading(self, mock_from_pretrained):
mock_tokenizer = MagicMock()
mock_from_pretrained.return_value = mock_tokenizer
# @patch('s3helper.S3HelperAutoTokenizer.from_pretrained')
def test_tokenizer_loading(self):
# mock_tokenizer = MagicMock()
# mock_from_pretrained.return_value = mock_tokenizer

tokenizer = S3HelperAutoTokenizer.from_pretrained(self.model_name)

mock_from_pretrained.assert_called_once_with(self.model_name)
# mock_from_pretrained.assert_called_once_with(self.model_name)
self.assertIsNotNone(tokenizer)
self.assertEqual(tokenizer, mock_tokenizer)
# self.assertEqual(tokenizer, mock_tokenizer)

@test_name("Dataset Loading Test")
# @patch('s3helper.s3_load_dataset')
Expand All @@ -107,28 +107,22 @@ def test_dataset_loading(self):
except Exception as e:
self.fail(f"s3_load_dataset raised an exception: {e}")
@test_name("Config Loading Test")
@patch('s3helper.S3HelperAutoConfig.from_pretrained')
def test_config_loading(self, mock_from_pretrained):
mock_config = MagicMock()
mock_from_pretrained.return_value = mock_config
def test_config_loading(self):

config = S3HelperAutoConfig.from_pretrained(self.model_name)

mock_from_pretrained.assert_called_once_with(self.model_name)
self.assertIsNotNone(config)
self.assertEqual(config, mock_config)

@test_name("Model Loading Test")
@patch('s3helper.S3HelperAutoModelForCausalLM.from_pretrained')
def test_model_loading(self, mock_from_pretrained):
mock_model = MagicMock()
mock_from_pretrained.return_value = mock_model
def test_model_loading(self):
# mock_model = MagicMock()
# mock_from_pretrained.return_value = mock_model

model = S3HelperAutoModelForCausalLM.from_pretrained(self.model_name)

mock_from_pretrained.assert_called_once_with(self.model_name)
# mock_from_pretrained.assert_called_once_with(self.model_name)
self.assertIsNotNone(model)
self.assertEqual(model, mock_model)
# self.assertEqual(model, mock_model)
if __name__ == '__main__':
runner = CustomTestRunner()
test_suite = unittest.TestLoader().loadTestsFromTestCase(TestS3Helper)
Expand Down

0 comments on commit 26832a1

Please sign in to comment.