Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable CUDA graphs in DDP (ASR). Improve toggle messages #11087

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,39 +655,45 @@ def __init__(
else:
self._greedy_decode = self._greedy_decode_masked

def disable_cuda_graphs(self):
def disable_cuda_graphs(self) -> bool:
"""Disable CUDA graphs (e.g., for decoding in training)"""
if not self.use_cuda_graph_decoder:
# CUDA graphs not allowed, nothing to do
return
return False

if not self.decoder.blank_as_pad:
# blank as pad uses decoding without CUDA graphs
return
return False

if self.loop_labels:
# Label-Looping implementation
self._decoding_computer.disable_cuda_graphs()
return self._decoding_computer.disable_cuda_graphs()
else:
self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames
if self._greedy_decode != self._greedy_decode_blank_as_pad_loop_frames:
self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames
return True
return False

def maybe_enable_cuda_graphs(self):
def maybe_enable_cuda_graphs(self) -> bool:
"""Enable CUDA graphs (if allowed)"""
if not self.use_cuda_graph_decoder:
# CUDA graphs not allowed, nothing to do
return
return False

if not self.decoder.blank_as_pad:
# blank as pad uses decoding without CUDA graphs
return
return False

if self.loop_labels:
# Label-Looping implementation
self._decoding_computer.maybe_enable_cuda_graphs()
return self._decoding_computer.maybe_enable_cuda_graphs()
else:
from nemo.collections.asr.parts.submodules.cuda_graph_rnnt_greedy_decoding import RNNTGreedyDecodeCudaGraph

self._greedy_decode = RNNTGreedyDecodeCudaGraph(self.max_symbols, self)
if not isinstance(self._greedy_decode, RNNTGreedyDecodeCudaGraph):
self._greedy_decode = RNNTGreedyDecodeCudaGraph(self.max_symbols, self)
return True
return False

@typecheck()
def forward(
Expand Down Expand Up @@ -2832,12 +2838,14 @@ def _greedy_decode_blank_as_pad_loop_labels(
hyp.dec_state = state
return hyps

def disable_cuda_graphs(self):
def disable_cuda_graphs(self) -> bool:
"""Disable CUDA graphs (e.g., for decoding in training)"""
if self._decoding_computer is not None:
self._decoding_computer.disable_cuda_graphs()
return self._decoding_computer.disable_cuda_graphs()
return False # nothing changed

def maybe_enable_cuda_graphs(self):
def maybe_enable_cuda_graphs(self) -> bool:
"""Enable CUDA graphs (if allowed)"""
if self._decoding_computer is not None:
self._decoding_computer.maybe_enable_cuda_graphs()
return self._decoding_computer.maybe_enable_cuda_graphs()
return False # nothing changed
15 changes: 11 additions & 4 deletions nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,11 @@ def force_cuda_graphs_mode(self, mode: Optional[Union[str, CudaGraphsMode]]):
self.cuda_graphs_mode = self.CudaGraphsMode(mode) if mode is not None else None
self.state = None

def maybe_enable_cuda_graphs(self):
def maybe_enable_cuda_graphs(self) -> bool:
"""Enable CUDA graphs if conditions met"""
if self.cuda_graphs_mode is not None:
# CUDA graphs are already enabled
return
return False # nothing changed

if not self.allow_cuda_graphs:
self.cuda_graphs_mode = None
Expand All @@ -274,14 +274,16 @@ def maybe_enable_cuda_graphs(self):
)
self.cuda_graphs_mode = self.CudaGraphsMode.NO_WHILE_LOOPS
self.reset_cuda_graphs_state()
return self.cuda_graphs_mode is not None

def disable_cuda_graphs(self):
def disable_cuda_graphs(self) -> bool:
"""Disable CUDA graphs, can be used to disable graphs temporary, e.g., in training process"""
if self.cuda_graphs_mode is None:
# nothing to disable
return
return False
self.cuda_graphs_mode = None
self.reset_cuda_graphs_state()
return True

def reset_cuda_graphs_state(self):
"""Reset state to release memory (for CUDA graphs implementations)"""
Expand Down Expand Up @@ -895,6 +897,11 @@ def __call__(
x: torch.Tensor,
out_len: torch.Tensor,
) -> Tuple[rnnt_utils.BatchedHyps, Optional[rnnt_utils.BatchedAlignments], Any]:
if self.cuda_graphs_mode is not None and torch.distributed.is_initialized():
# TODO(vbataev): fix torch.distributed + CUDA graphs, remove this switch
logging.warning(f"In distributed mode, CUDA graphs are not supported yet. Switching off CUDA graphs.")
self.disable_cuda_graphs()

if self.cuda_graphs_mode is not None and x.device.type == "cuda":
return self.loop_labels_cuda_graphs(encoder_output=x, encoder_output_length=out_len)

Expand Down
15 changes: 11 additions & 4 deletions nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,11 @@ def __init__(
self.cuda_graphs_mode = None
self.maybe_enable_cuda_graphs()

def maybe_enable_cuda_graphs(self):
def maybe_enable_cuda_graphs(self) -> bool:
"""Enable CUDA graphs if conditions met"""
if self.cuda_graphs_mode is not None:
# CUDA graphs are enabled
return
return False # nothing changed

if not self.allow_cuda_graphs:
self.cuda_graphs_mode = None
Expand All @@ -285,14 +285,16 @@ def maybe_enable_cuda_graphs(self):
)
self.cuda_graphs_mode = self.CudaGraphsMode.NO_WHILE_LOOPS
self.reset_cuda_graphs_state()
return self.cuda_graphs_mode is not None

def disable_cuda_graphs(self):
def disable_cuda_graphs(self) -> bool:
"""Disable CUDA graphs, can be used to disable graphs temporary, e.g., in training process"""
if self.cuda_graphs_mode is None:
# nothing to disable
return
return False
self.cuda_graphs_mode = None
self.reset_cuda_graphs_state()
return True

def reset_cuda_graphs_state(self):
"""Reset state to release memory (for CUDA graphs implementations)"""
Expand Down Expand Up @@ -1013,6 +1015,11 @@ def __call__(
x: torch.Tensor,
out_len: torch.Tensor,
) -> Tuple[rnnt_utils.BatchedHyps, Optional[rnnt_utils.BatchedAlignments], Any]:
if self.cuda_graphs_mode is not None and torch.distributed.is_initialized():
# TODO(vbataev): fix torch.distributed + CUDA graphs, remove this switch
logging.warning(f"In distributed mode, CUDA graphs are not supported yet. Switching off CUDA graphs.")
self.disable_cuda_graphs()

if self.cuda_graphs_mode is not None and x.device.type == "cuda":
return self.loop_labels_cuda_graphs(encoder_output=x, encoder_output_length=out_len)

Expand Down
16 changes: 8 additions & 8 deletions nemo/collections/common/parts/optional_cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def disable_cuda_graphs_recursive(cls, module: nn.Module, attribute_path: Option
continue # loop over modules, no attribute

if isinstance(object_to_check, cls):
object_to_check.disable_cuda_graphs()
logging.info(f"Disabled CUDA graphs for module {type(submodule)}" + ".".join([name] + attributes))
if object_to_check.disable_cuda_graphs():
logging.info(f"Disabled CUDA graphs for module {type(submodule)}" + ".".join([name] + attributes))

@classmethod
def enable_cuda_graphs_recursive(cls, module: nn.Module, attribute_path: Optional[str] = None):
Expand All @@ -75,15 +75,15 @@ def enable_cuda_graphs_recursive(cls, module: nn.Module, attribute_path: Optiona
continue # loop over modules, no attribute

if isinstance(object_to_check, cls):
object_to_check.maybe_enable_cuda_graphs()
logging.info(f"Enabled CUDA graphs for module {type(submodule)}" + ".".join([name] + attributes))
if object_to_check.maybe_enable_cuda_graphs():
logging.info(f"Enabled CUDA graphs for module {type(submodule)}" + ".".join([name] + attributes))

@abc.abstractmethod
def disable_cuda_graphs(self):
"""Disable (maybe temporary) CUDA graphs"""
def disable_cuda_graphs(self) -> bool:
"""Disable (maybe temporary) CUDA graphs. Return True if CUDA graphs status changed enabled->disabled"""
raise NotImplementedError

@abc.abstractmethod
def maybe_enable_cuda_graphs(self):
"""Enable CUDA graphs if all conditions met"""
def maybe_enable_cuda_graphs(self) -> bool:
"""Enable CUDA graphs if all conditions met. Return True if CUDA graphs status changed disabled->enabled"""
raise NotImplementedError
6 changes: 4 additions & 2 deletions tests/collections/common/test_optional_cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ def __init__(self):
super().__init__()
self.cuda_graphs_used = True

def disable_cuda_graphs(self):
def disable_cuda_graphs(self) -> bool:
self.cuda_graphs_used = False
return True

def maybe_enable_cuda_graphs(self):
def maybe_enable_cuda_graphs(self) -> bool:
self.cuda_graphs_used = True
return True


class MockModuleWithCudaGraphs(MockClassWithCudaGraphs, nn.Module):
Expand Down
Loading