Skip to content

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
AleHD committed Jun 27, 2024
1 parent ed1ca7d commit 9b0de5b
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/nanotron/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
"""PyTorch LLaMa model."""

from typing import Dict, Optional, Union, List
from typing import Dict, Optional, Union

import torch
from torch import nn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def forward(ctx, tensor, group: Optional[ProcessGroup]):
# https://cs.github.com/pytorch/pytorch/blob/2b267fa7f28e18ca6ea1de4201d2541a40411457/torch/distributed/nn/functional.py#L305
tensor = tensor.contiguous()

sharded_tensor = MemoryBuffer().get("dist", (unsharded_batch_size//group.size(), *rest_size), dtype=tensor.dtype)
sharded_tensor = MemoryBuffer().get(
"dist", (unsharded_batch_size // group.size(), *rest_size), dtype=tensor.dtype
)
dist.reduce_scatter_tensor(sharded_tensor, tensor, group=group, op=dist.ReduceOp.SUM)
return sharded_tensor

Expand Down
7 changes: 4 additions & 3 deletions src/nanotron/parallel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from torch import nn

from nanotron import distributed as dist
from nanotron.utils import Singleton
from nanotron.parallel import ParallelContext
from nanotron.parallel.tied_parameters import get_tied_id_to_param
from nanotron.utils import Singleton


class MemoryBuffer(metaclass=Singleton):
Expand All @@ -22,8 +22,9 @@ def __init__(self):
def get(self, name: str, shape: tuple[int], dtype: torch.dtype = torch.bfloat16) -> torch.Tensor:
required_numel = functools.reduce(operator.mul, shape, 1)
if (name, dtype) not in self.buffer or self.buffer[name, dtype].numel() < required_numel:
self.buffer[name, dtype] = torch.empty(required_numel, dtype=dtype, device=torch.cuda.current_device(),
requires_grad=False)
self.buffer[name, dtype] = torch.empty(
required_numel, dtype=dtype, device=torch.cuda.current_device(), requires_grad=False
)
return self.buffer[name, dtype][:required_numel].view(shape)


Expand Down
8 changes: 5 additions & 3 deletions src/nanotron/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import functools
import inspect
import math
import os
import random
import socket
from contextlib import ExitStack, contextmanager
from typing import Callable, ContextManager, List, Optional
from typing import ContextManager, List, Optional

import torch
from packaging import version
Expand All @@ -25,7 +24,9 @@ class Logger(metaclass=Singleton):
...
```
"""

_instances = {}

def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
Expand Down Expand Up @@ -69,7 +70,7 @@ def main_rank_first(group: dist.ProcessGroup):
@contextmanager
def local_ranks_zero_first(group: Optional[dist.ProcessGroup] = None):
"""Context manager that executes the code in the context with all the local rank zero of the group going first.
Usefull to run only once per node first (e.g. to create local files, etc)
Useful to run only once per node first (e.g. to create local files, etc)
"""
is_main = int(os.environ.get("LOCAL_RANK", 0)) == 0
if is_main:
Expand Down Expand Up @@ -140,6 +141,7 @@ def get_untyped_storage(tensor: torch.Tensor) -> torch.UntypedStorage:
else:
return tensor.storage().untyped()


def tensor_from_untyped_storage(untyped_storage: torch.UntypedStorage, dtype: torch.dtype):
# TODO @thomasw21: Figure out what's the best Pytorch way of building a tensor from a storage.
device = untyped_storage.device
Expand Down

0 comments on commit 9b0de5b

Please sign in to comment.