Skip to content

Commit

Permalink
quick update for fp8 rowwise
Browse files Browse the repository at this point in the history
Summary: currently rowwise only uses production shapes and skips default shapes. we want to make this inclusive and not an either or.

Reviewed By: danzimm

Differential Revision: D67478772

fbshipit-source-id: edc1e903ab91fcdefcf880b8a892b41a8d3b39d5
  • Loading branch information
adamomainz authored and facebook-github-bot committed Dec 19, 2024
1 parent bec8389 commit 06c28ed
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions tritonbench/operators/fp8_gemm_rowwise/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,14 +118,16 @@ def __init__(
super().__init__(tb_args, extra_args)
self.use_cuda_graphs = True
addmm_args = parse_args(self.extra_args)
if hasattr(tb_args, "production_shapes") and tb_args.production_shapes:
self.shapes = get_production_shapes(self.name, "fp32_gemm")
elif addmm_args.m and addmm_args.n and addmm_args.k:
if addmm_args.m and addmm_args.n and addmm_args.k:
self.shapes = [(addmm_args.m, addmm_args.n, addmm_args.k)]
elif addmm_args.llama:
self.shapes = gemm_shapes()
else:
self.shapes = BUILDIN_SHAPES
if hasattr(tb_args, "production_shapes") and tb_args.production_shapes:
extras = get_production_shapes(self.name, "fp32_gemm")
if len(extras):
self.shapes.extend(extras)
self.fp8_fast_accum = addmm_args.fp8_fast_accum
self.use_tma = addmm_args.use_tma
self.no_use_persistent = addmm_args.no_use_persistent
Expand Down

0 comments on commit 06c28ed

Please sign in to comment.