From 42bff566ac01e289c99b48a6fd4dad4a8451cfc3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20K=C3=B6sel?= Date: Thu, 27 Jun 2019 19:33:29 +0200 Subject: [PATCH] Update both training scripts --- train.py | 2 +- train_sbd.py | 27 ++++++++++++--------------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/train.py b/train.py index 8b31503..70d59d8 100644 --- a/train.py +++ b/train.py @@ -2,7 +2,6 @@ import random import sys from argparse import ArgumentParser -from datetime import datetime import torch import torch.nn as nn @@ -99,6 +98,7 @@ def run(args): sys.exit() if args.freeze_bn: + print("Freezing batch norm") model = freeze_batchnorm(model) trainer = create_supervised_trainer(model, optimizer, criterion, device, non_blocking=True) diff --git a/train_sbd.py b/train_sbd.py index e3cbb22..1b6fc59 100644 --- a/train_sbd.py +++ b/train_sbd.py @@ -2,16 +2,15 @@ import random import sys from argparse import ArgumentParser -from datetime import datetime import torch import torch.nn as nn import torch.optim as optim +import torch.utils.data as data from ignite.contrib.handlers import ProgressBar, TensorboardLogger from ignite.contrib.handlers.tensorboard_logger import OutputHandler, OptimizerParamsHandler from ignite.engine import Events, create_supervised_evaluator, create_supervised_trainer from ignite.metrics import RunningAverage, Loss -from torch.utils.data import DataLoader from torchvision.datasets import SBDataset from googlenet_fcn.datasets.transforms.transforms import Compose, ToTensor, \ @@ -34,14 +33,14 @@ def get_data_loaders(data_dir, batch_size, val_batch_size, num_workers, download Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) - train_loader = DataLoader(SBDataset(root=os.path.join(data_dir, 'sbd'), image_set='train', mode='segmentation', - download=download, transforms=transform), - batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn, - pin_memory=True) + train_loader = data.DataLoader(SBDataset(root=os.path.join(data_dir, 'sbd'), image_set='train', mode='segmentation', + download=download, transforms=transform), + batch_size=batch_size, shuffle=True, num_workers=num_workers, + collate_fn=collate_fn, pin_memory=True) - val_loader = DataLoader(VOC(root=data_dir, download=download, transforms=val_transform), - batch_size=val_batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn, - pin_memory=True) + val_loader = data.DataLoader(VOC(root=data_dir, download=download, transforms=val_transform), + batch_size=val_batch_size, shuffle=False, num_workers=num_workers, + collate_fn=collate_fn, pin_memory=True) return train_loader, val_loader @@ -80,7 +79,6 @@ def run(args): print("Loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] - args.start_iteration = checkpoint['iteration'] best_iou = checkpoint['bestIoU'] model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) @@ -112,8 +110,7 @@ def run(args): def _global_step_transform(engine, event_name): return trainer.state.iteration - exp_name = datetime.now().strftime("%Y%m%d-%H%M%S") - tb_logger = TensorboardLogger(os.path.join(args.log_dir, exp_name)) + tb_logger = TensorboardLogger(args.log_dir) tb_logger.attach(trainer, log_handler=OutputHandler(tag='training', metric_names=['loss']), @@ -149,7 +146,7 @@ def save_checkpoint(engine): def initialize(engine): if args.resume: engine.state.epoch = args.start_epoch - engine.state.iteration = args.start_iteration + engine.state.iteration = args.start_epoch * len(engine.state.dataloader) engine.state.best_iou = best_iou else: engine.state.best_iou = 0.0 @@ -175,11 +172,11 @@ def log_validation_results(engine): parser = ArgumentParser('GoogLeNet-FCN with PyTorch') parser.add_argument('--batch-size', type=int, default=1, help='input batch size for training') - parser.add_argument('--val-batch-size', type=int, default=4, + parser.add_argument('--val-batch-size', type=int, default=8, help='input batch size for validation') parser.add_argument('--num-workers', type=int, default=4, help='number of workers') - parser.add_argument('--epochs', type=int, default=250, + parser.add_argument('--epochs', type=int, default=50, help='number of epochs to train') parser.add_argument('--lr', type=float, default=1e-10, help='learning rate')