From 4d8f77e2d4d1d4f3c9b27cf0806f02abff78376e Mon Sep 17 00:00:00 2001 From: wei-chenglai Date: Sun, 22 Sep 2024 22:08:53 -0400 Subject: [PATCH] [SDK] test: add unit test for list_jobs method of the training_client Signed-off-by: wei-chenglai --- .../training/api/training_client_test.py | 83 +++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/sdk/python/kubeflow/training/api/training_client_test.py b/sdk/python/kubeflow/training/api/training_client_test.py index 2e147c5e35..3b31cbc3c4 100644 --- a/sdk/python/kubeflow/training/api/training_client_test.py +++ b/sdk/python/kubeflow/training/api/training_client_test.py @@ -70,6 +70,25 @@ def get_namespaced_custom_object_response(*args, **kwargs): return mock_thread +def list_namespaced_custom_object_response(*args, **kwargs): + if args[2] == TIMEOUT: + raise multiprocessing.TimeoutError() + elif args[2] == RUNTIME: + raise RuntimeError() + + # Create a serialized Job + serialized_job = serialize_k8s_object(generate_job_with_status(create_job())) + + # Mock the response containing a list of jobs + mock_response = {"items": [serialized_job]} + + # Mock the thread and set it's return value to the mock response + mock_thread = Mock() + mock_thread.get.return_value = mock_response + + return mock_thread + + def list_namespaced_pod_response(*args, **kwargs): class MockResponse: def get(self, timeout): @@ -482,6 +501,51 @@ def __init__(self, kind) -> None: ), ] +test_data_list_jobs = [ + ( + "valid flow with default namespace and default timeout", + {}, + SUCCESS, + ), + ( + "valid flow with all parameters set", + { + "namespace": TEST_NAME, + "job_kind": constants.PYTORCHJOB_KIND, + "timeout": 120, + }, + SUCCESS, + ), + ( + "invalid flow with default namespace and a Job that doesn't exist", + {"job_kind": constants.TFJOB_KIND}, + RuntimeError, + ), + ( + "invalid flow with incorrect parameter", + {"test": "example"}, + TypeError, + ), + ( + "invalid flow with incorrect job_kind value", + {"job_kind": "FailJob"}, + ValueError, + ), + ( + "invalid flow with runtime error", + { + "namespace": RUNTIME, + "job_kind": constants.PYTORCHJOB_KIND, + }, + RuntimeError, + ), + ( + "invalid flow with timeout error", + {"namespace": TIMEOUT}, + TimeoutError, + ), +] + test_data_delete_job = [ ( @@ -854,6 +918,9 @@ def training_client(): get_namespaced_custom_object=Mock( side_effect=get_namespaced_custom_object_response ), + list_namespaced_custom_object=Mock( + side_effect=list_namespaced_custom_object_response + ), ), ), patch( "kubernetes.client.CoreV1Api", @@ -1109,3 +1176,19 @@ def test_is_job_succeeded(training_client, test_name, kwargs, expected_output): assert type(e) is expected_output print("test execution complete") + + +@pytest.mark.parametrize("test_name,kwargs,expected_output", test_data_list_jobs) +def test_list_jobs(training_client, test_name, kwargs, expected_output): + """ + test list_jobs function of training client + """ + print("Executing test: ", test_name) + + try: + training_client.list_jobs(**kwargs) + assert expected_output == SUCCESS + except Exception as e: + assert type(e) is expected_output + + print("test execution complete")