From 9db401e7b9efa0fdf7d9ac603eefbcc34e76e517 Mon Sep 17 00:00:00 2001 From: Chen Date: Tue, 8 Oct 2024 22:37:54 -0400 Subject: [PATCH] refactor: add helper functions and eliminate redundant variables The fallback allocation will be reintroduced once the branching logic is fully refactored. This commit prepares the function infer_auto_device_map for further refactoring. --- src/accelerate/utils/modeling.py | 167 +++++++++++++++++++------------ 1 file changed, 105 insertions(+), 62 deletions(-) diff --git a/src/accelerate/utils/modeling.py b/src/accelerate/utils/modeling.py index e1e61382955..8b863221972 100644 --- a/src/accelerate/utils/modeling.py +++ b/src/accelerate/utils/modeling.py @@ -1268,6 +1268,80 @@ def fallback_allocate( return name, module, modules +def init_infer_auto_device_map( + model: nn.Module, + max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, + no_split_module_classes: Optional[List[str]] = None, + dtype: Optional[Union[str, torch.dtype]] = None, + special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None, +) -> Tuple: + max_memory = get_max_memory(max_memory) + if no_split_module_classes is None: + no_split_module_classes = [] + elif not isinstance(no_split_module_classes, (list, tuple)): + no_split_module_classes = [no_split_module_classes] + + devices = list(max_memory.keys()) + if "disk" not in devices: + devices.append("disk") + gpus = [device for device in devices if device not in ["cpu", "disk"]] + + # Devices that need to keep space for a potential offloaded layer. + if "mps" in gpus: + main_devices = ["mps"] + elif len(gpus) > 0: + main_devices = [gpus[0], "cpu"] + else: + main_devices = ["cpu"] + + module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes) + tied_parameters = find_tied_parameters(model) + + if check_tied_parameters_in_config(model) and len(tied_parameters) == 0: + logger.warn( + "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function." + ) + + # Direct submodules and parameters + modules_to_treat = ( + list(model.named_parameters(recurse=False)) + + list(model.named_children()) + + list(model.named_buffers(recurse=False)) + ) + + return ( + devices, + main_devices, + gpus, + module_sizes, + tied_parameters, + no_split_module_classes, + modules_to_treat, + ) + + +def get_module_size_with_ties( + tied_params, + module_size, + module_sizes, + modules_to_treat, +) -> int: + if not tied_params: + return module_size, [], [] + tied_module_names = [] + tied_modules = [] + for tied_param in tied_params: + tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n in tied_param][0] + tied_module_names.append(modules_to_treat[tied_module_index][0]) + tied_modules.append(modules_to_treat[tied_module_index][1]) + + module_size_with_ties = module_size + for tied_param, tied_module_name in zip(tied_params, tied_module_names): + module_size_with_ties += module_sizes[tied_module_name] - module_sizes[tied_param] + + return module_size_with_ties, tied_module_names, tied_modules + + def infer_auto_device_map( model: nn.Module, max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None, @@ -1317,47 +1391,24 @@ def infer_auto_device_map( In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as well as the parameters. """ - # Get default / clean up max_memory - max_memory = get_max_memory(max_memory) - if no_split_module_classes is None: - no_split_module_classes = [] - elif not isinstance(no_split_module_classes, (list, tuple)): - no_split_module_classes = [no_split_module_classes] - - devices = list(max_memory.keys()) - if "disk" not in devices: - devices.append("disk") - gpus = [device for device in devices if device not in ["cpu", "disk"]] - - # Devices that need to keep space for a potential offloaded layer. - if "mps" in gpus: - main_devices = ["mps"] - elif len(gpus) > 0: - main_devices = [gpus[0], "cpu"] - else: - main_devices = ["cpu"] - - module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes) - tied_parameters = find_tied_parameters(model) - if check_tied_parameters_in_config(model) and len(tied_parameters) == 0: - logger.warn( - "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function." - ) + # Initialize the variables + ( + devices, + main_devices, + gpus, + module_sizes, + tied_parameters, + no_split_module_classes, + modules_to_treat, + ) = init_infer_auto_device_map(model, max_memory, no_split_module_classes, dtype, special_dtypes) device_map = OrderedDict() current_device = 0 - current_memory_used = 0 - device_memory_used = {} + device_memory_used = {device: 0 for device in devices} device_buffer_sizes = {} device_minimum_assignment_memory = {} - # Direct submodules and parameters - modules_to_treat = ( - list(model.named_parameters(recurse=False)) - + list(model.named_children()) - + list(model.named_buffers(recurse=False)) - ) # Initialize maximum largest layer, to know which space to keep in memory max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, no_split_module_classes) @@ -1381,18 +1432,18 @@ def infer_auto_device_map( # and the other is not. # Note: If we are currently processing the name `compute.weight`, an other parameter named e.g. `compute.weight_submodule.parameter` # needs to be considered outside the current module, hence the check with additional dots. - tied_param_goups = [ + tied_param_groups = [ tied_group for tied_group in tied_parameters if any(name + "." in k + "." for k in tied_group) and not all(name + "." in k + "." for k in tied_group) ] - if verbose and len(tied_param_goups) > 0: - print(f" Found the relevant tied param groups {tied_param_goups}") + if verbose and len(tied_param_groups) > 0: + print(f" Found the relevant tied param groups {tied_param_groups}") # Then we keep track of all the parameters that are tied to the current module, but not in the current module tied_params = sum( - [[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_goups], [] + [[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_groups], [] ) if verbose and len(tied_params) > 0: @@ -1405,8 +1456,9 @@ def infer_auto_device_map( if devices[current_device] in main_devices: current_max_size = current_max_size - max_layer_size current_memory_reserved = max_layer_size + # Case 1 -> We're too big! - if current_max_size is not None and current_memory_used + module_size > current_max_size: + if current_max_size is not None and device_memory_used[device] + module_size > current_max_size: # Split or not split? modules_children = ( [] @@ -1416,18 +1468,17 @@ def infer_auto_device_map( if verbose: print( f"Not enough space on {devices[current_device]} to put {name} (space available " - f"{current_max_size - current_memory_used}, module size {module_size})." + f"{current_max_size - device_memory_used[device]}, module size {module_size})." ) if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes: # -> no split, we go to the next device if verbose: print("This module cannot be split, going to the next device.") - if current_memory_used == 0: + if device_memory_used[device] == 0: device_minimum_assignment_memory[device] = module_size + current_memory_reserved - device_memory_used[device] = current_memory_used + current_memory_reserved + device_memory_used[device] = device_memory_used[device] + current_memory_reserved current_device += 1 modules_to_treat = [(name, module)] + modules_to_treat - current_memory_used = 0 else: # -> split, we replace the module studied by its children + parameters if verbose: @@ -1444,12 +1495,7 @@ def infer_auto_device_map( # Case 2, it fits! We're not entirely out of the wood though, because we may have some tied parameters. elif len(tied_params) > 0: # First locate all tied modules - tied_module_names = [] - tied_modules = [] - for tied_param in tied_params: - tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n in tied_param][0] - tied_module_names.append(modules_to_treat[tied_module_index][0]) - tied_modules.append(modules_to_treat[tied_module_index][1]) + module_size_with_ties, tied_module_names, tied_modules = get_module_size_with_ties(tied_params, module_size, module_sizes, modules_to_treat) if verbose: print( f" It looks like {name} is going to fit on {devices[current_device]} but we have tied " @@ -1457,15 +1503,11 @@ def infer_auto_device_map( ) # Let's see if it all fits first - module_size_with_ties = module_size - for tied_param, tied_module_name in zip(tied_params, tied_module_names): - module_size_with_ties += module_sizes[tied_module_name] - module_sizes[tied_param] - - if current_max_size is None or current_memory_used + module_size_with_ties <= current_max_size: + if current_max_size is None or device_memory_used[device] + module_size_with_ties <= current_max_size: # We really really fit! if verbose: print(f"Putting {name} and {tied_module_names} on {devices[current_device]}.") - current_memory_used += module_size_with_ties + device_memory_used[device] += module_size_with_ties device_map[name] = devices[current_device] for tied_module_name in tied_module_names: if tied_module_name in [m[0] for m in modules_to_treat]: @@ -1488,7 +1530,7 @@ def infer_auto_device_map( if verbose: print( f"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space " - f"available {current_max_size - current_memory_used}, needed size {module_size_with_ties})." + f"available {current_max_size - device_memory_used[device]}, needed size {module_size_with_ties})." ) split_happened = False for tied_module_name, tied_module in zip(tied_module_names, tied_modules): @@ -1522,13 +1564,13 @@ def infer_auto_device_map( # If the tied module is not split, we go to the next device if verbose: print("None of the tied module can be split, going to the next device.") - if current_memory_used == 0: + if device_memory_used[device] == 0: device_minimum_assignment_memory[device] = module_size_with_ties + current_memory_reserved - device_memory_used[device] = current_memory_used + current_memory_reserved + device_memory_used[device] = device_memory_used[device] + current_memory_reserved current_device += 1 modules_to_treat = [(name, module)] + modules_to_treat - current_memory_used = 0 + device_memory_used[device] = 0 else: if verbose: @@ -1537,10 +1579,9 @@ def infer_auto_device_map( else: print( f"Putting {name} (size={module_size}) on {devices[current_device]} " - f"(available={current_max_size - current_memory_used})." + f"(available={current_max_size - device_memory_used[device]})." ) - current_memory_used += module_size - device_memory_used[device] = current_memory_used + current_memory_reserved + device_memory_used[device] += module_size device_map[name] = devices[current_device] if not offload_buffers and isinstance(module, nn.Module): @@ -1549,6 +1590,8 @@ def infer_auto_device_map( ) device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size + device_memory_used = {device: mem for device, mem in device_memory_used.items() if mem > 0} + if clean_result: device_map = clean_device_map(device_map)