From 66d797c2e773774807fc0177c7c97eb2eb50edee Mon Sep 17 00:00:00 2001 From: fotstrt Date: Sun, 10 Nov 2024 16:18:25 +0100 Subject: [PATCH] multi-gpu --- sailor/run_ft_train.py | 15 ++++++++++----- sailor/run_train_custom.py | 1 + src/nanotron/trainer.py | 8 ++++---- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/sailor/run_ft_train.py b/sailor/run_ft_train.py index 8bbcec9f..b42c1d80 100644 --- a/sailor/run_ft_train.py +++ b/sailor/run_ft_train.py @@ -18,6 +18,7 @@ class ElasticWorkerAgent(WorkerAgentServicer): def __init__(self, script_args): self.training_process_alive = False self.hostname = socket.gethostname() + self.gpus_per_node = 4 self.world_size = 0 self.node_rank = -1 self.master_addr = None @@ -44,8 +45,12 @@ def ConfigurationChange(self, request, context): topology_list = list(request.topology) if self.is_in_topo(topology_list): print(f"Starting new process, node rank is {self.node_rank}") - start_cmd = f"python run_train_custom.py --config-file {self.script_args.config_file} --world-size {self.world_size} --rank {self.node_rank} --master-ip {self.master_addr}" - os.system(start_cmd) + start_cmd_base = f"python run_train_custom.py --config-file {self.script_args.config_file} --world-size {self.world_size} --master-ip {self.master_addr}" + for i in range(self.gpus_per_node): + print(f"Start for process {i}") + rank_i = self.node_rank*self.gpus_per_node + i + start_cmd_i = start_cmd_base + f" --rank {rank_i} &" + os.system(start_cmd_i) self.training_process_alive = True return WorkerConfigurationResponse() @@ -53,7 +58,7 @@ def is_in_topo(self, topology): if self.hostname not in topology: return False self.node_rank = topology.index(self.hostname) - self.world_size = len(topology) + self.world_size = len(topology) * self.gpus_per_node self.master_addr = topology[0] return True @@ -73,8 +78,8 @@ def is_in_topo(self, topology): server.add_insecure_port(f'[::]:{args.grpc_port}') def terminate(signum, _): - if agent.training_process is not None: - agent.training_process.terminate() + if not agent.training_process_alive: + os.system("pkill -f run_train_custom.py") done = server.stop(5) done.wait() print(f"Received {signum}, stop complete!") diff --git a/sailor/run_train_custom.py b/sailor/run_train_custom.py index 2dc2e811..8667ce1c 100644 --- a/sailor/run_train_custom.py +++ b/sailor/run_train_custom.py @@ -236,6 +236,7 @@ def get_args(): os.environ['WORLD_SIZE'] = str(args.world_size) os.environ['RANK'] = str(args.rank) + os.environ['LOCAL_RANK'] = str(args.rank % 4) os.environ['MASTER_ADDR'] = args.master_ip os.environ['MASTER_PORT'] = "1234" # TODO diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 21251a32..98cdf8c7 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -105,10 +105,10 @@ "Starcoder2Config": Starcoder2ForTraining, } -try: - import wandb -except ImportError: - wandb = None +#try: +# import wandb +#except ImportError: +wandb = None class DistributedTrainer: