forked from wubinzzu/DER
-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
73 lines (60 loc) · 3.76 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import argparse
from six.moves import configparser
import logging.config
from solver import *
from argparse import Namespace
class ExcuteExperiments:
def __init__(self, s):
self.solver = s
def excute(self):
result = self.solver.run()
return result
def update_args(args, key, value):
tmp_dict = dict()
for k, v in vars(args).items():
tmp_dict[k] = v
tmp_dict[key] = value
args_out = Namespace(**tmp_dict)
return args_out
if __name__ == '__main__':
cf = configparser.ConfigParser()
cf.read("/content/DER/conf/default_setting.conf")
parser = argparse.ArgumentParser()
parser.add_argument('--root_path', type=str, default=cf.get("path", "root_path"), required=False, help='root_path')
parser.add_argument('--input_data_type', type=str, default=cf.get("path", "input_data_type"), required=False, help='input_data_type')
parser.add_argument('--output_path', type=str, default=cf.get("path", "output_path"), required=False, help='output_path')
parser.add_argument('--log_conf_path', type=str, default=cf.get("path", "log_conf_path"), required=False, help='log_conf_path')
parser.add_argument('--global_dimension', type=int, default=cf.getint("parameters", "global_dimension"), required=False, help='number of latent factors')
parser.add_argument('--word_dimension', type=int, default=cf.getint("parameters", "word_dimension"), required=False, help='word_dimension')
parser.add_argument('--batch_size', type=int, default=cf.getint("parameters", "batch_size"), required=False, help='batch_size')
parser.add_argument('--K', type=int, default=cf.getint("parameters", "K"), required=False, help='K')
parser.add_argument('--epoch', type=int, default=cf.getint("parameters", "epoch"), required=False, help='epoch')
parser.add_argument('--learning_rate', type=float, default=float(cf.get("parameters", "learning_rate")), required=False, help='learning_rate')
parser.add_argument('--reg', type=float, default=float(cf.get("parameters", "reg")), required=False, help='reg')
parser.add_argument('--mode', type=str, default=cf.get("parameters", "mode"), required=False, help='reg')
parser.add_argument('--merge', type=str, default=cf.get("parameters", "merge"), required=False, help='merge')
parser.add_argument('--concat', type=int, default=int(cf.get("parameters", "concat")), required=False, help='concat')
parser.add_argument('--item_review_combine', type=str, default=cf.get("parameters", "item_review_combine"), required=False, help='item_review_combine')
parser.add_argument('--item_review_combine_c', type=float, default=float(cf.get("parameters", "item_review_combine_c")),
required=False, help='item_review_combine_c')
parser.add_argument('--lmd', type=float, default=float(cf.get("parameters", "lmd")), required=False, help='lmd')
parser.add_argument('--drop_out_rate', type=float, default=float(cf.get("parameters", "drop_out_rate")), required=False, help='drop_out_rate')
args = parser.parse_args()
logging.config.fileConfig(os.path.join(args.root_path, args.log_conf_path))
print(os.path.join(args.root_path, args.log_conf_path))
args.logger = logging.getLogger()
result_file = '/content/DER/results/DER_result_' + args.mode + args.input_data_type.split('/')[-1]
f = open(result_file, 'wb')
f.write((str(args.mode) + ' parameters:').encode())
f.write(str(args).encode())
f.write('\n'.encode())
f.write((str(args.mode) + ' result:').encode())
f.write('\n'.encode())
with tf.compat.v1.variable_scope(args.mode):
args.namespace = args.mode
s = Solver(args, 0)
exp = ExcuteExperiments(s)
r = exp.excute()
f.write(str(r).encode())
f.write('\n'.encode())
f.close()