Skip to content

Commit

Permalink
[SDK] Unit tests for TrainingClient APIs - get_job_pod_names and upda…
Browse files Browse the repository at this point in the history
…te_job (#2192)

* [SDK] Add more unit tests for TrainingClient APIs - get_job_pod_names and update_job

Signed-off-by: yelias <[email protected]>

* fix isort

Signed-off-by: yelias <[email protected]>

* Remove dict_to_object and SimpleNamespace, and add option of return pod obj

Signed-off-by: yelias <[email protected]>

* Fix typo in docstring

Signed-off-by: yelias <[email protected]>

* Change str to constants

Signed-off-by: yelias <[email protected]>

---------

Signed-off-by: yelias <[email protected]>
Co-authored-by: yelias <[email protected]>
  • Loading branch information
YosiElias and yelias authored Aug 22, 2024
1 parent 3f9b0a4 commit 6900714
Showing 1 changed file with 136 additions and 16 deletions.
152 changes: 136 additions & 16 deletions sdk/python/kubeflow/training/api/training_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,51 @@
from kubernetes.client import V1ResourceRequirements
import pytest

LIST_RESPONSE = [{"metadata": {"name": "Dummy V1PodList"}}]
TEST_NAME = "test"
TIMEOUT = "timeout"
RUNTIME = "runtime"
MOCK_POD_OBJ = "mock_pod_obj"
NO_PODS = "no_pods"
DUMMY_POD_NAME = "Dummy V1PodList"
LIST_RESPONSE = [
{"metadata": {"name": DUMMY_POD_NAME}},
]


def create_namespaced_custom_object_response(*args, **kwargs):
if args[2] == "timeout":
def conditional_error_handler(*args, **kwargs):
if args[2] == TIMEOUT:
raise multiprocessing.TimeoutError()
elif args[2] == "runtime":
elif args[2] == RUNTIME:
raise RuntimeError()


def list_namespaced_pod_response(*args, **kwargs):
class MockResponse:
def get(self, timeout):
# Simulate a response from the Kubernetes API, and pass timeout for verification
LIST_RESPONSE[0]["timeout"] = timeout
if args[0] == "timeout":
"""
Simulates Kubernetes API response for listing namespaced pods,
and pass timeout for verification
:return:
- If `args[0] == "timeout"`, raises `TimeoutError`.
- If `args[0] == "runtime"`, raises `Exception`.
- If `args[0] == "mock_pod_obj"`, returns a mock pod object
with `metadata.name = "Dummy V1PodList"`.
- If `args[0] == "no_pods"`, returns an empty list of pods.
- Otherwise, returns a default list of dicts representing pods,
with `timeout` included, for testing.
"""
LIST_RESPONSE[0][TIMEOUT] = timeout
if args[0] == TIMEOUT:
raise multiprocessing.TimeoutError()
if args[0] == "runtime":
if args[0] == RUNTIME:
raise Exception()
if args[0] == MOCK_POD_OBJ:
pod_obj = Mock(metadata=Mock())
pod_obj.metadata.name = DUMMY_POD_NAME
return Mock(items=[pod_obj])
if args[0] == NO_PODS:
return Mock(items=[])
return Mock(items=LIST_RESPONSE)

return MockResponse()
Expand Down Expand Up @@ -156,12 +181,12 @@ def __init__(self, kind) -> None:
),
(
"create_namespaced_custom_object timeout error",
{"job": create_job(), "namespace": "timeout"},
{"job": create_job(), "namespace": TIMEOUT},
TimeoutError,
),
(
"create_namespaced_custom_object runtime error",
{"job": create_job(), "namespace": "runtime"},
{"job": create_job(), "namespace": RUNTIME},
RuntimeError,
),
(
Expand Down Expand Up @@ -235,7 +260,7 @@ def __init__(self, kind) -> None:
"invalid flow with TimeoutError",
{
"name": TEST_NAME,
"namespace": "timeout",
"namespace": TIMEOUT,
},
"Label not relevant",
TimeoutError,
Expand All @@ -244,22 +269,80 @@ def __init__(self, kind) -> None:
"invalid flow with RuntimeError",
{
"name": TEST_NAME,
"namespace": "runtime",
"namespace": RUNTIME,
},
"Label not relevant",
RuntimeError,
),
]


test_data_get_job_pod_names = [
(
"valid flow",
{
"name": TEST_NAME,
"namespace": MOCK_POD_OBJ,
},
[DUMMY_POD_NAME],
),
(
"valid flow with no pods available",
{
"name": TEST_NAME,
"namespace": NO_PODS,
},
[],
),
]


test_data_update_job = [
(
"valid flow",
{
"name": TEST_NAME,
"job": create_job(),
},
"No output",
),
(
"invalid job_kind",
{
"name": TEST_NAME,
"job": create_job(),
"job_kind": "invalid_job_kind",
},
ValueError,
),
(
"invalid flow with TimeoutError",
{
"name": TEST_NAME,
"namespace": TIMEOUT,
"job": create_job(),
},
TimeoutError,
),
(
"invalid flow with RuntimeError",
{
"name": TEST_NAME,
"namespace": RUNTIME,
"job": create_job(),
},
RuntimeError,
),
]


@pytest.fixture
def training_client():
with patch(
"kubernetes.client.CustomObjectsApi",
return_value=Mock(
create_namespaced_custom_object=Mock(
side_effect=create_namespaced_custom_object_response
)
create_namespaced_custom_object=Mock(side_effect=conditional_error_handler),
patch_namespaced_custom_object=Mock(side_effect=conditional_error_handler),
),
), patch(
"kubernetes.client.CoreV1Api",
Expand Down Expand Up @@ -306,8 +389,45 @@ def test_get_job_pods(
label_selector=expected_label_selector,
async_req=True,
)
assert out[0].pop("timeout") == kwargs.get("timeout", constants.DEFAULT_TIMEOUT)
assert out[0].pop(TIMEOUT) == kwargs.get(TIMEOUT, constants.DEFAULT_TIMEOUT)
assert out == expected_output
except Exception as e:
assert type(e) is expected_output
print("test execution complete")


@pytest.mark.parametrize(
"test_name,kwargs,expected_output",
test_data_get_job_pod_names,
)
def test_get_job_pod_names(training_client, test_name, kwargs, expected_output):
"""
test get_job_pod_names function of training client
"""
print("Executing test:", test_name)
out = training_client.get_job_pod_names(**kwargs)
assert out == expected_output
print("test execution complete")


@pytest.mark.parametrize("test_name,kwargs,expected_output", test_data_update_job)
def test_update_job(training_client, test_name, kwargs, expected_output):
"""
test update_job function of training client
"""
print("Executing test:", test_name)
try:
training_client.update_job(**kwargs)
training_client.custom_api.patch_namespaced_custom_object.assert_called_with(
constants.GROUP,
constants.VERSION,
kwargs.get("namespace", constants.DEFAULT_NAMESPACE),
constants.JOB_PARAMETERS[kwargs.get("job_kind", training_client.job_kind)][
"plural"
],
kwargs.get("name"),
kwargs.get("job"),
)
except Exception as e:
assert type(e) is expected_output
print("test execution complete")

0 comments on commit 6900714

Please sign in to comment.