Skip to content

Commit

Permalink
[Utils] align_module_device (#3204)
Browse files Browse the repository at this point in the history
* implement align_module

* add docs

* move to modeling utils, integrate into existing source code

* update source, expose through utils

* Suggested docstring

Co-authored-by: Zach Mueller <[email protected]>

* Rewrite for readability, add try finally

Co-authored-by: Zach Mueller <[email protected]>

* Use try-finally when aligning with hook

Co-authored-by: Zach Mueller <[email protected]>

* apply style

* improve get_state_dict_from_offload readability

* Update docstring

Co-authored-by: Benjamin Bossan <[email protected]>

* rename to align_module_device, update docstring

---------

Co-authored-by: Zach Mueller <[email protected]>
Co-authored-by: Benjamin Bossan <[email protected]>
  • Loading branch information
3 people authored Nov 1, 2024
1 parent a4a44ac commit bf4572b
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 32 deletions.
6 changes: 5 additions & 1 deletion docs/source/package_reference/big_modeling.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,4 +95,8 @@ rendered properly in your Markdown viewer.

### has_offloaded_params

[[autodoc]] utils.has_offloaded_params
[[autodoc]] utils.has_offloaded_params

### align_module_device

[[autodoc]] utils.align_module_device
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@
is_xpu_available,
)
from .modeling import (
align_module_device,
calculate_maximum_sizes,
check_device_map,
check_tied_parameters_in_config,
Expand Down
81 changes: 50 additions & 31 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1527,22 +1527,12 @@ def get_state_dict_offloaded_model(model: nn.Module):
for name, module in model.named_modules():
if name == "":
continue
if has_offloaded_params(module):
original_device = module._hf_hook.execution_device
# assign hook execution device to cpu
module._hf_hook.execution_device = "cpu"
# onload meta tensors to execution device
try:
module._hf_hook.pre_forward(module)
except MemoryError:
raise MemoryError("Offloaded module must fit in CPU memory to call save_model!") from None
module_state_dict = module.state_dict()
# offload meta tensors from cpu
module._hf_hook.post_forward(module, torch.tensor([]))
# re-assign hook to original execution device
module._hf_hook.execution_device = original_device
else:
module_state_dict = module.state_dict()

try:
with align_module_device(module, "cpu"):
module_state_dict = module.state_dict()
except MemoryError:
raise MemoryError("Offloaded module must fit in CPU memory to call save_model!") from None

for key in module_state_dict:
# ignore placeholder parameters that are still on the meta device
Expand Down Expand Up @@ -1582,22 +1572,12 @@ def get_state_dict_from_offload(
"""

root = module_name[: module_name.rfind(".")] # module name without .weight or .bias
preforward = False
if has_offloaded_params(module):
# assign the device to which the offloaded parameters will be sent
original_device = module._hf_hook.execution_device
module._hf_hook.execution_device = device_to_put_offload
module._hf_hook.pre_forward(module)
preforward = True

for m_key in module.state_dict():
params = module.state_dict()[m_key]
if (root + f".{m_key}") in state_dict:
state_dict[root + f".{m_key}"] = params

if preforward:
module._hf_hook.post_forward(module, torch.tensor([]))
module._hf_hook.execution_device = original_device
# assign the device to which the offloaded parameters will be sent
with align_module_device(module, device_to_put_offload):
for m_key, params in module.state_dict().items():
if (root + f".{m_key}") in state_dict:
state_dict[root + f".{m_key}"] = params

return state_dict

Expand Down Expand Up @@ -1915,3 +1895,42 @@ def has_offloaded_params(module: torch.nn.Module) -> bool:
from ..hooks import AlignDevicesHook # avoid circular import

return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, AlignDevicesHook) and module._hf_hook.offload


@contextlib.contextmanager
def align_module_device(module: torch.nn.Module, execution_device: Optional[torch.device] = None):
"""
Context manager that moves a module's parameters to the specified execution device.
Args:
module (`torch.nn.Module`):
Module with parameters to align.
execution_device (`torch.device`, *optional*):
If provided, overrides the module's execution device within the context. Otherwise, use hook execution
device or pass
"""
if has_offloaded_params(module):
if execution_device is not None:
original_device = module._hf_hook.execution_device
module._hf_hook.execution_device = execution_device

try:
module._hf_hook.pre_forward(module)
yield
finally:
module._hf_hook.post_forward(module, None)
if execution_device is not None:
module._hf_hook.execution_device = original_device

elif execution_device is not None:
devices = {name: param.device for name, param in module.named_parameters()}
try:
for name in devices:
set_module_tensor_to_device(module, name, execution_device)
yield
finally:
for name, device in devices.items():
set_module_tensor_to_device(module, name, device)

else:
yield
49 changes: 49 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from safetensors.torch import save_file

from accelerate import init_empty_weights
from accelerate.big_modeling import cpu_offload
from accelerate.test_utils import (
require_cuda,
require_huggingface_suite,
Expand All @@ -34,6 +35,7 @@
torch_device,
)
from accelerate.utils.modeling import (
align_module_device,
check_device_map,
clean_device_map,
compute_module_sizes,
Expand Down Expand Up @@ -785,3 +787,50 @@ def test_convert_file_size(self):

with self.assertRaises(ValueError):
convert_file_size_to_int("-1GB")

def test_align_module_device_simple(self):
model = ModelForTest()
execution_device = torch.device(torch_device)
model_device = torch.device("cpu")

# test default execution device
with align_module_device(model.batchnorm):
assert model.linear1.weight.device == model_device
assert model.batchnorm.weight.device == model_device
assert model.linear2.weight.device == model_device
assert model.linear1.weight.device == model_device
assert model.batchnorm.weight.device == model_device
assert model.linear2.weight.device == model_device

# test with explicit execution device
with align_module_device(model.batchnorm, execution_device=execution_device):
assert model.linear1.weight.device == model_device
assert model.batchnorm.weight.device == execution_device
assert model.linear2.weight.device == model_device
assert model.linear1.weight.device == model_device
assert model.batchnorm.weight.device == model_device
assert model.linear2.weight.device == model_device

def test_align_module_device_offloaded(self):
model = ModelForTest()
execution_device = torch.device(torch_device)
offload_device = torch.device("meta")
cpu_offload(model, execution_device=execution_device)

# test default execution device
with align_module_device(model.batchnorm):
assert model.linear1.weight.device == offload_device
assert model.batchnorm.weight.device == execution_device
assert model.linear2.weight.device == offload_device
assert model.linear1.weight.device == offload_device
assert model.batchnorm.weight.device == offload_device
assert model.linear2.weight.device == offload_device

# test with explicit execution device
with align_module_device(model.batchnorm, execution_device="cpu"):
assert model.linear1.weight.device == offload_device
assert model.batchnorm.weight.device == torch.device("cpu")
assert model.linear2.weight.device == offload_device
assert model.linear1.weight.device == offload_device
assert model.batchnorm.weight.device == offload_device
assert model.linear2.weight.device == offload_device

0 comments on commit bf4572b

Please sign in to comment.