diff --git a/sdk/python/kubeflow/training/api/training_client_test.py b/sdk/python/kubeflow/training/api/training_client_test.py index ea8c495032..93f20534a0 100644 --- a/sdk/python/kubeflow/training/api/training_client_test.py +++ b/sdk/python/kubeflow/training/api/training_client_test.py @@ -1,4 +1,6 @@ import multiprocessing +import queue +from datetime import datetime, timedelta from typing import Optional from unittest.mock import Mock, patch @@ -41,6 +43,15 @@ SUCCEEDED = "Succeeded" INVALID = "invalid" +FAIL_LOGS = "fail_logs" +FAIL_EVENTS = "fail_events" +MULTI_PODS = "multi_pods" +PENDING_POD = "pending_pod" +NO_STATUS_POD = "no_status_pod" +QUEUE_TIMEOUT = "queue_timeout" +QUEUE_EMPTY = "queue_empty" +EVENT_CREATION_TIMESTAMP = datetime(2024, 1, 5, 22, 58, 20) + def conditional_error_handler(*args, **kwargs): if args[2] == TIMEOUT: @@ -55,9 +66,9 @@ def serialize_k8s_object(obj): def get_namespaced_custom_object_response(*args, **kwargs): - if args[2] == "timeout": + if args[2] == TIMEOUT: raise multiprocessing.TimeoutError() - elif args[2] == "runtime": + elif args[2] == RUNTIME: raise RuntimeError() # Create a serialized Job @@ -203,6 +214,124 @@ def __init__(self, kind) -> None: self.kind = kind +def generate_pod(status, name=DUMMY_POD_NAME, timestamp=None): + pod = Mock(metadata=Mock()) + pod.metadata.name = name + pod.metadata.creation_timestamp = timestamp + pod.status = status + return pod + + +def mock_get_job_pods(*args, **kwargs): + """Mock get_job_pods to return controlled pod objects""" + namespace = kwargs.get("namespace") + if namespace == f"pod {TIMEOUT}": + raise TimeoutError() + if namespace == f"pod {RUNTIME}": + raise RuntimeError() + if namespace == INVALID: + raise ValueError() + + # Handle different test scenarios + if namespace == MULTI_PODS: + return [generate_pod(Mock(phase=RUNNING), f"pod-{i}") for i in range(3)] + + # To find relevant events, the pod's creation time must precede the event's creation time + pod_creation_timestamp = EVENT_CREATION_TIMESTAMP - timedelta(seconds=1) + pod = generate_pod(None, timestamp=pod_creation_timestamp) + if namespace == PENDING_POD: + pod.status = Mock(phase=constants.POD_PHASE_PENDING) + elif namespace == NO_STATUS_POD: + pod.status = None + else: + pod.status = Mock(phase=RUNNING) + return [pod] + + +def mock_get_job(*args, **kwargs): + """Mock get_job_pods to return controlled pod objects""" + namespace = kwargs.get("namespace") + if namespace == f"job {TIMEOUT}": + raise TimeoutError() + if namespace == f"job {RUNTIME}": + raise RuntimeError() + + # Handle different test scenarios + job = Mock() + # To find relevant events, the job's creation time must precede the event's creation time + job.metadata = Mock( + creation_timestamp=EVENT_CREATION_TIMESTAMP - timedelta(seconds=1) + ) + return job + + +def mock_read_namespaced_pod_log(*args, **kwargs): + """Mock for reading pod logs""" + if kwargs.get("namespace") == FAIL_LOGS: + raise Exception("Failed to read logs") + return "test log content" + + +def mock_watch(self, *args, **kwargs): + namespace = kwargs.get("namespace") + if namespace == FAIL_LOGS: + raise Exception("Failed to read logs") + if namespace == QUEUE_TIMEOUT: + log_lines = [TIMEOUT] + elif namespace == QUEUE_EMPTY: + log_lines = [QUEUE_EMPTY] + else: + log_lines = ["line 1 of pod logs", "line 2 of pod logs", "line 3 of pod logs"] + return iter(log_lines) + + +def mock_get_log_queue_pool(log_streams): + mock_logs = [] + for stream in log_streams: + # Convert iterator to list to preserve values + log_lines = list(stream) + mock_queue = Mock() + # Use a list to maintain state between calls + remaining_logs = log_lines.copy() # Make a copy to avoid modifying original + + def get_next(timeout, logs=remaining_logs): + if logs: + log = logs.pop(0) + if log == TIMEOUT: + raise TimeoutError + if log == QUEUE_EMPTY: + raise queue.Empty + return log + return None + + mock_queue.get = Mock(side_effect=get_next) + mock_queue.put = Mock() + mock_logs.append(mock_queue) + return mock_logs + + +def mock_list_namespaced_event(*args, **kwargs): + """Mock for listing namespace events""" + + class MockEvent: + def __init__(self, kind, name): + self.involved_object = Mock(kind=kind) + self.involved_object.name = name + self.metadata = Mock(creation_timestamp=EVENT_CREATION_TIMESTAMP) + self.message = f"{kind} Event 1" + + class MockEventList: + def __init__(self): + self.items = [ + MockEvent(constants.POD_KIND, DUMMY_POD_NAME), + MockEvent(constants.PYTORCHJOB_KIND, TEST_NAME), + ] + + if kwargs.get("namespace") == FAIL_EVENTS: + raise Exception("Failed to read events") + return MockEventList() + + test_data_create_job = [ ( "invalid extra parameter", @@ -934,6 +1063,230 @@ def __init__(self, kind) -> None: ), ] +test_data_get_job_logs = [ + # Basic cases + ( + "valid flow with default parameters", + { + "name": TEST_NAME, + }, + {DUMMY_POD_NAME: "test log content"}, + {}, + SUCCESS, + ), + ( + "pod with pending status", + { + "name": TEST_NAME, + "namespace": PENDING_POD, + }, + {}, # No logs expected for pending pods + {}, + SUCCESS, + ), + ( + "pod with pending status and follow", + { + "name": TEST_NAME, + "namespace": PENDING_POD, + "follow": True, + }, + {}, # No logs expected for pending pods + {}, + SUCCESS, + ), + ( + "pod with no status", + { + "name": TEST_NAME, + "namespace": NO_STATUS_POD, + }, + {}, # No logs expected + {}, + SUCCESS, + ), + ( + "pod with no status and follow", + { + "name": TEST_NAME, + "namespace": NO_STATUS_POD, + "follow": True, + }, + {}, # No logs expected + {}, + SUCCESS, + ), + ( + "valid flow with logs and verbose", + { + "name": TEST_NAME, + "namespace": TEST_NAME, + "verbose": True, + }, + {DUMMY_POD_NAME: "test log content"}, + { + f"{constants.PYTORCHJOB_KIND.lower()}/{TEST_NAME}": [ + f"{EVENT_CREATION_TIMESTAMP.strftime('%Y-%m-%d %H:%M:%S')} PyTorchJob Event 1" + ], + f"{constants.POD_KIND.lower()}/{DUMMY_POD_NAME}": [ + f"{EVENT_CREATION_TIMESTAMP.strftime('%Y-%m-%d %H:%M:%S')} Pod Event 1" + ], + }, + SUCCESS, + ), + ( + "valid flow with worker logs", + { + "name": TEST_NAME, + "namespace": TEST_NAME, + "is_master": False, + "replica_type": constants.REPLICA_TYPE_WORKER.lower(), + "replica_index": 0, + }, + {DUMMY_POD_NAME: "test log content"}, + {}, + SUCCESS, + ), + # Streaming cases + ( + "valid flow with follow logs", + { + "name": TEST_NAME, + "follow": True, + }, + { + DUMMY_POD_NAME: ( + "line 1 of pod logs" "line 2 of pod logs" "line 3 of pod logs" + ) + }, + {}, + SUCCESS, + ), + ( + "valid flow with follow logs and multiple pods", + { + "name": TEST_NAME, + "namespace": MULTI_PODS, + "follow": True, + }, + { + "pod-0": ("line 1 of pod logs" "line 2 of pod logs" "line 3 of pod logs"), + "pod-1": ("line 1 of pod logs" "line 2 of pod logs" "line 3 of pod logs"), + "pod-2": ("line 1 of pod logs" "line 2 of pod logs" "line 3 of pod logs"), + }, + {}, + SUCCESS, + ), + ( + "follow logs with queue empty", + { + "name": TEST_NAME, + "namespace": QUEUE_EMPTY, + "follow": True, + }, + {}, + {}, + SUCCESS, + ), + # Error cases + ( + "invalid replica type", + { + "name": TEST_NAME, + "namespace": INVALID, + "replica_type": "invalid_replica", + }, + None, + None, + ValueError, + ), + ( + "timeout error when getting pods", + { + "name": TEST_NAME, + "namespace": f"pod {TIMEOUT}", + }, + None, + None, + TimeoutError, + ), + ( + "runtime error when getting pods", + { + "name": TEST_NAME, + "namespace": f"pod {RUNTIME}", + }, + None, + None, + RuntimeError, + ), + ( + "exception when reading logs with follow", + { + "name": TEST_NAME, + "namespace": FAIL_LOGS, + "follow": True, + }, + None, + None, + Exception, + ), + ( + "runtime error when reading logs", + { + "name": TEST_NAME, + "namespace": FAIL_LOGS, + }, + None, + None, + RuntimeError, + ), + ( + "exception when reading events", + { + "name": TEST_NAME, + "namespace": FAIL_EVENTS, + "verbose": True, + }, + None, + None, + Exception, + ), + ( + "timeout error when getting job", + { + "name": TEST_NAME, + "namespace": f"job {TIMEOUT}", + "verbose": True, + }, + None, + None, + TimeoutError, + ), + ( + "runtime error when getting job", + { + "name": TEST_NAME, + "namespace": f"job {RUNTIME}", + "verbose": True, + }, + None, + None, + RuntimeError, + ), + ( + "follow logs with queue timeout", + { + "name": TEST_NAME, + "namespace": QUEUE_TIMEOUT, + "follow": True, + }, + {}, # Empty logs due to timeout + {}, + TimeoutError, + ), +] + @pytest.fixture def training_client(): @@ -962,6 +1315,30 @@ def training_client(): yield client +@pytest.fixture +def training_client_for_job_logs(): + """Fixture providing a mocked training client""" + with patch( + "kubernetes.client.CoreV1Api", + return_value=Mock( + read_namespaced_pod_log=Mock(side_effect=mock_read_namespaced_pod_log), + list_namespaced_event=Mock(side_effect=mock_list_namespaced_event), + ), + ), patch("kubernetes.config.load_kube_config", return_value=Mock()), patch( + "kubernetes.watch.Watch", + return_value=Mock( + stream=Mock(side_effect=mock_watch), + ), + ), patch( + "kubeflow.training.utils.utils.get_log_queue_pool", + side_effect=mock_get_log_queue_pool, + ): + client = TrainingClient(job_kind=constants.PYTORCHJOB_KIND) + client.get_job_pods = Mock(side_effect=mock_get_job_pods) + client.get_job = Mock(side_effect=mock_get_job) + yield client + + @pytest.mark.parametrize("test_name,kwargs,expected_output", test_data_create_job) def test_create_job(training_client, test_name, kwargs, expected_output): """ @@ -1227,3 +1604,56 @@ def test_list_jobs( assert type(e) is expected_status print("test execution complete") + + +@pytest.mark.parametrize( + "test_name,kwargs,expected_logs,expected_events,expected_output", + test_data_get_job_logs, +) +def test_get_job_logs( + training_client_for_job_logs, + test_name, + kwargs, + expected_logs, + expected_events, + expected_output, +): + """ + test get_job_logs function of training client + """ + print("Executing test:", test_name) + + try: + logs_dict, events_dict = training_client_for_job_logs.get_job_logs(**kwargs) + + assert expected_output == SUCCESS + assert logs_dict == expected_logs + + if kwargs.get("verbose", False): + assert events_dict == expected_events + else: + assert events_dict == {} + + # Verify API calls + training_client_for_job_logs.get_job_pods.assert_called_with( + name=kwargs["name"], + namespace=kwargs.get("namespace", constants.DEFAULT_NAMESPACE), + is_master=kwargs.get("is_master", True), + replica_type=kwargs.get("replica_type"), + replica_index=kwargs.get("replica_index"), + timeout=kwargs.get("timeout", constants.DEFAULT_TIMEOUT), + ) + + if kwargs.get("verbose", False): + training_client_for_job_logs.get_job.assert_called_with( + name=kwargs["name"], + namespace=kwargs.get("namespace", constants.DEFAULT_NAMESPACE), + ) + training_client_for_job_logs.core_api.list_namespaced_event.assert_called_with( + namespace=kwargs.get("namespace", constants.DEFAULT_NAMESPACE) + ) + + except Exception as e: + assert type(e) is expected_output + + print("test execution complete")