-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
116 lines (95 loc) · 3.66 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
from tensorflow.data import Dataset
from tensorflow.keras.optimizers import Adam
from tensorflow.data.experimental import AUTOTUNE
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.config.experimental import set_memory_growth
from tensorflow.config import list_physical_devices
# Failed to get convolution algorithm
physical_devices = list_physical_devices('GPU')
set_memory_growth(physical_devices[0], True)
# Hyperparameters
from cycle.config import ORIG_IMG_SIZE
from cycle.config import INPUT_IMG_SIZE
from cycle.config import BUFFER_SIZE
from cycle.config import BATCH_SIZE
from cycle import CycleGAN
from cycle.callbacks import GANMonitor
from cycle.loss import generator_loss_fn
from cycle.loss import discriminator_loss_fn
from cycle.generator import get_resnet_generator
from cycle.discriminator import get_discriminator
from cycle.preprocessing import preprocess_train_image
from cycle.preprocessing import preprocess_test_image
datasets = tfds.load('cycle_gan/horse2zebra', as_supervised=True)
train_horses, train_zebras = datasets['trainA'], datasets['trainB']
test_horses, test_zebras = datasets['testA'], datasets['testB']
# Apply preprocessing
train_horses = (
train_horses.map(preprocess_train_image, num_parallel_calls=AUTOTUNE)
.cache()
.shuffle(BUFFER_SIZE)
.batch(BATCH_SIZE)
)
train_zebras = (
train_zebras.map(preprocess_train_image, num_parallel_calls=AUTOTUNE)
.cache()
.shuffle(BUFFER_SIZE)
.batch(BATCH_SIZE)
)
test_horses = (
test_horses.map(preprocess_test_image, num_parallel_calls=AUTOTUNE)
.cache()
.shuffle(BUFFER_SIZE)
.batch(BATCH_SIZE)
)
test_zebras = (
test_zebras.map(preprocess_test_image, num_parallel_calls=AUTOTUNE)
.cache()
.shuffle(BUFFER_SIZE)
.batch(BATCH_SIZE)
)
_, ax_train = plt.subplots(4, 2, figsize=(10, 15))
for i, samples in enumerate(zip(train_horses.take(4), train_zebras.take(4))):
horse = (((samples[0][0] * 127.5) + 127.5).numpy()).astype(np.uint8)
zebra = (((samples[1][0] * 127.5) + 127.5).numpy()).astype(np.uint8)
ax_train[i, 0].imshow(horse)
ax_train[i, 1].imshow(zebra)
plt.show()
_, ax_test = plt.subplots(4, 2, figsize=(10, 15))
for i, samples in enumerate(zip(test_horses.take(4), test_zebras.take(4))):
horse = (((samples[0][0] * 127.5) + 127.5).numpy()).astype(np.uint8)
zebra = (((samples[1][0] * 127.5) + 127.5).numpy()).astype(np.uint8)
ax_test[i, 0].imshow(horse)
ax_test[i, 1].imshow(zebra)
plt.show()
# Putting all together
generator_G = get_resnet_generator(name='generator_G')
generator_F = get_resnet_generator(name='generator_F')
discriminator_X = get_discriminator(name='discriminator_X')
discriminator_Y = get_discriminator(name='discriminator_Y')
cycle_model = CycleGAN(
generator_G=generator_G,
generator_F=generator_F,
discriminator_X=discriminator_X,
discriminator_Y=discriminator_Y)
cycle_model.compile(
generator_G_opt=Adam(learning_rate=2e-4, beta_1=0.5),
generator_F_opt=Adam(learning_rate=2e-4, beta_1=0.5),
discriminator_X_opt=Adam(learning_rate=2e-4, beta_1=0.5),
discriminator_Y_opt=Adam(learning_rate=2e-4, beta_1=0.5),
generator_loss_fn=generator_loss_fn,
discriminator_loss_fn=discriminator_loss_fn)
plotter = GANMonitor(data=test_horses)
checkpoint_filepath = "./model_checkpoints/cyclegan_checkpoints.{epoch:03d}"
model_checkpoint_callback = ModelCheckpoint(
filepath=checkpoint_filepath
)
cycle_model.fit(
Dataset.zip((train_horses, train_zebras)),
epochs=90,
callbacks=[plotter, model_checkpoint_callback],
)