Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: dimapihtar <[email protected]>
  • Loading branch information
dimapihtar committed May 17, 2024
1 parent ff0eb2f commit 43dcb2b
Showing 1 changed file with 63 additions and 19 deletions.
82 changes: 63 additions & 19 deletions nemo/core/optim/distributed_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,8 @@ def hook(*unused):
self._grad_copy(param)
if self.overlap_grad_sync and not getattr(param, '_disable_overlap_grad_sync', False):
self._try_start_bucket_grad_sync(
params=[param], ignore_last_bucket=need_to_initialize,
params=[param],
ignore_last_bucket=need_to_initialize,
)

return hook
Expand Down Expand Up @@ -167,10 +168,14 @@ def init_params(
# Initialize FP8 and non-FP8 tensors separately
if any(is_float8tensor(param) for param in params):
super().init_params(
filter(is_float8tensor, params), param_sync_dtype=torch.uint8, **kwargs,
filter(is_float8tensor, params),
param_sync_dtype=torch.uint8,
**kwargs,
)
super().init_params(
params, param_sync_dtype=param_sync_dtype, **kwargs,
params,
param_sync_dtype=param_sync_dtype,
**kwargs,
)

def init_params_bucket(
Expand Down Expand Up @@ -200,7 +205,10 @@ def init_params_bucket(
params = remaining_params
start_bucket_id = len(self.state["buckets"])
super().init_params_bucket(
fp32_params, grad_sync_dtype=torch.float32, param_sync_dtype=param_sync_dtype, **kwargs,
fp32_params,
grad_sync_dtype=torch.float32,
param_sync_dtype=param_sync_dtype,
**kwargs,
)
end_bucket_id = len(self.state["buckets"])
fp32_buckets = self.state["buckets"][start_bucket_id:end_bucket_id]
Expand All @@ -216,7 +224,10 @@ def init_params_bucket(
params = remaining_params
start_bucket_id = len(self.state["buckets"])
super().init_params_bucket(
fp8_params, grad_sync_dtype=grad_sync_dtype, param_sync_dtype=torch.uint8, **kwargs,
fp8_params,
grad_sync_dtype=grad_sync_dtype,
param_sync_dtype=torch.uint8,
**kwargs,
)
end_bucket_id = len(self.state["buckets"])
fp8_buckets = self.state["buckets"][start_bucket_id:end_bucket_id]
Expand All @@ -225,12 +236,18 @@ def init_params_bucket(
normal_buckets = []
start_bucket_id = len(self.state["buckets"])
super().init_params_bucket(
params, grad_sync_dtype=grad_sync_dtype, param_sync_dtype=param_sync_dtype, **kwargs,
params,
grad_sync_dtype=grad_sync_dtype,
param_sync_dtype=param_sync_dtype,
**kwargs,
)
end_bucket_id = len(self.state["buckets"])
normal_buckets = self.state["buckets"][start_bucket_id:end_bucket_id]

def add_param_to_bucket(param: torch.nn.Parameter, bucket: self.StateBucket,) -> None:
def add_param_to_bucket(
param: torch.nn.Parameter,
bucket: self.StateBucket,
) -> None:
"""Add trivial param fragment to bucket"""
param_fragments = self.state[param]["fragments"]
param_group_id = param_fragments[0].param_group_id
Expand Down Expand Up @@ -283,7 +300,11 @@ def _init_param_state(
# Initialize non-FP8 params as usual
if not is_float8tensor(param):
super()._init_param_state(
param, param_group_id, param_id, param_sync_dtype=param_sync_dtype, **kwargs,
param,
param_group_id,
param_id,
param_sync_dtype=param_sync_dtype,
**kwargs,
)

# Return immediately if already initialized
Expand All @@ -293,7 +314,11 @@ def _init_param_state(
# Initialize with FP32 copy of param
fp32_param = param.float()
super()._init_param_state(
fp32_param, param_group_id, param_id, param_sync_dtype=torch.uint8, **kwargs,
fp32_param,
param_group_id,
param_id,
param_sync_dtype=torch.uint8,
**kwargs,
)
self.state[param].update(self.state[fp32_param])
del self.state[fp32_param]
Expand Down Expand Up @@ -360,7 +385,9 @@ def init_param_buffer(self) -> None:

# Copy values into param buffer
_multi_tensor_copy(
param_flat_views, param_buffer_views, dummy_overflow_buf=self._dummy_overflow_buf,
param_flat_views,
param_buffer_views,
dummy_overflow_buf=self._dummy_overflow_buf,
)

# Make all params a view into the param buffer
Expand Down Expand Up @@ -393,7 +420,10 @@ def zero_grad(self, *args, **kwargs) -> None:
param.main_grad = self.grad_buffer_view(param)

def grad_norm(
self, parameters: Optional[Iterable[torch.nn.Parameter]] = None, norm_type: float = 2.0, force: bool = False,
self,
parameters: Optional[Iterable[torch.nn.Parameter]] = None,
norm_type: float = 2.0,
force: bool = False,
) -> torch.Tensor:
assert norm_type == 2

Expand All @@ -411,7 +441,8 @@ def grad_norm(

# Sum over all procs to get grad norm
torch.distributed.all_reduce(
grad_norm_sq, op=torch.distributed.ReduceOp.SUM,
grad_norm_sq,
op=torch.distributed.ReduceOp.SUM,
)
self._grad_norm = grad_norm_sq.sqrt()

Expand Down Expand Up @@ -479,7 +510,9 @@ def _param_copy_fragments(self, fragments: Iterable[DistributedFusedAdam.Paramet

# Copy data from parameter buckets to parameters
_multi_tensor_copy(
buffers_in, buffers_out, dummy_overflow_buf=self._dummy_overflow_buf,
buffers_in,
buffers_out,
dummy_overflow_buf=self._dummy_overflow_buf,
)

# Update transpose caches
Expand Down Expand Up @@ -570,25 +603,35 @@ def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedA
packed_scales = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device)
packed_scale_views = [packed_scales[i].view(1) for i in range(num_fp8_params)]
_multi_tensor_copy(
scales, packed_scale_views, dummy_overflow_buf=self._dummy_overflow_buf,
scales,
packed_scale_views,
dummy_overflow_buf=self._dummy_overflow_buf,
)
torch.reciprocal(packed_scales, out=packed_scales)
_multi_tensor_copy(
packed_scale_views, scale_invs, dummy_overflow_buf=self._dummy_overflow_buf,
packed_scale_views,
scale_invs,
dummy_overflow_buf=self._dummy_overflow_buf,
)

# Reduce amaxes
# Note: Assume each param has a separate amax
packed_amaxes = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device)
packed_amax_views = [packed_amaxes[i].view(1) for i in range(num_fp8_params)]
_multi_tensor_copy(
amaxes, packed_amax_views, dummy_overflow_buf=self._dummy_overflow_buf,
amaxes,
packed_amax_views,
dummy_overflow_buf=self._dummy_overflow_buf,
)
torch.distributed.all_reduce(
packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=self.distributed_process_group,
packed_amaxes,
op=torch.distributed.ReduceOp.MAX,
group=self.distributed_process_group,
)
_multi_tensor_copy(
packed_amax_views, amaxes, dummy_overflow_buf=self._dummy_overflow_buf,
packed_amax_views,
amaxes,
dummy_overflow_buf=self._dummy_overflow_buf,
)

# Reset
Expand All @@ -602,7 +645,8 @@ def sharded_state_dict(self, model_sharded_state_dict, optimizer_state_dict=None
optimizer_state_dict = self.state_dict()

id_to_sharded_param_map = get_param_id_to_sharded_param_map(
model_sharded_state_dict=model_sharded_state_dict, optim_params_iter=self.parameters(),
model_sharded_state_dict=model_sharded_state_dict,
optim_params_iter=self.parameters(),
)
# Convert state
step = optimizer_state_dict['state'].pop('step')
Expand Down

0 comments on commit 43dcb2b

Please sign in to comment.