Skip to content

Commit

Permalink
fix varname bug (#25)
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin-Anthony authored Feb 11, 2024
1 parent ba41da6 commit ef6afa8
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion calc/calc_transformer_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def calc_mem(args):
per_gpu_model_mem = (EP_total_params * bytes_per_param) / (args.tensor_parallel_size * args.pipeline_parallel_size)
# ZeRO stage 3 shards the model parameters across GPUs (plus the gradients and optimizer states)
if args.zero_stage == 3:
model_mem_per_gpu /= args.num_gpus
per_gpu_model_mem /= args.num_gpus

# --- GRADIENT MEMORY ---
# E.g. 4 bytes in fp32, 2 bytes in fp16/bf16, 1 byte in fp8
Expand Down

0 comments on commit ef6afa8

Please sign in to comment.