Skip to content

Commit

Permalink
[CUDA][cuBLAS] Remove explicit cuBLAS workspace allocation for CUDA 1…
Browse files Browse the repository at this point in the history
…2.2+ (pytorch#113994)

cuBLAS should be using `cudaMallocAsync` in CUDA 12.2+, which removes the need for explicit workspace allocation to avoid increasing memory usage with multiple graph captures.

CC @ptrblck @malfet

Pull Request resolved: pytorch#113994
Approved by: https://github.com/ezyang, https://github.com/malfet
  • Loading branch information
eqy authored and pytorchmergebot committed Nov 22, 2023
1 parent 5f504d1 commit 6a86cf0
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
10 changes: 7 additions & 3 deletions aten/src/ATen/cuda/CublasHandlePool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle
} // namespace

void clearCublasWorkspaces() {
cublas_handle_stream_to_workspace().clear();
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION < 12200
cublas_handle_stream_to_workspace().clear();
#endif
}

size_t parseChosenWorkspaceSize() {
Expand Down Expand Up @@ -105,8 +107,10 @@ cublasHandle_t getCurrentCUDABlasHandle() {
auto handle = myPoolWindow->reserve(device);
auto stream = c10::cuda::getCurrentCUDAStream();
TORCH_CUDABLAS_CHECK(cublasSetStream(handle, stream));
#if !defined(USE_ROCM)
// cublasSetWorkspace not available on CUDA 10.2
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION < 12200
// cuBLAS should not need an explicitly allocated workspace after CUDA 12.2
// to avoid increasing memory usage during graph captures
// original issue: https://github.com/pytorch/pytorch/pull/83461
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
auto workspace_it = cublas_handle_stream_to_workspace().find(key);
Expand Down
4 changes: 3 additions & 1 deletion test/test_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
NO_MULTIPROCESSING_SPAWN, skipIfRocm, load_tests, IS_WINDOWS, \
slowTest, skipCUDANonDefaultStreamIf, skipCUDAMemoryLeakCheckIf, TEST_CUDA, TEST_CUDA_GRAPH, TEST_WITH_ROCM, TEST_NUMPY, \
get_cycles_per_ms, parametrize, instantiate_parametrized_tests, subtest, IS_JETSON, gcIfJetson, NoTest, IS_LINUX
from torch.testing._internal.common_cuda import TEST_CUDNN, TEST_MULTIGPU, _create_scaling_case, _create_scaling_models_optimizers
from torch.testing._internal.common_cuda import TEST_CUDNN, TEST_MULTIGPU, \
_create_scaling_case, _create_scaling_models_optimizers, _get_torch_cuda_version
from torch.testing._internal.autocast_test_lists import AutocastTestLists
from torch.utils.viz._cycles import observe_tensor_cycles

Expand Down Expand Up @@ -296,6 +297,7 @@ def test_serialization_array_with_storage(self):
self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10))

@unittest.skipIf(TEST_CUDAMALLOCASYNC or TEST_WITH_ROCM, "temporarily disabled for async")
@unittest.skipIf(_get_torch_cuda_version() >= (12, 2), "skipped as explicit workspace allocation is removed")
def test_cublas_workspace_explicit_allocation(self):
a = torch.randn(7, 7, device='cuda', requires_grad=False)
default_workspace_size = 4096 * 2 * 1024 + 16 * 8 * 1024 # :4096:2:16:8
Expand Down

0 comments on commit 6a86cf0

Please sign in to comment.