Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add use_all_gather for option #3164

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2405,7 +2405,7 @@ def clip_grad_value_(self, parameters, clip_value):
self.unscale_gradients()
torch.nn.utils.clip_grad_value_(parameters, clip_value)

def gather(self, tensor):
def gather(self, tensor, use_all_gather=True):
"""
Gather the values in *tensor* across all processes and concatenate them on the first dimension. Useful to
regroup the predictions from all processes when doing evaluation.
Expand All @@ -2416,6 +2416,8 @@ def gather(self, tensor):
Args:
tensor (`torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`):
The tensors to gather across all processes.
use_all_gather(`bool`):
Whether to use all_gather or gather

Returns:
`torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`: The gathered tensor(s). Note that the
Expand All @@ -2435,9 +2437,9 @@ def gather(self, tensor):
tensor([0, 1, 2, 3])
```
"""
return gather(tensor)
return gather(tensor, use_all_gather)

def gather_for_metrics(self, input_data, use_gather_object=False):
def gather_for_metrics(self, input_data, use_gather_object=False, use_all_gather=True):
"""
Gathers `input_data` and potentially drops duplicates in the last batch if on a distributed system. Should be
used for gathering the inputs and targets for metric calculation.
Expand All @@ -2450,6 +2452,11 @@ def gather_for_metrics(self, input_data, use_gather_object=False):
not contain tensors). This flag can be useful for gathering tensors with different sizes that we don't
want to pad and concatenate along the first dimension. Using it with GPU tensors is not well supported
and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled.
use_all_gather(`bool`):
Whether to use all_gather instead of gather. all_gather collects a list of tensors from all processes,
while gather collects tensors into a single process. Using all_gather can be beneficial in scenarios
where all processes need access to the complete dataset, but it may use more memory. For the evaluation purpose
try to use gather function.

Example:

Expand Down Expand Up @@ -2477,9 +2484,9 @@ def gather_for_metrics(self, input_data, use_gather_object=False):
use_gather_object = use_gather_object or not all_tensors

if use_gather_object:
data = gather_object(input_data)
data = gather_object(input_data, use_all_gather)
else:
data = self.gather(input_data)
data = self.gather(input_data, use_all_gather)

try:
if self.gradient_state.end_of_dataloader:
Expand Down
17 changes: 15 additions & 2 deletions src/accelerate/test_utils/scripts/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,19 +36,32 @@ def create_tensor(state):

def test_gather(state):
tensor = create_tensor(state)
gathered_tensor = gather(tensor)
gathered_tensor = gather(tensor, use_all_gather=True)
assert gathered_tensor.tolist() == list(range(1, state.num_processes**2 + 1))

gathered_tensor = gather(tensor, use_all_gather=False)
if state.is_main_process():
assert gathered_tensor.tolist() == list(range(1, state.num_processes**2 + 1))


def test_gather_object(state):
# Gather objects in TorchXLA is not supported.
if state.distributed_type == DistributedType.XLA:
return
obj = [state.process_index]
gathered_obj = gather_object(obj)
gathered_obj = gather_object(obj, use_all_gather=True)
assert len(gathered_obj) == state.num_processes, f"{gathered_obj}, {len(gathered_obj)} != {state.num_processes}"
assert gathered_obj == list(range(state.num_processes)), f"{gathered_obj} != {list(range(state.num_processes))}"

gathered_obj = gather_object(obj, use_all_gather=False)
if state.is_main_process():
assert (
len(gathered_obj) == state.num_processes
), f"{gathered_obj}, {len(gathered_obj)} != {state.num_processes}"
assert gathered_obj == list(
range(state.num_processes)
), f"{gathered_obj} != {list(range(state.num_processes))}"


def test_gather_non_contigous(state):
# Skip this test because the 'is_contiguous' function of XLA tensor always returns True.
Expand Down
42 changes: 31 additions & 11 deletions src/accelerate/utils/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def _tpu_gather_one(tensor):
return res


def _gpu_gather(tensor):
def _gpu_gather(tensor, use_all_gather=True):
state = PartialState()
if is_torch_version(">=", "1.13"):
gather_op = torch.distributed.all_gather_into_tensor
Expand Down Expand Up @@ -350,8 +350,15 @@ def _gpu_gather_one(tensor):
# also gloo does not support `all_gather_into_tensor`,
# which will result in a larger memory overhead for the op
output_tensors = [torch.empty_like(tensor) for _ in range(state.num_processes)]
torch.distributed.all_gather(output_tensors, tensor)
return torch.cat(output_tensors, dim=0)
if use_all_gather:
torch.distributed.all_gather(output_tensors, tensor)
return torch.cat(output_tensors, dim=0)
else:
if state.is_main_process:
torch.distributed.gather(output_tensors, tensor)
return torch.cat(output_tensors, dim=0)
else:
return []

return recursively_apply(_gpu_gather_one, tensor, error_on_other_type=True)

Expand Down Expand Up @@ -420,47 +427,60 @@ def wrapper(*args, **kwargs):


@verify_operation
def gather(tensor):
def gather(tensor, use_all_gather=True):
"""
Recursively gather tensor in a nested list/tuple/dictionary of tensors from all devices.

Args:
tensor (nested list/tuple/dictionary of `torch.Tensor`):
The data to gather.
use_all_gather(`bool`):
Whether to use all_gather or gather

Returns:
The same data structure as `tensor` with all tensors sent to the proper device.
"""
if PartialState().distributed_type == DistributedType.XLA:
return _tpu_gather(tensor)
elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES:
return _gpu_gather(tensor)
return _gpu_gather(tensor, use_all_gather)
else:
return tensor


def _gpu_gather_object(object: Any):
def _gpu_gather_object(object: Any, use_all_gather: bool):
output_objects = [None for _ in range(PartialState().num_processes)]
torch.distributed.all_gather_object(output_objects, object)
# all_gather_object returns a list of lists, so we need to flatten it
return [x for y in output_objects for x in y]
if use_all_gather:
torch.distributed.all_gather_object(output_objects, object)
# all_gather_object returns a list of lists, so we need to flatten it
return [x for y in output_objects for x in y]
else:
if PartialState().is_main_process():
torch.distributed.gather_object(output_objects, object)
# all_gather_object returns a list of lists, so we need to flatten it
return [x for y in output_objects for x in y]
else:
torch.distributed.gather_object(output_objects, object)
return []


def gather_object(object: Any):
def gather_object(object: Any, use_all_gather: bool = True):
"""
Recursively gather object in a nested list/tuple/dictionary of objects from all devices.

Args:
object (nested list/tuple/dictionary of picklable object):
The data to gather.
use_all_gather(`bool`):
Whether to use all_gather or gather

Returns:
The same data structure as `object` with all the objects sent to every device.
"""
if PartialState().distributed_type == DistributedType.XLA:
raise NotImplementedError("gather objects in TPU is not supported")
elif PartialState().distributed_type in TORCH_DISTRIBUTED_OPERATION_TYPES:
return _gpu_gather_object(object)
return _gpu_gather_object(object, use_all_gather)
else:
return object

Expand Down