Skip to content

Commit

Permalink
update unit test
Browse files Browse the repository at this point in the history
Signed-off-by: helenxie-bit <[email protected]>
  • Loading branch information
helenxie-bit committed Sep 11, 2024
1 parent 5ddcc30 commit 4909456
Showing 1 changed file with 25 additions and 16 deletions.
41 changes: 25 additions & 16 deletions sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import multiprocessing
from typing import List, Optional
from unittest.mock import Mock, patch
from unittest.mock import MagicMock, Mock, patch

import kubeflow.katib.katib_api_pb2 as katib_api_pb2
import pytest
Expand Down Expand Up @@ -341,12 +341,12 @@ def __init__(
self.secret_key = secret_key


class PyTorchJobSpec:
class KubeflowOrgV1PyTorchJobSpec:
def __init__(
self,
elastic_policy=None,
nproc_per_node=None,
pytorch_replica_specs=None,
pytorch_replica_specs={},
run_policy=None,
):
self.elastic_policy = elastic_policy
Expand All @@ -355,13 +355,13 @@ def __init__(
self.run_policy = run_policy


class PyTorchJob:
class KubeflowOrgV1PyTorchJob:
def __init__(
self,
api_version=None,
kind=None,
metadata=None,
spec=PyTorchJobSpec,
spec=KubeflowOrgV1PyTorchJobSpec,
status=None,
):
self.api_version = api_version
Expand Down Expand Up @@ -571,6 +571,7 @@ def test_tune(katib_client, test_name, kwargs, expected_output):
"""
print("\n\nExecuting test:", test_name)

PYTORCHJOB_KIND = "PyTorchJob"
JOB_PARAMETERS = {
"PyTorchJob": {
"model": "KubeflowOrgV1PyTorchJob",
Expand All @@ -580,7 +581,20 @@ def test_tune(katib_client, test_name, kwargs, expected_output):
}
}

with patch(
with patch.dict(
"sys.modules",
{
"kubeflow.storage_initializer": Mock(),
"kubeflow.storage_initializer.hugging_face": Mock(),
"kubeflow.storage_initializer.s3": Mock(),
"kubeflow.storage_initializer.constants": Mock(),
"kubeflow.training": MagicMock(),
"kubeflow.training.models": Mock(),
"kubeflow.training.utils": Mock(),
"kubeflow.training.constants": Mock(),
"kubeflow.training.constants.constants": Mock(),
}), \
patch(
"kubeflow.storage_initializer.hugging_face.HuggingFaceModelParams",
HuggingFaceModelParams,
), patch(
Expand All @@ -592,21 +606,16 @@ def test_tune(katib_client, test_name, kwargs, expected_output):
), patch(
"kubeflow.storage_initializer.s3.S3DatasetParams", S3DatasetParams
), patch(
"kubeflow.training.models.KubeflowOrgV1PyTorchJob", PyTorchJob
"kubeflow.training.models.KubeflowOrgV1PyTorchJob", KubeflowOrgV1PyTorchJob
), patch(
"kubeflow.training.constants.constants.JOB_PARAMETERS", JOB_PARAMETERS
), patch(
"kubeflow.training.constants.constants.PYTORCHJOB_KIND", PYTORCHJOB_KIND
), patch(
"kubeflow.katib.utils.utils.get_trial_substitutions_from_trainer",
return_value={"param": "value"},
), patch.dict(
"sys.modules",
{
"kubeflow.storage_initializer.constants": Mock(),
"kubeflow.training.models": Mock(),
"kubeflow.training.utils": Mock(),
"kubeflow.training.constants": Mock(),
},
), patch.object(
), \
patch.object(
katib_client, "create_experiment", return_value=Mock()
) as mock_create_experiment:
try:
Expand Down

0 comments on commit 4909456

Please sign in to comment.