Skip to content

Commit

Permalink
Update both training scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
TheCodez committed Jun 27, 2019
1 parent cbc53eb commit 42bff56
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 16 deletions.
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import random
import sys
from argparse import ArgumentParser
from datetime import datetime

import torch
import torch.nn as nn
Expand Down Expand Up @@ -99,6 +98,7 @@ def run(args):
sys.exit()

if args.freeze_bn:
print("Freezing batch norm")
model = freeze_batchnorm(model)

trainer = create_supervised_trainer(model, optimizer, criterion, device, non_blocking=True)
Expand Down
27 changes: 12 additions & 15 deletions train_sbd.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@
import random
import sys
from argparse import ArgumentParser
from datetime import datetime

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
from ignite.contrib.handlers import ProgressBar, TensorboardLogger
from ignite.contrib.handlers.tensorboard_logger import OutputHandler, OptimizerParamsHandler
from ignite.engine import Events, create_supervised_evaluator, create_supervised_trainer
from ignite.metrics import RunningAverage, Loss
from torch.utils.data import DataLoader
from torchvision.datasets import SBDataset

from googlenet_fcn.datasets.transforms.transforms import Compose, ToTensor, \
Expand All @@ -34,14 +33,14 @@ def get_data_loaders(data_dir, batch_size, val_batch_size, num_workers, download
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_loader = DataLoader(SBDataset(root=os.path.join(data_dir, 'sbd'), image_set='train', mode='segmentation',
download=download, transforms=transform),
batch_size=batch_size, shuffle=True, num_workers=num_workers, collate_fn=collate_fn,
pin_memory=True)
train_loader = data.DataLoader(SBDataset(root=os.path.join(data_dir, 'sbd'), image_set='train', mode='segmentation',
download=download, transforms=transform),
batch_size=batch_size, shuffle=True, num_workers=num_workers,
collate_fn=collate_fn, pin_memory=True)

val_loader = DataLoader(VOC(root=data_dir, download=download, transforms=val_transform),
batch_size=val_batch_size, shuffle=False, num_workers=num_workers, collate_fn=collate_fn,
pin_memory=True)
val_loader = data.DataLoader(VOC(root=data_dir, download=download, transforms=val_transform),
batch_size=val_batch_size, shuffle=False, num_workers=num_workers,
collate_fn=collate_fn, pin_memory=True)

return train_loader, val_loader

Expand Down Expand Up @@ -80,7 +79,6 @@ def run(args):
print("Loading checkpoint '{}'".format(args.resume))
checkpoint = torch.load(args.resume)
args.start_epoch = checkpoint['epoch']
args.start_iteration = checkpoint['iteration']
best_iou = checkpoint['bestIoU']
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
Expand Down Expand Up @@ -112,8 +110,7 @@ def run(args):
def _global_step_transform(engine, event_name):
return trainer.state.iteration

exp_name = datetime.now().strftime("%Y%m%d-%H%M%S")
tb_logger = TensorboardLogger(os.path.join(args.log_dir, exp_name))
tb_logger = TensorboardLogger(args.log_dir)
tb_logger.attach(trainer,
log_handler=OutputHandler(tag='training',
metric_names=['loss']),
Expand Down Expand Up @@ -149,7 +146,7 @@ def save_checkpoint(engine):
def initialize(engine):
if args.resume:
engine.state.epoch = args.start_epoch
engine.state.iteration = args.start_iteration
engine.state.iteration = args.start_epoch * len(engine.state.dataloader)
engine.state.best_iou = best_iou
else:
engine.state.best_iou = 0.0
Expand All @@ -175,11 +172,11 @@ def log_validation_results(engine):
parser = ArgumentParser('GoogLeNet-FCN with PyTorch')
parser.add_argument('--batch-size', type=int, default=1,
help='input batch size for training')
parser.add_argument('--val-batch-size', type=int, default=4,
parser.add_argument('--val-batch-size', type=int, default=8,
help='input batch size for validation')
parser.add_argument('--num-workers', type=int, default=4,
help='number of workers')
parser.add_argument('--epochs', type=int, default=250,
parser.add_argument('--epochs', type=int, default=50,
help='number of epochs to train')
parser.add_argument('--lr', type=float, default=1e-10,
help='learning rate')
Expand Down

0 comments on commit 42bff56

Please sign in to comment.