From 14dae6ff9803cec1d5998d10fa5f9d4f10a9a117 Mon Sep 17 00:00:00 2001 From: Iztok Lebar Bajec Date: Wed, 10 Jan 2024 00:32:30 +0100 Subject: [PATCH] fix: numba.*_num_threads resets torch num_threads #8141 temporary fix until https://github.com/numba/numba/issues/9387 gets resolved. Signed-off-by: Iztok Lebar Bajec --- .../asr/parts/numba/rnnt_loss/utils/cpu_utils/cpu_rnnt.py | 3 +++ .../asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py | 2 ++ 2 files changed, 5 insertions(+) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cpu_utils/cpu_rnnt.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cpu_utils/cpu_rnnt.py index 3feb7b513a50..bcc1865ab78e 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cpu_utils/cpu_rnnt.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cpu_utils/cpu_rnnt.py @@ -203,10 +203,13 @@ def __init__( self.num_threads_ = num_threads self.batch_first = batch_first + _torch_num_threads = torch.get_num_threads() if num_threads > 0: numba.set_num_threads(min(multiprocessing.cpu_count(), num_threads)) + self.num_threads_ = numba.get_num_threads() else: self.num_threads_ = numba.get_num_threads() + torch.set_num_threads(_torch_num_threads) def cost_and_grad_kernel( self, diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py index 70ffb459cb97..87d6ee147dea 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt.py @@ -82,11 +82,13 @@ def __init__( self.num_threads_ = num_threads self.stream_ = stream # type: cuda.cudadrv.driver.Stream + _torch_num_threads = torch.get_num_threads() if num_threads > 0: numba.set_num_threads(min(multiprocessing.cpu_count(), num_threads)) self.num_threads_ = numba.get_num_threads() else: self.num_threads_ = numba.get_num_threads() + torch.set_num_threads(_torch_num_threads) def log_softmax(self, acts: torch.Tensor, denom: torch.Tensor): """