Skip to content

Commit

Permalink
Enable multi-device for some models (#30207)
Browse files Browse the repository at this point in the history
* feat: multidevice for resnet

* feat: yes! resnet

* fix: compare all elements in tuple

* feat: support for regnet

* feat: support for convnextv2

* feat: support for bit

* feat: support for cvt

* feat: add support for focalnet

* feat: support for yolos

* feat: support for glpn

* feat: support for imagegpt

* feat: support for levit

* feat: support for mgp_str

* feat: support for mobilnet_v1

* feat: support for mobilnet_v2

* feat: support for mobilevit

* feat: support for mobilevitv2

* feat: support for poolformer

* fix: copies

* fix: code quality check

* update: upstream changes from main

* fix: consistency check

* feat: support for sam

* feat: support for switchformer

* feat: support for swin

* feat: support for swinv2

* feat: support for timesformer

* feat: suport for trocr

* feat: support for upernet

* fix: check copies

* update: rerun CI

* update: rerun again, maybe

* update: one more rerun

---------

Co-authored-by: Jacky Lee <[email protected]>
  • Loading branch information
jla524 and jackylee328 authored Apr 19, 2024
1 parent ecfe9be commit 30b4532
Show file tree
Hide file tree
Showing 27 changed files with 42 additions and 4 deletions.
1 change: 1 addition & 0 deletions src/transformers/models/bit/modeling_bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,7 @@ class BitPreTrainedModel(PreTrainedModel):
config_class = BitConfig
base_model_prefix = "bit"
main_input_name = "pixel_values"
_no_split_modules = ["BitEmbeddings"]

def _init_weights(self, module):
if isinstance(module, nn.Conv2d):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/convnext/modeling_convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ class ConvNextPreTrainedModel(PreTrainedModel):
config_class = ConvNextConfig
base_model_prefix = "convnext"
main_input_name = "pixel_values"
_no_split_modules = ["ConvNextLayer"]

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/convnextv2/modeling_convnextv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ class ConvNextV2PreTrainedModel(PreTrainedModel):
config_class = ConvNextV2Config
base_model_prefix = "convnextv2"
main_input_name = "pixel_values"
_no_split_modules = ["ConvNextV2Layer"]

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/cvt/modeling_cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ class CvtPreTrainedModel(PreTrainedModel):
config_class = CvtConfig
base_model_prefix = "cvt"
main_input_name = "pixel_values"
_no_split_modules = ["CvtLayer"]

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/donut/modeling_donut_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,7 @@ class DonutSwinPreTrainedModel(PreTrainedModel):
base_model_prefix = "swin"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["DonutSwinStage"]

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/focalnet/modeling_focalnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,7 @@ class FocalNetPreTrainedModel(PreTrainedModel):
base_model_prefix = "focalnet"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["FocalNetStage"]

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/glpn/modeling_glpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ class GLPNPreTrainedModel(PreTrainedModel):
config_class = GLPNConfig
base_model_prefix = "glpn"
main_input_name = "pixel_values"
_no_split_modules = []

# Copied from transformers.models.segformer.modeling_segformer.SegformerPreTrainedModel._init_weights
def _init_weights(self, module):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/imagegpt/modeling_imagegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ class ImageGPTPreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
main_input_name = "input_ids"
supports_gradient_checkpointing = True
_no_split_modules = ["ImageGPTBlock"]

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/levit/modeling_levit.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ class LevitPreTrainedModel(PreTrainedModel):
config_class = LevitConfig
base_model_prefix = "levit"
main_input_name = "pixel_values"
_no_split_modules = ["LevitResidualLayer"]

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,7 @@ class MaskFormerSwinPreTrainedModel(PreTrainedModel):
base_model_prefix = "model"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["MaskFormerSwinStage"]

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/mgp_str/modeling_mgp_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ class MgpstrPreTrainedModel(PreTrainedModel):

config_class = MgpstrConfig
base_model_prefix = "mgp_str"
_no_split_modules = []

def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ class MobileNetV1PreTrainedModel(PreTrainedModel):
base_model_prefix = "mobilenet_v1"
main_input_name = "pixel_values"
supports_gradient_checkpointing = False
_no_split_modules = []

def _init_weights(self, module: Union[nn.Linear, nn.Conv2d]) -> None:
"""Initialize the weights"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ class MobileNetV2PreTrainedModel(PreTrainedModel):
base_model_prefix = "mobilenet_v2"
main_input_name = "pixel_values"
supports_gradient_checkpointing = False
_no_split_modules = []

def _init_weights(self, module: Union[nn.Linear, nn.Conv2d]) -> None:
"""Initialize the weights"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/mobilevit/modeling_mobilevit.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,7 @@ class MobileViTPreTrainedModel(PreTrainedModel):
base_model_prefix = "mobilevit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["MobileViTLayer"]

def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ class MobileViTV2PreTrainedModel(PreTrainedModel):
base_model_prefix = "mobilevitv2"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["MobileViTV2Layer"]

def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/poolformer/modeling_poolformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ class PoolFormerPreTrainedModel(PreTrainedModel):
config_class = PoolFormerConfig
base_model_prefix = "poolformer"
main_input_name = "pixel_values"
_no_split_modules = ["PoolFormerLayer"]

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/regnet/modeling_regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ class RegNetPreTrainedModel(PreTrainedModel):
config_class = RegNetConfig
base_model_prefix = "regnet"
main_input_name = "pixel_values"
_no_split_modules = ["RegNetYLayer"]

# Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel._init_weights
def _init_weights(self, module):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/resnet/modeling_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ class ResNetPreTrainedModel(PreTrainedModel):
config_class = ResNetConfig
base_model_prefix = "resnet"
main_input_name = "pixel_values"
_no_split_modules = ["ResNetConvLayer", "ResNetShortCut"]

def _init_weights(self, module):
if isinstance(module, nn.Conv2d):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/sam/modeling_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -1074,6 +1074,7 @@ class SamPreTrainedModel(PreTrainedModel):
config_class = SamConfig
base_model_prefix = "sam"
main_input_name = "pixel_values"
_no_split_modules = ["SamVisionAttention"]

def _init_weights(self, module):
std = self.config.initializer_range
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ class SwiftFormerPreTrainedModel(PreTrainedModel):
base_model_prefix = "swiftformer"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["SwiftFormerEncoderBlock"]

def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/swin/modeling_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,7 @@ class SwinPreTrainedModel(PreTrainedModel):
base_model_prefix = "swin"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["SwinStage"]

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/swinv2/modeling_swinv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,7 @@ class Swinv2PreTrainedModel(PreTrainedModel):
base_model_prefix = "swinv2"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["Swinv2Stage"]

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ class TimesformerPreTrainedModel(PreTrainedModel):
base_model_prefix = "timesformer"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["TimesformerLayer"]

def _init_weights(self, module):
if isinstance(module, (nn.Linear, nn.Conv2d)):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/trocr/modeling_trocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ class TrOCRPreTrainedModel(PreTrainedModel):
config_class = TrOCRConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["TrOCRDecoderLayer"]

def _init_weights(self, module):
std = self.config.init_std
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/upernet/modeling_upernet.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ class UperNetPreTrainedModel(PreTrainedModel):

config_class = UperNetConfig
main_input_name = "pixel_values"
_no_split_modules = []

def _init_weights(self, module):
if isinstance(module, UperNetPreTrainedModel):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/yolos/modeling_yolos.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@ class YolosPreTrainedModel(PreTrainedModel):
base_model_prefix = "vit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = []

def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
Expand Down
20 changes: 16 additions & 4 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2907,7 +2907,10 @@ def test_disk_offload_bin(self):
torch.manual_seed(0)
new_output = new_model(**inputs_dict_class)

self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple):
self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0]))
else:
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))

@require_accelerate
@mark.accelerate_tests
Expand Down Expand Up @@ -2939,7 +2942,10 @@ def test_disk_offload_safetensors(self):
torch.manual_seed(0)
new_output = new_model(**inputs_dict_class)

self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple):
self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0]))
else:
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))

@require_accelerate
@mark.accelerate_tests
Expand Down Expand Up @@ -2975,7 +2981,10 @@ def test_cpu_offload(self):
torch.manual_seed(0)
new_output = new_model(**inputs_dict_class)

self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple):
self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0]))
else:
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))

@require_accelerate
@mark.accelerate_tests
Expand Down Expand Up @@ -3011,7 +3020,10 @@ def test_model_parallelism(self):
torch.manual_seed(0)
new_output = new_model(**inputs_dict_class)

self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple):
self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0]))
else:
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))

def test_problem_types(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
Expand Down

0 comments on commit 30b4532

Please sign in to comment.