Skip to content

Commit

Permalink
Add a straggler detection test as well
Browse files Browse the repository at this point in the history
Signed-off-by: Shriya Palsamudram <[email protected]>
  • Loading branch information
ShriyaPalsamudram committed Nov 13, 2024
1 parent a8fda44 commit 3e34a5e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 25 deletions.
26 changes: 21 additions & 5 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3827,17 +3827,32 @@ jobs:
--index-mapping-dir=/tmp/llm_tests/llama_index_mappings \
--cp 1 --tp 2 --sp 1
L2_NeMo_2_llama3_fault_tolerance:
L2_NeMo_2_llama3_fault_tolerance_plugin:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_llama3_fault_tolerance') || needs.cicd-test-container-setup.outputs.all == 'true'
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_llama3_fault_tolerance_plugin') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
python tests/collections/llm/test_fault_tolerance.py \
python tests/collections/llm/test_nvrx.py \
--devices=2 \
--crash-step=4 \
--experiment-dir=/tmp/llm_tests/llama_pretrain_results \
--data-path=/home/TestData/nlp/megatron_llama/data/rp2_sample_sentencepiece_preproc_text_document \
--tokenizer-path=/home/TestData/nlp/megatron_llama/tokenizer.model \
--index-mapping-dir=/tmp/llm_tests/llama_index_mappings \
L2_NeMo_2_llama3_straggler_detection:
needs: [cicd-test-container-setup]
uses: ./.github/workflows/_test_template.yml
if: contains(fromJSON(needs.cicd-test-container-setup.outputs.test_to_run), 'L2_NeMo_2_llama3_straggler_detection') || needs.cicd-test-container-setup.outputs.all == 'true'
with:
RUNNER: self-hosted-azure
SCRIPT: |
python tests/collections/llm/test_nvrx.py \
--devices=2 \
--max-steps=20 \
--experiment-dir=/tmp/llm_tests/llama_pretrain_results \
--data-path=/home/TestData/nlp/megatron_llama/data/rp2_sample_sentencepiece_preproc_text_document \
--tokenizer-path=/home/TestData/nlp/megatron_llama/tokenizer.model \
Expand Down Expand Up @@ -4470,7 +4485,8 @@ jobs:
- L2_NeMo_2_GPT_DDP_Param_Parity_check
- L2_NeMo_2_HF_MODEL_IMPORT
- L2_NeMo_2_llama3_pretraining_recipe
- L2_NeMo_2_llama3_fault_tolerance
- L2_NeMo_2_llama3_fault_tolerance_plugin
- L2_NeMo_2_llama3_straggler_detection
- L2_NeMo_2_SSM_Pretraining
- L2_NeMo_2_SSM_Finetuning
- L2_NeMo_2_T5_Pretraining
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,37 +33,22 @@


class CrashCallback(Callback):
def __init__(self, crash_step=3, crash_time=None):
def __init__(self, crash_step=3):
self.crash_step = crash_step
self.crash_time = crash_time
print(f"Setup to simulate a crash if step time > {self.crash_step} before {self.crash_time}")
print(f"Setup to simulate a crash if step == {self.crash_step}")

def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
if self.crash_step and trainer.global_step == self.crash_step:
raise Exception(f"Simulating a crash at step {self.crash_step}!")

# if (
# datetime.now() <= datetime.strptime(self.crash_time, "%Y-%m-%d %H:%M:%S")
# and trainer.global_step >= self.crash_step
# ):
# raise Exception("Simulating a crash!")


def get_args():
parser = argparse.ArgumentParser(prog="", description="")
parser.add_argument('--devices', type=int, required=True, help="Number of devices to use for training")
parser.add_argument('--max-steps', type=int, required=True, help="Number of steps to train for")
parser.add_argument(
'--crash-time',
type=str,
# required=True,
help="Datetime string which indicates when to simulate a crash. Set this to a few minutes after the training starts to ensure a successful crash.",
)
parser.add_argument(
'--crash-step',
type=int,
required=True,
help="Crash step",
help="Step when a crash should be simulated",
)
parser.add_argument(
'--experiment-dir', type=str, required=True, help="directory to write results and checkpoints to"
Expand Down Expand Up @@ -109,13 +94,16 @@ def main():
pretrain_recipe.trainer.limit_val_batches = 2

executor: run.SlurmExecutor = run.LocalExecutor(ntasks_per_node=args.devices, launcher="ft")
# Add the fault tolerance plugin which enables restart after a crash
run_plugins: list[run.Plugin] = [FaultTolerancePlugin(num_in_process_restarts=1, num_job_retries_on_failure=0)]
pretrain_recipe.trainer.callbacks = [
run.Config(TimingCallback),
straggler_det_callback(),
run.Config(CrashCallback, crash_step=6, crash_time=args.crash_time),
straggler_det_callback(straggler_report_time_interval=0.5)
]

if args.crash_step:
pretrain_recipe.trainer.callbacks.append(run.Config(CrashCallback, crash_step=args.crash_step))

run.run(pretrain_recipe, plugins=run_plugins, executor=executor)


Expand Down

0 comments on commit 3e34a5e

Please sign in to comment.