Skip to content

Commit

Permalink
Merge pull request #58 from epfLLM/instruct_loss_scalar
Browse files Browse the repository at this point in the history
Instruct loss scalar
  • Loading branch information
AleHD authored Sep 6, 2023
2 parents ce75fc9 + 20d6f09 commit 10aaed8
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 12 deletions.
31 changes: 25 additions & 6 deletions examples/finetune.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,15 @@ ITERS=1000
SEQ_LEN=none
DATA_PATH=none
TRAINED_PATH=none
VAL_PATH=none
USR_LR=none
USR_MIN_LR=none
LOSS_MASK=0.0
HELP_STR="[--rank=$RANK] [--size=$SIZE] [--tp=$TP] [--pp=$PP] [--gpus=$GPUS_PER_NODE] \
[--micro-batch=$MICRO_BATCH] [--global-batch=$GLOBAL_BATCH] [--nodes=$N_NODES] \
[--addr=$ADDR] [--wandb] [--instruct] [--checkpoint=...] [--data=...] [--iters=$ITERS] \
[--wandb-proj=none] [--wandb-id=none] [--wandb-entity=none] [--seq-len=...] [--out=...] [--help]"
[--wandb-proj=none] [--wandb-id=none] [--wandb-entity=none] [--seq-len=...] \
[--val-path=none] [--out=...] [--lr=lr minlr] [--loss-mask=$LOSS_MASK] --help]"


# define help function
Expand Down Expand Up @@ -68,6 +73,9 @@ while [[ $# -gt 0 ]]; do
--iters) ITERS=$2; shift; shift;;
--seq-len) SEQ_LEN=$2; shift; shift;;
--out) TRAINED_PATH=$2; shift; shift;;
--val-path) VAL_PATH=$2; shift; shift;;
--lr) USR_LR=$2; USR_MIN_LR=$3; shift; shift; shift;;
--loss-mask) LOSS_MASK=$2; shift; shift;;
*) echo unknown argument $1; help; exit 1;;
esac
done
Expand Down Expand Up @@ -107,7 +115,8 @@ if [[ $MODEL = falcon ]]; then
fi
elif [[ $MODEL = llama ]] || [[ $MODEL = llama2 ]] || [[ $MODEL = codellama ]]; then
EXTRA_IDS="[bib_ref],[/bib_ref],[fig_ref],[/fig_ref],[bib],[/bib],[fig],[/fig],[table],[/table],[formula],[/formula]"
EXTRA_ARGS="--use_rms_norm --glu_activation swiglu --no_tie_embed_logits"
EXTRA_ARGS="--vocab_file=/pure-mlo-scratch/llama/tokenizer.model --use_rms_norm
--glu_activation swiglu --no_tie_embed_logits"
if [[ $INSTRUCT = 1 ]]; then
if [[ $DATA_PATH = none ]]; then
DATA_PATH=/pure-mlo-scratch/alhernan/data/orca/orca
Expand Down Expand Up @@ -163,13 +172,16 @@ COMMON_ARGS="--use_flash_attn --no_bias_gelu_fusion
--attention_dropout 0.0 --adam_beta1 0.9 --adam_beta2 0.95 --adam_eps 1e-5
--lr_decay_style cosine --lr_warmup_fraction 0.1 --lr $LR --min_lr $MIN_LR
--weight_decay 0.1 --sequence_parallel --recompute_granularity selective
--log_timers_to_tensorboard --rope_scaling_factor 1.0"
--log_timers_to_tensorboard --scalar_loss_mask=$LOSS_MASK
--rope_scaling_factor 1.0"

if [[ $INSTRUCT = 1 ]]; then
COMMON_ARGS="$COMMON_ARGS --variable_seq_lengths --data_type instruction"
COMMON_ARGS="$COMMON_ARGS --variable_seq_lengths --data_type instruction --metrics all"
if [[ $CHECKPOINT_PATH != $TRAINED_PATH ]]; then
COMMON_ARGS="$COMMON_ARGS --finetune"
fi
else
COMMON_ARGS="$COMMON_ARGS --metrics perplexity accuracy count_loss_mask"
fi

if [[ $CHECKPOINT_PATH != $TRAINED_PATH ]]; then
Expand All @@ -189,13 +201,19 @@ if [[ $WANDB = 1 ]]; then
fi
fi

if [[ $VAL_PATH = none ]]; then
DATA_ARGS="--data_path $DATA_PATH"
else
DATA_ARGS="--train_data_path $DATA_PATH --valid_data_path $VAL_PATH"
fi

# print some args
echo
echo Settings:
echo RANK=$RANK
echo ADDR=$ADDR
echo N_NODES=$N_NODES
echo DATA_PATH=$DATA_PATH
echo DATA_ARGS=$DATA_ARGS
echo CHECKPOINT_PATH=$CHECKPOINT_PATH
echo TRAINED_PATH=$TRAINED_PATH
echo MODEL=$MODEL
Expand All @@ -216,11 +234,12 @@ CUDA_DEVICE_MAX_CONNECTIONS=1 OMP_NUM_THREADS=16 torchrun $DISTRIBUTED_ARGS fine
--load $CHECKPOINT_PATH \
--save $TRAINED_PATH \
--tensorboard_dir $TENSORBOARD_PATH \
--data_path $DATA_PATH \
$DATA_ARGS \
--model_name $MODEL \
--tokenizer_type $TOKENIZER \
--bf16 \
--global_batch_size $GLOBAL_BATCH \
--micro_batch_size $MICRO_BATCH \
--num_workers=2 \
$EXTRA_ARGS \
$COMMON_ARGS
9 changes: 7 additions & 2 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def get_batch(data_iterator):
if args.data_type == "gpt":
keys = ["text"]
elif args.data_type == "instruction":
keys = ["text", "attention_mask", "loss_mask"]
keys = ["text", "attention_mask", "assistant_mask", "pad_mask"]
else:
raise KeyError(f"Unknown dataset type {args.data_type}")

Expand Down Expand Up @@ -134,7 +134,12 @@ def get_batch(data_iterator):
# Instruction dataset.
# Heavily inspired by Andreas Köpf: https://github.com/andreaskoepf/epfl-megatron/tree/local_changes/
attention_mask = data_b["attention_mask"][:, :-1]
loss_mask = data_b["loss_mask"][:, 1:].float().to(tokens.device)
assistant_mask = data_b["assistant_mask"][:, 1:].to(tokens.device)
pad_mask = data_b["pad_mask"][:, 1:].to(tokens.device)
loss_mask = torch.full(labels.size(), args.scalar_loss_mask, dtype=torch.float,
device=tokens.device)
loss_mask[assistant_mask == 1] = 1.0
loss_mask[pad_mask == 1] = 0.0
attention_mask, position_ids = get_attention_mask_and_position_ids(
tokens, attention_mask
)
Expand Down
5 changes: 5 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,6 +930,11 @@ def _add_data_args(parser):
help='Maximum sequence length to process.')
group.add_argument('--variable_seq_lengths', action='store_true', default=None,
help='Enable variable sequence lengths.')
group.add_argument('--scalar_loss_mask', type=float, default=0.0,
help=('Instruction-tuning argument: Scalar to multiply the '
'loss of the "masked out" tokens (usually the user '
'tokens, not assistant ones). Set to zero (default) '
'to completely remove the loss of said tokens'))
group.add_argument('--encoder_seq_length', type=int, default=None,
help='Maximum encoder sequence length to process.'
'This should be exclusive of --seq_length')
Expand Down
8 changes: 4 additions & 4 deletions megatron/data/instruction_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def instruction_collator(data):
input[i] = torch.from_numpy(t[:seq_len])
role[i] = torch.from_numpy(r[:seq_len])

# assistant tokens are masked out in the loss for finetuning (can be changed if better with or without...)
loss_mask = (role == Role.assistant.value).long() # assistant tokens have role == 2

return {"text": input, "attention_mask": attention_mask, "loss_mask": loss_mask}
assistant_mask = (role == Role.assistant.value).long()
pad_mask = (input == pad_id).long()
return {"text": input, "attention_mask": attention_mask,
"assistant_mask": assistant_mask, "pad_mask": pad_mask}

0 comments on commit 10aaed8

Please sign in to comment.