Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable multi-device for some models #30207

Merged
merged 35 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
9962db7
feat: multidevice for resnet
jla524 Apr 12, 2024
9292c98
feat: yes! resnet
jla524 Apr 12, 2024
26cbabc
fix: compare all elements in tuple
jla524 Apr 12, 2024
3623ddb
feat: support for regnet
jla524 Apr 13, 2024
a28789c
feat: support for convnextv2
jla524 Apr 14, 2024
7bc5ca9
feat: support for bit
jla524 Apr 14, 2024
de0ee8d
feat: support for cvt
jla524 Apr 14, 2024
ed4c1cc
feat: add support for focalnet
jackylee328 Apr 15, 2024
ad7f023
feat: support for yolos
jackylee328 Apr 15, 2024
da150ab
feat: support for glpn
jla524 Apr 16, 2024
dff4972
feat: support for imagegpt
jla524 Apr 16, 2024
6cc5243
feat: support for levit
jla524 Apr 16, 2024
d6cf3eb
feat: support for mgp_str
jla524 Apr 17, 2024
e122cf4
feat: support for mobilnet_v1
jla524 Apr 17, 2024
bb376da
feat: support for mobilnet_v2
jla524 Apr 17, 2024
7300b01
feat: support for mobilevit
jla524 Apr 17, 2024
32e5a45
feat: support for mobilevitv2
jla524 Apr 17, 2024
65c75d8
feat: support for poolformer
jla524 Apr 17, 2024
249e9ec
fix: copies
jla524 Apr 17, 2024
b015b9c
fix: code quality check
jla524 Apr 18, 2024
d0c095c
update: upstream changes from main
jla524 Apr 18, 2024
d5b7e83
fix: consistency check
jla524 Apr 18, 2024
c45fd56
Merge branch 'upstream' into resnet_multidevice
jla524 Apr 18, 2024
0b97fab
feat: support for sam
jla524 Apr 19, 2024
59d17da
feat: support for switchformer
jla524 Apr 19, 2024
142b749
feat: support for swin
jla524 Apr 19, 2024
54047fe
feat: support for swinv2
jla524 Apr 19, 2024
d823e4d
feat: support for timesformer
jla524 Apr 19, 2024
7f730c8
feat: suport for trocr
jla524 Apr 19, 2024
64959f6
feat: support for upernet
jla524 Apr 19, 2024
268484c
fix: check copies
jla524 Apr 19, 2024
90f52e6
update: rerun CI
jla524 Apr 19, 2024
b470e5f
update: rerun again, maybe
jla524 Apr 19, 2024
71d0034
Merge branch 'resnet_multidevice' of github.com:jla524/transformers i…
jla524 Apr 19, 2024
a223c8a
update: one more rerun
jla524 Apr 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/models/bit/modeling_bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,7 @@ class BitPreTrainedModel(PreTrainedModel):
config_class = BitConfig
base_model_prefix = "bit"
main_input_name = "pixel_values"
_no_split_modules = ["BitEmbeddings"]

def _init_weights(self, module):
if isinstance(module, nn.Conv2d):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/convnext/modeling_convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ class ConvNextPreTrainedModel(PreTrainedModel):
config_class = ConvNextConfig
base_model_prefix = "convnext"
main_input_name = "pixel_values"
_no_split_modules = ["ConvNextLayer"]

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/convnextv2/modeling_convnextv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ class ConvNextV2PreTrainedModel(PreTrainedModel):
config_class = ConvNextV2Config
base_model_prefix = "convnextv2"
main_input_name = "pixel_values"
_no_split_modules = ["ConvNextV2LayerNorm"]

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/cvt/modeling_cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ class CvtPreTrainedModel(PreTrainedModel):
config_class = CvtConfig
base_model_prefix = "cvt"
main_input_name = "pixel_values"
_no_split_modules = ["CvtLayer"]

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/focalnet/modeling_focalnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,6 +636,7 @@ class FocalNetPreTrainedModel(PreTrainedModel):
base_model_prefix = "focalnet"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = []

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/glpn/modeling_glpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ class GLPNPreTrainedModel(PreTrainedModel):
config_class = GLPNConfig
base_model_prefix = "glpn"
main_input_name = "pixel_values"
_no_split_modules = []

# Copied from transformers.models.segformer.modeling_segformer.SegformerPreTrainedModel._init_weights
def _init_weights(self, module):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/imagegpt/modeling_imagegpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ class ImageGPTPreTrainedModel(PreTrainedModel):
base_model_prefix = "transformer"
main_input_name = "input_ids"
supports_gradient_checkpointing = True
_no_split_modules = ["ImageGPTBlock"]

def __init__(self, *inputs, **kwargs):
super().__init__(*inputs, **kwargs)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/levit/modeling_levit.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,7 @@ class LevitPreTrainedModel(PreTrainedModel):
config_class = LevitConfig
base_model_prefix = "levit"
main_input_name = "pixel_values"
_no_split_modules = ["LevitResidualLayer"]

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/mgp_str/modeling_mgp_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,7 @@ class MgpstrPreTrainedModel(PreTrainedModel):

config_class = MgpstrConfig
base_model_prefix = "mgp_str"
_no_split_modules = []

def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ class MobileNetV1PreTrainedModel(PreTrainedModel):
base_model_prefix = "mobilenet_v1"
main_input_name = "pixel_values"
supports_gradient_checkpointing = False
_no_split_modules = []

def _init_weights(self, module: Union[nn.Linear, nn.Conv2d]) -> None:
"""Initialize the weights"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ class MobileNetV2PreTrainedModel(PreTrainedModel):
base_model_prefix = "mobilenet_v2"
main_input_name = "pixel_values"
supports_gradient_checkpointing = False
_no_split_modules = []

def _init_weights(self, module: Union[nn.Linear, nn.Conv2d]) -> None:
"""Initialize the weights"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/mobilevit/modeling_mobilevit.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,7 @@ class MobileViTPreTrainedModel(PreTrainedModel):
base_model_prefix = "mobilevit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = ["MobileViTLayer"]

def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,7 @@ class MobileViTV2PreTrainedModel(PreTrainedModel):
base_model_prefix = "mobilevitv2"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = []

def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/poolformer/modeling_poolformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ class PoolFormerPreTrainedModel(PreTrainedModel):
config_class = PoolFormerConfig
base_model_prefix = "poolformer"
main_input_name = "pixel_values"
_no_split_modules = ["PoolFormerLayer"]

def _init_weights(self, module):
"""Initialize the weights"""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/regnet/modeling_regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ class RegNetPreTrainedModel(PreTrainedModel):
config_class = RegNetConfig
base_model_prefix = "regnet"
main_input_name = "pixel_values"
_no_split_modules = ["RegNetYLayer"]

# Copied from transformers.models.resnet.modeling_resnet.ResNetPreTrainedModel._init_weights
def _init_weights(self, module):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/resnet/modeling_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ class ResNetPreTrainedModel(PreTrainedModel):
config_class = ResNetConfig
base_model_prefix = "resnet"
main_input_name = "pixel_values"
_no_split_modules = ["ResNetConvLayer", "ResNetShortCut"]

def _init_weights(self, module):
if isinstance(module, nn.Conv2d):
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/yolos/modeling_yolos.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@ class YolosPreTrainedModel(PreTrainedModel):
base_model_prefix = "vit"
main_input_name = "pixel_values"
supports_gradient_checkpointing = True
_no_split_modules = []

def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None:
"""Initialize the weights"""
Expand Down
20 changes: 16 additions & 4 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2905,7 +2905,10 @@ def test_disk_offload_bin(self):
torch.manual_seed(0)
new_output = new_model(**inputs_dict_class)

self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple):
self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0]))
else:
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))

@require_accelerate
@mark.accelerate_tests
Expand Down Expand Up @@ -2937,7 +2940,10 @@ def test_disk_offload_safetensors(self):
torch.manual_seed(0)
new_output = new_model(**inputs_dict_class)

self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple):
self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0]))
else:
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))

@require_accelerate
@mark.accelerate_tests
Expand Down Expand Up @@ -2973,7 +2979,10 @@ def test_cpu_offload(self):
torch.manual_seed(0)
new_output = new_model(**inputs_dict_class)

self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple):
self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0]))
else:
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))

@require_accelerate
@mark.accelerate_tests
Expand Down Expand Up @@ -3009,7 +3018,10 @@ def test_model_parallelism(self):
torch.manual_seed(0)
new_output = new_model(**inputs_dict_class)

self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))
if isinstance(base_output[0], tuple) and isinstance(new_output[0], tuple):
self.assertTrue(torch.allclose(a, b, atol=1e-5) for a, b in zip(base_output[0], new_output[0]))
else:
self.assertTrue(torch.allclose(base_output[0], new_output[0], atol=1e-5))

def test_problem_types(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
Expand Down