Skip to content

Commit

Permalink
[cm] Setting num_threads
Browse files Browse the repository at this point in the history
  • Loading branch information
christhetree committed Aug 28, 2024
1 parent 9ab9cc4 commit 9c5a6a1
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions tests/benchmark_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import os

import numba
import torch as tr
import torch.utils.cpp_extension
from torch.utils import benchmark
Expand Down Expand Up @@ -36,6 +37,12 @@

def main() -> None:
tr.manual_seed(42)
# Set intraop threads
tr.set_num_threads(num_threads)
# Set interop threads
tr.set_num_interop_threads(num_threads)
numba.set_num_threads(num_threads)
log.info(f"numba.get_num_threads(): {numba.get_num_threads()}")

results = []

Expand All @@ -59,10 +66,10 @@ def main() -> None:
benchmark.Timer(
stmt="y = forward_func(x, a, zi)",
globals=globals,
sub_label=f"bs_{bs}__n_{n}__order_{order}",
sub_label=f"bs_{bs}__n_{n}__order_{order}__threads_{num_threads}",
description=forward_func.__name__,
num_threads=num_threads,
).blocked_autorange(min_run_time=1)
).blocked_autorange(min_run_time=0.5)
)

compare = benchmark.Compare(results)
Expand Down

0 comments on commit 9c5a6a1

Please sign in to comment.