From bf4572b6ce0a534a9d73537485a0edf1d68144b8 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Fri, 1 Nov 2024 09:05:50 -0400 Subject: [PATCH] [Utils] `align_module_device` (#3204) * implement align_module * add docs * move to modeling utils, integrate into existing source code * update source, expose through utils * Suggested docstring Co-authored-by: Zach Mueller * Rewrite for readability, add try finally Co-authored-by: Zach Mueller * Use try-finally when aligning with hook Co-authored-by: Zach Mueller * apply style * improve get_state_dict_from_offload readability * Update docstring Co-authored-by: Benjamin Bossan * rename to align_module_device, update docstring --------- Co-authored-by: Zach Mueller Co-authored-by: Benjamin Bossan --- docs/source/package_reference/big_modeling.md | 6 +- src/accelerate/utils/__init__.py | 1 + src/accelerate/utils/modeling.py | 81 ++++++++++++------- tests/test_modeling_utils.py | 49 +++++++++++ 4 files changed, 105 insertions(+), 32 deletions(-) diff --git a/docs/source/package_reference/big_modeling.md b/docs/source/package_reference/big_modeling.md index 12f95583043..73ffe90d601 100644 --- a/docs/source/package_reference/big_modeling.md +++ b/docs/source/package_reference/big_modeling.md @@ -95,4 +95,8 @@ rendered properly in your Markdown viewer. ### has_offloaded_params -[[autodoc]] utils.has_offloaded_params \ No newline at end of file +[[autodoc]] utils.has_offloaded_params + +### align_module_device + +[[autodoc]] utils.align_module_device \ No newline at end of file diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index f1a856360cf..5b8917fcd48 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -124,6 +124,7 @@ is_xpu_available, ) from .modeling import ( + align_module_device, calculate_maximum_sizes, check_device_map, check_tied_parameters_in_config, diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index ed8cd350cce..1e6b1c9c6f6 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1527,22 +1527,12 @@ def get_state_dict_offloaded_model(model: nn.Module): for name, module in model.named_modules(): if name == "": continue - if has_offloaded_params(module): - original_device = module._hf_hook.execution_device - # assign hook execution device to cpu - module._hf_hook.execution_device = "cpu" - # onload meta tensors to execution device - try: - module._hf_hook.pre_forward(module) - except MemoryError: - raise MemoryError("Offloaded module must fit in CPU memory to call save_model!") from None - module_state_dict = module.state_dict() - # offload meta tensors from cpu - module._hf_hook.post_forward(module, torch.tensor([])) - # re-assign hook to original execution device - module._hf_hook.execution_device = original_device - else: - module_state_dict = module.state_dict() + + try: + with align_module_device(module, "cpu"): + module_state_dict = module.state_dict() + except MemoryError: + raise MemoryError("Offloaded module must fit in CPU memory to call save_model!") from None for key in module_state_dict: # ignore placeholder parameters that are still on the meta device @@ -1582,22 +1572,12 @@ def get_state_dict_from_offload( """ root = module_name[: module_name.rfind(".")] # module name without .weight or .bias - preforward = False - if has_offloaded_params(module): - # assign the device to which the offloaded parameters will be sent - original_device = module._hf_hook.execution_device - module._hf_hook.execution_device = device_to_put_offload - module._hf_hook.pre_forward(module) - preforward = True - for m_key in module.state_dict(): - params = module.state_dict()[m_key] - if (root + f".{m_key}") in state_dict: - state_dict[root + f".{m_key}"] = params - - if preforward: - module._hf_hook.post_forward(module, torch.tensor([])) - module._hf_hook.execution_device = original_device + # assign the device to which the offloaded parameters will be sent + with align_module_device(module, device_to_put_offload): + for m_key, params in module.state_dict().items(): + if (root + f".{m_key}") in state_dict: + state_dict[root + f".{m_key}"] = params return state_dict @@ -1915,3 +1895,42 @@ def has_offloaded_params(module: torch.nn.Module) -> bool: from ..hooks import AlignDevicesHook # avoid circular import return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, AlignDevicesHook) and module._hf_hook.offload + + +@contextlib.contextmanager +def align_module_device(module: torch.nn.Module, execution_device: Optional[torch.device] = None): + """ + Context manager that moves a module's parameters to the specified execution device. + + Args: + module (`torch.nn.Module`): + Module with parameters to align. + execution_device (`torch.device`, *optional*): + If provided, overrides the module's execution device within the context. Otherwise, use hook execution + device or pass + """ + if has_offloaded_params(module): + if execution_device is not None: + original_device = module._hf_hook.execution_device + module._hf_hook.execution_device = execution_device + + try: + module._hf_hook.pre_forward(module) + yield + finally: + module._hf_hook.post_forward(module, None) + if execution_device is not None: + module._hf_hook.execution_device = original_device + + elif execution_device is not None: + devices = {name: param.device for name, param in module.named_parameters()} + try: + for name in devices: + set_module_tensor_to_device(module, name, execution_device) + yield + finally: + for name, device in devices.items(): + set_module_tensor_to_device(module, name, device) + + else: + yield diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 50dd06cd61a..41ce475c6de 100644 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -26,6 +26,7 @@ from safetensors.torch import save_file from accelerate import init_empty_weights +from accelerate.big_modeling import cpu_offload from accelerate.test_utils import ( require_cuda, require_huggingface_suite, @@ -34,6 +35,7 @@ torch_device, ) from accelerate.utils.modeling import ( + align_module_device, check_device_map, clean_device_map, compute_module_sizes, @@ -785,3 +787,50 @@ def test_convert_file_size(self): with self.assertRaises(ValueError): convert_file_size_to_int("-1GB") + + def test_align_module_device_simple(self): + model = ModelForTest() + execution_device = torch.device(torch_device) + model_device = torch.device("cpu") + + # test default execution device + with align_module_device(model.batchnorm): + assert model.linear1.weight.device == model_device + assert model.batchnorm.weight.device == model_device + assert model.linear2.weight.device == model_device + assert model.linear1.weight.device == model_device + assert model.batchnorm.weight.device == model_device + assert model.linear2.weight.device == model_device + + # test with explicit execution device + with align_module_device(model.batchnorm, execution_device=execution_device): + assert model.linear1.weight.device == model_device + assert model.batchnorm.weight.device == execution_device + assert model.linear2.weight.device == model_device + assert model.linear1.weight.device == model_device + assert model.batchnorm.weight.device == model_device + assert model.linear2.weight.device == model_device + + def test_align_module_device_offloaded(self): + model = ModelForTest() + execution_device = torch.device(torch_device) + offload_device = torch.device("meta") + cpu_offload(model, execution_device=execution_device) + + # test default execution device + with align_module_device(model.batchnorm): + assert model.linear1.weight.device == offload_device + assert model.batchnorm.weight.device == execution_device + assert model.linear2.weight.device == offload_device + assert model.linear1.weight.device == offload_device + assert model.batchnorm.weight.device == offload_device + assert model.linear2.weight.device == offload_device + + # test with explicit execution device + with align_module_device(model.batchnorm, execution_device="cpu"): + assert model.linear1.weight.device == offload_device + assert model.batchnorm.weight.device == torch.device("cpu") + assert model.linear2.weight.device == offload_device + assert model.linear1.weight.device == offload_device + assert model.batchnorm.weight.device == offload_device + assert model.linear2.weight.device == offload_device