Skip to content

Commit

Permalink
fix cudagraph mem
Browse files Browse the repository at this point in the history
  • Loading branch information
FindHao committed Dec 11, 2024
1 parent d94fae4 commit 9102236
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,9 +1337,10 @@ def get_peak_mem(
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
if use_cuda_graphs:
self.do_bench_cudagraph_mem(
fn, n_repeat=2, grad_to_none=grad_to_none, device_type=device_type
)
with torch.cuda.stream(torch.cuda.Stream()):
self.do_bench_cudagraph_mem(
fn, n_repeat=2, grad_to_none=grad_to_none, device_type=device_type
)
else:
self.do_bench_mem(
fn, n_repeat=2, grad_to_none=grad_to_none, device_type=device_type
Expand Down

0 comments on commit 9102236

Please sign in to comment.