diff --git a/sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py b/sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py index 0755d6d533a..f8ded73eeff 100644 --- a/sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py +++ b/sdk/python/v1beta1/kubeflow/katib/api/katib_client_test.py @@ -1,6 +1,6 @@ import multiprocessing from typing import List, Optional -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import kubeflow.katib.katib_api_pb2 as katib_api_pb2 import pytest @@ -341,12 +341,12 @@ def __init__( self.secret_key = secret_key -class PyTorchJobSpec: +class KubeflowOrgV1PyTorchJobSpec: def __init__( self, elastic_policy=None, nproc_per_node=None, - pytorch_replica_specs=None, + pytorch_replica_specs={}, run_policy=None, ): self.elastic_policy = elastic_policy @@ -355,13 +355,13 @@ def __init__( self.run_policy = run_policy -class PyTorchJob: +class KubeflowOrgV1PyTorchJob: def __init__( self, api_version=None, kind=None, metadata=None, - spec=PyTorchJobSpec, + spec=KubeflowOrgV1PyTorchJobSpec, status=None, ): self.api_version = api_version @@ -571,6 +571,7 @@ def test_tune(katib_client, test_name, kwargs, expected_output): """ print("\n\nExecuting test:", test_name) + PYTORCHJOB_KIND = "PyTorchJob" JOB_PARAMETERS = { "PyTorchJob": { "model": "KubeflowOrgV1PyTorchJob", @@ -580,7 +581,20 @@ def test_tune(katib_client, test_name, kwargs, expected_output): } } - with patch( + with patch.dict( + "sys.modules", + { + "kubeflow.storage_initializer": Mock(), + "kubeflow.storage_initializer.hugging_face": Mock(), + "kubeflow.storage_initializer.s3": Mock(), + "kubeflow.storage_initializer.constants": Mock(), + "kubeflow.training": MagicMock(), + "kubeflow.training.models": Mock(), + "kubeflow.training.utils": Mock(), + "kubeflow.training.constants": Mock(), + "kubeflow.training.constants.constants": Mock(), + }), \ + patch( "kubeflow.storage_initializer.hugging_face.HuggingFaceModelParams", HuggingFaceModelParams, ), patch( @@ -592,21 +606,16 @@ def test_tune(katib_client, test_name, kwargs, expected_output): ), patch( "kubeflow.storage_initializer.s3.S3DatasetParams", S3DatasetParams ), patch( - "kubeflow.training.models.KubeflowOrgV1PyTorchJob", PyTorchJob + "kubeflow.training.models.KubeflowOrgV1PyTorchJob", KubeflowOrgV1PyTorchJob ), patch( "kubeflow.training.constants.constants.JOB_PARAMETERS", JOB_PARAMETERS + ), patch( + "kubeflow.training.constants.constants.PYTORCHJOB_KIND", PYTORCHJOB_KIND ), patch( "kubeflow.katib.utils.utils.get_trial_substitutions_from_trainer", return_value={"param": "value"}, - ), patch.dict( - "sys.modules", - { - "kubeflow.storage_initializer.constants": Mock(), - "kubeflow.training.models": Mock(), - "kubeflow.training.utils": Mock(), - "kubeflow.training.constants": Mock(), - }, - ), patch.object( + ), \ + patch.object( katib_client, "create_experiment", return_value=Mock() ) as mock_create_experiment: try: