From 5b4ae479f2969a19610f2f86b4ce61b8e0ec479b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alejandro=20Hern=C3=A1ndez=20Cano?= Date: Fri, 22 Sep 2023 15:59:45 +0200 Subject: [PATCH] Small optimization: replaced gather() with allreduce() --- finetune.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/finetune.py b/finetune.py index bd96949..22e6b7e 100644 --- a/finetune.py +++ b/finetune.py @@ -128,10 +128,10 @@ def get_batch(data_iterator): n_tokens = n_tokens.item() else: group = get_data_parallel_group() - token_counts = torch.zeros(args.data_parallel_size, dtype=torch.long, - device=tokens.device) - torch.distributed.all_gather_into_tensor(token_counts, n_tokens, group=group) - n_tokens = torch.sum(token_counts).item() + torch.distributed.all_reduce( + n_tokens, op=torch.distributed.ReduceOp.SUM, group=group + ) + n_tokens = n_tokens.item() counters["tokens"] += n_tokens if args.data_type == "gpt":