-
Notifications
You must be signed in to change notification settings - Fork 0
/
parse.py
73 lines (57 loc) · 3.68 KB
/
parse.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import argparse
def parse_kgat_args():
parser = argparse.ArgumentParser(description="Run KGAT.")
parser.add_argument('--local_rank', type=int, default=0,
help='Local rank for using multi GPUs.')
parser.add_argument('--seed', type=int, default=123,
help='Random seed.')
parser.add_argument('--data_name', nargs='?', default='amazon-book',
help='Choose a dataset from {yelp2018, last-fm, amazon-book}')
parser.add_argument('--data_dir', nargs='?', default='datasets/',
help='Input data path.')
parser.add_argument('--use_pretrain', type=int, default=0,
help='0: No pretrain, 1: Pretrain with the learned embeddings, 2: Pretrain with stored model.')
parser.add_argument('--pretrain_embedding_dir', nargs='?', default='datasets/pretrain/',
help='Path of learned embeddings.')
parser.add_argument('--pretrain_model_path', nargs='?', default='trained_model/model.pth',
help='Path of stored model.')
parser.add_argument('--cf_batch_size', type=int, default=1024,
help='CF batch size.')
parser.add_argument('--kg_batch_size', type=int, default=2048,
help='KG batch size.')
parser.add_argument('--test_batch_size', type=int, default=10000,
help='Test batch size (the user number to test every batch).')
parser.add_argument('--entity_dim', type=int, default=64,
help='User / entity Embedding size.')
parser.add_argument('--relation_dim', type=int, default=64,
help='Relation Embedding size.')
parser.add_argument('--aggregation_type', nargs='?', default='bi-interaction',
help='Specify the type of the aggregation layer from {gcn, graphsage, bi-interaction}.')
parser.add_argument('--conv_dim_list', nargs='?', default='[64, 32, 16]',
help='Output sizes of every aggregation layer.')
parser.add_argument('--mess_dropout', nargs='?', default='[0.1, 0.1, 0.1]',
help='Dropout probability w.r.t. message dropout for each deep layer. 0: no dropout.')
parser.add_argument('--kg_l2loss_lambda', type=float, default=1e-5,
help='Lambda when calculating KG l2 loss.')
parser.add_argument('--cf_l2loss_lambda', type=float, default=1e-5,
help='Lambda when calculating CF l2 loss.')
parser.add_argument('--lr', type=float, default=0.0001,
help='Learning rate.')
parser.add_argument('--n_epoch', type=int, default=1000,
help='Number of epoch.')
parser.add_argument('--stopping_steps', type=int, default=10,
help='Number of epoch for early stopping')
parser.add_argument('--cf_print_every', type=int, default=1,
help='Iter interval of printing CF loss.')
parser.add_argument('--kg_print_every', type=int, default=1,
help='Iter interval of printing KG loss.')
parser.add_argument('--evaluate_every', type=int, default=1,
help='Epoch interval of evaluating CF.')
parser.add_argument('--K', type=int, default=20,
help='Calculate metric@K when evaluating.')
args = parser.parse_args()
save_dir = 'trained_model/KGAT/{}/entitydim{}_relationdim{}_{}_{}_lr{}_pretrain{}/'.format(
args.data_name, args.entity_dim, args.relation_dim, args.aggregation_type,
'-'.join([str(i) for i in eval(args.conv_dim_list)]), args.lr, args.use_pretrain)
args.save_dir = save_dir
return args