Skip to content

Commit

Permalink
Guard cuda memory allocator update (#9312)
Browse files Browse the repository at this point in the history
* Guard cuda memory allocator update

Signed-off-by: smajumdar <[email protected]>

* Apply isort and black reformatting

Signed-off-by: titu1994 <[email protected]>

---------

Signed-off-by: smajumdar <[email protected]>
Signed-off-by: titu1994 <[email protected]>
  • Loading branch information
titu1994 authored May 25, 2024
1 parent 6040af5 commit 7235f2b
Showing 1 changed file with 34 additions and 8 deletions.
42 changes: 34 additions & 8 deletions nemo/collections/common/data/lhotse/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ class LhotseDataLoadingConfig:

# 4. Optional Lhotse data augmentation.
# a. On-the-fly noise/audio mixing.
noise_path: Any | None = None # str | dict where dict can have any of keys: manifest_filepath, tarred_audio_filepaths, cuts_path, shar_path
noise_path: Any | None = (
None # str | dict where dict can have any of keys: manifest_filepath, tarred_audio_filepaths, cuts_path, shar_path
)
noise_snr: tuple[float, float] = (10.0, 20.0)
noise_mix_prob: float = 0.5
# b. On-the-fly 3-way speed perturbation.
Expand All @@ -114,7 +116,9 @@ class LhotseDataLoadingConfig:
cut_into_windows_duration: Optional[float] = None # set this to enable
cut_into_windows_hop: Optional[float] = None
# III) common options
keep_excessive_supervisions: bool = True # when a cut is truncated in the middle of a supervision, should we keep them.
keep_excessive_supervisions: bool = (
True # when a cut is truncated in the middle of a supervision, should we keep them.
)
# e. RIR augmentation (synthetic RIR if rir_path is None)
# at the moment supports only Lhotse recording manifests, e.g. https://github.com/lhotse-speech/lhotse/blob/master/lhotse/recipes/rir_noise.py
rir_enabled: bool = False
Expand All @@ -130,7 +134,11 @@ class LhotseDataLoadingConfig:


def get_lhotse_dataloader_from_config(
config: DictConfig, global_rank: int, world_size: int, dataset: torch.utils.data.Dataset, tokenizer=None,
config: DictConfig,
global_rank: int,
world_size: int,
dataset: torch.utils.data.Dataset,
tokenizer=None,
) -> torch.utils.data.DataLoader:
"""
Set up a Lhotse training dataloder.
Expand Down Expand Up @@ -205,7 +213,11 @@ def get_lhotse_dataloader_from_config(
# and applying it here (before sampler/dataset) ensures optimal
# bucket allocation.
if config.perturb_speed:
cuts = CutSet.mux(cuts, cuts.perturb_speed(0.9), cuts.perturb_speed(1.1),)
cuts = CutSet.mux(
cuts,
cuts.perturb_speed(0.9),
cuts.perturb_speed(1.1),
)

# 2.d: truncation/slicing
if config.truncate_duration is not None:
Expand Down Expand Up @@ -291,7 +303,10 @@ def get_lhotse_dataloader_from_config(
# object with texts joined by a whitespace so that "regular" dataset classes don't
# have to add a special support for multi-supervision cuts.
sampler = sampler.map(
CutConcatenate(gap=config.concatenate_gap_seconds, duration_factor=config.concatenate_duration_factor,)
CutConcatenate(
gap=config.concatenate_gap_seconds,
duration_factor=config.concatenate_duration_factor,
)
)
if config.db_norm is not None:
sampler = sampler.map(partial(_normalize_loudness, db_norm=config.db_norm))
Expand Down Expand Up @@ -326,7 +341,10 @@ def get_lhotse_dataloader_from_config(
# the meta-data to Dataset, which performs the actual I/O inside its __getitem__ method.
dloader_kwargs = dict(dataset=dataset, sampler=sampler)
dloader = torch.utils.data.DataLoader(
**dloader_kwargs, batch_size=None, num_workers=config.num_workers, pin_memory=config.pin_memory,
**dloader_kwargs,
batch_size=None,
num_workers=config.num_workers,
pin_memory=config.pin_memory,
)

return dloader
Expand Down Expand Up @@ -377,7 +395,9 @@ class MultimodalSamplingConstraint(SamplingConstraint):

def __post_init__(self):
self._internal = TokenConstraint(
max_tokens=self.batch_tokens, max_examples=self.batch_size, quadratic_length=self.quadratic_factor,
max_tokens=self.batch_tokens,
max_examples=self.batch_size,
quadratic_length=self.quadratic_factor,
)

def add(self, example: Any) -> None:
Expand Down Expand Up @@ -487,7 +507,13 @@ def maybe_set_cuda_expandable_segments(enabled: bool):
warnings.warn(
"You have set PYTORCH_CUDA_ALLOC_CONF without expandable_segments:True option. We're setting that option anyway. To disable it, set cuda_expandable_segments=False in NeMo dataloader configuration."
)
torch.cuda.memory._set_allocator_settings("expandable_segments:True")

try:
torch.cuda.memory._set_allocator_settings("expandable_segments:True")
except RuntimeError:
logging.info(
"Failed to set expandable_segments:True for PyTorch CUDA allocator. You may get training speed improvements if you enable this"
)


def _select_channel(cut, channel_selector: int | str) -> list:
Expand Down

0 comments on commit 7235f2b

Please sign in to comment.