Skip to content

Commit

Permalink
Improve messages for enabling/disabling CUDA graphs (log only if some…
Browse files Browse the repository at this point in the history
…thing changed)

Signed-off-by: Vladimir Bataev <[email protected]>
  • Loading branch information
artbataev committed Oct 29, 2024
1 parent a41595d commit cf7f2f1
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 33 deletions.
37 changes: 22 additions & 15 deletions nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,39 +655,44 @@ def __init__(
else:
self._greedy_decode = self._greedy_decode_masked

def disable_cuda_graphs(self):
def disable_cuda_graphs(self) -> bool:
"""Disable CUDA graphs (e.g., for decoding in training)"""
if not self.use_cuda_graph_decoder:
# CUDA graphs not allowed, nothing to do
return
return False

if not self.decoder.blank_as_pad:
# blank as pad uses decoding without CUDA graphs
return
return False

if self.loop_labels:
# Label-Looping implementation
self._decoding_computer.disable_cuda_graphs()
return self._decoding_computer.disable_cuda_graphs()
else:
self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames
if self._greedy_decode != self._greedy_decode_blank_as_pad_loop_frames:
self._greedy_decode = self._greedy_decode_blank_as_pad_loop_frames
return True
return False

def maybe_enable_cuda_graphs(self):
def maybe_enable_cuda_graphs(self) -> bool:
"""Enable CUDA graphs (if allowed)"""
if not self.use_cuda_graph_decoder:
# CUDA graphs not allowed, nothing to do
return
return False

if not self.decoder.blank_as_pad:
# blank as pad uses decoding without CUDA graphs
return
return False

if self.loop_labels:
# Label-Looping implementation
self._decoding_computer.maybe_enable_cuda_graphs()
return self._decoding_computer.maybe_enable_cuda_graphs()
else:
from nemo.collections.asr.parts.submodules.cuda_graph_rnnt_greedy_decoding import RNNTGreedyDecodeCudaGraph

self._greedy_decode = RNNTGreedyDecodeCudaGraph(self.max_symbols, self)
if not isinstance(self._greedy_decode, RNNTGreedyDecodeCudaGraph):
self._greedy_decode = RNNTGreedyDecodeCudaGraph(self.max_symbols, self)
return True
return False

@typecheck()
def forward(
Expand Down Expand Up @@ -2832,12 +2837,14 @@ def _greedy_decode_blank_as_pad_loop_labels(
hyp.dec_state = state
return hyps

def disable_cuda_graphs(self):
def disable_cuda_graphs(self) -> bool:
"""Disable CUDA graphs (e.g., for decoding in training)"""
if self._decoding_computer is not None:
self._decoding_computer.disable_cuda_graphs()
return self._decoding_computer.disable_cuda_graphs()
return False # nothing changed

def maybe_enable_cuda_graphs(self):
def maybe_enable_cuda_graphs(self) -> bool:
"""Enable CUDA graphs (if allowed)"""
if self._decoding_computer is not None:
self._decoding_computer.maybe_enable_cuda_graphs()
return self._decoding_computer.maybe_enable_cuda_graphs()
return False # nothing changed
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,11 @@ def force_cuda_graphs_mode(self, mode: Optional[Union[str, CudaGraphsMode]]):
self.cuda_graphs_mode = self.CudaGraphsMode(mode) if mode is not None else None
self.state = None

def maybe_enable_cuda_graphs(self):
def maybe_enable_cuda_graphs(self) -> bool:
"""Enable CUDA graphs if conditions met"""
if self.cuda_graphs_mode is not None:
# CUDA graphs are already enabled
return
return False # nothing changed

if not self.allow_cuda_graphs:
self.cuda_graphs_mode = None
Expand All @@ -274,14 +274,16 @@ def maybe_enable_cuda_graphs(self):
)
self.cuda_graphs_mode = self.CudaGraphsMode.NO_WHILE_LOOPS
self.reset_cuda_graphs_state()
return self.cuda_graphs_mode is not None

def disable_cuda_graphs(self):
def disable_cuda_graphs(self) -> bool:
"""Disable CUDA graphs, can be used to disable graphs temporary, e.g., in training process"""
if self.cuda_graphs_mode is None:
# nothing to disable
return
return False
self.cuda_graphs_mode = None
self.reset_cuda_graphs_state()
return True

def reset_cuda_graphs_state(self):
"""Reset state to release memory (for CUDA graphs implementations)"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,11 +256,11 @@ def __init__(
self.cuda_graphs_mode = None
self.maybe_enable_cuda_graphs()

def maybe_enable_cuda_graphs(self):
def maybe_enable_cuda_graphs(self) -> bool:
"""Enable CUDA graphs if conditions met"""
if self.cuda_graphs_mode is not None:
# CUDA graphs are enabled
return
return False # nothing changed

if not self.allow_cuda_graphs:
self.cuda_graphs_mode = None
Expand All @@ -285,14 +285,16 @@ def maybe_enable_cuda_graphs(self):
)
self.cuda_graphs_mode = self.CudaGraphsMode.NO_WHILE_LOOPS
self.reset_cuda_graphs_state()
return self.cuda_graphs_mode is not None

def disable_cuda_graphs(self):
def disable_cuda_graphs(self) -> bool:
"""Disable CUDA graphs, can be used to disable graphs temporary, e.g., in training process"""
if self.cuda_graphs_mode is None:
# nothing to disable
return
return False
self.cuda_graphs_mode = None
self.reset_cuda_graphs_state()
return True

def reset_cuda_graphs_state(self):
"""Reset state to release memory (for CUDA graphs implementations)"""
Expand Down
16 changes: 8 additions & 8 deletions nemo/collections/common/parts/optional_cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def disable_cuda_graphs_recursive(cls, module: nn.Module, attribute_path: Option
continue # loop over modules, no attribute

if isinstance(object_to_check, cls):
object_to_check.disable_cuda_graphs()
logging.info(f"Disabled CUDA graphs for module {type(submodule)}" + ".".join([name] + attributes))
if object_to_check.disable_cuda_graphs():
logging.info(f"Disabled CUDA graphs for module {type(submodule)}" + ".".join([name] + attributes))

@classmethod
def enable_cuda_graphs_recursive(cls, module: nn.Module, attribute_path: Optional[str] = None):
Expand All @@ -75,15 +75,15 @@ def enable_cuda_graphs_recursive(cls, module: nn.Module, attribute_path: Optiona
continue # loop over modules, no attribute

if isinstance(object_to_check, cls):
object_to_check.maybe_enable_cuda_graphs()
logging.info(f"Enabled CUDA graphs for module {type(submodule)}" + ".".join([name] + attributes))
if object_to_check.maybe_enable_cuda_graphs():
logging.info(f"Enabled CUDA graphs for module {type(submodule)}" + ".".join([name] + attributes))

@abc.abstractmethod
def disable_cuda_graphs(self):
"""Disable (maybe temporary) CUDA graphs"""
def disable_cuda_graphs(self) -> bool:
"""Disable (maybe temporary) CUDA graphs. Return True if CUDA graphs status changed enabled->disabled"""
raise NotImplementedError

@abc.abstractmethod
def maybe_enable_cuda_graphs(self):
"""Enable CUDA graphs if all conditions met"""
def maybe_enable_cuda_graphs(self) -> bool:
"""Enable CUDA graphs if all conditions met. Return True if CUDA graphs status changed disabled->enabled"""
raise NotImplementedError
6 changes: 4 additions & 2 deletions tests/collections/common/test_optional_cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ def __init__(self):
super().__init__()
self.cuda_graphs_used = True

def disable_cuda_graphs(self):
def disable_cuda_graphs(self) -> bool:
self.cuda_graphs_used = False
return True

def maybe_enable_cuda_graphs(self):
def maybe_enable_cuda_graphs(self) -> bool:
self.cuda_graphs_used = True
return True


class MockModuleWithCudaGraphs(MockClassWithCudaGraphs, nn.Module):
Expand Down

0 comments on commit cf7f2f1

Please sign in to comment.