Skip to content

Commit

Permalink
Add simple inference FLOP counter to calc_transformer_flops.py (#31)
Browse files Browse the repository at this point in the history
* add --infer arg to flops calculator

* add comment

* fix comment

* Update calc_transformer_flops.py
  • Loading branch information
haileyschoelkopf authored Feb 20, 2024
1 parent 939fa3c commit 56aeee1
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion calc/calc_transformer_flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,13 @@ def config_parser():
action='store_false',
help='Whether Megatron-style activation checkpointing is being used',
dest='checkpoint_activations')
parser.add_argument("--infer", "-i",
action='store_true',
help='Pass to calculate FLOPs for inference-only workload (no backward pass)')
return parser

# calculates the flops of a model given its hparams
def calc_params(args):

assert args.topk <= args.num_experts, "You cannot route to more experts than you have!"
assert args.num_layers % args.expert_interval == 0, "Require for simplicity that we don't have hanging dense layers"

Expand All @@ -82,6 +84,10 @@ def calc_params(args):
iter_factor = 3
if args.checkpoint_activations:
iter_factor += 1
# If inference-only, no bwd pass or activation ckpting necessary
# This assumes simply running a single forward pass ('prefill' stage of decoding) and no subsequent autoregressively generated tokens.
if args.infer:
iter_factor = 1

qkv_flops = int(iter_factor * 2 * (1 + 2 * args.kv_size_ratio) * args.num_layers * args.tokens * args.hidden_size * args.hidden_size)
attention_matrix_flops = iter_factor * 2 * args.num_layers * args.tokens * args.sequence_length * args.hidden_size
Expand Down

0 comments on commit 56aeee1

Please sign in to comment.