Skip to content

Commit

Permalink
add --dist-opt arg
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Koumparoulis <[email protected]>
  • Loading branch information
akoumpa committed Nov 1, 2024
1 parent 777be9c commit 582012f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
9 changes: 6 additions & 3 deletions .github/workflows/cicd-main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4279,7 +4279,8 @@ jobs:
--max-steps 3 \
--tp 1 \
--mbs 1 \
--model mixtral
--model mixtral \
--dist-opt
L2_NeMo_2_Mistral_LoRA_TP1PP1_MBS1:
needs: [cicd-test-container-setup]
Expand All @@ -4293,7 +4294,8 @@ jobs:
--max-steps 3 \
--tp 1 \
--mbs 1 \
--model mistral
--model mistral \
--dist-opt
L2_NeMo_2_Mistral_LoRA_TP2PP1_MBS1:
needs: [cicd-test-container-setup]
Expand All @@ -4307,7 +4309,8 @@ jobs:
--max-steps 3 \
--tp 2 \
--mbs 1 \
--model mistral
--model mistral \
--dist-opt
L2_NeMo_2_NeMo_Mcore_Mixtral_bitexact:
Expand Down
3 changes: 2 additions & 1 deletion tests/collections/llm/lora_mistralai.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def get_args():
parser.add_argument('--mbs', type=int, default=2, help="micro batch size")
parser.add_argument('--gbs', type=int, default=4, help="global batch size")
parser.add_argument('--tp', type=int, default=1, help="tensor parallel size")
parser.add_argument('--dist-opt', action='store_true', help='use dist opt')
return parser.parse_args()


Expand Down Expand Up @@ -128,7 +129,7 @@ def mistral_7b() -> pl.LightningModule:
optimizer="adam",
lr=0.0001,
adam_beta2=0.98,
use_distributed_optimizer=True,
use_distributed_optimizer=args.dist_opt,
clip_grad=1.0,
bf16=True,
),
Expand Down

0 comments on commit 582012f

Please sign in to comment.