diff --git a/nemo/core/optim/distributed_adam.py b/nemo/core/optim/distributed_adam.py index f58fb7352c38..50ad38978476 100644 --- a/nemo/core/optim/distributed_adam.py +++ b/nemo/core/optim/distributed_adam.py @@ -136,7 +136,8 @@ def hook(*unused): self._grad_copy(param) if self.overlap_grad_sync and not getattr(param, '_disable_overlap_grad_sync', False): self._try_start_bucket_grad_sync( - params=[param], ignore_last_bucket=need_to_initialize, + params=[param], + ignore_last_bucket=need_to_initialize, ) return hook @@ -167,10 +168,14 @@ def init_params( # Initialize FP8 and non-FP8 tensors separately if any(is_float8tensor(param) for param in params): super().init_params( - filter(is_float8tensor, params), param_sync_dtype=torch.uint8, **kwargs, + filter(is_float8tensor, params), + param_sync_dtype=torch.uint8, + **kwargs, ) super().init_params( - params, param_sync_dtype=param_sync_dtype, **kwargs, + params, + param_sync_dtype=param_sync_dtype, + **kwargs, ) def init_params_bucket( @@ -200,7 +205,10 @@ def init_params_bucket( params = remaining_params start_bucket_id = len(self.state["buckets"]) super().init_params_bucket( - fp32_params, grad_sync_dtype=torch.float32, param_sync_dtype=param_sync_dtype, **kwargs, + fp32_params, + grad_sync_dtype=torch.float32, + param_sync_dtype=param_sync_dtype, + **kwargs, ) end_bucket_id = len(self.state["buckets"]) fp32_buckets = self.state["buckets"][start_bucket_id:end_bucket_id] @@ -216,7 +224,10 @@ def init_params_bucket( params = remaining_params start_bucket_id = len(self.state["buckets"]) super().init_params_bucket( - fp8_params, grad_sync_dtype=grad_sync_dtype, param_sync_dtype=torch.uint8, **kwargs, + fp8_params, + grad_sync_dtype=grad_sync_dtype, + param_sync_dtype=torch.uint8, + **kwargs, ) end_bucket_id = len(self.state["buckets"]) fp8_buckets = self.state["buckets"][start_bucket_id:end_bucket_id] @@ -225,12 +236,18 @@ def init_params_bucket( normal_buckets = [] start_bucket_id = len(self.state["buckets"]) super().init_params_bucket( - params, grad_sync_dtype=grad_sync_dtype, param_sync_dtype=param_sync_dtype, **kwargs, + params, + grad_sync_dtype=grad_sync_dtype, + param_sync_dtype=param_sync_dtype, + **kwargs, ) end_bucket_id = len(self.state["buckets"]) normal_buckets = self.state["buckets"][start_bucket_id:end_bucket_id] - def add_param_to_bucket(param: torch.nn.Parameter, bucket: self.StateBucket,) -> None: + def add_param_to_bucket( + param: torch.nn.Parameter, + bucket: self.StateBucket, + ) -> None: """Add trivial param fragment to bucket""" param_fragments = self.state[param]["fragments"] param_group_id = param_fragments[0].param_group_id @@ -283,7 +300,11 @@ def _init_param_state( # Initialize non-FP8 params as usual if not is_float8tensor(param): super()._init_param_state( - param, param_group_id, param_id, param_sync_dtype=param_sync_dtype, **kwargs, + param, + param_group_id, + param_id, + param_sync_dtype=param_sync_dtype, + **kwargs, ) # Return immediately if already initialized @@ -293,7 +314,11 @@ def _init_param_state( # Initialize with FP32 copy of param fp32_param = param.float() super()._init_param_state( - fp32_param, param_group_id, param_id, param_sync_dtype=torch.uint8, **kwargs, + fp32_param, + param_group_id, + param_id, + param_sync_dtype=torch.uint8, + **kwargs, ) self.state[param].update(self.state[fp32_param]) del self.state[fp32_param] @@ -360,7 +385,9 @@ def init_param_buffer(self) -> None: # Copy values into param buffer _multi_tensor_copy( - param_flat_views, param_buffer_views, dummy_overflow_buf=self._dummy_overflow_buf, + param_flat_views, + param_buffer_views, + dummy_overflow_buf=self._dummy_overflow_buf, ) # Make all params a view into the param buffer @@ -393,7 +420,10 @@ def zero_grad(self, *args, **kwargs) -> None: param.main_grad = self.grad_buffer_view(param) def grad_norm( - self, parameters: Optional[Iterable[torch.nn.Parameter]] = None, norm_type: float = 2.0, force: bool = False, + self, + parameters: Optional[Iterable[torch.nn.Parameter]] = None, + norm_type: float = 2.0, + force: bool = False, ) -> torch.Tensor: assert norm_type == 2 @@ -411,7 +441,8 @@ def grad_norm( # Sum over all procs to get grad norm torch.distributed.all_reduce( - grad_norm_sq, op=torch.distributed.ReduceOp.SUM, + grad_norm_sq, + op=torch.distributed.ReduceOp.SUM, ) self._grad_norm = grad_norm_sq.sqrt() @@ -479,7 +510,9 @@ def _param_copy_fragments(self, fragments: Iterable[DistributedFusedAdam.Paramet # Copy data from parameter buckets to parameters _multi_tensor_copy( - buffers_in, buffers_out, dummy_overflow_buf=self._dummy_overflow_buf, + buffers_in, + buffers_out, + dummy_overflow_buf=self._dummy_overflow_buf, ) # Update transpose caches @@ -570,11 +603,15 @@ def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedA packed_scales = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device) packed_scale_views = [packed_scales[i].view(1) for i in range(num_fp8_params)] _multi_tensor_copy( - scales, packed_scale_views, dummy_overflow_buf=self._dummy_overflow_buf, + scales, + packed_scale_views, + dummy_overflow_buf=self._dummy_overflow_buf, ) torch.reciprocal(packed_scales, out=packed_scales) _multi_tensor_copy( - packed_scale_views, scale_invs, dummy_overflow_buf=self._dummy_overflow_buf, + packed_scale_views, + scale_invs, + dummy_overflow_buf=self._dummy_overflow_buf, ) # Reduce amaxes @@ -582,13 +619,19 @@ def _check_params_shard_dtypes(self, params_buckets: Dict[int, DistributedFusedA packed_amaxes = torch.empty(num_fp8_params, dtype=torch.float32, device=self.device) packed_amax_views = [packed_amaxes[i].view(1) for i in range(num_fp8_params)] _multi_tensor_copy( - amaxes, packed_amax_views, dummy_overflow_buf=self._dummy_overflow_buf, + amaxes, + packed_amax_views, + dummy_overflow_buf=self._dummy_overflow_buf, ) torch.distributed.all_reduce( - packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=self.distributed_process_group, + packed_amaxes, + op=torch.distributed.ReduceOp.MAX, + group=self.distributed_process_group, ) _multi_tensor_copy( - packed_amax_views, amaxes, dummy_overflow_buf=self._dummy_overflow_buf, + packed_amax_views, + amaxes, + dummy_overflow_buf=self._dummy_overflow_buf, ) # Reset @@ -602,7 +645,8 @@ def sharded_state_dict(self, model_sharded_state_dict, optimizer_state_dict=None optimizer_state_dict = self.state_dict() id_to_sharded_param_map = get_param_id_to_sharded_param_map( - model_sharded_state_dict=model_sharded_state_dict, optim_params_iter=self.parameters(), + model_sharded_state_dict=model_sharded_state_dict, + optim_params_iter=self.parameters(), ) # Convert state step = optimizer_state_dict['state'].pop('step')