From 60dc3226abd8e4c19df141552177a68f2528469e Mon Sep 17 00:00:00 2001 From: Mustafa Eyceoz Date: Tue, 25 Jun 2024 16:18:13 -0400 Subject: [PATCH] Bug Fixes: Small batch barrier and Broken model loading (#72) Fixes model loading on multi-gpu setups and also raises an error when batch sizes are too small for multipack sampling --------- Signed-off-by: Mustafa Eyceoz --- src/instructlab/training/multipack_sampler.py | 4 ++ src/instructlab/training/utils.py | 37 +++++++++---------- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/src/instructlab/training/multipack_sampler.py b/src/instructlab/training/multipack_sampler.py index 254dafb2..682871aa 100644 --- a/src/instructlab/training/multipack_sampler.py +++ b/src/instructlab/training/multipack_sampler.py @@ -204,6 +204,10 @@ def find_packing_max_batch_len_and_grad_accum( while packing_max_batch_len > max_batch_len_per_gpu: grad_accum += 1 total_micro_batch = (effective_batch_size / grad_accum) / num_gpus + if int(avg_sample_len * total_micro_batch) < dataset.get_lengths().max(): + raise RuntimeError( + f"Effective batch size is too low for multipack sampling, max sample length={dataset.get_lengths().max()} and min packing length={int(avg_sample_len * total_micro_batch)}" + ) if is_padding: addition = find_padding_max_batch_len_addition( avg_sample_len, diff --git a/src/instructlab/training/utils.py b/src/instructlab/training/utils.py index d5b05700..c58182a7 100644 --- a/src/instructlab/training/utils.py +++ b/src/instructlab/training/utils.py @@ -483,26 +483,25 @@ class UniversalCheckpointArgs: @contextmanager def ensure_loadable_granite_checkpoint(model_name_or_path: str): - if not dist.is_initialized() or dist.get_rank() == 0: - try: - GPTDolomiteConfig.from_pretrained(model_name_or_path) - yield model_name_or_path - except: # pylint: disable=bare-except - log_rank_0( - f"\033[93mModel saved in {model_name_or_path} requires conversion \033[0m", - to_print=True, - ) - # if the load failed then it must not be a granite - # for now just assume its a llama - # with TemporaryDirectory("w") as tmpdir: - # make a temp directory name, but do not create it - tmpdir = mktemp() + try: + GPTDolomiteConfig.from_pretrained(model_name_or_path) + yield model_name_or_path + except: # pylint: disable=bare-except + log_rank_0( + f"\033[93mModel saved in {model_name_or_path} requires conversion \033[0m", + to_print=True, + ) + # if the load failed then it must not be a granite + # for now just assume its a llama + # with TemporaryDirectory("w") as tmpdir: + # make a temp directory name, but do not create it + tmpdir = mktemp() + if not dist.is_initialized() or dist.get_rank() == 0: import_from_huggingface(model_name_or_path, tmpdir) - yield tmpdir - shutil.rmtree(tmpdir, ignore_errors=True) - - if dist.is_initialized(): - dist.barrier() + if dist.is_initialized(): + dist.barrier() + yield tmpdir + shutil.rmtree(tmpdir, ignore_errors=True) # this function is for supporting gradient checkpointing for padding free