Skip to content

Commit

Permalink
add distributed backend param
Browse files Browse the repository at this point in the history
  • Loading branch information
abhishekkrthakur committed Sep 19, 2024
1 parent 2831933 commit 52eb2d4
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 58 deletions.
1 change: 1 addition & 0 deletions src/autotrain/app/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
padding="right",
chat_template="none",
max_completion_length=128,
distributed_backend="ddp",
).model_dump()

PARAMS["text-classification"] = TextClassificationParams(
Expand Down
128 changes: 70 additions & 58 deletions src/autotrain/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,75 @@
from autotrain.trainers.vlm.params import VLMTrainingParams


CPU_COMMAND = [
"accelerate",
"launch",
"--cpu",
]

SINGLE_GPU_COMMAND = [
"accelerate",
"launch",
"--num_machines",
"1",
"--num_processes",
"1",
]


def get_accelerate_command(num_gpus, gradient_accumulation_steps=1, distributed_backend=None):
"""
Returns the accelerate command based on the number of GPUs available.
Args:
num_gpus: Number of GPUs available.
gradient_accumulation_steps: Number of gradient accumulation steps.
distributed_backend: Distributed backend to use: ddp, deepspeed, None.
Returns:
List: Accelerate command.
"""
if num_gpus == 0:
logger.warning("No GPU found. Forcing training on CPU. This will be super slow!")
return CPU_COMMAND

if num_gpus == 1:
return SINGLE_GPU_COMMAND

if distributed_backend in ("ddp", None):
return [
"accelerate",
"launch",
"--multi_gpu",
"--num_machines",
"1",
"--num_processes",
str(num_gpus),
]
elif distributed_backend == "deepspeed":
return [
"accelerate",
"launch",
"--use_deepspeed",
"--zero_stage",
"3",
"--offload_optimizer_device",
"none",
"--offload_param_device",
"none",
"--zero3_save_16bit_model",
"true",
"--zero3_init_flag",
"true",
"--deepspeed_multinode_launcher",
"standard",
"--gradient_accumulation_steps",
str(gradient_accumulation_steps),
]
else:
raise ValueError("Unsupported distributed backend")


def launch_command(params):
"""
Launches training command based on the given parameters.
Expand All @@ -43,64 +112,7 @@ def launch_command(params):
else:
num_gpus = 0
if isinstance(params, LLMTrainingParams):
if num_gpus == 0:
logger.warning("No GPU found. Forcing training on CPU. This will be super slow!")
cmd = [
"accelerate",
"launch",
"--cpu",
]
elif num_gpus == 1:
cmd = [
"accelerate",
"launch",
"--num_machines",
"1",
"--num_processes",
"1",
]
elif num_gpus == 2:
cmd = [
"accelerate",
"launch",
"--multi_gpu",
"--num_machines",
"1",
"--num_processes",
"2",
]
else:
if params.quantization in ("int8", "int4") and params.peft and params.mixed_precision == "bf16":
cmd = [
"accelerate",
"launch",
"--multi_gpu",
"--num_machines",
"1",
"--num_processes",
str(num_gpus),
]
else:
cmd = [
"accelerate",
"launch",
"--use_deepspeed",
"--zero_stage",
"3",
"--offload_optimizer_device",
"none",
"--offload_param_device",
"none",
"--zero3_save_16bit_model",
"true",
"--zero3_init_flag",
"true",
"--deepspeed_multinode_launcher",
"standard",
"--gradient_accumulation_steps",
str(params.gradient_accumulation),
]

cmd = get_accelerate_command(num_gpus, params.gradient_accumulation, params.distributed_backend)
if num_gpus > 0:
cmd.append("--mixed_precision")
if params.mixed_precision == "fp16":
Expand Down
1 change: 1 addition & 0 deletions src/autotrain/trainers/clm/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,4 @@ class LLMTrainingParams(AutoTrainParams):

# unsloth
unsloth: bool = Field(False, title="Use unsloth")
distributed_backend: Optional[str] = Field(None, title="Distributed backend")

0 comments on commit 52eb2d4

Please sign in to comment.