Skip to content

Commit

Permalink
Small optimization: replaced gather() with allreduce()
Browse files Browse the repository at this point in the history
  • Loading branch information
AleHD committed Sep 22, 2023
1 parent 339cf5b commit 5b4ae47
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,10 @@ def get_batch(data_iterator):
n_tokens = n_tokens.item()
else:
group = get_data_parallel_group()
token_counts = torch.zeros(args.data_parallel_size, dtype=torch.long,
device=tokens.device)
torch.distributed.all_gather_into_tensor(token_counts, n_tokens, group=group)
n_tokens = torch.sum(token_counts).item()
torch.distributed.all_reduce(
n_tokens, op=torch.distributed.ReduceOp.SUM, group=group
)
n_tokens = n_tokens.item()
counters["tokens"] += n_tokens

if args.data_type == "gpt":
Expand Down

0 comments on commit 5b4ae47

Please sign in to comment.