Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix event action consuming. #1367

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dlrover/python/diagnosis/common/diagnosis_action.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ def __init__(
):
super().__init__(
DiagnosisActionType.EVENT,
instance=DiagnosisConstant.MASTER_INSTANCE,
timestamp=timestamp,
expired_time_period=expired_time_period,
)
Expand Down Expand Up @@ -203,7 +204,7 @@ def add_action(self, new_action: DiagnosisAction):
try:
if new_action.is_needed():
ins_actions.put(new_action, timeout=3)
logger.info(f"New diagnosis action {new_action}")
logger.info(f"New diagnosis action: {new_action}")
except queue.Full:
logger.warning(
f"Diagnosis actions for {instance} is full, "
Expand Down
4 changes: 2 additions & 2 deletions dlrover/python/diagnosis/inferencechain/inference_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
self.operators = operators

def infer(self) -> List[Inference]:
logger.info(f"Infer {self.inferences}")
logger.debug(f"Infer {self.inferences}")
inferences = self.inferences
while True:
has_new_inference = False
Expand Down Expand Up @@ -66,5 +66,5 @@ def get_operator(self, inference: Inference) -> InferenceOperator:
for operator in self.operators:
if operator.is_compatible(inference):
return operator
logger.info(f"No operator for inference: {inference.__dict__}")
logger.debug(f"No operator for inference: {inference.__dict__}")
return None # type: ignore
5 changes: 4 additions & 1 deletion dlrover/python/master/elastic_training/net_topology.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from dataclasses import dataclass
from typing import Dict, List, Tuple

from dlrover.python.common.serialize import JsonSerializable


@dataclass
class NodeTopologyMeta(object):
class NodeTopologyMeta(JsonSerializable):
node_id: int = 0
node_rank: int = 0
process_num: int = 0
Expand Down
20 changes: 14 additions & 6 deletions dlrover/python/master/node/dist_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,10 @@
from dlrover.python.common.grpc import ParallelConfig
from dlrover.python.common.log import default_logger as logger
from dlrover.python.common.node import Node, NodeGroupResource
from dlrover.python.diagnosis.common.constants import (
DiagnosisActionType,
DiagnosisConstant,
)
from dlrover.python.diagnosis.common.constants import DiagnosisConstant
from dlrover.python.diagnosis.common.diagnosis_action import (
DiagnosisAction,
EventAction,
NoAction,
)
from dlrover.python.master.monitor.error_monitor import K8sJobErrorMonitor
Expand Down Expand Up @@ -670,10 +668,20 @@ def _get_pod_unique_labels(self, node: Node):
}

def _process_diagnosis_action(self, action: DiagnosisAction):
if not action or action.action_type == DiagnosisActionType.NONE:
if not action or isinstance(action, NoAction):
return

# TODO
if isinstance(action, EventAction):
self._report_event(
action.event_type,
action.event_instance,
action.event_action,
action.event_msg,
action.event_labels,
)
else:
# TODO: deal with other action
pass

def _process_event(self, event: NodeEvent):
node_type = event.node.type
Expand Down
6 changes: 3 additions & 3 deletions dlrover/python/tests/test_diagnosis.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_action_basic(self):
)
self.assertEqual(event_action.action_type, DiagnosisActionType.EVENT)
self.assertEqual(
event_action._instance, DiagnosisConstant.LOCAL_INSTANCE
event_action._instance, DiagnosisConstant.MASTER_INSTANCE
)
self.assertEqual(event_action.event_type, "info")
self.assertEqual(event_action.event_instance, "job")
Expand Down Expand Up @@ -101,15 +101,15 @@ def test_action_queue(self):
)
self.assertEqual(
action_queue.next_action(
instance=DiagnosisConstant.LOCAL_INSTANCE
instance=DiagnosisConstant.MASTER_INSTANCE
).action_type,
DiagnosisActionType.EVENT,
)
self.assertEqual(
action_queue.next_action(
instance=DiagnosisConstant.LOCAL_INSTANCE
).action_type,
DiagnosisActionType.EVENT,
DiagnosisActionType.NONE,
)
self.assertEqual(
action_queue.next_action(instance=1).action_type,
Expand Down
19 changes: 19 additions & 0 deletions dlrover/python/tests/test_job_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
ParallelConfig,
)
from dlrover.python.common.node import NodeGroupResource, NodeResource
from dlrover.python.diagnosis.common.diagnosis_action import (
EventAction,
NoAction,
)
from dlrover.python.master.dist_master import DistributedJobMaster
from dlrover.python.master.monitor.error_monitor import SimpleErrorMonitor
from dlrover.python.master.monitor.speed_monitor import SpeedMonitor
Expand Down Expand Up @@ -869,6 +873,21 @@ def test_get_pending_timeout(self):
# reset
_dlrover_context.seconds_to_wait_pending_pod = 900

@patch.object(DistributedJobManager, "_report_event")
def test_process_diagnosis_action(self, mock_method):
params = MockK8sPSJobArgs()
params.initilize()
manager = create_job_manager(params, SpeedMonitor())

manager._process_diagnosis_action(None)
self.assertEqual(mock_method.call_count, 0)

manager._process_diagnosis_action(NoAction)
self.assertEqual(mock_method.call_count, 0)

manager._process_diagnosis_action(EventAction())
self.assertEqual(mock_method.call_count, 1)


class LocalJobManagerTest(unittest.TestCase):
def test_local_job_manager(self):
Expand Down