Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: adding graphormer graph token breaks node permutation invariance #40

Open
isefos opened this issue Nov 15, 2023 · 3 comments · May be fixed by #46
Open

BUG: adding graphormer graph token breaks node permutation invariance #40

isefos opened this issue Nov 15, 2023 · 3 comments · May be fixed by #46
Assignees
Labels
bug Something isn't working

Comments

@isefos
Copy link

isefos commented Nov 15, 2023

I was experimenting with the graphormer model, specifically for graph classification using the virtual node for global pooling (graph_pooling: graph_token).

Problem

I noticed that the model was producing different outputs for the same input graph with permuted node order. The problem should be easy to replicate, here is an example:

import torch
from torch_geometric.data import Batch

# given some data batch, e.g. inside the training loop
# create a copy of the first graph
data = Batch.from_data_list([batch.get_example(0).clone()])
data_p = Batch.from_data_list([batch.get_example(0).clone()])

# and permute the nodes: 
# here we simply put the previously last node in first place of the first graph
n = data_p.x.size(0)
p = torch.arange(n, dtype=torch.long) - 1
p[0] = n - 1
data_p.x = data_p.x[p]
assert (data_p.x[0, :] == data.x[-1, :]).all()
assert (data_p.x[1:, :] == data.x[:-1, :]).all()

# make sure to permute the other node features as well
data_p.batch = data_p.batch[p]
data_p.in_degrees = data_p.in_degrees[p]
data_p.out_degrees = data_p.out_degrees[p]

# and change the indices accordingly (all increase by one, just the last one gets set to zero)
n = data_p.x.size(0)
data_p.edge_index += 1
data_p.edge_index[data_p.edge_index == n] = 0
data_p.graph_index += 1
data_p.graph_index[data_p.graph_index == n] = 0

# then get the model outputs for each graph
model.eval()
with torch.no_grad():
    output, _ = model(data)
    output_p, _ = model(data_p)

# check if outputs are equal
assert torch.allclose(output, output_p), "Permuted graph produces different output!"

This is unexpected (and worrisome) behavior. In theory, the model architecture should be invariant to such changes, as should any GNN.

Cause

The cause turned out to be in the add_graph_token function, in this line:

data.batch, sort_idx = torch.sort(data.batch)
data.x = data.x[sort_idx]

torch.sort is called to get all the newly concatenated virtual nodes neatly grouped together with their respective other batch nodes.

But it is called without the argument stable, which means the default stable=False is used. As a result the indices inside each graph (same batch index) don't stay in the same order as before. Rather, each graph gets its nodes permuted by the sorting algorithm. This by itself would not necessarily be a problem, as the model should be invariant to such permutations. However, all the indices used in the other data attributes (edge_index, in_degrees, att_bias, etc.) are still referencing the old node order and should then also get permuted/ remapped.

Fix

Of course the much simpler solution is to simply use the stable sorting, and change the line to:

data.batch, sort_idx = torch.sort(data.batch, stable=True)

When running the example from above again with this change the outputs are now indeed the same!

I haven't done any testing yet on how this bug fix affects the training and classification performance, but I could imagine that being node permutation invariant, and not having the node features "randomly" permuted would make things a bit easier for the model...

@migalkin
Copy link
Collaborator

Did you also permute positional encodings (on both node and edge level) that are needed for Graphormer?

@isefos
Copy link
Author

isefos commented Nov 16, 2023

Yes, I believe so, at least for the case of using the "normal" graphormer preprocessing and encoder.
The positional encoding on node level would be the in and out degrees, which I permute for my example above with:

data_p.in_degrees = data_p.in_degrees[p]
data_p.out_degrees = data_p.out_degrees[p]

And for the edge level the encodings are stored "sparsely" (spatial_types and shortest_path_types indexed by graph_index, analogously to edge_attr with edge_index). So to reflect the node permutation in the edges, I ensured that the index arrays get remapped to the new node labels. In my (very specific) given example this is accomplished by:

n = data_p.x.size(0)
data_p.edge_index += 1
data_p.edge_index[data_p.edge_index == n] = 0
data_p.graph_index += 1
data_p.graph_index[data_p.graph_index == n] = 0

The example data I am using does not have edge_attr and therefore also no shortest_path_types, but it should still work correctly for data with edge attributes using the above example (because of the change to edge_index and graph_index).

Also, the fact that the model outputs are the same for the original and permuted graph when using stable sorting indicates that the positional encodings were probably permuted correctly as well...

@luis-mueller
Copy link
Collaborator

@isefos Thank you a lot for raising this issue. Indeed, this seems to be a bug. Further, I believe that the stable sorting is exactly what we need here. I will test this on my side and also run a few experiments to see whether there is any (positive or negative) impact on performance.

@luis-mueller luis-mueller self-assigned this Nov 20, 2023
@luis-mueller luis-mueller added the bug Something isn't working label Nov 20, 2023
@luis-mueller luis-mueller linked a pull request Jan 5, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants