-
Notifications
You must be signed in to change notification settings - Fork 0
/
training_utils.py
79 lines (50 loc) · 2 KB
/
training_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import tensorflow as tf
import numpy as np
import yaml
import os
import sys
sys.path.append('models')
from generator import Generator
from callback import Callback
def load_config_file(file_path, create_folder_results):
stream = open(file_path, 'r')
config = yaml.load(stream, Loader=yaml.Loader)
stream.close()
config["r_path"] = os.path.join(config['results_path'], config['prefix'])
for p in config['ablation']:
config['r_path'] += '_' + str(config[p])
# model name
if create_folder_results:
os.makedirs(config['r_path'], exist_ok = True)
return config
def get_callbacks(config, vae_m, data):
if config['formal_training']:
callback_1 = Callback(vae_m,
data.X_vali, data.Y_vali,
config,
save_model = True)
callback_2 = Callback(vae_m,
data.X_test, data.Y_test,
config,
save_model = False)
return [callback_1, callback_2]
callback = Callback(vae_m,
data.X_test, data.Y_test,
config,
save_model = True)
return [callback]
def get_optimizer(config):
if config['sgd']:
optimizer = tf.optimizers.SGD(learning_rate = config['lr'],
momentum = config['momentum'],
decay = config['decay'])
else:
optimizer = tf.keras.optimizers.Adam(learning_rate = config['lr'])
return optimizer
def save_history(path, r, callbacks, name_model):
if len(callbacks) == 2:
r.history['vali_metrics'] = callbacks[0].history
r.history['test_metrics'] = callbacks[1].history
else:
r.history['test_metrics'] = callbacks[0].history
np.save(os.path.join(path, name_model + '_history'), r.history)