Skip to content

Commit

Permalink
add param (PaddlePaddle#9481)
Browse files Browse the repository at this point in the history
  • Loading branch information
AndSonder authored Nov 22, 2024
1 parent 9494e9a commit 1ba3bef
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1361,21 +1361,13 @@ def is_segment_parallel_supported():
strategy.hybrid_configs["sharding_configs"].comm_buffer_size_MB = int(
self.sharding_comm_buffer_size_MB
)
# The `comm_buffer_size_MB` is added directly to sharding properties
# for semi-auto mode, avoiding potential confusion with strategy config,
# as parameters in semi-auto mode are managed via strategy.
strategy.sharding.comm_buffer_size_MB = int(self.sharding_comm_buffer_size_MB)

if "split_param" in sharding_parallel_config:
strategy.hybrid_configs["sharding_configs"].split_param = True
assert self.amp_master_grad, "Currently sharding stage1 v2 only support amp_master_grad"

if "enable_release_grads" in sharding_parallel_config:
strategy.hybrid_configs["sharding_configs"].release_gradients = True
# `release_gradients` is set directly in sharding properties for the same
# reason as `comm_buffer_size_MB`, to avoid confusion with centralized
# strategy management in semi-auto mode.
strategy.sharding.release_gradients = True

if self.pipeline_parallel_degree == 1:
strategy.hybrid_configs["sharding_configs"].tensor_fusion = (
Expand Down Expand Up @@ -1588,6 +1580,8 @@ def is_segment_parallel_supported():
sharding.stage = 2
elif ShardingOption.FULL_SHARD in self.sharding:
sharding.stage = 3
if self.sharding_comm_buffer_size_MB > 0:
sharding.comm_buffer_size_MB = int(self.sharding_comm_buffer_size_MB)

sharding_parallel_config = split_parallel_config(self.sharding_parallel_config)
for x in sharding_parallel_config:
Expand All @@ -1596,6 +1590,7 @@ def is_segment_parallel_supported():
"enable_stage1_tensor_fusion",
"enable_stage1_overlap",
"enable_stage2_overlap",
"enable_release_grads",
]:
raise ValueError(
f"Found unknown pipeline mode config {x}, " f"accpet config is reduce_overlap."
Expand All @@ -1610,6 +1605,9 @@ def is_segment_parallel_supported():
if "enable_stage1_tensor_fusion" in sharding_parallel_config:
sharding.grad_bucket_size_numel = 210355872

if "enable_release_grads" in sharding_parallel_config:
sharding.release_gradients = True

if self.bf16 or self.fp16:
amp = strategy.amp
amp.enable = True
Expand Down

0 comments on commit 1ba3bef

Please sign in to comment.