forked from jkalogero/scalegmn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
scalegmn_relu.yml
127 lines (108 loc) · 3.06 KB
/
scalegmn_relu.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
batch_size: 64
data:
dataset: cifar10
dataset_path: ./data/cifar10/
data_path: ./data/cifar10/weights.npy
metrics_path: ./data/cifar10/metrics.csv.gz
layout_path: ./data/cifar10/layout.csv
idcs_file: ./data/cifar10/cifar10_split.csv
activation_function: relu
node_pos_embed: &node_pos_embed True
edge_pos_embed: &edge_pos_embed False
# the below can be extracted per datapoint, but since it is the same for all, we can define it here
layer_layout: [1, 16, 16, 16, 10]
train_args:
num_epochs: 200
seed: 0
loss: MSE
scalegmn_args:
d_in_v: &d_in_v 1 # initial dimension of input nn bias
d_in_e: &d_in_e 1 # initial dimension of input nn weights
d_hid: &d_hid 128 # hidden dimension
num_layers: 4 # number of gnn layers to apply
direction: forward
equivariant: False
symmetry: scale # symmetry
jit: False # prefer compile - compile gnn to optimize performance
compile: False # compile gnn to optimize performance
readout_range: last_layer # or full_graph
gnn_skip_connections: False
concat_mlp_directions: False
reciprocal: True
node_pos_embed: *node_pos_embed # use positional encodings
edge_pos_embed: *edge_pos_embed # use positional encodings
_max_kernel_height: 3
_max_kernel_width: 3
graph_init:
d_in_v: *d_in_v
d_in_e: *d_in_e
project_node_feats: True
project_edge_feats: True
d_node: *d_hid
d_edge: *d_hid
positional_encodings:
final_linear_pos_embed: False
sum_pos_enc: False
po_as_different_linear: False
equiv_net: False
# args for the equiv net option.
sum_on_io: True
equiv_on_hidden: True
num_mlps: 3
layer_equiv_on_hidden: False
gnn_args:
d_hid: *d_hid
message_fn_layers: 1
message_fn_skip_connections: False
update_node_feats_fn_layers: 1
update_node_feats_fn_skip_connections: False
update_edge_attr: True
dropout: 0.
dropout_all: False # False: only in between the gnn layers, True: + all mlp layers
update_as_act: False
update_as_act_arg: sum
mlp_on_io: True
msg_equiv_on_hidden: True
upd_equiv_on_hidden: True
layer_msg_equiv_on_hidden: False
layer_upd_equiv_on_hidden: False
msg_num_mlps: 2
upd_num_mlps: 2
pos_embed_msg: False
pos_embed_upd: False
layer_norm: False
aggregator: add
sign_symmetrization: True
mlp_args:
d_k: [ *d_hid ]
activation: silu
dropout: 0.
final_activation: identity
batch_norm: False # check again
layer_norm: True
bias: True
skip: False
readout_args:
d_out: 1 # output dimension of the model
d_rho: *d_hid # intermediate dimension within Readout module - only used in PermutationInvariantSignNet
optimization:
clip_grad: True
clip_grad_max_norm: 10.0
optimizer_name: AdamW
optimizer_args:
lr: 1e-3
weight_decay: 0.01
scheduler_args:
scheduler: WarmupLRScheduler
warmup_steps: 1000
scheduler_mode: min
decay_rate: 0
decay_steps: 0
patience: None
min_lr: None
wandb_args:
project: cifar10_relu
entity: null
group: null
name: null
tags: null