-
Notifications
You must be signed in to change notification settings - Fork 12
/
1_train.py
30 lines (23 loc) · 1.04 KB
/
1_train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
#-*- coding: utf-8 -*-
from src.model_builder import CamModelBuilder
from keras.optimizers import Adam
from keras.applications.resnet50 import preprocess_input
from src.keras_utils import build_generator, create_callbacks
if __name__ == "__main__":
model_builder = CamModelBuilder()
model = model_builder.get_cls_model()
model.summary()
# fixed_layers = []
# for layer in model.layers[:-6]:
# layer.trainable = False
# fixed_layers.append(layer.name)
# print(fixed_layers)
optimizer = Adam(lr=1e-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08, decay=0.005)
model.compile(loss = 'categorical_crossentropy',
optimizer = optimizer,
metrics = ['accuracy'])
train_generator = build_generator("dataset//train", preprocess_input, augment=True)
model.fit_generator(train_generator,
steps_per_epoch = len(train_generator),
callbacks = create_callbacks(),
epochs=20)