Skip to content

Commit

Permalink
fix killing and add general script
Browse files Browse the repository at this point in the history
  • Loading branch information
fotstrt committed Nov 10, 2024
1 parent b3c5d6f commit d72159e
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 4 deletions.
20 changes: 20 additions & 0 deletions sailor/run.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os
import argparse

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Arguments for Agent')
parser.add_argument('--grpc_port', type=int,
help='Port to start grpc server', required=True)
parser.add_argument('--training_master_port', type=int,
help='Port used for training', required=True)
parser.add_argument("--config-file", type=str, required=True,
help="Path to the YAML or python config file")
parser.add_argument('--world_size', type=int,
help='world_size (in number of nodes)', required=True)
args = parser.parse_args()

os.system(f"python /workspace/nanotron/sailor/run_ft_train.py --grpc_port {args.grpc_port} --config-file {args.config_file} &")
node_id = os.environ["SLURM_NODEID"]
if node_id=="0":
print("Start controller")
os.system(f"python /workspace/nanotron/sailor/controller.py --grpc_port {args.grpc_port} --training_master_port {args.training_master_port} --world_size {args.world_size} &")
9 changes: 5 additions & 4 deletions sailor/run_ft_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def Kill(self, request, context):
print(f"Killing local process ...")
if self.training_process_alive:
print("HERE")
os.system("pkill -f run_train_custom.py") # TODO: check cleanup
os.system("pkill -f run_train_custom") # TODO: check cleanup
self.training_process_alive = False
# TODO: check abort
return KillResponse()
Expand All @@ -49,7 +49,7 @@ def ConfigurationChange(self, request, context):
for i in range(self.gpus_per_node):
print(f"Start for process {i}")
rank_i = self.node_rank*self.gpus_per_node + i
start_cmd_i = start_cmd_base + f" --rank {rank_i} &"
start_cmd_i = start_cmd_base + f" --rank {rank_i} > log_{i} &"
os.system(start_cmd_i)
self.training_process_alive = True
return WorkerConfigurationResponse()
Expand Down Expand Up @@ -78,8 +78,9 @@ def is_in_topo(self, topology):
server.add_insecure_port(f'[::]:{args.grpc_port}')

def terminate(signum, _):
if not agent.training_process_alive:
os.system("pkill -f run_train_custom.py")
if agent.training_process_alive:
print(f"KILL ALL")
os.system("pkill -f run_train_custom")
done = server.stop(5)
done.wait()
print(f"Received {signum}, stop complete!")
Expand Down

0 comments on commit d72159e

Please sign in to comment.