Skip to content

Commit

Permalink
remove distributed flag (#86)
Browse files Browse the repository at this point in the history
  • Loading branch information
mshuaibii authored Oct 4, 2020
1 parent af6f9d7 commit 421e0f3
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
15 changes: 5 additions & 10 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ def main(config):
amp=config.get("amp", False),
)
import time

start_time = time.time()
trainer.train()
distutils.synchronize()
print('Time = ', time.time() - start_time)
print("Time = ", time.time() - start_time)


def distributed_main(config):
Expand Down Expand Up @@ -61,19 +62,13 @@ def distributed_main(config):
slurm_partition=args.slurm_partition,
gpus_per_node=args.num_gpus,
cpus_per_task=(args.num_workers + 1),
tasks_per_node=(args.num_gpus if args.distributed else 1),
tasks_per_node=args.num_gpus,
nodes=args.num_nodes,
)
if args.distributed:
jobs = executor.map_array(distributed_main, configs)
else:
jobs = executor.map_array(main, configs)
jobs = executor.map_array(distributed_main, configs)
print("Submitted jobs:", ", ".join([job.job_id for job in jobs]))
log_file = save_experiment_log(args, jobs, configs)
print(f"Experiment log saved to: {log_file}")

else: # Run locally
if args.distributed:
distributed_main(config)
else:
main(config)
distributed_main(config)
18 changes: 12 additions & 6 deletions ocpmodels/common/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,16 +91,22 @@ def add_core_args(self):
"--num-gpus", default=1, type=int, help="Number of GPUs to request"
)
self.parser.add_argument(
"--num-nodes", default=1, type=int, help="Number of Nodes to request"
)
self.parser.add_argument(
"--distributed", action='store_true', help='Run with DDP'
"--num-nodes",
default=1,
type=int,
help="Number of Nodes to request",
)
self.parser.add_argument(
"--distributed-port", type=int, default=13356, help='Port on master for DDP'
"--distributed-port",
type=int,
default=13356,
help="Port on master for DDP",
)
self.parser.add_argument(
"--distributed-backend", type=str, default='nccl', help='Backend for DDP'
"--distributed-backend",
type=str,
default="nccl",
help="Backend for DDP",
)
self.parser.add_argument(
"--local_rank", default=0, type=int, help="Local rank"
Expand Down

0 comments on commit 421e0f3

Please sign in to comment.