-
Notifications
You must be signed in to change notification settings - Fork 4
/
run_lsgan.sh
executable file
·66 lines (66 loc) · 1.94 KB
/
run_lsgan.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
#!/bin/bash
nodes=$9 #"nodes"
n=$1
E=$2
C=$3
B=$4
s=$5
w=$6
model=$7
magic=$8
num_servers=$n
iid=${10}
job_id=${11}
port=2379
master=''
fid_master=''
num_machines=0
while read node; do
num_machines=$((num_machines+1))
if [ $num_machines -eq 1 ]
then
master=$node
fid_master=$node
fi
# if [ $num_machines -eq 2 ]
# then
# fid_master=$node
# fi
done < $nodes
num_gpus=$(((num_machines-1)*2+1))
echo "Total number of GPUs for workers in this deployment: $num_gpus"
node_per_gpu=$((n/num_gpus))
echo "Number of nodes per GPU $node_per_gpu"
if [ $((n%node_per_gpu)) != 0 ] || [[ $node_per_gpu -eq 1 && $n != $node_per_gpu ]]
then
echo "ERROR: Choose a value of n that is divisible by $node_per_gpu in this setup; the current value for n is $n"
echo "This run file is only to help deploy FeGAN; feel free to write your own run file with whatever parameters you like"
exit
fi
if [ $node_per_gpu -gt 16 ]
then
echo "WARNING: The number of nodes per GPU to be placed is HUGE...the maximum number is 16...You can go to 25 nodes on 96 GB GPU.."
fi
pwd=`pwd`
common="$pwd/dist-lsgan.py --size $n --master $fid_master --local_steps $E --frac_workers $C --batch_size $B --sample $s --weight_avg $w --model $model --port $port --magic_num $magic --iid $iid"
r=0
while read node; do
if [ $r -lt $node_per_gpu ]
then
node_per_machine=$node_per_gpu
else
node_per_machine=$((node_per_gpu*2))
fi
for i in `seq 0 $((node_per_machine-1))`;
do
if false; #[ $r -lt $num_servers ]
then
cmd="python -m torchelastic.distributed.launch --nnodes=$num_servers --nproc_per_node=1 --rdzv_id=$job_id --rdzv_backend=etcd --rdzv_endpoint=$master:$port $common --rank $r"
else
cmd="python3 $common --rank $r"
fi
echo "running $cmd on $node"
ssh $node $cmd < /dev/tty &
r=$((r+1))
done
done < $nodes