Skip to content

Commit

Permalink
Merge branch 'main' into llm_docs_upd
Browse files Browse the repository at this point in the history
  • Loading branch information
ssh-meister authored Oct 17, 2023
2 parents da2723f + 9610045 commit 057767d
Show file tree
Hide file tree
Showing 6 changed files with 188 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def __init__(
ub_tp_comm_overlap=False,
use_flash_attention=False,
seq_len_interpolation_factor=None,
rotary_base=10000,
):
super(GPTModel, self).__init__(config=config, share_token_embeddings=share_embeddings_and_output_weights)

Expand Down Expand Up @@ -248,6 +249,7 @@ def __init__(
ub_tp_comm_overlap=ub_tp_comm_overlap,
use_flash_attention=use_flash_attention,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
)

if self.share_embeddings_and_output_weights:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,7 @@ def model_provider_func(self, pre_process, post_process):
use_flash_attention=self.cfg.get('use_flash_attention', False),
megatron_legacy=self.cfg.get('megatron_legacy', False),
seq_len_interpolation_factor=self.cfg.get('seq_len_interpolation_factor', None),
rotary_base=self.cfg.get('rotary_base', 10000),
)
return model

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def get_language_model(
ub_tp_comm_overlap=False,
use_flash_attention=False,
seq_len_interpolation_factor=None,
rotary_base=10000,
):
"""Build language model and return along with the key to save."""

Expand Down Expand Up @@ -202,6 +203,7 @@ def get_language_model(
ub_tp_comm_overlap=ub_tp_comm_overlap,
use_flash_attention=use_flash_attention,
seq_len_interpolation_factor=seq_len_interpolation_factor,
rotary_base=rotary_base,
)
# key used for checkpoints.
language_model_key = 'language_model'
Expand Down Expand Up @@ -502,6 +504,7 @@ def __init__(
ub_tp_comm_overlap=False,
use_flash_attention=False,
seq_len_interpolation_factor=None,
rotary_base=10000,
):
super(TransformerLanguageModel, self).__init__(
config=config, share_token_embeddings=share_embeddings_and_output_weights
Expand Down Expand Up @@ -557,6 +560,7 @@ def __init__(
rotary_dim,
seq_len_interpolation_factor=seq_len_interpolation_factor,
pretrained_max_position_embeddings=max_position_embeddings,
rotary_base=rotary_base,
)

elif position_embedding_type == 'alibi':
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,25 @@ class RotaryEmbedding(nn.Module):
"""

def __init__(
self, dim: int, seq_len_interpolation_factor: int = None, pretrained_max_position_embeddings: int = None
self,
dim: int,
seq_len_interpolation_factor: int = None,
rotary_base: int = 10000,
pretrained_max_position_embeddings: int = None,
):
"""
Args:
dim (int): rotary embedding dimension
seq_len_interpolation_factor (int): if not None, discrete positions will be interpolated
by this factor via the trick in https://arxiv.org/abs/2306.15595.
by this factor via the trick in https://arxiv.org/abs/2306.15595.
rotary_base (int): rotary_base for the positional frequency (default: 10000)
pretrained_max_position_embeddings (int): pre-trained max_position_embeddings before position interpolation.
"""
super().__init__()
self.seq_len_interpolation_factor = seq_len_interpolation_factor
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.rotary_base = rotary_base
inv_freq = 1.0 / (self.rotary_base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
self.pretrained_max_position_embeddings = pretrained_max_position_embeddings

Expand Down
12 changes: 12 additions & 0 deletions nemo/utils/exp_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,18 @@ def state_dict(self) -> Dict[str, Any]:
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
return

def _check_time_remaining(self, trainer: "pl.Trainer") -> None:
super()._check_time_remaining(trainer)
if trainer.should_stop:
checkpoint_callback: Optional[NeMoModelCheckpoint] = trainer.checkpoint_callback
if checkpoint_callback:
monitor_candidates = checkpoint_callback._monitor_candidates(trainer)
checkpoint_callback._save_last_checkpoint(trainer, monitor_candidates)
# Throw this exception to signal to Lightning to terminate gracefully.
from pytorch_lightning.utilities.exceptions import _TunerExitException

raise _TunerExitException()


def configure_no_restart_validation_training_loop(trainer: pytorch_lightning.Trainer) -> None:
if type(trainer.fit_loop.epoch_loop) != _TrainingEpochLoop:
Expand Down
160 changes: 160 additions & 0 deletions scripts/checkpoint_averaging/distributed_checkpoint_averaging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Copyright 2017 Johns Hopkins University (Shinji Watanabe)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Example: python scripts/checkpoint_averaging/distributed_checkpoint_averaging.py \
--name_prefix=<checkpoint name> \
--checkpoint_dir=<folder with mp_rank_X subfolders containing checkpoints>
--steps <optinally a list of checkpoint steps to average, if not provided, it will average all the checkpoints>
will generate a new directory in each of the distributed checkpoint subfolders named <checkpoint name>-averaged
"""

import argparse
import logging
import os
import shutil

import zarr

logging.basicConfig(level=logging.INFO)


def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'--name_prefix', help='Name of the final checkpoint. Will append -averaged automatically.',
)
parser.add_argument(
'--checkpoint_dir', help='Folder containing all the distributed checkpoints.',
)
# list of checkpoint steps to average
parser.add_argument(
'--steps',
nargs='+',
type=int,
help='List of checkpoint steps to average. If not specified, will average all.',
)

args = parser.parse_args()

if args.steps is not None:
logging.info(f"Will average only steps {args.steps}")

# repeating for all ranks

checkpoint_paths = []
for ckpt_dir in os.listdir(args.checkpoint_dir):
logging.info("Processing %s", ckpt_dir)
if ckpt_dir.endswith('0-last'):
continue
if args.steps is None:
checkpoint_paths.append(ckpt_dir)
else:
for step in args.steps:
key = f"-step={step}-"
if key in ckpt_dir:
checkpoint_paths.append(ckpt_dir)

n = len(checkpoint_paths)
# initialize dict, will be used to store the weights that need to be averaged
avg_weights = {}

logging.info(f"Averaging {n} checkpoints ... {'at steps:' + str(args.steps) if args.steps is not None else ''}")

# item that needs to be copied to the new checkpoint folder
copy_items = []
for ix, path in enumerate(checkpoint_paths):
full_path = os.path.join(args.checkpoint_dir, path)

for item in os.listdir(full_path):

# if item is not a directory, skip it
if not os.path.isdir(os.path.join(full_path, item)):
if ix == 0:
copy_items.append(os.path.join(full_path, item))
continue

# transformer engine states, leave them out
if item.endswith('._extra_state'):
if ix == 0:
copy_items.append(os.path.join(full_path, item))
continue

# optimizer states, no point of averaing them
if item.startswith('optimizer.'):
if ix == 0:
copy_items.append(os.path.join(full_path, item))
continue

if item not in avg_weights:
logging.info(f"Initialized average weights dict with: {item}")
avg_weights[item] = zarr.open(os.path.join(full_path, item), mode='r')
else:
logging.info(f"Updated average weights dict with weight: {item}")
array_z = zarr.open(os.path.join(full_path, item), mode='r')
sum_array = avg_weights[item][:] + array_z[:]
avg_weights[item] = zarr.array(sum_array, chunks=array_z.chunks, dtype=array_z.dtype)

for k in avg_weights:
logging.info(f"Average weights dict key : {k}, dtype : {avg_weights[k].dtype}, shape : {avg_weights[k].shape}")
if str(avg_weights[k].dtype).startswith("int"):
raise ValueError("Int type not supported")
else:
array_z = avg_weights[k][:]
array_z = array_z / n
avg_weights[k] = zarr.array(array_z, chunks=avg_weights[k].chunks, dtype=avg_weights[k].dtype)

# Save model
if args.steps is None:
ckpt_name = os.path.join(args.checkpoint_dir, args.name_prefix + '-averaged')
else:
steps_combined = '_'.join([str(x) for x in args.steps])
ckpt_name = os.path.join(args.checkpoint_dir, args.name_prefix + '-' + steps_combined + '-averaged')

# save avg_weights
for k in avg_weights:
logging.info(f"Saving {k} to {ckpt_name}")
zarr.save(os.path.join(ckpt_name, k), avg_weights[k])

# copy other files
for item in copy_items:
is_file = os.path.isfile(item)
logging.info(f"Copying {'directory' if is_file else 'file'} {item} to {ckpt_name}")
if os.path.isfile(item):
# copy single file
shutil.copy(item, ckpt_name)
else:
# copy directory
shutil.copytree(item, os.path.join(ckpt_name, os.path.basename(item)), dirs_exist_ok=True)

logging.info(f"Averaged distributed checkpoint saved as : {ckpt_name}")


if __name__ == '__main__':
main()

0 comments on commit 057767d

Please sign in to comment.