Skip to content

Commit

Permalink
debug
Browse files Browse the repository at this point in the history
  • Loading branch information
bachvudinh committed Jul 18, 2024
1 parent 87d69af commit 536cd58
Showing 1 changed file with 16 additions and 19 deletions.
35 changes: 16 additions & 19 deletions tests/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,61 +89,58 @@ class TestS3Helper(unittest.TestCase):
@classmethod
def setUpClass(cls):
# Set up any necessary test environment
os.environ['S3_ACCESS_KEY'] = 'test_access_key'
os.environ['S3_SECRET_KEY'] = 'test_secret_key'
os.environ['S3_ENDPOINT_URL'] = 'http://test.endpoint:9000'
cls.model_name_or_path = "jan-hq-test/tinyllama-v1.1"
cls.dataset_name_or_path = "jan-hq-test/test-dataset"


@test_name("S3Helper Initialization")
@test_name("Connect to Minio")
def test_s3helper_initialization(self):
with patch('s3helper.S3Helper') as mock_s3helper:
S3Helper()
mock_s3helper.assert_called_once()

@test_name("AutoTokenizer from_pretrained")
@test_name("Load tokenizer from Minio")
def test_auto_tokenizer_from_pretrained(self):
with patch('s3helper.S3HelperAutoTokenizer.from_pretrained') as mock_from_pretrained:
model_name = "jan-hq-test/tokenizer-tinyllama"
mock_tokenizer = MagicMock()
mock_from_pretrained.return_value = mock_tokenizer

tokenizer = S3HelperAutoTokenizer.from_pretrained(model_name)
tokenizer = S3HelperAutoTokenizer.from_pretrained(self.model_name_or_path)

mock_from_pretrained.assert_called_once_with(model_name)
mock_from_pretrained.assert_called_once_with(self.model_name_or_path)
self.assertEqual(tokenizer, mock_tokenizer)

@test_name("s3_load_dataset")
@test_name("Load dataset from Minio")
def test_s3_load_dataset(self):
with patch('s3helper.s3_load_dataset') as mock_load_dataset:
mock_dataset = MagicMock()
mock_load_dataset.return_value = mock_dataset

dataset = s3_load_dataset("jan-hq-test/test-dataset", file_format='parquet', split='train')
dataset = s3_load_dataset(self.dataset_name_or_path, file_format='parquet', split='train')

mock_load_dataset.assert_called_once_with("jan-hq-test/test-dataset", file_format='parquet', split='train')
mock_load_dataset.assert_called_once_with(self.dataset_name_or_path, file_format='parquet', split='train')
self.assertEqual(dataset, mock_dataset)

@test_name("AutoModelForCausalLM from_pretrained")
@test_name("Load Causal LM model from Minio")
def test_auto_model_for_causal_lm_from_pretrained(self):
with patch('s3helper.S3HelperAutoModelForCausalLM.from_pretrained') as mock_from_pretrained:
model_name = "jan-hq-test/tokenizer-tinyllama"
mock_model = MagicMock()
mock_from_pretrained.return_value = mock_model

model = S3HelperAutoModelForCausalLM.from_pretrained(model_name)
model = S3HelperAutoModelForCausalLM.from_pretrained(self.model_name_or_path)

mock_from_pretrained.assert_called_once_with(model_name)
mock_from_pretrained.assert_called_once_with(self.model_name_or_path)
self.assertEqual(model, mock_model)

@test_name("AutoConfig from_pretrained")
@test_name("Load Model Config from Minio")
def test_auto_config_from_pretrained(self):
with patch('s3helper.S3HelperAutoConfig.from_pretrained') as mock_from_pretrained:
model_name = "jan-hq-test/tokenizer-tinyllama"
mock_config = MagicMock()
mock_from_pretrained.return_value = mock_config

config = S3HelperAutoConfig.from_pretrained(model_name)
config = S3HelperAutoConfig.from_pretrained(self.model_name_or_path)

mock_from_pretrained.assert_called_once_with(model_name)
mock_from_pretrained.assert_called_once_with(self.model_name_or_path)
self.assertEqual(config, mock_config)

if __name__ == "__main__":
Expand Down

0 comments on commit 536cd58

Please sign in to comment.