Skip to content

Commit

Permalink
refactor: add helper functions and eliminate redundant variables
Browse files Browse the repository at this point in the history
The fallback allocation will be reintroduced once the branching logic is fully refactored. This commit prepares the function infer_auto_device_map for further refactoring.
  • Loading branch information
Nech-C committed Oct 9, 2024
1 parent f040302 commit 9db401e
Showing 1 changed file with 105 additions and 62 deletions.
167 changes: 105 additions & 62 deletions src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,6 +1268,80 @@ def fallback_allocate(
return name, module, modules


def init_infer_auto_device_map(
model: nn.Module,
max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None,
no_split_module_classes: Optional[List[str]] = None,
dtype: Optional[Union[str, torch.dtype]] = None,
special_dtypes: Optional[Dict[str, Union[str, torch.device]]] = None,
) -> Tuple:
max_memory = get_max_memory(max_memory)
if no_split_module_classes is None:
no_split_module_classes = []
elif not isinstance(no_split_module_classes, (list, tuple)):
no_split_module_classes = [no_split_module_classes]

devices = list(max_memory.keys())
if "disk" not in devices:
devices.append("disk")
gpus = [device for device in devices if device not in ["cpu", "disk"]]

# Devices that need to keep space for a potential offloaded layer.
if "mps" in gpus:
main_devices = ["mps"]
elif len(gpus) > 0:
main_devices = [gpus[0], "cpu"]
else:
main_devices = ["cpu"]

module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes)
tied_parameters = find_tied_parameters(model)

if check_tied_parameters_in_config(model) and len(tied_parameters) == 0:
logger.warn(
"The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function."
)

# Direct submodules and parameters
modules_to_treat = (
list(model.named_parameters(recurse=False))
+ list(model.named_children())
+ list(model.named_buffers(recurse=False))
)

return (
devices,
main_devices,
gpus,
module_sizes,
tied_parameters,
no_split_module_classes,
modules_to_treat,
)


def get_module_size_with_ties(
tied_params,
module_size,
module_sizes,
modules_to_treat,
) -> int:
if not tied_params:
return module_size, [], []
tied_module_names = []
tied_modules = []
for tied_param in tied_params:
tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n in tied_param][0]
tied_module_names.append(modules_to_treat[tied_module_index][0])
tied_modules.append(modules_to_treat[tied_module_index][1])

module_size_with_ties = module_size
for tied_param, tied_module_name in zip(tied_params, tied_module_names):
module_size_with_ties += module_sizes[tied_module_name] - module_sizes[tied_param]

return module_size_with_ties, tied_module_names, tied_modules


def infer_auto_device_map(
model: nn.Module,
max_memory: Optional[Dict[Union[int, str], Union[int, str]]] = None,
Expand Down Expand Up @@ -1317,47 +1391,24 @@ def infer_auto_device_map(
In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as
well as the parameters.
"""
# Get default / clean up max_memory
max_memory = get_max_memory(max_memory)
if no_split_module_classes is None:
no_split_module_classes = []
elif not isinstance(no_split_module_classes, (list, tuple)):
no_split_module_classes = [no_split_module_classes]

devices = list(max_memory.keys())
if "disk" not in devices:
devices.append("disk")
gpus = [device for device in devices if device not in ["cpu", "disk"]]

# Devices that need to keep space for a potential offloaded layer.
if "mps" in gpus:
main_devices = ["mps"]
elif len(gpus) > 0:
main_devices = [gpus[0], "cpu"]
else:
main_devices = ["cpu"]

module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes)
tied_parameters = find_tied_parameters(model)

if check_tied_parameters_in_config(model) and len(tied_parameters) == 0:
logger.warn(
"The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function."
)
# Initialize the variables
(
devices,
main_devices,
gpus,
module_sizes,
tied_parameters,
no_split_module_classes,
modules_to_treat,
) = init_infer_auto_device_map(model, max_memory, no_split_module_classes, dtype, special_dtypes)

device_map = OrderedDict()
current_device = 0
current_memory_used = 0
device_memory_used = {}
device_memory_used = {device: 0 for device in devices}
device_buffer_sizes = {}
device_minimum_assignment_memory = {}

# Direct submodules and parameters
modules_to_treat = (
list(model.named_parameters(recurse=False))
+ list(model.named_children())
+ list(model.named_buffers(recurse=False))
)
# Initialize maximum largest layer, to know which space to keep in memory
max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, no_split_module_classes)

Expand All @@ -1381,18 +1432,18 @@ def infer_auto_device_map(
# and the other is not.
# Note: If we are currently processing the name `compute.weight`, an other parameter named e.g. `compute.weight_submodule.parameter`
# needs to be considered outside the current module, hence the check with additional dots.
tied_param_goups = [
tied_param_groups = [
tied_group
for tied_group in tied_parameters
if any(name + "." in k + "." for k in tied_group) and not all(name + "." in k + "." for k in tied_group)
]

if verbose and len(tied_param_goups) > 0:
print(f" Found the relevant tied param groups {tied_param_goups}")
if verbose and len(tied_param_groups) > 0:
print(f" Found the relevant tied param groups {tied_param_groups}")

# Then we keep track of all the parameters that are tied to the current module, but not in the current module
tied_params = sum(
[[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_goups], []
[[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_groups], []
)

if verbose and len(tied_params) > 0:
Expand All @@ -1405,8 +1456,9 @@ def infer_auto_device_map(
if devices[current_device] in main_devices:
current_max_size = current_max_size - max_layer_size
current_memory_reserved = max_layer_size

# Case 1 -> We're too big!
if current_max_size is not None and current_memory_used + module_size > current_max_size:
if current_max_size is not None and device_memory_used[device] + module_size > current_max_size:
# Split or not split?
modules_children = (
[]
Expand All @@ -1416,18 +1468,17 @@ def infer_auto_device_map(
if verbose:
print(
f"Not enough space on {devices[current_device]} to put {name} (space available "
f"{current_max_size - current_memory_used}, module size {module_size})."
f"{current_max_size - device_memory_used[device]}, module size {module_size})."
)
if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes:
# -> no split, we go to the next device
if verbose:
print("This module cannot be split, going to the next device.")
if current_memory_used == 0:
if device_memory_used[device] == 0:
device_minimum_assignment_memory[device] = module_size + current_memory_reserved
device_memory_used[device] = current_memory_used + current_memory_reserved
device_memory_used[device] = device_memory_used[device] + current_memory_reserved
current_device += 1
modules_to_treat = [(name, module)] + modules_to_treat
current_memory_used = 0
else:
# -> split, we replace the module studied by its children + parameters
if verbose:
Expand All @@ -1444,28 +1495,19 @@ def infer_auto_device_map(
# Case 2, it fits! We're not entirely out of the wood though, because we may have some tied parameters.
elif len(tied_params) > 0:
# First locate all tied modules
tied_module_names = []
tied_modules = []
for tied_param in tied_params:
tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n in tied_param][0]
tied_module_names.append(modules_to_treat[tied_module_index][0])
tied_modules.append(modules_to_treat[tied_module_index][1])
module_size_with_ties, tied_module_names, tied_modules = get_module_size_with_ties(tied_params, module_size, module_sizes, modules_to_treat)
if verbose:
print(
f" It looks like {name} is going to fit on {devices[current_device]} but we have tied "
f"parameters to account for.\n - Names {tied_params}\n - Module names {tied_module_names}"
)

# Let's see if it all fits first
module_size_with_ties = module_size
for tied_param, tied_module_name in zip(tied_params, tied_module_names):
module_size_with_ties += module_sizes[tied_module_name] - module_sizes[tied_param]

if current_max_size is None or current_memory_used + module_size_with_ties <= current_max_size:
if current_max_size is None or device_memory_used[device] + module_size_with_ties <= current_max_size:
# We really really fit!
if verbose:
print(f"Putting {name} and {tied_module_names} on {devices[current_device]}.")
current_memory_used += module_size_with_ties
device_memory_used[device] += module_size_with_ties
device_map[name] = devices[current_device]
for tied_module_name in tied_module_names:
if tied_module_name in [m[0] for m in modules_to_treat]:
Expand All @@ -1488,7 +1530,7 @@ def infer_auto_device_map(
if verbose:
print(
f"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space "
f"available {current_max_size - current_memory_used}, needed size {module_size_with_ties})."
f"available {current_max_size - device_memory_used[device]}, needed size {module_size_with_ties})."
)
split_happened = False
for tied_module_name, tied_module in zip(tied_module_names, tied_modules):
Expand Down Expand Up @@ -1522,13 +1564,13 @@ def infer_auto_device_map(
# If the tied module is not split, we go to the next device
if verbose:
print("None of the tied module can be split, going to the next device.")
if current_memory_used == 0:
if device_memory_used[device] == 0:
device_minimum_assignment_memory[device] = module_size_with_ties + current_memory_reserved

device_memory_used[device] = current_memory_used + current_memory_reserved
device_memory_used[device] = device_memory_used[device] + current_memory_reserved
current_device += 1
modules_to_treat = [(name, module)] + modules_to_treat
current_memory_used = 0
device_memory_used[device] = 0

else:
if verbose:
Expand All @@ -1537,10 +1579,9 @@ def infer_auto_device_map(
else:
print(
f"Putting {name} (size={module_size}) on {devices[current_device]} "
f"(available={current_max_size - current_memory_used})."
f"(available={current_max_size - device_memory_used[device]})."
)
current_memory_used += module_size
device_memory_used[device] = current_memory_used + current_memory_reserved
device_memory_used[device] += module_size
device_map[name] = devices[current_device]

if not offload_buffers and isinstance(module, nn.Module):
Expand All @@ -1549,6 +1590,8 @@ def infer_auto_device_map(
)
device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size

device_memory_used = {device: mem for device, mem in device_memory_used.items() if mem > 0}

if clean_result:
device_map = clean_device_map(device_map)

Expand Down

0 comments on commit 9db401e

Please sign in to comment.