Skip to content

Commit

Permalink
multi-gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
fotstrt committed Nov 10, 2024
1 parent dae940c commit 66d797c
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
15 changes: 10 additions & 5 deletions sailor/run_ft_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class ElasticWorkerAgent(WorkerAgentServicer):
def __init__(self, script_args):
self.training_process_alive = False
self.hostname = socket.gethostname()
self.gpus_per_node = 4
self.world_size = 0
self.node_rank = -1
self.master_addr = None
Expand All @@ -44,16 +45,20 @@ def ConfigurationChange(self, request, context):
topology_list = list(request.topology)
if self.is_in_topo(topology_list):
print(f"Starting new process, node rank is {self.node_rank}")
start_cmd = f"python run_train_custom.py --config-file {self.script_args.config_file} --world-size {self.world_size} --rank {self.node_rank} --master-ip {self.master_addr}"
os.system(start_cmd)
start_cmd_base = f"python run_train_custom.py --config-file {self.script_args.config_file} --world-size {self.world_size} --master-ip {self.master_addr}"
for i in range(self.gpus_per_node):
print(f"Start for process {i}")
rank_i = self.node_rank*self.gpus_per_node + i
start_cmd_i = start_cmd_base + f" --rank {rank_i} &"
os.system(start_cmd_i)
self.training_process_alive = True
return WorkerConfigurationResponse()

def is_in_topo(self, topology):
if self.hostname not in topology:
return False
self.node_rank = topology.index(self.hostname)
self.world_size = len(topology)
self.world_size = len(topology) * self.gpus_per_node
self.master_addr = topology[0]
return True

Expand All @@ -73,8 +78,8 @@ def is_in_topo(self, topology):
server.add_insecure_port(f'[::]:{args.grpc_port}')

def terminate(signum, _):
if agent.training_process is not None:
agent.training_process.terminate()
if not agent.training_process_alive:
os.system("pkill -f run_train_custom.py")
done = server.stop(5)
done.wait()
print(f"Received {signum}, stop complete!")
Expand Down
1 change: 1 addition & 0 deletions sailor/run_train_custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ def get_args():

os.environ['WORLD_SIZE'] = str(args.world_size)
os.environ['RANK'] = str(args.rank)
os.environ['LOCAL_RANK'] = str(args.rank % 4)
os.environ['MASTER_ADDR'] = args.master_ip
os.environ['MASTER_PORT'] = "1234" # TODO

Expand Down
8 changes: 4 additions & 4 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,10 @@
"Starcoder2Config": Starcoder2ForTraining,
}

try:
import wandb
except ImportError:
wandb = None
#try:
# import wandb
#except ImportError:
wandb = None


class DistributedTrainer:
Expand Down

0 comments on commit 66d797c

Please sign in to comment.