Skip to content

Commit

Permalink
feat(test): support testing optcast ring-allreduce implementation
Browse files Browse the repository at this point in the history
Signed-off-by: Wataru Ishida <[email protected]>
  • Loading branch information
ishidawataru committed Mar 11, 2024
1 parent a128f7c commit a3223d0
Showing 1 changed file with 26 additions and 7 deletions.
33 changes: 26 additions & 7 deletions test/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import os
import re
import argparse
import time
from functools import reduce
from dateutil import parser

epoch = 0

Expand Down Expand Up @@ -186,6 +186,8 @@ def get_time(line):


def analyze_optcast_client_log(filename, output, xlim=None, no_plot=False):
from dateutil import parser

global epoch
epoch = 0

Expand Down Expand Up @@ -281,6 +283,8 @@ def get_time(line):


def analyze_server_log(filename, output, xlim=None, no_plot=False):
from dateutil import parser

global epoch
epoch = 0

Expand Down Expand Up @@ -467,14 +471,30 @@ async def client(args):
rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
if args.no_gpu:
client_cmd = f"{args.shared_dir}/{SERVER_CMD}"
count = parse_chunksize(args.chunksize) // (4 if args.data_type == "f32" else 2)
cmd = f"{client_cmd} -c -a {args.reduction_servers} --count {count} --try-count 1000 --nreq 4"
os.environ["RUST_LOG"] = "TRACE" if rank == 0 else "INFO"

nreq = 4
try_count = 1000
count = parse_chunksize(args.chunksize) // (
4 if args.data_type == "f32" else 2
)
if args.type == "optcast":
cmd = f"{client_cmd} -c -a {args.reduction_servers} --count {count} --try-count {try_count} --nreq {nreq}"
elif args.type == "ring":
with open(args.config) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
if args.nrank == 0:
args.nrank = len(config["clients"])
clients = config["clients"][: args.nrank]
neighs = [(rank + 1 + i) % args.nrank for i in range(2)]
addrs = ",".join(
clients[j]["name"] + f":{8080+i}" for i, j in enumerate(neighs)
)
cmd = f"{client_cmd} -a {addrs} --reduce-threads 2 --count {count} --try-count {try_count} --nrank {args.nrank} --ring-rank {rank+1} --nreq {nreq}"
else:
dt = "float" if args.data_type == "f32" else "half"
client_cmd = f"{args.shared_dir}/{CLIENT_CMD}"
cmd = f"{client_cmd} -d {dt} -e {args.size} -b {args.size} {args.nccl_test_options}"
# print(f"[{platform.node()}] client:", cmd, file=sys.stderr)

os.environ["NCCL_DEBUG"] = "TRACE" if rank == 0 else "INFO"
os.environ["NCCL_P2P_DISABLE"] = "1"
Expand Down Expand Up @@ -558,7 +578,7 @@ async def run(
clients = config["clients"][: args.nrank]

if args.no_gpu:
if args.type not in ["optcast"]:
if args.type not in ["optcast", "ring"]:
raise ValueError(f"no-gpu option doesn't work with {args.type}")

reduction_servers = ",".join(
Expand All @@ -574,7 +594,6 @@ async def run(
f"--client --reduction-servers {reduction_servers}",
)
)
print("client:", cmd)
client = await asyncio.create_subprocess_shell(
cmd,
stdout=subprocess.PIPE,
Expand Down Expand Up @@ -688,7 +707,7 @@ def arguments():
parser.add_argument("--nsplit", default=1, type=int)
parser.add_argument("--reduction-servers")
parser.add_argument(
"--type", choices=["optcast", "sharp", "nccl"], default="optcast"
"--type", choices=["optcast", "sharp", "nccl", "ring"], default="optcast"
)
parser.add_argument("--nccl-test-options", default="-c 1 -n 1 -w 1")
parser.add_argument("--data-type", default="f32", choices=["f32", "f16"])
Expand Down

0 comments on commit a3223d0

Please sign in to comment.