From 9380370621bc9e2ebb300c16add4705548c732e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Hern=C3=A1ndez=20Cano?= Date: Fri, 15 Sep 2023 20:09:02 +0200 Subject: [PATCH 1/5] Added tokens/sec metric for gpt datasets --- megatron/__init__.py | 1 + megatron/global_vars.py | 15 +++++++++++++++ megatron/training.py | 7 +++++++ 3 files changed, 23 insertions(+) diff --git a/megatron/__init__.py b/megatron/__init__.py index 19bb819..eb67678 100644 --- a/megatron/__init__.py +++ b/megatron/__init__.py @@ -10,6 +10,7 @@ from .global_vars import get_tensorboard_writer from .global_vars import get_adlr_autoresume from .global_vars import get_timers +from .global_vars import get_counters from .utils import (print_rank_0, print_all_nodes, diff --git a/megatron/global_vars.py b/megatron/global_vars.py index a55a939..f4c812c 100644 --- a/megatron/global_vars.py +++ b/megatron/global_vars.py @@ -4,6 +4,7 @@ import os import sys +from collections import defaultdict from megatron import dist_signal_handler from megatron.tokenizer import build_tokenizer @@ -17,6 +18,7 @@ _GLOBAL_ADLR_AUTORESUME = None _GLOBAL_TIMERS = None _GLOBAL_SIGNAL_HANDLER = None +_GLOBAL_COUNTERS = None def get_args(): @@ -62,6 +64,12 @@ def get_timers(): return _GLOBAL_TIMERS +def get_counters(): + """Return counters.""" + _ensure_var_is_initialized(_GLOBAL_COUNTERS, 'counters') + return _GLOBAL_COUNTERS + + def get_signal_handler(): _ensure_var_is_initialized(_GLOBAL_SIGNAL_HANDLER, 'signal handler') return _GLOBAL_SIGNAL_HANDLER @@ -90,6 +98,7 @@ def set_global_variables(args): _set_tensorboard_writer(args) _set_adlr_autoresume(args) _set_timers(args) + _set_counters(args) if args.exit_signal_handler: _set_signal_handler() @@ -178,6 +187,12 @@ def _set_timers(args): _GLOBAL_TIMERS = Timers(args.timing_log_level, args.timing_log_option) +def _set_counters(args): + global _GLOBAL_COUNTERS + _ensure_var_is_not_initialized(_GLOBAL_COUNTERS, 'counters') + _GLOBAL_COUNTERS = defaultdict(int) + + def _ensure_var_is_initialized(var, name): """Make sure the input variable is not None.""" assert var is not None, '{} is not initialized.'.format(name) diff --git a/megatron/training.py b/megatron/training.py index 6e6d1ce..bb0cb5d 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -17,6 +17,7 @@ from megatron import get_args from megatron import get_signal_handler from megatron import get_timers +from megatron import get_counters from megatron import get_tensorboard_writer from megatron import get_current_global_batch_size from megatron import get_num_microbatches @@ -590,10 +591,15 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, if iteration % args.log_interval == 0: elapsed_time = timers('interval-time').elapsed(barrier=True) elapsed_time_per_iteration = elapsed_time / total_iterations + counters = get_counters() + tokens = args.data_parallel_size*counters['tokens'] + del counters['tokens'] # reset counter for future iterations + tokens_per_sec = tokens/(elapsed_time) if writer: if args.log_timers_to_tensorboard: writer.add_scalar('iteration-time', elapsed_time_per_iteration, iteration) + writer.add_scalar('tokens-per-sec', tokens_per_sec, iteration) log_string = ' iteration {:8d}/{:8d} |'.format( iteration, args.train_iters) @@ -601,6 +607,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, args.consumed_train_samples) log_string += ' elapsed time per iteration (ms): {:.1f} |'.format( elapsed_time_per_iteration * 1000.0) + log_string += f' rate (tokens/sec): {tokens_per_sec:.2f} |' log_string += ' learning rate: {:.3E} |'.format(learning_rate) log_string += ' global batch size: {:5d} |'.format(batch_size) for key in total_loss_dict: From 5c0409f48ee12e1facfb8a3954bd40514f9fc7c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Hern=C3=A1ndez=20Cano?= Date: Fri, 15 Sep 2023 20:35:01 +0200 Subject: [PATCH 2/5] Dynamically calculate token count --- finetune.py | 18 ++++++++++++++++-- megatron/training.py | 3 +-- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/finetune.py b/finetune.py index 26020ae..bd96949 100644 --- a/finetune.py +++ b/finetune.py @@ -5,9 +5,10 @@ import torch -from megatron import get_args, get_tokenizer, get_timers, print_rank_0 +from megatron import get_args, get_tokenizer, get_timers, get_counters, print_rank_0 from megatron.training import pretrain from megatron.core import tensor_parallel +from megatron.core.parallel_state import get_data_parallel_group from megatron.model import GPTModel, ModelType, LlamaModel, FalconModel from megatron.utils import get_ltor_masks_and_position_ids, average_losses_across_data_parallel_group from megatron.data.gpt_dataset import build_train_valid_test_datasets as gpt_build_datasets @@ -119,8 +120,21 @@ def get_batch(data_iterator): tokens = data_b["text"] labels = tokens[:, 1:].contiguous() tokens = tokens[:, :-1].contiguous() - if args.data_type == "gpt": + # Update tokens counter. + counters = get_counters() + n_tokens = torch.tensor(tokens.numel(), device=tokens.device) + if args.data_parallel_size == 1: + n_tokens = n_tokens.item() + else: + group = get_data_parallel_group() + token_counts = torch.zeros(args.data_parallel_size, dtype=torch.long, + device=tokens.device) + torch.distributed.all_gather_into_tensor(token_counts, n_tokens, group=group) + n_tokens = torch.sum(token_counts).item() + counters["tokens"] += n_tokens + + if args.data_type == "gpt": # Get the masks and position ids. attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( tokens, diff --git a/megatron/training.py b/megatron/training.py index bb0cb5d..6f366b8 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -592,8 +592,7 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration, elapsed_time = timers('interval-time').elapsed(barrier=True) elapsed_time_per_iteration = elapsed_time / total_iterations counters = get_counters() - tokens = args.data_parallel_size*counters['tokens'] - del counters['tokens'] # reset counter for future iterations + tokens = counters.pop('tokens') # reset counter for future iterations tokens_per_sec = tokens/(elapsed_time) if writer: if args.log_timers_to_tensorboard: From 503e4cf0a9f743f8352d5e8b29c80e48fa57ff0b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Hern=C3=A1ndez=20Cano?= Date: Tue, 19 Sep 2023 20:40:18 +0200 Subject: [PATCH 3/5] Fixed logging of timers --- megatron/timers.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/megatron/timers.py b/megatron/timers.py index 073602a..a9478fa 100644 --- a/megatron/timers.py +++ b/megatron/timers.py @@ -302,6 +302,3 @@ def write(self, names, writer, iteration, normalizer=1.0, for name in name_to_min_max_time: _, max_time = name_to_min_max_time[name] writer.add_scalar(name + '-time', max_time, iteration) - # if using wandb writer, flush the stats we just filled here, close to the creation time - if hasattr(writer,"flush_all"): - writer.flush_all() From 339cf5bf384b374b345b73e6e7d8762956be6aba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Hern=C3=A1ndez=20Cano?= Date: Tue, 19 Sep 2023 20:40:35 +0200 Subject: [PATCH 4/5] Fixed overestimation of tokens/sec after evaluation steps --- megatron/training.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/megatron/training.py b/megatron/training.py index 6f366b8..3e87b91 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -674,6 +674,7 @@ def _train(args, forward_step_func, # Iterations. iteration = args.iteration + counters = get_counters() timers('interval-time', log_level=0).start(barrier=True) print_datetime('before the start of training step') report_memory_flag = True @@ -712,10 +713,13 @@ def _train(args, forward_step_func, if args.eval_interval and iteration % args.eval_interval == 0 and \ args.do_valid: prefix = 'iteration {}'.format(iteration) + current_tokens = counters['tokens'] evaluate_and_print_results(prefix, forward_step_func, valid_data_iterator, model, iteration, process_non_loss_data_func, verbose=False, args=args) + counters['tokens'] = current_tokens + # if using wandb writer, flush the stats of train_step & potentially evaluate writer = get_tensorboard_writer() From 5b4ae479f2969a19610f2f86b4ce61b8e0ec479b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Hern=C3=A1ndez=20Cano?= Date: Fri, 22 Sep 2023 15:59:45 +0200 Subject: [PATCH 5/5] Small optimization: replaced gather() with allreduce() --- finetune.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/finetune.py b/finetune.py index bd96949..22e6b7e 100644 --- a/finetune.py +++ b/finetune.py @@ -128,10 +128,10 @@ def get_batch(data_iterator): n_tokens = n_tokens.item() else: group = get_data_parallel_group() - token_counts = torch.zeros(args.data_parallel_size, dtype=torch.long, - device=tokens.device) - torch.distributed.all_gather_into_tensor(token_counts, n_tokens, group=group) - n_tokens = torch.sum(token_counts).item() + torch.distributed.all_reduce( + n_tokens, op=torch.distributed.ReduceOp.SUM, group=group + ) + n_tokens = n_tokens.item() counters["tokens"] += n_tokens if args.data_type == "gpt":