Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unet3d] - Add infinite data loader to align epochs->samples transition #697

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
7 changes: 3 additions & 4 deletions image_segmentation/pytorch/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,14 @@ ARG FROM_IMAGE_NAME=pytorch/pytorch:1.7.1-cuda11.0-cudnn8-runtime
#ARG FROM_IMAGE_NAME=nvcr.io/nvidia/pytorch:21.02-py3
FROM ${FROM_IMAGE_NAME}

ADD . /workspace/unet3d
WORKDIR /workspace/unet3d

RUN apt-get update && \
apt-get upgrade -y && \
apt-get install -y git
RUN apt-get install -y vim

ADD . /workspace/unet3d
WORKDIR /workspace/unet3d

RUN pip install --upgrade pip
RUN pip install --disable-pip-version-check -r requirements.txt

#RUN pip uninstall -y apex; pip uninstall -y apex; git clone --branch seryilmaz/fused_dropout_softmax https://github.com/seryilmaz/apex.git; cd apex; pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--xentropy" --global-option="--deprecated_fused_adam" --global-option="--deprecated_fused_lamb" --global-option="--fast_multihead_attn" .
2 changes: 1 addition & 1 deletion image_segmentation/pytorch/data_loading/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_data_loaders(flags, num_shards, global_rank):
raise ValueError(f"Loader {flags.loader} unknown. Valid loaders are: synthetic, pytorch")

# The DistributedSampler seed should be the same for all workers
train_sampler = DistributedSampler(train_dataset, seed=flags.shuffling_seed, drop_last=True) if num_shards > 1 else None
train_sampler = None#, DistributedSampler(train_dataset, seed=flags.shuffling_seed, drop_last=True) if num_shards > 1 else None
val_sampler = None

train_dataloader = DataLoader(train_dataset,
Expand Down
5 changes: 3 additions & 2 deletions image_segmentation/pytorch/data_loading/pytorch_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,13 @@ def __init__(self, images, labels, **kwargs):
patch_size, oversampling = kwargs["patch_size"], kwargs["oversampling"]
self.patch_size = patch_size
self.rand_crop = RandBalancedCrop(patch_size=patch_size, oversampling=oversampling)
self.real_len = len(self.images)

def __len__(self):
return len(self.images)
return int(168*10000) #len(self.images)

def __getitem__(self, idx):
data = {"image": np.load(self.images[idx]), "label": np.load(self.labels[idx])}
data = {"image": np.load(self.images[idx % self.real_len]), "label": np.load(self.labels[idx % self.real_len])}
data = self.rand_crop(data)
data = self.train_transforms(data)
return data["image"], data["label"]
Expand Down
6 changes: 0 additions & 6 deletions image_segmentation/pytorch/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@


def main():
mllog.config(filename=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'unet3d.log'))
mllog.config(filename=os.path.join("/results", 'unet3d.log'))
mllogger = mllog.get_mllogger()
mllogger.logger.propagate = False
Expand Down Expand Up @@ -50,11 +49,6 @@ def main():
mllog_end(key=constants.INIT_STOP, sync=True)
mllog_start(key=constants.RUN_START, sync=True)
train_dataloader, val_dataloader = get_data_loaders(flags, num_shards=world_size, global_rank=local_rank)
samples_per_epoch = world_size * len(train_dataloader) * flags.batch_size
mllog_event(key='samples_per_epoch', value=samples_per_epoch, sync=False)
flags.evaluate_every = flags.evaluate_every or ceil(20*DATASET_SIZE/samples_per_epoch)
flags.start_eval_at = flags.start_eval_at or ceil(1000*DATASET_SIZE/samples_per_epoch)

mllog_event(key=constants.GLOBAL_BATCH_SIZE, value=flags.batch_size * world_size * flags.ga_steps, sync=False)
mllog_event(key=constants.GRADIENT_ACCUMULATION_STEPS, value=flags.ga_steps)
loss_fn = DiceCELoss(to_onehot_y=True, use_softmax=True, layout=flags.layout,
Expand Down
2 changes: 0 additions & 2 deletions image_segmentation/pytorch/oldREADME.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,6 @@ The complete list of the available parameters for the main.py script contains:
* `--batch_size`: Size of each minibatch per GPU (default: `2`).
* `--ga_steps`: Number of steps for gradient accumulation (default: `1`).
* `--epochs`: Maximum number of epochs for training (default: `1`).
* `--evaluate_every`: Epoch interval for evaluation (default: `20`).
* `--start_eval_at`: First epoch to start running evaluation at (default: `1000`).
* `--layout`: Data layout (default: `NCDHW`. `NDHWC` is not implemented).
* `--input_shape`: Input shape for images during training (default: `[128, 128, 128]`).
* `--val_input_shape`: Input shape for images during evaluation (default: `[128, 128, 128]`).
Expand Down
4 changes: 1 addition & 3 deletions image_segmentation/pytorch/run_and_time.sh
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,13 @@ mllog_event(key=constants.CACHE_CLEAR, value=True)"

python main.py --data_dir ${DATASET_DIR} \
--epochs ${MAX_EPOCHS} \
--evaluate_every ${EVALUATE_EVERY} \
--start_eval_at ${START_EVAL_AT} \
--quality_threshold ${QUALITY_THRESHOLD} \
--batch_size ${BATCH_SIZE} \
--optimizer sgd \
--ga_steps ${GRADIENT_ACCUMULATION_STEPS} \
--learning_rate ${LEARNING_RATE} \
--seed ${SEED} \
--lr_warmup_epochs ${LR_WARMUP_EPOCHS}
--lr_warmup_samples ${LR_WARMUP_SAMPLES}

# end timing
end=$(date +%s)
Expand Down
10 changes: 4 additions & 6 deletions image_segmentation/pytorch/runtime/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
PARSER.add_argument('--quality_threshold', dest='quality_threshold', type=float, default=0.908)
PARSER.add_argument('--ga_steps', dest='ga_steps', type=int, default=1)
PARSER.add_argument('--warmup_steps', dest='warmup_steps', type=int, default=4)
PARSER.add_argument('--batch_size', dest='batch_size', type=int, default=2)
PARSER.add_argument('--batch_size', dest='batch_size', type=int, default=7)
PARSER.add_argument('--layout', dest='layout', type=str, choices=['NCDHW'], default='NCDHW')
PARSER.add_argument('--input_shape', nargs='+', type=int, default=[128, 128, 128])
PARSER.add_argument('--val_input_shape', nargs='+', type=int, default=[128, 128, 128])
Expand All @@ -25,16 +25,14 @@
PARSER.add_argument('--benchmark', dest='benchmark', action='store_true', default=False)
PARSER.add_argument('--amp', dest='amp', action='store_true', default=False)
PARSER.add_argument('--optimizer', dest='optimizer', default="sgd", choices=["sgd", "adam", "lamb"], type=str)
PARSER.add_argument('--learning_rate', dest='learning_rate', type=float, default=1.0)
PARSER.add_argument('--learning_rate', dest='learning_rate', type=float, default=2.0)
PARSER.add_argument('--init_learning_rate', dest='init_learning_rate', type=float, default=1e-4)
PARSER.add_argument('--lr_warmup_epochs', dest='lr_warmup_epochs', type=int, default=0)
PARSER.add_argument('--lr_decay_epochs', nargs='+', type=int, default=[])
PARSER.add_argument('--lr_warmup_samples', dest='lr_warmup_samples', type=int, default=168000)
PARSER.add_argument('--lr_decay_samples', nargs='+', type=int, default=[])
PARSER.add_argument('--lr_decay_factor', dest='lr_decay_factor', type=float, default=1.0)
PARSER.add_argument('--lamb_betas', nargs='+', type=int, default=[0.9, 0.999])
PARSER.add_argument('--momentum', dest='momentum', type=float, default=0.9)
PARSER.add_argument('--weight_decay', dest='weight_decay', type=float, default=0.0)
PARSER.add_argument('--evaluate_every', '--eval_every', dest='evaluate_every', type=int, default=None)
PARSER.add_argument('--start_eval_at', dest='start_eval_at', type=int, default=None)
PARSER.add_argument('--verbose', '-v', dest='verbose', action='store_true', default=False)
PARSER.add_argument('--normalization', dest='normalization', type=str,
choices=['instancenorm', 'batchnorm'], default='instancenorm')
Expand Down
2 changes: 1 addition & 1 deletion image_segmentation/pytorch/runtime/distributed_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def get_world_size():
def reduce_tensor(tensor, num_gpus):
if num_gpus > 1:
rt = tensor.clone()
dist.all_reduce(rt, op=dist.reduce_op.SUM)
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
if rt.is_floating_point():
rt = rt / num_gpus
else:
Expand Down
4 changes: 2 additions & 2 deletions image_segmentation/pytorch/runtime/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,9 @@ def mlperf_submission_log():
def mlperf_run_param_log(flags):
mllog_event(key=mllog.constants.OPT_NAME, value=flags.optimizer)
mllog_event(key=mllog.constants.OPT_BASE_LR, value=flags.learning_rate)
mllog_event(key=mllog.constants.OPT_LR_WARMUP_EPOCHS, value=flags.lr_warmup_epochs)
mllog_event(key=mllog.constants.OPT_LR_WARMUP_EPOCHS, value=flags.lr_warmup_samples)
# mllog_event(key=mllog.constants.OPT_LR_WARMUP_FACTOR, value=flags.lr_warmup_factor)
mllog_event(key=mllog.constants.OPT_LR_DECAY_BOUNDARY_EPOCHS, value=flags.lr_decay_epochs)
mllog_event(key=mllog.constants.OPT_LR_DECAY_BOUNDARY_EPOCHS, value=flags.lr_decay_samples)
mllog_event(key=mllog.constants.OPT_LR_DECAY_FACTOR, value=flags.lr_decay_factor)
mllog_event(key=mllog.constants.OPT_WEIGHT_DECAY, value=flags.weight_decay)
mllog_event(key="opt_momentum", value=flags.momentum)
Expand Down
98 changes: 52 additions & 46 deletions image_segmentation/pytorch/runtime/training.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from tqdm import tqdm
from time import time

import torch
from torch.optim import Adam, SGD
Expand All @@ -9,6 +10,10 @@
from runtime.logging import mllog_event, mllog_start, mllog_end, CONSTANTS


START_EVAL_AT = 168*1000
EVALUATE_EVERY = 168*20


def get_optimizer(params, flags):
if flags.optimizer == "adam":
optim = Adam(params, lr=flags.learning_rate, weight_decay=flags.weight_decay)
Expand All @@ -24,23 +29,27 @@ def get_optimizer(params, flags):
return optim


def lr_warmup(optimizer, init_lr, lr, current_epoch, warmup_epochs):
scale = current_epoch / warmup_epochs
def lr_warmup(optimizer, init_lr, lr, current_samples, warmup_samples):
scale = current_samples / warmup_samples
for param_group in optimizer.param_groups:
param_group['lr'] = init_lr + (lr - init_lr) * scale


def lr_decay(optimizer, lr_decay_samples, lr_decay_factor, total_samples):
if len(lr_decay_samples) > 0 and total_samples > lr_decay_samples[0]:
lr_decay_samples = lr_decay_samples[1:]
for param_group in optimizer.param_groups:
param_group['lr'] *= lr_decay_factor
return lr_decay_samples


def train(flags, model, train_loader, val_loader, loss_fn, score_fn, device, callbacks, is_distributed):
rank = get_rank()

world_size = get_world_size()
torch.backends.cudnn.benchmark = flags.cudnn_benchmark
torch.backends.cudnn.deterministic = flags.cudnn_deterministic

optimizer = get_optimizer(model.parameters(), flags)
if flags.lr_decay_epochs:
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
milestones=flags.lr_decay_epochs,
gamma=flags.lr_decay_factor)
scaler = GradScaler()

model.to(device)
Expand All @@ -52,24 +61,32 @@ def train(flags, model, train_loader, val_loader, loss_fn, score_fn, device, cal

is_successful = False
diverged = False
next_eval_at = flags.start_eval_at
total_samples = 0
iteration = 0
lr_decay_samples = flags.lr_decay_samples
next_eval_at = EVALUATE_EVERY
model.train()
train_loader = iter(train_loader)
for callback in callbacks:
callback.on_fit_start()
for epoch in range(1, flags.epochs + 1):
cumulative_loss = []
if epoch <= flags.lr_warmup_epochs and flags.lr_warmup_epochs > 0:
lr_warmup(optimizer, flags.init_learning_rate, flags.learning_rate, epoch, flags.lr_warmup_epochs)

while not diverged and not is_successful:
mllog_start(key=CONSTANTS.BLOCK_START, sync=False,
metadata={CONSTANTS.FIRST_EPOCH_NUM: epoch, CONSTANTS.EPOCH_COUNT: 1})
mllog_start(key=CONSTANTS.EPOCH_START, metadata={CONSTANTS.EPOCH_NUM: epoch}, sync=False)
metadata={CONSTANTS.FIRST_EPOCH_NUM: total_samples,
CONSTANTS.EPOCH_COUNT: EVALUATE_EVERY})

t0 = time()
while total_samples < next_eval_at:
if total_samples <= flags.lr_warmup_samples and flags.lr_warmup_samples > 0:
lr_warmup(optimizer, flags.init_learning_rate, flags.learning_rate, total_samples, flags.lr_warmup_samples)
if len(flags.lr_decay_samples) > 0:
lr_decay_samples = lr_decay(optimizer, lr_decay_samples, flags.lr_decay_factor, total_samples)

if is_distributed:
train_loader.sampler.set_epoch(epoch)
optimizer.zero_grad()

batch = next(train_loader)
total_samples += flags.batch_size * world_size

loss_value = None
optimizer.zero_grad()
for iteration, batch in enumerate(tqdm(train_loader, disable=(rank != 0) or not flags.verbose)):
image, label = batch
image, label = image.to(device), label.to(device)
for callback in callbacks:
Expand All @@ -93,32 +110,20 @@ def train(flags, model, train_loader, val_loader, loss_fn, score_fn, device, cal
optimizer.step()

optimizer.zero_grad()
iteration += 1

loss_value = reduce_tensor(loss_value, world_size).detach().cpu().numpy()
cumulative_loss.append(loss_value)

mllog_end(key=CONSTANTS.EPOCH_STOP, sync=False,
metadata={CONSTANTS.EPOCH_NUM: epoch, 'current_lr': optimizer.param_groups[0]['lr']})

if flags.lr_decay_epochs:
scheduler.step()

if epoch == next_eval_at:
next_eval_at += flags.evaluate_every
del output
mllog_start(key=CONSTANTS.EVAL_START, value=epoch, metadata={CONSTANTS.EPOCH_NUM: epoch}, sync=False)
# Evaluation
del output
if total_samples >= START_EVAL_AT:
mllog_start(key=CONSTANTS.EVAL_START, value=total_samples,
metadata={CONSTANTS.EPOCH_NUM: total_samples}, sync=False)

eval_metrics = evaluate(flags, model, val_loader, loss_fn, score_fn, device, epoch)
eval_metrics["train_loss"] = sum(cumulative_loss) / len(cumulative_loss)
eval_metrics = evaluate(flags, model, val_loader, loss_fn, score_fn, device, total_samples)

mllog_event(key=CONSTANTS.EVAL_ACCURACY,
value=eval_metrics["mean_dice"],
metadata={CONSTANTS.EPOCH_NUM: epoch},
sync=False)
mllog_end(key=CONSTANTS.EVAL_STOP, metadata={CONSTANTS.EPOCH_NUM: epoch}, sync=False)
mllog_event(key=CONSTANTS.EVAL_ACCURACY, value=eval_metrics["mean_dice"],
metadata={CONSTANTS.EPOCH_NUM: total_samples}, sync=False)
mllog_end(key=CONSTANTS.EVAL_STOP, metadata={CONSTANTS.EPOCH_NUM: total_samples}, sync=False)

for callback in callbacks:
callback.on_epoch_end(epoch=epoch, metrics=eval_metrics, model=model, optimizer=optimizer)
model.train()
if eval_metrics["mean_dice"] >= flags.quality_threshold:
is_successful = True
Expand All @@ -127,12 +132,13 @@ def train(flags, model, train_loader, val_loader, loss_fn, score_fn, device, cal
diverged = True

mllog_end(key=CONSTANTS.BLOCK_STOP, sync=False,
metadata={CONSTANTS.FIRST_EPOCH_NUM: epoch, CONSTANTS.EPOCH_COUNT: 1})

if is_successful or diverged:
break
metadata={CONSTANTS.FIRST_EPOCH_NUM: total_samples,
CONSTANTS.EPOCH_COUNT: EVALUATE_EVERY})
next_eval_at += EVALUATE_EVERY

mllog_end(key=CONSTANTS.RUN_STOP, sync=True,
metadata={CONSTANTS.STATUS: CONSTANTS.SUCCESS if is_successful else CONSTANTS.ABORTED})
metadata={CONSTANTS.STATUS: CONSTANTS.SUCCESS if is_successful else CONSTANTS.ABORTED,
CONSTANTS.EPOCH_COUNT: total_samples})

for callback in callbacks:
callback.on_fit_end()