-
Notifications
You must be signed in to change notification settings - Fork 27.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Community contribution: enabling device_map="auto"
support for more vision and multimodal models
#29786
Comments
I'm working on Resnet edit: I'm running into a strange issue, where the tests would pass on one system and fail on another. I'm going to close to PR for now and investigate further. |
BERT is not included in the above list of models. Does it mean that "device_map='auto'" is available for BERT models in any upcoming version of HF transformers? I still see the message BertForSequenceClassification does not support |
Hi @amyeroberts, hope you are well :) I'm not sure why, but it looks like the unit tests are passing even without defining
|
Models updated so far:
|
Models remaining:
|
Hi! Would love to take the following models and give it a try: |
Hi! I encountered an issue while running tests for some models, specifically Additionally, I want to know how to define certain models to be skipped in the test. For example, I have |
@WenheLI Ah, I should take the vision text dual encoder off the list, we can theoretically load any encoder and decoder there, so it's not possible to know the modules that can be split or not, same for vision encoder-decoder |
Hey @amyeroberts, I was experimenting with defining When I set
To investigate further, I ran
compute_module_sizes(model)
total_size = compute_module_sizes(model)[""]
max_memory = {0: int(0.7 * total_size), "cpu": total_size * 2}
print(f"Total model size: {total_size}, max memory: {max_memory}")
print(
infer_auto_device_map(
model,
max_memory=max_memory,
no_split_module_classes=[]
)
) Output:
max_memory = {0: int(0.9 * total_size), "cpu": total_size * 2}
print(f"Total model size: {total_size}, max memory: {max_memory}")
print(
infer_auto_device_map(
model,
max_memory=max_memory,
no_split_module_classes=[]
)
) Output:
The model can definitely be split into smaller modules as the 90% split case suggests. The problem with the 70% split case doesn't come from the smaller max_memory assigned for the GPU because the modules allocated to the GPU in the 90% case only account for 21,408 bytes of the total 1,195,632 bytes model size. This number (about 1.8% of the total model size) is significantly smaller than both the 70% (836,942 bytes) and 90% (1,076,068 bytes) max_memory defined for the GPU. Therefore, the problem is not the max_memory defined for the GPU, but rather some issues with the After looking into the
This code reserves space for the largest layer on each main device. For Segformer, where the decode_head (1,107,984 bytes) is significantly larger than other layers, this approach may be too conservative, leaving little room for other layers on the GPU.
This part of the function decides whether to split a module or move to the next device. However, once it moves to the next device (i.e., CPU), it never goes back to the GPU, even if there's available space. This could explain why smaller modules aren't being allocated to the GPU after the decode_head is moved to the CPU. @amyeroberts, should we wait until |
Hi @Nech-C, thanks for writing all of this up! I don't know Regarding the order of things, having to update the tests I think is a sign that we should wait: |
Hi @amyeroberts, sorry for getting back to you after so long. While working on the The infer_auto_device_map function allocates modules sequentially across devices, both in terms of model layers and device hierarchy. It processes modules from the first layer to the last, and allocates them starting with the fastest device (e.g., GPUs) and moving to slower devices (e.g., CPUs, then disks). Although this method reduces the overhead of moving offloaded modules between devices and simplifies memory calculations, it does not make the most efficient use of available memory. To successfully assign a module to a device, the available memory must exceed the combined size of the current module and the largest subsequent layer. If this condition is not met, the function moves to the next device without revisiting the current one, assuming no module split happens. If the allocation attempt fails on the first module for a device, the resultant device map won't include this device. This causes the tests to fail. When the Here are the code snippets from the test file: transformers/tests/test_modeling_common.py Line 207 in 1349321
transformers/tests/test_modeling_common.py Lines 3218 to 3234 in 1349321
At a minimum, a GPU must have sufficient memory to accommodate the largest layer for inference. To address this, we may need a dynamic function to calculate split sizes during test time instead of relying on fixed ratios. Thank you for reading through this (I know it’s a bit lengthy!). I’d love to hear your thoughts and recommendations on how to proceed. |
Hi @Nech-C, thanks for so detailed description of I think computing split sizes during the test time might be not explicit enough. What do you think regarding improving the test behavior with a proper message why test actually fails. Maybe we can check if any module can be fitted into a device? |
Hey @qubvel, sure thing! I actually have been working on a PR that adds warning messages for no allocation situations in the
Just so you know, only the first device's minimum requirement is guaranteed to work since a device's assignment will affect all subsequent devices. For the current implementation, the accurate minimums for later devices cannot be determined unless the function tries to recompute the device map with the new configuration. I also implemented a more flexible way of allocating in the same PR. When you set The PR should get merged soon. I plan to continue working on the function to address an improvement that SunMarc brought up. If there is anything you would like me to address/change in the function, please let me know. I will try to accommodate those needs. |
Feature request
Feature Request
transformers
models can be easily loaded across multiple devices usingdevice_map="auto"
. This will automatically allocate weights across available devices e.g. GPUs and offload any weights onto CPU, then disk as necessary. This is useful when doing inference with large models.To enable this,
_no_split_modules
has to be defined in the model's pretrained model class e.g. like here for LLaMa. This defines layers which should not be split across devices, and should contain as few layers as possible.Steps to add
_no_split_modules
in the PreTrainedModel subclass. Try with_no_split_modules = []
firsttest_disk_offload_bin
,test_disk_offload_safetensors
,test_cpu_offload
,test_model_parallelism
,test_model_parallel_beam_search
pytest tests/models/{MODEL_NAME}/test_modeling_{MODEL_NAME}.py -vv -k "offload or parallelism"
Models
Motivation
Enable a powerful HF feature for all of our vision models
Your contribution
Ping me for review 🤗
The text was updated successfully, but these errors were encountered: