You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I was experimenting with the graphormer model, specifically for graph classification using the virtual node for global pooling (graph_pooling: graph_token).
Problem
I noticed that the model was producing different outputs for the same input graph with permuted node order. The problem should be easy to replicate, here is an example:
importtorchfromtorch_geometric.dataimportBatch# given some data batch, e.g. inside the training loop# create a copy of the first graphdata=Batch.from_data_list([batch.get_example(0).clone()])
data_p=Batch.from_data_list([batch.get_example(0).clone()])
# and permute the nodes: # here we simply put the previously last node in first place of the first graphn=data_p.x.size(0)
p=torch.arange(n, dtype=torch.long) -1p[0] =n-1data_p.x=data_p.x[p]
assert (data_p.x[0, :] ==data.x[-1, :]).all()
assert (data_p.x[1:, :] ==data.x[:-1, :]).all()
# make sure to permute the other node features as welldata_p.batch=data_p.batch[p]
data_p.in_degrees=data_p.in_degrees[p]
data_p.out_degrees=data_p.out_degrees[p]
# and change the indices accordingly (all increase by one, just the last one gets set to zero)n=data_p.x.size(0)
data_p.edge_index+=1data_p.edge_index[data_p.edge_index==n] =0data_p.graph_index+=1data_p.graph_index[data_p.graph_index==n] =0# then get the model outputs for each graphmodel.eval()
withtorch.no_grad():
output, _=model(data)
output_p, _=model(data_p)
# check if outputs are equalasserttorch.allclose(output, output_p), "Permuted graph produces different output!"
This is unexpected (and worrisome) behavior. In theory, the model architecture should be invariant to such changes, as should any GNN.
Cause
The cause turned out to be in the add_graph_token function, in this line:
torch.sort is called to get all the newly concatenated virtual nodes neatly grouped together with their respective other batch nodes.
But it is called without the argument stable, which means the default stable=False is used. As a result the indices inside each graph (same batch index) don't stay in the same order as before. Rather, each graph gets its nodes permuted by the sorting algorithm. This by itself would not necessarily be a problem, as the model should be invariant to such permutations. However, all the indices used in the other data attributes (edge_index, in_degrees, att_bias, etc.) are still referencing the old node order and should then also get permuted/ remapped.
Fix
Of course the much simpler solution is to simply use the stable sorting, and change the line to:
When running the example from above again with this change the outputs are now indeed the same!
I haven't done any testing yet on how this bug fix affects the training and classification performance, but I could imagine that being node permutation invariant, and not having the node features "randomly" permuted would make things a bit easier for the model...
The text was updated successfully, but these errors were encountered:
Yes, I believe so, at least for the case of using the "normal" graphormer preprocessing and encoder.
The positional encoding on node level would be the in and out degrees, which I permute for my example above with:
And for the edge level the encodings are stored "sparsely" (spatial_types and shortest_path_types indexed by graph_index, analogously to edge_attr with edge_index). So to reflect the node permutation in the edges, I ensured that the index arrays get remapped to the new node labels. In my (very specific) given example this is accomplished by:
The example data I am using does not have edge_attr and therefore also no shortest_path_types, but it should still work correctly for data with edge attributes using the above example (because of the change to edge_index and graph_index).
Also, the fact that the model outputs are the same for the original and permuted graph when using stable sorting indicates that the positional encodings were probably permuted correctly as well...
@isefos Thank you a lot for raising this issue. Indeed, this seems to be a bug. Further, I believe that the stable sorting is exactly what we need here. I will test this on my side and also run a few experiments to see whether there is any (positive or negative) impact on performance.
I was experimenting with the graphormer model, specifically for graph classification using the virtual node for global pooling (
graph_pooling: graph_token
).Problem
I noticed that the model was producing different outputs for the same input graph with permuted node order. The problem should be easy to replicate, here is an example:
This is unexpected (and worrisome) behavior. In theory, the model architecture should be invariant to such changes, as should any GNN.
Cause
The cause turned out to be in the
add_graph_token
function, in this line:torch.sort
is called to get all the newly concatenated virtual nodes neatly grouped together with their respective other batch nodes.But it is called without the argument
stable
, which means the defaultstable=False
is used. As a result the indices inside each graph (same batch index) don't stay in the same order as before. Rather, each graph gets its nodes permuted by the sorting algorithm. This by itself would not necessarily be a problem, as the model should be invariant to such permutations. However, all the indices used in the other data attributes (edge_index
,in_degrees
,att_bias
, etc.) are still referencing the old node order and should then also get permuted/ remapped.Fix
Of course the much simpler solution is to simply use the stable sorting, and change the line to:
When running the example from above again with this change the outputs are now indeed the same!
I haven't done any testing yet on how this bug fix affects the training and classification performance, but I could imagine that being node permutation invariant, and not having the node features "randomly" permuted would make things a bit easier for the model...
The text was updated successfully, but these errors were encountered: