Skip to content

Commit

Permalink
[DCP] Adds better handling in logging of specific kwargs (pytorch#123658
Browse files Browse the repository at this point in the history
)

Adds additional signpost integrations to DCP Logger, to add support for MLU and metric collection.

Differential Revision: [D55803461](https://our.internmc.facebook.com/intern/diff/D55803461/)

Pull Request resolved: pytorch#123658
Approved by: https://github.com/fegin
  • Loading branch information
LucasLLC authored and pytorchmergebot committed Apr 11, 2024
1 parent b7fac76 commit 13070e2
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 7 deletions.
7 changes: 4 additions & 3 deletions torch/distributed/c10d_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,16 @@ def _get_logging_handler(destination: str = _DEFAULT_DESTINATION) -> Tuple[loggi

def _get_msg_dict(func_name, *args, **kwargs) -> Dict[str, Any]:
if dist.is_initialized():
group = kwargs.get("group") or kwargs.get("process_group")
msg_dict = {
"func_name": f"{func_name}",
"args": f"{args}, {kwargs}",
"pg_name": f"{dist._get_process_group_name(kwargs.get('pg'))}", # type: ignore[arg-type]
"backend": f"{dist.get_backend(kwargs.get('group') or kwargs.get('process_group'))}",
"backend": f"{dist.get_backend(group)}",
"world_size": f"{dist.get_world_size()}",
"group_size": f"{dist.get_world_size(kwargs.get('group'))}",
"group_size": f"{dist.get_world_size(group)}",
"global_rank": f"{dist.get_rank()}",
"local_rank": f"{dist.get_rank(kwargs.get('group'))}",
"local_rank": f"{dist.get_rank(group)}",
}
if msg_dict["backend"] == "nccl":
nccl_version = torch.cuda.nccl.version()
Expand Down
9 changes: 5 additions & 4 deletions torch/distributed/checkpoint/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ def _msg_dict_from_dcp_method_args(*args, **kwargs) -> Dict[str, Any]:
# checkpoint ID can be passed in through the serializer or through the checkpoint id directly
storage_writer = kwargs.get("storage_writer", None)
storage_reader = kwargs.get("storage_reader", None)
if kwargs.get("checkpoint_id") is None and (
serializer := storage_writer or storage_reader
):
msg_dict["checkpoint_id"] = getattr(serializer, "checkpoint_id", None)
checkpoint_id = kwargs.get("checkpoint_id", None)
if not checkpoint_id and (serializer := storage_writer or storage_reader):
checkpoint_id = getattr(serializer, "checkpoint_id", None)

msg_dict["checkpoint_id"] = str(checkpoint_id)

return msg_dict

Expand Down

0 comments on commit 13070e2

Please sign in to comment.