Skip to content

Commit

Permalink
Bug Fixes: Small batch barrier and Broken model loading (#72)
Browse files Browse the repository at this point in the history
Fixes model loading on multi-gpu setups and also raises an error when batch sizes are too small for multipack sampling

---------

Signed-off-by: Mustafa Eyceoz <[email protected]>
  • Loading branch information
Maxusmusti authored Jun 25, 2024
1 parent ebc5e31 commit 60dc322
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 19 deletions.
4 changes: 4 additions & 0 deletions src/instructlab/training/multipack_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ def find_packing_max_batch_len_and_grad_accum(
while packing_max_batch_len > max_batch_len_per_gpu:
grad_accum += 1
total_micro_batch = (effective_batch_size / grad_accum) / num_gpus
if int(avg_sample_len * total_micro_batch) < dataset.get_lengths().max():
raise RuntimeError(
f"Effective batch size is too low for multipack sampling, max sample length={dataset.get_lengths().max()} and min packing length={int(avg_sample_len * total_micro_batch)}"
)
if is_padding:
addition = find_padding_max_batch_len_addition(
avg_sample_len,
Expand Down
37 changes: 18 additions & 19 deletions src/instructlab/training/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,26 +483,25 @@ class UniversalCheckpointArgs:

@contextmanager
def ensure_loadable_granite_checkpoint(model_name_or_path: str):
if not dist.is_initialized() or dist.get_rank() == 0:
try:
GPTDolomiteConfig.from_pretrained(model_name_or_path)
yield model_name_or_path
except: # pylint: disable=bare-except
log_rank_0(
f"\033[93mModel saved in {model_name_or_path} requires conversion \033[0m",
to_print=True,
)
# if the load failed then it must not be a granite
# for now just assume its a llama
# with TemporaryDirectory("w") as tmpdir:
# make a temp directory name, but do not create it
tmpdir = mktemp()
try:
GPTDolomiteConfig.from_pretrained(model_name_or_path)
yield model_name_or_path
except: # pylint: disable=bare-except
log_rank_0(
f"\033[93mModel saved in {model_name_or_path} requires conversion \033[0m",
to_print=True,
)
# if the load failed then it must not be a granite
# for now just assume its a llama
# with TemporaryDirectory("w") as tmpdir:
# make a temp directory name, but do not create it
tmpdir = mktemp()
if not dist.is_initialized() or dist.get_rank() == 0:
import_from_huggingface(model_name_or_path, tmpdir)
yield tmpdir
shutil.rmtree(tmpdir, ignore_errors=True)

if dist.is_initialized():
dist.barrier()
if dist.is_initialized():
dist.barrier()
yield tmpdir
shutil.rmtree(tmpdir, ignore_errors=True)


# this function is for supporting gradient checkpointing for padding free
Expand Down

0 comments on commit 60dc322

Please sign in to comment.