-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
395 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import torch | ||
import random | ||
from torch.nn import functional as F | ||
from attacks.poison_attacks import PoisonAttacker | ||
|
||
from torch_geometric.utils import to_dense_adj | ||
|
||
|
||
class CLGAAttack(PoisonAttacker): | ||
name = "CLGAAttack" | ||
|
||
def __init__(self, num_nodes, feature_shape, encoder, augmentation_set, threshold, device="cpu"): | ||
super().__init__() | ||
self.num_nodes = num_nodes | ||
self.feature_shape = feature_shape | ||
self.encoder = encoder # Differentiable encoder (e.g., GCN) | ||
self.augmentation_set = augmentation_set # Set of augmentation methods | ||
self.threshold = threshold # Maximum number of edge changes | ||
self.device = device | ||
|
||
self.modified_adj = None | ||
self.augmented_graph = None | ||
|
||
def attack(self, adj_matrix, features): | ||
""" | ||
Execute the CLGA attack on the graph to maximize contrastive loss. | ||
""" | ||
adj_matrix = to_dense_adj(adj_matrix).squeeze() | ||
current_adj = adj_matrix.clone().to(self.device) | ||
for iteration in range(self.threshold): | ||
gradients_sum = torch.zeros_like(current_adj) | ||
|
||
for _ in range(len(self.augmentation_set)): | ||
# Generate augmented views | ||
t1, t2 = random.sample(self.augmentation_set, 2) | ||
adj_view1, features_view1 = t1(current_adj, features) | ||
adj_view2, features_view2 = t2(current_adj, features) | ||
|
||
# Forward pass and compute contrastive loss | ||
embeddings1 = self.encoder(adj_view1, features_view1) | ||
embeddings2 = self.encoder(adj_view2, features_view2) | ||
loss = self.contrastive_loss(embeddings1, embeddings2) | ||
|
||
# Backpropagate to compute gradients | ||
adj_grad1 = torch.autograd.grad(loss, adj_view1, retain_graph=True)[0] | ||
adj_grad2 = torch.autograd.grad(loss, adj_view2, retain_graph=True)[0] | ||
gradients_sum += adj_grad1 + adj_grad2 | ||
|
||
# Flip the edge with the largest gradient | ||
max_gradient_index = torch.argmax(gradients_sum.abs()) | ||
row, col = divmod(max_gradient_index, current_adj.shape[1]) | ||
current_adj[row, col] = 1 - current_adj[row, col] # Flip edge | ||
current_adj[col, row] = current_adj[row, col] # Ensure symmetry | ||
|
||
# Save updated adjacency | ||
self.modified_adj = current_adj.detach() | ||
|
||
def contrastive_loss(self, embeddings1, embeddings2): | ||
""" | ||
Compute the contrastive loss based on two embeddings. | ||
""" | ||
pos_loss = F.cosine_similarity(embeddings1, embeddings2).mean() | ||
neg_loss = F.cosine_similarity(embeddings1, embeddings2.roll(shifts=1, dims=0)).mean() | ||
return -pos_loss + neg_loss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
import torch | ||
from torch_geometric.utils import dropout_adj, dense_to_sparse | ||
from attacks.poison_attacks import PoisonAttacker | ||
from attacks.CLGA.differentiable_models.gcn import GCN | ||
from attacks.CLGA.differentiable_models.model import GRACE | ||
|
||
from torch_geometric.utils import to_dense_adj | ||
from torch_geometric.nn import MessagePassing | ||
from models_builder.models_utils import apply_decorator_to_graph_layers | ||
|
||
class CLGAAttack(PoisonAttacker): | ||
name = "CLGAAttack" | ||
|
||
def __init__( | ||
self, num_nodes, feature_shape, learning_rate=0.01, num_hidden=256, num_proj_hidden=32, activation="prelu", | ||
drop_edge_rate_1=0.3, drop_edge_rate_2=0.4, tau=0.4, num_epochs=3000, weight_decay=1e-5, | ||
drop_scheme="degree", device="cpu" | ||
): | ||
super().__init__() | ||
self.num_nodes = num_nodes | ||
self.feature_shape = feature_shape | ||
self.learning_rate = learning_rate | ||
self.num_hidden = num_hidden | ||
self.num_proj_hidden = num_proj_hidden | ||
self.activation = activation | ||
self.drop_edge_rate_1 = drop_edge_rate_1 | ||
self.drop_edge_rate_2 = drop_edge_rate_2 | ||
self.tau = tau | ||
self.num_epochs = num_epochs | ||
self.weight_decay = weight_decay | ||
self.drop_scheme = drop_scheme | ||
self.device = device | ||
|
||
self.modified_adj = None | ||
self.model = None | ||
self.optimizer = None | ||
|
||
def drop_edge(self, edge_index, p): | ||
""" | ||
Perform edge dropout based on the chosen scheme. | ||
""" | ||
return dropout_adj(edge_index, p=p)[0] | ||
|
||
def train_gcn(self, data): | ||
""" | ||
Train the GCN model with augmented graphs. | ||
""" | ||
self.model.train() | ||
self.optimizer.zero_grad() | ||
edge_index_1 = self.drop_edge(data.edge_index, self.drop_edge_rate_1) | ||
edge_index_2 = self.drop_edge(data.edge_index, self.drop_edge_rate_2) | ||
x_1 = data.x.clone() | ||
x_2 = data.x.clone() | ||
|
||
z1 = self.model(x_1, edge_index_1) | ||
z2 = self.model(x_2, edge_index_2) | ||
|
||
loss = self.model.loss(z1, z2) | ||
loss.backward() | ||
self.optimizer.step() | ||
return loss.item() | ||
|
||
def compute_gradient(self, data): | ||
""" | ||
Compute gradients of the contrastive loss w.r.t. adjacency matrix. | ||
""" | ||
self.model.eval() | ||
edge_index_1 = self.drop_edge(data.edge_index, self.drop_edge_rate_1) | ||
edge_index_2 = self.drop_edge(data.edge_index, self.drop_edge_rate_2) | ||
|
||
# adj_dense_1 = torch.sparse.FloatTensor( | ||
# edge_index_1, torch.ones(edge_index_1.shape[1], device=self.device), | ||
# (self.num_nodes, self.num_nodes) | ||
# ).to_dense().requires_grad_(True) | ||
# | ||
# adj_dense_2 = torch.sparse.FloatTensor( | ||
# edge_index_2, torch.ones(edge_index_2.shape[1], device=self.device), | ||
# (self.num_nodes, self.num_nodes) | ||
# ).to_dense().requires_grad_(True) | ||
|
||
# z1 = self.model(data.x, adj_dense_1) | ||
# z2 = self.model(data.x, adj_dense_2) | ||
|
||
z1 = self.model(data.x, edge_index_1) | ||
z2 = self.model(data.x, edge_index_2) | ||
|
||
loss = self.model.loss(z1, z2) | ||
loss.backward() | ||
|
||
grad = torch.zeros_like() | ||
for name, layer in self.model.encoder.named_children(): | ||
if isinstance(layer, MessagePassing): | ||
#print(f"{name}: {layer.get_message_gradients()}") | ||
for l_name, l_grad in layer.get_message_gradients().items(): | ||
grad += 1 | ||
|
||
return edge_index_1.grad, edge_index_2.grad | ||
|
||
def attack(self, gen_dataset): | ||
""" | ||
Execute the CLGA attack. | ||
""" | ||
self.model = GRACE( | ||
encoder=GCN(self.feature_shape, self.num_hidden, 'prelu'), | ||
num_hidden=self.num_hidden, | ||
num_proj_hidden=self.num_proj_hidden, | ||
tau=self.tau | ||
).to(self.device) | ||
|
||
apply_decorator_to_graph_layers(self.model) | ||
apply_decorator_to_graph_layers(self.model.encoder) | ||
|
||
self.optimizer = torch.optim.Adam( | ||
self.model.parameters(), | ||
lr=self.learning_rate, | ||
weight_decay=self.weight_decay | ||
) | ||
|
||
perturbed_edges = [] | ||
|
||
# adj = torch.sparse.FloatTensor( | ||
# gen_dataset.dataset.data.edge_index, torch.ones(gen_dataset.dataset.data.edge_index.shape[1], device=self.device), | ||
# (self.num_nodes, self.num_nodes) | ||
# ).to_dense() | ||
|
||
adj = to_dense_adj(gen_dataset.dataset.data.edge_index).squeeze() | ||
|
||
for epoch in range(self.num_epochs): | ||
self.train_gcn(gen_dataset.dataset.data) | ||
|
||
# grad_1, grad_2 = self.compute_gradient(gen_dataset.dataset.data) | ||
# grad_sum = grad_1 + grad_2 | ||
grad_sum = self.compute_gradient(gen_dataset.dataset.data) | ||
|
||
max_grad_index = torch.argmax(torch.abs(grad_sum.view(-1))) | ||
row, col = divmod(max_grad_index.item(), self.num_nodes) | ||
|
||
if grad_sum[row, col] > 0 and adj[row, col] == 0: | ||
adj[row, col] = 1 | ||
adj[col, row] = 1 | ||
elif grad_sum[row, col] < 0 and adj[row, col] == 1: | ||
adj[row, col] = 0 | ||
adj[col, row] = 0 | ||
|
||
perturbed_edges.append((row, col)) | ||
gen_dataset.dataset.data.edge_index = dense_to_sparse(adj)[0] | ||
|
||
self.modified_adj = adj | ||
return adj |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import torch | ||
import torch.nn as nn | ||
from attacks.metattack import utils | ||
|
||
# differentiable | ||
# class GCN(nn.Module): | ||
# def __init__(self, in_ft, out_ft, act, dropout=0, bias=True): | ||
# super(GCN, self).__init__() | ||
# self.fc1 = nn.Linear(in_ft, 2*out_ft, bias=False) | ||
# self.dropout = nn.Dropout(p=dropout) | ||
# self.fc2 = nn.Linear(2*out_ft, out_ft, bias=False) | ||
# self.act = nn.PReLU() if act == 'prelu' else act | ||
# | ||
# if bias: | ||
# self.bias1 = nn.Parameter(torch.FloatTensor(2*out_ft)) | ||
# self.bias1.data.fill_(0.0) | ||
# self.bias2 = nn.Parameter(torch.FloatTensor(out_ft)) | ||
# self.bias2.data.fill_(0.0) | ||
# else: | ||
# self.register_parameter('bias1', None) | ||
# self.register_parameter('bias2', None) | ||
# | ||
# for m in self.modules(): | ||
# self.weights_init(m) | ||
# | ||
# def weights_init(self, m): | ||
# if isinstance(m, nn.Linear): | ||
# torch.nn.init.xavier_uniform_(m.weight.data) | ||
# if m.bias is not None: | ||
# m.bias.data.fill_(0.0) | ||
# | ||
# # Shape of seq: (nodes, features) | ||
# def forward(self, seq, adj, sparse=False): | ||
# adj_norm = utils.normalize_adj_tensor(adj, sparse=sparse) | ||
# seq_fts1 = self.fc1(seq) | ||
# if sparse: | ||
# out1 = torch.spmm(adj_norm, seq_fts1) | ||
# else: | ||
# out1 = torch.mm(adj_norm, seq_fts1) | ||
# if self.bias1 is not None: | ||
# out1 += self.bias1 | ||
# out1 = self.act(out1) | ||
# out1 = self.dropout(out1) | ||
# | ||
# seq_fts2 = self.fc2(out1) | ||
# if sparse: | ||
# out2 = torch.spmm(adj_norm, seq_fts2) | ||
# else: | ||
# out2 = torch.mm(adj_norm, seq_fts2) | ||
# if self.bias2 is not None: | ||
# out2 += self.bias2 | ||
# return self.act(out2) | ||
|
||
from torch_geometric.nn import GCNConv | ||
|
||
class GCN(nn.Module): | ||
def __init__(self, in_ft, out_ft, act='prelu', dropout=0.0, bias=True): | ||
super(GCN, self).__init__() | ||
self.conv1 = GCNConv(in_ft, 2 * out_ft, bias=bias) | ||
self.conv2 = GCNConv(2 * out_ft, out_ft, bias=bias) | ||
self.dropout = nn.Dropout(p=dropout) | ||
self.act = nn.PReLU() if act == 'prelu' else nn.ReLU() if act == 'relu' else nn.Identity() | ||
|
||
def forward(self, x, edge_index): | ||
""" | ||
Forward pass of the GCN. | ||
Args: | ||
x (Tensor): Input feature matrix of shape [num_nodes, num_features]. | ||
edge_index (Tensor): Edge indices of shape [2, num_edges]. | ||
Returns: | ||
Tensor: Node embeddings of shape [num_nodes, out_ft]. | ||
""" | ||
# First GCN layer | ||
x = self.conv1(x, edge_index) | ||
x = self.act(x) | ||
x = self.dropout(x) | ||
|
||
# Second GCN layer | ||
x = self.conv2(x, edge_index) | ||
x = self.act(x) | ||
return x |
Oops, something went wrong.