Skip to content

Commit

Permalink
Merge branch 'main' into aot/gemma2
Browse files Browse the repository at this point in the history
  • Loading branch information
suiyoubi authored Oct 29, 2024
2 parents efd426f + 100e093 commit cf33e03
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 3 deletions.
14 changes: 14 additions & 0 deletions nemo/collections/nlp/modules/common/hyena/hyena.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Implementation of Hyena operator
#
# Michael Poli and Stefano Massaroli and Eric Nguyen and Daniel Y Fu and Tri Dao and Stephen Baccus and
Expand Down
6 changes: 6 additions & 0 deletions nemo/lightning/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import List, Literal, Optional

import torch
from pytorch_lightning.overrides.distributed import _IndexBatchSamplerWrapper
from torch.utils.data import DataLoader, Dataset


Expand Down Expand Up @@ -139,6 +140,7 @@ def add_megatron_sampler(
dataloader_type: Literal["single", "cyclic", "batch"] = "single",
drop_last: bool = True,
pad_samples_to_global_batch_size: bool = False,
dataloader_mode: Literal["train", "validation", "test", "predict"] = "train",
rank: int = 0,
world_size: int = 1,
# data_sharding: bool = False
Expand Down Expand Up @@ -170,6 +172,7 @@ def add_megatron_sampler(
pad_samples_to_global_batch_size (bool, optional): Whether to pad the last incomplete
batch to the `global_batch_size` (defaults to False, only applies when
`drop_last` is False).
dataloader_mode (Literal["train", "validation", "test", "predict"]): The mode of dataloader.
Returns:
DataLoader: A new DataLoader instance with the configured Megatron sampler.
Expand Down Expand Up @@ -214,6 +217,9 @@ def add_megatron_sampler(
else:
raise Exception(f'{dataloader_type} dataloader type is not supported.')

if dataloader_mode in ["test", "predict"]:
batch_sampler = _IndexBatchSamplerWrapper(batch_sampler) # BatchSampler wrapper to capture its indices

return DataLoader(
dataloader.dataset,
batch_sampler=batch_sampler,
Expand Down
14 changes: 14 additions & 0 deletions nemo/lightning/fabric/strategies.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import ExitStack, contextmanager
from datetime import timedelta
from typing import (
Expand Down
5 changes: 2 additions & 3 deletions nemo/lightning/pytorch/plugins/data_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ def __init__(
init_consumed_samples: int = 0,
init_global_step: int = 0,
output_log: bool = True,
drop_last: bool = True,
):
self.seq_len = seq_len
self.output_log = output_log
Expand All @@ -57,7 +56,6 @@ def __init__(
self.if_first_step = 0
self.prev_global_batch_size = None
self.init_global_step = init_global_step
self.drop_last = drop_last

def setup(self, global_rank: int) -> None:
from nemo.lightning.data import setup_microbatch_calculator
Expand All @@ -80,7 +78,8 @@ def transform_dataloader(self, dataloader: DataLoader, consumed_samples: int = 0
rampup_batch_size=self.rampup_batch_size,
consumed_samples=self.init_consumed_samples if mode == 'train' else 0,
dataloader_type=self.dataloader_type,
drop_last=self.drop_last,
drop_last=mode not in ["test", "predict"], # don't drop the incomplete batch in test and predict methods
dataloader_mode=mode, # dataloader wrapped with nemo.lightning.data.WrappedDataLoader has mode attribute
rank=data_parallel_rank,
world_size=data_parallel_size,
)
Expand Down
14 changes: 14 additions & 0 deletions tests/collections/llm/recipes/test_mixtral_8x7b.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import nemo_run as run
import pytest
import torch
Expand Down

0 comments on commit cf33e03

Please sign in to comment.