-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
74 lines (61 loc) · 1.97 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Patrick Chao and Noah Gundotra
# 1/11/18
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
import torchvision
from torchvision import transforms,models
import numpy as np
import torch.nn as nn
from model import PolygonRNN
from tqdm import tqdm
import time
import pdb
starttime = time.time()
have_cuda = torch.cuda.is_available()
epochs = 5
# Creating Datasets
"""
In the paper there are 3 transforms listed:
1) Random Flip of Img Crop & Corresponding Label
2) Expand the bounding box between 10-20% (random)
3) Random selection of the starting vertex of the polygon annotation
"""
transform = transforms.Compose([
#transforms.Rescale(256),
transforms.RandomCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
image_dir="../val"
train_set = torchvision.datasets.ImageFolder(image_dir,transform)
train_set_size = len(train_set)
val_dir="../val"
val_set = torchvision.datasets.ImageFolder(val_dir,transform)
val_set_size = len(val_set)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=32, shuffle=True, num_workers=1)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=8, shuffle=True, num_workers=1)
# Initializing Model
vgg = models.vgg16(pretrained=True)
vgg = nn.Sequential(*list(vgg.features.children())[:-1])
model = PolygonRNN(vgg)
if have_cuda:
model.cuda()
elapsed = time.time() - starttime
print("About to train! Time elapsed: {}".format(elapsed))
# Train Process Pt.1/2
def train(epoch):
print("Training process has started")
model.eval()
for _, data in tqdm(enumerate(train_loader)):
original_img = data[0].float()
if have_cuda:
original_img = Variable(original_img, volatile=True).cuda()
else:
original_img = Variable(original_img, volatile=True)
output = model(original_img)
# Train Process Pt.2/2
for epoch in range(1, epochs + 1):
print("Epoch: {}".format(epoch))
train(epoch)