Skip to content

Commit

Permalink
[SDK] test: add unit test for list_jobs method of the training_client
Browse files Browse the repository at this point in the history
Signed-off-by: wei-chenglai <[email protected]>
  • Loading branch information
seanlaii committed Oct 2, 2024
1 parent 12d09d0 commit 30f0d2e
Showing 1 changed file with 88 additions and 0 deletions.
88 changes: 88 additions & 0 deletions sdk/python/kubeflow/training/api/training_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,25 @@ def get_namespaced_custom_object_response(*args, **kwargs):
return mock_thread


def list_namespaced_custom_object_response(*args, **kwargs):
if args[2] == TIMEOUT:
raise multiprocessing.TimeoutError()
elif args[2] == RUNTIME:
raise RuntimeError()

# Create a serialized Job
serialized_job = serialize_k8s_object(generate_job_with_status(create_job()))

# Mock the response containing a list of jobs
mock_response = {"items": [serialized_job]}

# Mock the thread and set it's return value to the mock response
mock_thread = Mock()
mock_thread.get.return_value = mock_response

return mock_thread


def list_namespaced_pod_response(*args, **kwargs):
class MockResponse:
def get(self, timeout):
Expand Down Expand Up @@ -482,6 +501,56 @@ def __init__(self, kind) -> None:
),
]

test_data_list_jobs = [
(
"valid flow with default namespace and default timeout",
{},
SUCCESS,
),
(
"valid flow with all parameters set",
{
"namespace": TEST_NAME,
"job_kind": constants.PYTORCHJOB_KIND,
"timeout": 120,
},
SUCCESS,
),
(
"invalid flow with default namespace and a Job that doesn't exist",
{"job_kind": constants.TFJOB_KIND},
RuntimeError,
),
(
"invalid flow with incorrect parameter",
{"test": "example"},
TypeError,
),
(
"invalid flow with incorrect job_kind value",
{"job_kind": "FailJob"},
ValueError,
),
(
"runtime error case",
{
"namespace": RUNTIME,
"job_kind": constants.PYTORCHJOB_KIND,
},
RuntimeError,
),
(
"invalid flow with timeout error",
{"namespace": TIMEOUT},
TimeoutError,
),
(
"invalid flow with runtime error",
{"namespace": RUNTIME},
RuntimeError,
),
]


test_data_delete_job = [
(
Expand Down Expand Up @@ -854,6 +923,9 @@ def training_client():
get_namespaced_custom_object=Mock(
side_effect=get_namespaced_custom_object_response
),
list_namespaced_custom_object=Mock(
side_effect=list_namespaced_custom_object_response
),
),
), patch(
"kubernetes.client.CoreV1Api",
Expand Down Expand Up @@ -1109,3 +1181,19 @@ def test_is_job_succeeded(training_client, test_name, kwargs, expected_output):
assert type(e) is expected_output

print("test execution complete")


@pytest.mark.parametrize("test_name,kwargs,expected_output", test_data_list_jobs)
def test_list_jobs(training_client, test_name, kwargs, expected_output):
"""
test list_jobs function of training client
"""
print("Executing test: ", test_name)

try:
training_client.list_jobs(**kwargs)
assert expected_output == SUCCESS
except Exception as e:
assert type(e) is expected_output

print("test execution complete")

0 comments on commit 30f0d2e

Please sign in to comment.