Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added option to do backward AG over smaller set of gpus instead of full DDP world #1125

Open
wants to merge 1 commit into
base: ngoyal_bf16_changes
Choose a base branch
from

Conversation

ngoyal2707
Copy link
Contributor

No description provided.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 20, 2023
Copy link

@awgu awgu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This makes sense to me!

# Cases for when zero2 world size > 1 but less than zero3 size
zero2_world_size = dist.get_world_size(self.zero2_process_group)
zero2_rank = dist.get_rank(self.zero2_process_group)
chunks = p._full_param_padded.chunk(zero2_world_size)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just want to mention that there is a divisibility assumption here (ZeRO-2 world size divides the ZeRO-3 world size), which should always hold in practice.

if wait_for_all_gather:
torch.cuda.current_stream().wait_stream(self._streams["all_gather"])
return output_tensors

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like the only difference compared to _rebuild_full_params() is no SSD offload, no CPU offload, and using p._zero2_fp16_shard, self.zero2_process_group, and self._free_zero2_param_shard() -- this makes sense to me.

# free it until the work in the current stream completes.
p._zero2_fp16_shard.record_stream(current_stream)
free_storage_(p._zero2_fp16_shard)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like _zero2_fp16_shard is allocated in the default stream (since _zero2_shard_to_smaller_group() is called from forward() without an explicit stream context manager):

if self.reshard_after_forward:
if self.zero2_process_group is not None:
self._zero2_shard_to_smaller_group()

_zero2_fp16_shard is consumed in the "all_gather" stream, and this _free_zero2_param_shard() is called from that "all_gather" stream as well:
# Fill output_tensor with (p.data for each shard in self.world_size)
if hasattr(dist, "_all_gather_base") and enable_nccl_base_collectives:
# New version of PyTorch has all_gather_base, which is faster than chunk and then all_gather.
dist._all_gather_base(output_tensor, p._zero2_fp16_shard , group=self.zero2_process_group)
else:
chunks = list(output_tensor.chunk(self.world_size))
dist.all_gather(chunks, p._zero2_fp16_shard, group=self.zero2_process_group)
# Set p.data = output_tensor (with padding trimmed)
update_p_data(output_tensor)
if (self.mixed_precision or self.move_params_to_cpu) and not force_full_precision:
self._free_zero2_param_shard([p])
if self.move_params_to_cpu and (self.params[0].dtype == self.compute_dtype):
self._free_zero2_param_shard([p])

In that case, I do agree this p._zero2_fp16_shard.record_stream(current_stream) call is necessary to notify the caching allocator of the usage in the "all_gather" stream. However, I think the comment can be changed to say that it was allocated in the default stream. Alternatively, you can do something like _cast_fp32_param_shards_to_fp16(), but I am not sure if there is any actual overlap opportunity given the data dependencies.

with torch.cuda.stream(self._streams["fp32_to_fp16"]):
for p in params:
assert p._fp16_shard is not None
alloc_storage_(p._fp16_shard, size=p._fp32_shard.size())
p._fp16_shard.copy_(
# If move_params_to_cpu is True, this will be non-blocking
# because _fp32_shard is pinned, otherwise it's a no-op.
p._fp32_shard.to(p._fp16_shard.device, non_blocking=True)
)
p.data = p._fp16_shard
torch.cuda.current_stream().wait_stream(self._streams["fp32_to_fp16"])

@awgu
Copy link

awgu commented Jun 1, 2023

One more comment: I am not familiar with how the model checkpointing works in Fairscale FSDP, but one concern might be what happens if a user tries to checkpoint the model only after forward (and backward has not run yet). Will this work out of the box?

Perhaps, this is not a major concern for your use case since may no one will save a state dict after only forward.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants