Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid poisoning process with CUDA calls as soon as importing #6810

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

HollowMan6
Copy link

Switch from torch.cuda.is_available() to torch.cuda.device_count() > 0, to give priority to nvml based availability, so that we can try not to poison process with CUDA calls as soon as we execute import deepspeed.

https://github.com/pytorch/pytorch/blob/v2.5.1/torch/cuda/__init__.py#L120-L124

There are 2 reasons to make this change:

Firstly, if we accidentally import deepspeed, since the CUDA runtime initializes when the first CUDA API call is made and caches the device list, changing the CUDA_VISIBLE_DEVICES within the same process after initialization won't have any effect on the visible devices. The specific case:
OpenRLHF/OpenRLHF#524 (comment)

A demo for reproduction before the fix is applied:

import torch
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import deepspeed
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
torch.cuda.set_device('cuda:0')

Secondly, https://pytorch.org/docs/stable/notes/cuda.html

When assessing the availability of CUDA in a given environment (is_available()), PyTorch’s default behavior is to call the CUDA Runtime API method cudaGetDeviceCount. Because this call in turn initializes the CUDA Driver API (via cuInit) if it is not already initialized, subsequent forks of a process that has run is_available() will fail with a CUDA initialization error.

Switch from `torch.cuda.is_available()` to
`torch.cuda.device_count() > 0`, to give priority to nvml based availability,
so that we can try not to poison process with CUDA calls as soon as we
execute `import deepspeed`.

https://github.com/pytorch/pytorch/blob/v2.5.1/torch/cuda/__init__.py#L120-L124

There are 2 reasons to make this change:

Firstly, if we accidentally import deepspeed, since the CUDA runtime initializes
when the first CUDA API call is made and caches the device list, changing the
CUDA_VISIBLE_DEVICES within the same process after initialization won't have any
effect on the visible devices. The specific case:
OpenRLHF/OpenRLHF#524 (comment)

A demo for reproduction before the fix is applied:

```python
import torch
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
import deepspeed
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
torch.cuda.set_device('cuda:0')
```

Secondly, https://pytorch.org/docs/stable/notes/cuda.html

When assessing the availability of CUDA in a given environment (is_available()),
PyTorch’s default behavior is to call the CUDA Runtime API method cudaGetDeviceCount.
Because this call in turn initializes the CUDA Driver API (via cuInit) if it is not
already initialized, subsequent forks of a process that has run is_available() will
fail with a CUDA initialization error.

Signed-off-by: Hollow Man <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant