diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index f9cf368fe405..383efceb20b9 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -655,39 +655,45 @@ def __init__( else: self._greedy_decode = self._greedy_decode_masked - def disable_cuda_graphs(self): + def disable_cuda_graphs(self) -> bool: """Disable CUDA graphs (e.g., for decoding in training)""" if not self.use_cuda_graph_decoder: # CUDA graphs not allowed, nothing to do - return + return False if not self.decoder.blank_as_pad: # blank as pad uses decoding without CUDA graphs - return + return False if self.loop_labels: # Label-Looping implementation - self._decoding_computer.disable_cuda_graphs() + return self._decoding_computer.disable_cuda_graphs() else: - self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames + if self._greedy_decode != self._greedy_decode_blank_as_pad_loop_frames: + self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames + return True + return False - def maybe_enable_cuda_graphs(self): + def maybe_enable_cuda_graphs(self) -> bool: """Enable CUDA graphs (if allowed)""" if not self.use_cuda_graph_decoder: # CUDA graphs not allowed, nothing to do - return + return False if not self.decoder.blank_as_pad: # blank as pad uses decoding without CUDA graphs - return + return False if self.loop_labels: # Label-Looping implementation - self._decoding_computer.maybe_enable_cuda_graphs() + return self._decoding_computer.maybe_enable_cuda_graphs() else: from nemo.collections.asr.parts.submodules.cuda_graph_rnnt_greedy_decoding import RNNTGreedyDecodeCudaGraph - self._greedy_decode = RNNTGreedyDecodeCudaGraph(self.max_symbols, self) + if not isinstance(self._greedy_decode, RNNTGreedyDecodeCudaGraph): + self._greedy_decode = RNNTGreedyDecodeCudaGraph(self.max_symbols, self) + return True + return False @typecheck() def forward( @@ -2832,12 +2838,14 @@ def _greedy_decode_blank_as_pad_loop_labels( hyp.dec_state = state return hyps - def disable_cuda_graphs(self): + def disable_cuda_graphs(self) -> bool: """Disable CUDA graphs (e.g., for decoding in training)""" if self._decoding_computer is not None: - self._decoding_computer.disable_cuda_graphs() + return self._decoding_computer.disable_cuda_graphs() + return False # nothing changed - def maybe_enable_cuda_graphs(self): + def maybe_enable_cuda_graphs(self) -> bool: """Enable CUDA graphs (if allowed)""" if self._decoding_computer is not None: - self._decoding_computer.maybe_enable_cuda_graphs() + return self._decoding_computer.maybe_enable_cuda_graphs() + return False # nothing changed diff --git a/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py b/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py index 13bb0b471ed2..aee269b7277a 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py +++ b/nemo/collections/asr/parts/submodules/rnnt_loop_labels_computer.py @@ -248,11 +248,11 @@ def force_cuda_graphs_mode(self, mode: Optional[Union[str, CudaGraphsMode]]): self.cuda_graphs_mode = self.CudaGraphsMode(mode) if mode is not None else None self.state = None - def maybe_enable_cuda_graphs(self): + def maybe_enable_cuda_graphs(self) -> bool: """Enable CUDA graphs if conditions met""" if self.cuda_graphs_mode is not None: # CUDA graphs are already enabled - return + return False # nothing changed if not self.allow_cuda_graphs: self.cuda_graphs_mode = None @@ -274,14 +274,16 @@ def maybe_enable_cuda_graphs(self): ) self.cuda_graphs_mode = self.CudaGraphsMode.NO_WHILE_LOOPS self.reset_cuda_graphs_state() + return self.cuda_graphs_mode is not None - def disable_cuda_graphs(self): + def disable_cuda_graphs(self) -> bool: """Disable CUDA graphs, can be used to disable graphs temporary, e.g., in training process""" if self.cuda_graphs_mode is None: # nothing to disable - return + return False self.cuda_graphs_mode = None self.reset_cuda_graphs_state() + return True def reset_cuda_graphs_state(self): """Reset state to release memory (for CUDA graphs implementations)""" @@ -895,6 +897,11 @@ def __call__( x: torch.Tensor, out_len: torch.Tensor, ) -> Tuple[rnnt_utils.BatchedHyps, Optional[rnnt_utils.BatchedAlignments], Any]: + if self.cuda_graphs_mode is not None and torch.distributed.is_initialized(): + # TODO(vbataev): fix torch.distributed + CUDA graphs, remove this switch + logging.warning(f"In distributed mode, CUDA graphs are not supported yet. Switching off CUDA graphs.") + self.disable_cuda_graphs() + if self.cuda_graphs_mode is not None and x.device.type == "cuda": return self.loop_labels_cuda_graphs(encoder_output=x, encoder_output_length=out_len) diff --git a/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py b/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py index c0fbe5361761..1688a9221946 100644 --- a/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py +++ b/nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py @@ -256,11 +256,11 @@ def __init__( self.cuda_graphs_mode = None self.maybe_enable_cuda_graphs() - def maybe_enable_cuda_graphs(self): + def maybe_enable_cuda_graphs(self) -> bool: """Enable CUDA graphs if conditions met""" if self.cuda_graphs_mode is not None: # CUDA graphs are enabled - return + return False # nothing changed if not self.allow_cuda_graphs: self.cuda_graphs_mode = None @@ -285,14 +285,16 @@ def maybe_enable_cuda_graphs(self): ) self.cuda_graphs_mode = self.CudaGraphsMode.NO_WHILE_LOOPS self.reset_cuda_graphs_state() + return self.cuda_graphs_mode is not None - def disable_cuda_graphs(self): + def disable_cuda_graphs(self) -> bool: """Disable CUDA graphs, can be used to disable graphs temporary, e.g., in training process""" if self.cuda_graphs_mode is None: # nothing to disable - return + return False self.cuda_graphs_mode = None self.reset_cuda_graphs_state() + return True def reset_cuda_graphs_state(self): """Reset state to release memory (for CUDA graphs implementations)""" @@ -1013,6 +1015,11 @@ def __call__( x: torch.Tensor, out_len: torch.Tensor, ) -> Tuple[rnnt_utils.BatchedHyps, Optional[rnnt_utils.BatchedAlignments], Any]: + if self.cuda_graphs_mode is not None and torch.distributed.is_initialized(): + # TODO(vbataev): fix torch.distributed + CUDA graphs, remove this switch + logging.warning(f"In distributed mode, CUDA graphs are not supported yet. Switching off CUDA graphs.") + self.disable_cuda_graphs() + if self.cuda_graphs_mode is not None and x.device.type == "cuda": return self.loop_labels_cuda_graphs(encoder_output=x, encoder_output_length=out_len) diff --git a/nemo/collections/common/parts/optional_cuda_graphs.py b/nemo/collections/common/parts/optional_cuda_graphs.py index 2417d9e00370..6c677e15a101 100644 --- a/nemo/collections/common/parts/optional_cuda_graphs.py +++ b/nemo/collections/common/parts/optional_cuda_graphs.py @@ -49,8 +49,8 @@ def disable_cuda_graphs_recursive(cls, module: nn.Module, attribute_path: Option continue # loop over modules, no attribute if isinstance(object_to_check, cls): - object_to_check.disable_cuda_graphs() - logging.info(f"Disabled CUDA graphs for module {type(submodule)}" + ".".join([name] + attributes)) + if object_to_check.disable_cuda_graphs(): + logging.info(f"Disabled CUDA graphs for module {type(submodule)}" + ".".join([name] + attributes)) @classmethod def enable_cuda_graphs_recursive(cls, module: nn.Module, attribute_path: Optional[str] = None): @@ -75,15 +75,15 @@ def enable_cuda_graphs_recursive(cls, module: nn.Module, attribute_path: Optiona continue # loop over modules, no attribute if isinstance(object_to_check, cls): - object_to_check.maybe_enable_cuda_graphs() - logging.info(f"Enabled CUDA graphs for module {type(submodule)}" + ".".join([name] + attributes)) + if object_to_check.maybe_enable_cuda_graphs(): + logging.info(f"Enabled CUDA graphs for module {type(submodule)}" + ".".join([name] + attributes)) @abc.abstractmethod - def disable_cuda_graphs(self): - """Disable (maybe temporary) CUDA graphs""" + def disable_cuda_graphs(self) -> bool: + """Disable (maybe temporary) CUDA graphs. Return True if CUDA graphs status changed enabled->disabled""" raise NotImplementedError @abc.abstractmethod - def maybe_enable_cuda_graphs(self): - """Enable CUDA graphs if all conditions met""" + def maybe_enable_cuda_graphs(self) -> bool: + """Enable CUDA graphs if all conditions met. Return True if CUDA graphs status changed disabled->enabled""" raise NotImplementedError diff --git a/tests/collections/common/test_optional_cuda_graphs.py b/tests/collections/common/test_optional_cuda_graphs.py index 7b1dda775863..d15c6cb09d92 100644 --- a/tests/collections/common/test_optional_cuda_graphs.py +++ b/tests/collections/common/test_optional_cuda_graphs.py @@ -23,11 +23,13 @@ def __init__(self): super().__init__() self.cuda_graphs_used = True - def disable_cuda_graphs(self): + def disable_cuda_graphs(self) -> bool: self.cuda_graphs_used = False + return True - def maybe_enable_cuda_graphs(self): + def maybe_enable_cuda_graphs(self) -> bool: self.cuda_graphs_used = True + return True class MockModuleWithCudaGraphs(MockClassWithCudaGraphs, nn.Module):