Skip to content

Commit

Permalink
KEP-2170: Add unit and E2E tests for model and dataset initializers
Browse files Browse the repository at this point in the history
Signed-off-by: wei-chenglai <[email protected]>
  • Loading branch information
seanlaii committed Dec 16, 2024
1 parent 0b8fb3e commit 12451ad
Show file tree
Hide file tree
Showing 13 changed files with 844 additions and 4 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/integration-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ jobs:
env:
GANG_SCHEDULER_NAME: ${{ matrix.gang-scheduler-name }}

- name: Run initializer_v2 e2e for Python 3.11+
if: ${{ matrix.python-version == '3.11' }}
run: |
pip install urllib3 huggingface_hub
pip install -U './sdk_v2'
pytest ./test/e2e/initializer_v2
- name: Collect volcano logs
if: ${{ failure() && matrix.gang-scheduler-name == 'volcano' }}
run: |
Expand Down
11 changes: 9 additions & 2 deletions .github/workflows/test-python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,14 @@ jobs:
- name: Install dependencies
run: |
pip install pytest python-dateutil urllib3 kubernetes
pip install -U './sdk/python[huggingface]'
pip install -U './sdk/python[huggingface]' './sdk_v2'
- name: Run unit test for training sdk
run: pytest ./sdk/python/kubeflow/training/api/training_client_test.py
run: |
pytest ./sdk/python/kubeflow/training/api/training_client_test.py
- name: Run Python unit tests for v2
run: |
pytest ./pkg/initializer_v2/model
pytest ./pkg/initializer_v2/dataset
pytest ./pkg/initializer_v2/utils
7 changes: 6 additions & 1 deletion pkg/initializer_v2/dataset/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
level=logging.INFO,
)

if __name__ == "__main__":

def main():
logging.info("Starting dataset initialization")

try:
Expand All @@ -29,3 +30,7 @@
case _:
logging.error("STORAGE_URI must have the valid dataset provider")
raise Exception


if __name__ == "__main__":
main()
144 changes: 144 additions & 0 deletions pkg/initializer_v2/dataset/huggingface_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
from unittest.mock import MagicMock, patch

import pytest
from kubeflow.training import DATASET_PATH

import pkg.initializer_v2.utils.utils as utils


@pytest.fixture
def huggingface_dataset_instance():
"""Fixture for HuggingFace Dataset instance"""
from pkg.initializer_v2.dataset.huggingface import HuggingFace

return HuggingFace()


# Test cases for config loading
@pytest.mark.parametrize(
"test_name, test_config, expected",
[
(
"Full config with token",
{"storage_uri": "hf://dataset/path", "access_token": "test_token"},
{"storage_uri": "hf://dataset/path", "access_token": "test_token"},
),
(
"Minimal config without token",
{"storage_uri": "hf://dataset/path"},
{"storage_uri": "hf://dataset/path", "access_token": None},
),
],
)
def test_load_config(test_name, test_config, expected, huggingface_dataset_instance):
"""Test config loading with different configurations"""
print(f"Running test: {test_name}")

with patch.object(utils, "get_config_from_env", return_value=test_config):
huggingface_dataset_instance.load_config()
assert (
huggingface_dataset_instance.config.storage_uri == expected["storage_uri"]
)
assert (
huggingface_dataset_instance.config.access_token == expected["access_token"]
)

print("Test execution completed")


@pytest.mark.parametrize(
"test_name, test_case",
[
(
"Successful download with token",
{
"config": {
"storage_uri": "hf://username/dataset-name",
"access_token": "test_token",
},
"should_login": True,
"expected_repo_id": "username/dataset-name",
"mock_login_side_effect": None,
"mock_download_side_effect": None,
"expected_error": None,
},
),
(
"Successful download without token",
{
"config": {"storage_uri": "hf://org/dataset-v1", "access_token": None},
"should_login": False,
"expected_repo_id": "org/dataset-v1",
"mock_login_side_effect": None,
"mock_download_side_effect": None,
"expected_error": None,
},
),
(
"Login failure",
{
"config": {
"storage_uri": "hf://username/dataset-name",
"access_token": "test_token",
},
"should_login": True,
"expected_repo_id": "username/dataset-name",
"mock_login_side_effect": Exception,
"mock_download_side_effect": None,
"expected_error": Exception,
},
),
(
"Download failure",
{
"config": {
"storage_uri": "hf://invalid/repo/name",
"access_token": None,
},
"should_login": False,
"expected_repo_id": "invalid/repo/name",
"mock_login_side_effect": None,
"mock_download_side_effect": Exception,
"expected_error": Exception,
},
),
],
)
def test_download_dataset(test_name, test_case, huggingface_dataset_instance):
"""Test dataset download with different configurations"""

print(f"Running test: {test_name}")

huggingface_dataset_instance.config = MagicMock(**test_case["config"])

with patch("huggingface_hub.login") as mock_login, patch(
"huggingface_hub.snapshot_download"
) as mock_download:

# Configure mock behavior
if test_case["mock_login_side_effect"]:
mock_login.side_effect = test_case["mock_login_side_effect"]
if test_case["mock_download_side_effect"]:
mock_download.side_effect = test_case["mock_download_side_effect"]

# Execute test
if test_case["expected_error"]:
with pytest.raises(test_case["expected_error"]):
huggingface_dataset_instance.download_dataset()
else:
huggingface_dataset_instance.download_dataset()

# Verify login behavior
if test_case["should_login"]:
mock_login.assert_called_once_with(test_case["config"]["access_token"])
else:
mock_login.assert_not_called()

# Verify download parameters
if test_case["expected_repo_id"]:
mock_download.assert_called_once_with(
repo_id=test_case["expected_repo_id"],
local_dir=DATASET_PATH,
repo_type="dataset",
)
print("Test execution completed")
122 changes: 122 additions & 0 deletions pkg/initializer_v2/dataset/main_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import os
from unittest.mock import MagicMock, patch

import pytest

from pkg.initializer_v2.dataset.__main__ import main


@pytest.fixture
def mock_env_vars():
"""Fixture to set and clean up environment variables"""
original_env = dict(os.environ)

def _set_env_vars(**kwargs):
for key, value in kwargs.items():
if value is None:
os.environ.pop(key, None)
else:
os.environ[key] = str(value)
return os.environ

yield _set_env_vars

# Cleanup
os.environ.clear()
os.environ.update(original_env)


@pytest.mark.parametrize(
"test_name, test_case",
[
(
"Successful download with HuggingFace provider",
{
"storage_uri": "hf://dataset/path",
"access_token": "test_token",
"mock_config_error": False,
"mock_download_error": False,
"expected_error": None,
},
),
(
"Missing storage URI environment variable",
{
"storage_uri": None,
"access_token": None,
"mock_config_error": False,
"mock_download_error": False,
"expected_error": Exception,
},
),
(
"Invalid storage URI scheme",
{
"storage_uri": "invalid://dataset/path",
"access_token": None,
"mock_config_error": False,
"mock_download_error": False,
"expected_error": Exception,
},
),
(
"Config loading failure",
{
"storage_uri": "hf://dataset/path",
"access_token": None,
"mock_config_error": True,
"mock_download_error": False,
"expected_error": Exception,
},
),
(
"Dataset download failure",
{
"storage_uri": "hf://dataset/path/error",
"access_token": None,
"mock_config_error": False,
"mock_download_error": True,
"expected_error": Exception,
},
),
],
)
def test_dataset_main(test_name, test_case, mock_env_vars):
"""Test main script with different scenarios"""
print(f"Running test: {test_name}")

# Setup mock environment variables
env_vars = {
"STORAGE_URI": test_case["storage_uri"],
"ACCESS_TOKEN": test_case["access_token"],
}
mock_env_vars(**env_vars)

# Setup mock HuggingFace instance
mock_hf_instance = MagicMock()
if test_case["mock_config_error"]:
mock_hf_instance.load_config.side_effect = Exception
if test_case["mock_download_error"]:
mock_hf_instance.download_dataset.side_effect = Exception

with patch(
"pkg.initializer_v2.dataset.__main__.HuggingFace",
return_value=mock_hf_instance,
) as mock_hf:

# Execute test
if test_case["expected_error"]:
with pytest.raises(test_case["expected_error"]):
main()
else:
main()

# Verify HuggingFace instance methods were called
mock_hf_instance.load_config.assert_called_once()
mock_hf_instance.download_dataset.assert_called_once()

# Verify HuggingFace class instantiation
if test_case["storage_uri"] and test_case["storage_uri"].startswith("hf://"):
mock_hf.assert_called_once()

print("Test execution completed")
7 changes: 6 additions & 1 deletion pkg/initializer_v2/model/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
level=logging.INFO,
)

if __name__ == "__main__":

def main():
logging.info("Starting pre-trained model initialization")

try:
Expand All @@ -31,3 +32,7 @@
f"STORAGE_URI must have the valid model provider. STORAGE_URI: {storage_uri}"
)
raise Exception


if __name__ == "__main__":
main()
Loading

0 comments on commit 12451ad

Please sign in to comment.