Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable multi-device for some models #30207

Merged
merged 35 commits into from
Apr 19, 2024

Conversation

jla524
Copy link
Contributor

@jla524 jla524 commented Apr 12, 2024

What does this PR do?

Fixes #29786 (issue)

Includes a fix for unit tests on Backbone models, where base_output[0] and new_output[0] are tuples

Tested on a system with 2x RTX A4000

$ pytest tests/models/convnext/test_modeling_convnext.py -k "offload or parallel"
=============================================================================================== test session starts ===============================================================================================
platform linux -- Python 3.10.13, pytest-7.4.4, pluggy-1.4.0
rootdir: /root/transformers
configfile: pyproject.toml
plugins: hypothesis-6.98.10, xdist-3.5.0, timeout-2.3.1, anyio-4.3.0
collected 117 items / 110 deselected / 7 selected                                                                                                                                                                 

tests/models/convnext/test_modeling_convnext.py .......                                                                                                                                                     [100%]

<warnings redacted>
================================================================================== 7 passed, 110 deselected, 2 warnings in 8.95s ==================================================================================
$ pytest tests/models/convnextv2/test_modeling_convnextv2.py -k "offload or parallel"
=============================================================================================== test session starts ===============================================================================================
platform linux -- Python 3.10.13, pytest-7.4.4, pluggy-1.4.0
rootdir: /root/transformers
configfile: pyproject.toml
plugins: hypothesis-6.98.10, xdist-3.5.0, timeout-2.3.1, anyio-4.3.0
collected 108 items / 101 deselected / 7 selected                                                                                                                                                                 

tests/models/convnextv2/test_modeling_convnextv2.py .......                                                                                                                                                 [100%]

<warnings redacted>
================================================================================== 7 passed, 101 deselected, 2 warnings in 8.30s ==================================================================================

$ pytest tests/models/cvt/test_modeling_cvt.py -k "offload or parallel"
=============================================================================================== test session starts ===============================================================================================
platform linux -- Python 3.10.13, pytest-7.4.4, pluggy-1.4.0
rootdir: /root/transformers
configfile: pyproject.toml
plugins: hypothesis-6.98.10, xdist-3.5.0, timeout-2.3.1, anyio-4.3.0
collected 108 items / 101 deselected / 7 selected                                                                                                                                                                 

tests/models/cvt/test_modeling_cvt.py .......                                                                                                                                                               [100%]

<warnings redacted>
================================================================================= 7 passed, 101 deselected, 2 warnings in 16.37s ==================================================================================
$ pytest tests/models/focalnet/test_modeling_focalnet.py -k "offload or parallel"
============================================================================================== test session starts ==============================================================================================
platform linux -- Python 3.10.13, pytest-7.4.4, pluggy-1.4.0
rootdir: /root/transformers
configfile: pyproject.toml
plugins: hypothesis-6.98.10, xdist-3.5.0, timeout-2.3.1, anyio-4.3.0
collected 119 items / 112 deselected / 7 selected                                                                                                                                                               

tests/models/focalnet/test_modeling_focalnet.py .......                                                                                                                                                   [100%]

<warnings redacted>
================================================================================ 7 passed, 112 deselected, 2 warnings in 15.75s =================================================================================
$ pytest tests/models/glpn/test_modeling_glpn.py -k "offload or parallel"
=============================================================================================== test session starts ===============================================================================================
platform linux -- Python 3.10.13, pytest-7.4.4, pluggy-1.4.0
rootdir: /root/transformers
configfile: pyproject.toml
plugins: hypothesis-6.98.10, xdist-3.5.0, timeout-2.3.1, anyio-4.3.0
collected 108 items / 101 deselected / 7 selected                                                                                                                                                                 

tests/models/glpn/test_modeling_glpn.py .......                                                                                                                                                             [100%]

<warnings redacted>
================================================================================== 7 passed, 101 deselected, 2 warnings in 9.81s ==================================================================================
$ pytest tests/models/imagegpt/test_modeling_imagegpt.py -k "offload or parallel"                                                                                                    
=============================================================================================== test session starts ===============================================================================================
platform linux -- Python 3.10.13, pytest-7.4.4, pluggy-1.4.0                                                                                                                                                       
rootdir: /root/transformers                                                                                                                                                                                        
configfile: pyproject.toml                                                                                                                                                                                         
plugins: hypothesis-6.98.10, xdist-3.5.0, timeout-2.3.1, anyio-4.3.0                                                                                                                                               
collected 140 items / 132 deselected / 8 selected                                                                                                                                                                  
                                                                                                                                                                                                                   
tests/models/imagegpt/test_modeling_imagegpt.py ........                                                                                                                                                    [100%] 
                                                                                                                                                                                                                   
<warnings redacted>                                                                                                          
================================================================================= 8 passed, 132 deselected, 29 warnings in 10.51s =================================================================================
$ pytest tests/models/levit/test_modeling_levit.py -k "offload or parallel"
=============================================================================================== test session starts ===============================================================================================
platform linux -- Python 3.10.13, pytest-7.4.4, pluggy-1.4.0
rootdir: /root/transformers
configfile: pyproject.toml
plugins: hypothesis-6.98.10, xdist-3.5.0, timeout-2.3.1, anyio-4.3.0
collected 108 items / 101 deselected / 7 selected                                                                                                                                                                 

tests/models/levit/test_modeling_levit.py .......                                                                                                                                                           [100%]

<warnings redacted>
================================================================================= 7 passed, 101 deselected, 2 warnings in 13.72s ==================================================================================
$ pytest tests/models/mgp_str/test_modeling_mgp_str.py -k "offload or parallel"
=============================================================================================== test session starts ===============================================================================================
platform linux -- Python 3.10.13, pytest-7.4.4, pluggy-1.4.0
rootdir: /root/transformers
configfile: pyproject.toml
plugins: hypothesis-6.98.10, xdist-3.5.0, timeout-2.3.1, anyio-4.3.0
collected 105 items / 98 deselected / 7 selected                                                                                                                                                                  

tests/models/mgp_str/test_modeling_mgp_str.py .......                                                                                                                                                       [100%]

<warnings redacted>
================================================================================== 7 passed, 98 deselected, 2 warnings in 4.03s ===================================================================================
$ pytest tests/models/mobilenet_v1/test_modeling_mobilenet_v1.py -k "offload or parallel"
=============================================================================================== test session starts ===============================================================================================
platform linux -- Python 3.10.13, pytest-7.4.4, pluggy-1.4.0
rootdir: /root/transformers
configfile: pyproject.toml
plugins: hypothesis-6.98.10, xdist-3.5.0, timeout-2.3.1, anyio-4.3.0
collected 108 items / 101 deselected / 7 selected                                                                                                                                                                 

tests/models/mobilenet_v1/test_modeling_mobilenet_v1.py .......                                                                                                                                             [100%]

<warnings redacted>
================================================================================== 7 passed, 101 deselected, 3 warnings in 5.93s ==================================================================================
$ pytest tests/models/mobilenet_v2/test_modeling_mobilenet_v2.py -k "offload or parallel"
=============================================================================================== test session starts ===============================================================================================
platform linux -- Python 3.10.13, pytest-7.4.4, pluggy-1.4.0
rootdir: /root/transformers
configfile: pyproject.toml
plugins: hypothesis-6.98.10, xdist-3.5.0, timeout-2.3.1, anyio-4.3.0
collected 110 items / 103 deselected / 7 selected                                                                                                                                                                 

tests/models/mobilenet_v2/test_modeling_mobilenet_v2.py .......                                                                                                                                             [100%]

<warnings redacted>
================================================================================= 7 passed, 103 deselected, 13 warnings in 10.57s =================================================================================
$ pytest tests/models/mobilevit/test_modeling_mobilevit.py -k "offload or parallel"
=============================================================================================== test session starts ===============================================================================================
platform linux -- Python 3.10.13, pytest-7.4.4, pluggy-1.4.0
rootdir: /root/transformers
configfile: pyproject.toml
plugins: hypothesis-6.98.10, xdist-3.5.0, timeout-2.3.1, anyio-4.3.0
collected 111 items / 104 deselected / 7 selected                                                                                                                                                                 

tests/models/mobilevit/test_modeling_mobilevit.py .......                                                                                                                                                   [100%]

<warnings redacted>
================================================================================= 7 passed, 104 deselected, 3 warnings in 15.42s ==================================================================================
$ pytest tests/models/poolformer/test_modeling_poolformer.py -k "offload or parallel"
=============================================================================================== test session starts ===============================================================================================
platform linux -- Python 3.10.13, pytest-7.4.4, pluggy-1.4.0
rootdir: /root/transformers
configfile: pyproject.toml
plugins: hypothesis-6.98.10, xdist-3.5.0, timeout-2.3.1, anyio-4.3.0
collected 107 items / 100 deselected / 7 selected                                                                                                                                                                 

tests/models/poolformer/test_modeling_poolformer.py .......                                                                                                                                                 [100%]

<warnings redacted>
================================================================================== 7 passed, 100 deselected, 2 warnings in 6.10s ==================================================================================

$ pytest tests/models/regnet/test_modeling_regnet.py -k "offload or parallel"
=============================================================================================== test session starts ===============================================================================================
platform linux -- Python 3.10.13, pytest-7.4.4, pluggy-1.4.0
rootdir: /root/transformers
configfile: pyproject.toml
plugins: hypothesis-6.98.10, xdist-3.5.0, timeout-2.3.1, anyio-4.3.0
collected 108 items / 101 deselected / 7 selected                                                                                                                                                                 

tests/models/regnet/test_modeling_regnet.py .......                                                                                                                                                         [100%]

<warnings redacted>
================================================================================== 7 passed, 101 deselected, 2 warnings in 7.05s ==================================================================================
$ pytest tests/models/resnet/test_modeling_resnet.py -k "offload or parallel"
=============================================================================================== test session starts ===============================================================================================
platform linux -- Python 3.10.13, pytest-7.4.4, pluggy-1.4.0
rootdir: /root/transformers
configfile: pyproject.toml
plugins: hypothesis-6.98.10, xdist-3.5.0, timeout-2.3.1, anyio-4.3.0
collected 117 items / 110 deselected / 7 selected                                                                                                                                                                 

tests/models/resnet/test_modeling_resnet.py .......                                                                                                                                                         [100%]

<warnings redacted>
================================================================================== 7 passed, 110 deselected, 5 warnings in 9.31s ==================================================================================
$ pytest tests/models/sam/test_modeling_sam.py -k "offload or parallel"
=============================================================================================== test session starts ===============================================================================================
platform linux -- Python 3.10.13, pytest-7.4.4, pluggy-1.4.0
rootdir: /root/transformers
configfile: pyproject.toml
plugins: hypothesis-6.98.10, xdist-3.5.0, timeout-2.3.1, anyio-4.3.0
collected 122 items / 115 deselected / 7 selected                                                                                                                                                                 

tests/models/sam/test_modeling_sam.py .......                                                                                                                                                               [100%]

<warnings redacted>
================================================================================== 7 passed, 115 deselected, 2 warnings in 4.14s ==================================================================================
$ pytest tests/models/swiftformer/test_modeling_swiftformer.py -k "offload or parallel"
=============================================================================================== test session starts ===============================================================================================
platform linux -- Python 3.10.13, pytest-7.4.4, pluggy-1.4.0
rootdir: /root/transformers
configfile: pyproject.toml
plugins: hypothesis-6.98.10, xdist-3.5.0, timeout-2.3.1, anyio-4.3.0
collected 110 items / 103 deselected / 7 selected                                                                                                                                                                 

tests/models/swiftformer/test_modeling_swiftformer.py .......                                                                                                                                               [100%]

<warnings redacted>
================================================================================== 7 passed, 103 deselected, 2 warnings in 8.28s ==================================================================================
$ pytest tests/models/swin/test_modeling_swin.py -k "offload or parallel"
=============================================================================================== test session starts ===============================================================================================
platform linux -- Python 3.10.13, pytest-7.4.4, pluggy-1.4.0
rootdir: /root/transformers
configfile: pyproject.toml
plugins: hypothesis-6.98.10, xdist-3.5.0, timeout-2.3.1, anyio-4.3.0
collected 121 items / 114 deselected / 7 selected                                                                                                                                                                 

tests/models/swin/test_modeling_swin.py ......s                                                                                                                                                             [100%]

<warnings redacted>
============================================================================ 6 passed, 1 skipped, 114 deselected, 3 warnings in 8.37s =============================================================================

^ test skipped due to CUDA error: misaligned address with PyTorch 2.0.0. which occurs when running on single GPU too

$ pytest tests/models/swinv2/test_modeling_swinv2.py -k "offload or parallel"
=============================================================================================== test session starts ===============================================================================================
platform linux -- Python 3.10.13, pytest-7.4.4, pluggy-1.4.0
rootdir: /root/transformers
configfile: pyproject.toml
plugins: hypothesis-6.98.10, xdist-3.5.0, timeout-2.3.1, anyio-4.3.0
collected 121 items / 114 deselected / 7 selected                                                                                                                                                                 

tests/models/swinv2/test_modeling_swinv2.py ......s                                                                                                                                                         [100%]

<warnings redacted>
============================================================================ 6 passed, 1 skipped, 114 deselected, 2 warnings in 7.05s =============================================================================
$ pytest tests/models/trocr/test_modeling_trocr.py -k "offload or parallel"
=============================================================================================== test session starts ===============================================================================================
platform linux -- Python 3.10.13, pytest-7.4.4, pluggy-1.4.0
rootdir: /root/transformers
configfile: pyproject.toml
plugins: hypothesis-6.98.10, xdist-3.5.0, timeout-2.3.1, anyio-4.3.0
collected 138 items / 130 deselected / 8 selected                                                                                                                                                                 

tests/models/trocr/test_modeling_trocr.py ........                                                                                                                                                          [100%]

<warnings redacted>
================================================================================== 8 passed, 130 deselected, 2 warnings in 5.55s ==================================================================================

$ pytest tests/models/upernet/test_modeling_upernet.py -k "offload or parallel"
=============================================================================================== test session starts ===============================================================================================
platform linux -- Python 3.10.13, pytest-7.4.4, pluggy-1.4.0
rootdir: /root/transformers
configfile: pyproject.toml
plugins: hypothesis-6.98.10, xdist-3.5.0, timeout-2.3.1, anyio-4.3.0
collected 111 items / 104 deselected / 7 selected                                                                                                                                                                 

tests/models/upernet/test_modeling_upernet.py ......s                                                                                                                                                       [100%]

<warnings redacted>
============================================================================ 6 passed, 1 skipped, 104 deselected, 3 warnings in 4.41s =============================================================================
$ pytest tests/models/yolos/test_modeling_yolos.py -k "offload or parallel"
=============================================================================================== test session starts ===============================================================================================
platform linux -- Python 3.10.13, pytest-7.4.4, pluggy-1.4.0
rootdir: /root/transformers
configfile: pyproject.toml
plugins: hypothesis-6.98.10, xdist-3.5.0, timeout-2.3.1, anyio-4.3.0
collected 108 items / 101 deselected / 7 selected                                                                                                                                                                 

tests/models/yolos/test_modeling_yolos.py .......                                                                                                                                                           [100%]

<warnings redacted>
================================================================================== 7 passed, 101 deselected, 2 warnings in 8.97s ==================================================================================

Who can review?

@amyeroberts

@jla524 jla524 changed the title Enable multi-device for Resnet and ConvNext Enable multi-device for ConvNext, ResNet, and RegNet Apr 13, 2024
@jla524 jla524 changed the title Enable multi-device for ConvNext, ResNet, and RegNet Enable multi-device for some models Apr 14, 2024
@jla524
Copy link
Contributor Author

jla524 commented Apr 16, 2024

Hey @amyeroberts! We have added support for 10 models now. Can I get a review for this PR?

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Amazing piece of work - thanks for adding this feature for all these models! ❤️

For the quality checks, running make fix-copies and pushing the changes should resolve this.

Could you update the issue to mark all of these models as done once merged in?

@jla524
Copy link
Contributor Author

jla524 commented Apr 17, 2024

Thanks Amy! I ran make fix-copies and pushed the changes. Interestingly, it automatically updated some of the _no_split_modules code. I have verified that the tests are still passing in a multi-GPU environment.

@amyeroberts
Copy link
Collaborator

@jackylee328 Ah yes - that's what make fix-copies should do! Because of the one file per model policy, it can end up with a lot of repeated code. The way we get around this is using # Copied from comments, which enable us to reuse code without forcing a modular design. make fix-copies ensures all the places where # Copied from occurs the code it correctly updated if the source has changed.

Thanks for running tests for the new models too ❤️

Re the failing checks at the moment:

  • The copy checks are failing because there's some unrelated diff changes in the PR. Doing git checkout upstream/main -- src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py src/transformers/models/cohere/modeling_cohere.py tests/models/roc_bert/test_tokenization_roc_bert.py should remove the diff and fix the break copies test.
  • For the quality check, running make fixup should resolve

@jla524
Copy link
Contributor Author

jla524 commented Apr 19, 2024

Finally passing now! I had to merge all the changes from main.

@jla524
Copy link
Contributor Author

jla524 commented Apr 19, 2024

Hmm I'm not sure why the torch tests are failing. I'm not able to reproduce the errors on my machines

@amyeroberts
Copy link
Collaborator

@jackylee328 Unfortunately sometimes our test suite will fail for reasons unrelated to the PR e.g. timeouts. If this happens, feel free to ping on the PR and I can restart the runs for you, without you needing to push lots of commits!

Thank you for all the efforts adding this across our library - it's a mammoth addition! 🔥

@amyeroberts amyeroberts merged commit 30b4532 into huggingface:main Apr 19, 2024
17 checks passed
@jla524 jla524 deleted the resnet_multidevice branch April 19, 2024 08:24
@jla524
Copy link
Contributor Author

jla524 commented Apr 19, 2024

Thank you @amyeroberts. I don't have permission to update the tracking issue, but I can make a list of the models that are now supported.

@amyeroberts
Copy link
Collaborator

@jla524 Ah, sorry, I didn't realise. No worries - I can update!

ArthurZucker pushed a commit that referenced this pull request Apr 22, 2024
* feat: multidevice for resnet

* feat: yes! resnet

* fix: compare all elements in tuple

* feat: support for regnet

* feat: support for convnextv2

* feat: support for bit

* feat: support for cvt

* feat: add support for focalnet

* feat: support for yolos

* feat: support for glpn

* feat: support for imagegpt

* feat: support for levit

* feat: support for mgp_str

* feat: support for mobilnet_v1

* feat: support for mobilnet_v2

* feat: support for mobilevit

* feat: support for mobilevitv2

* feat: support for poolformer

* fix: copies

* fix: code quality check

* update: upstream changes from main

* fix: consistency check

* feat: support for sam

* feat: support for switchformer

* feat: support for swin

* feat: support for swinv2

* feat: support for timesformer

* feat: suport for trocr

* feat: support for upernet

* fix: check copies

* update: rerun CI

* update: rerun again, maybe

* update: one more rerun

---------

Co-authored-by: Jacky Lee <[email protected]>
ydshieh pushed a commit that referenced this pull request Apr 23, 2024
* feat: multidevice for resnet

* feat: yes! resnet

* fix: compare all elements in tuple

* feat: support for regnet

* feat: support for convnextv2

* feat: support for bit

* feat: support for cvt

* feat: add support for focalnet

* feat: support for yolos

* feat: support for glpn

* feat: support for imagegpt

* feat: support for levit

* feat: support for mgp_str

* feat: support for mobilnet_v1

* feat: support for mobilnet_v2

* feat: support for mobilevit

* feat: support for mobilevitv2

* feat: support for poolformer

* fix: copies

* fix: code quality check

* update: upstream changes from main

* fix: consistency check

* feat: support for sam

* feat: support for switchformer

* feat: support for swin

* feat: support for swinv2

* feat: support for timesformer

* feat: suport for trocr

* feat: support for upernet

* fix: check copies

* update: rerun CI

* update: rerun again, maybe

* update: one more rerun

---------

Co-authored-by: Jacky Lee <[email protected]>
itazap pushed a commit that referenced this pull request May 14, 2024
* feat: multidevice for resnet

* feat: yes! resnet

* fix: compare all elements in tuple

* feat: support for regnet

* feat: support for convnextv2

* feat: support for bit

* feat: support for cvt

* feat: add support for focalnet

* feat: support for yolos

* feat: support for glpn

* feat: support for imagegpt

* feat: support for levit

* feat: support for mgp_str

* feat: support for mobilnet_v1

* feat: support for mobilnet_v2

* feat: support for mobilevit

* feat: support for mobilevitv2

* feat: support for poolformer

* fix: copies

* fix: code quality check

* update: upstream changes from main

* fix: consistency check

* feat: support for sam

* feat: support for switchformer

* feat: support for swin

* feat: support for swinv2

* feat: support for timesformer

* feat: suport for trocr

* feat: support for upernet

* fix: check copies

* update: rerun CI

* update: rerun again, maybe

* update: one more rerun

---------

Co-authored-by: Jacky Lee <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Community contribution: enabling device_map="auto" support for more vision and multimodal models
3 participants