From e078557c6fd2867f535bd527d983c88fa3cc99a6 Mon Sep 17 00:00:00 2001 From: abhhfcgjk Date: Sun, 13 Oct 2024 18:52:15 +0300 Subject: [PATCH 01/13] GNNGuard poison defense --- src/defense/GNNGuard/gnnguard.py | 382 +++++++++++++++ src/defense/GNNGuard/utils.py | 777 +++++++++++++++++++++++++++++++ src/defense/RGCN/rgcn.py | 324 +++++++++++++ src/defense/RGCN/utils.py | 777 +++++++++++++++++++++++++++++++ 4 files changed, 2260 insertions(+) create mode 100644 src/defense/GNNGuard/gnnguard.py create mode 100644 src/defense/GNNGuard/utils.py create mode 100644 src/defense/RGCN/rgcn.py create mode 100644 src/defense/RGCN/utils.py diff --git a/src/defense/GNNGuard/gnnguard.py b/src/defense/GNNGuard/gnnguard.py new file mode 100644 index 0000000..337853f --- /dev/null +++ b/src/defense/GNNGuard/gnnguard.py @@ -0,0 +1,382 @@ +import torch.nn as nn +import torch.nn.functional as F +import math +import torch +from torch.nn.parameter import Parameter +from torch.nn.modules.module import Module +from torch_geometric.nn import GCNConv + +# from defense.GNNGuard.base_model import BaseModel +from defense.poison_defense import PoisonDefender + +from models_builder.gnn_models import FrameworkGNNModelManager +from models_builder.gnn_constructor import FrameworkGNNConstructor +from models_builder.models_zoo import model_configs_zoo +from aux.configs import ModelManagerConfig, ModelModificationConfig, DatasetConfig, DatasetVarConfig, ConfigPattern +from aux.utils import import_by_name, CUSTOM_LAYERS_INFO_PATH, MODULES_PARAMETERS_PATH, hash_data_sha256, \ + TECHNICAL_PARAMETER_KEY, IMPORT_INFO_KEY, OPTIMIZERS_PARAMETERS_PATH + + +import warnings +import types +# from torch_sparse import coalesce, SparseTensor, matmul + +from sklearn.metrics.pairwise import cosine_similarity +from sklearn.preprocessing import normalize + +from scipy.sparse import lil_matrix +import scipy.sparse as sp + +import numpy as np + +from src.aux.configs import ModelConfig + + +# class BaseGCN(BaseModel): + +# def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, +# with_bn=False, weight_decay=5e-4, with_bias=True, device=None): + +# super(BaseGCN, self).__init__() + +# assert device is not None, "Please specify 'device'!" +# self.device = device + +# self.layers = nn.ModuleList([]) +# if with_bn: +# self.bns = nn.ModuleList() + +# if nlayers == 1: +# self.layers.append(GCNConv(nfeat, nclass, bias=with_bias)) +# else: +# self.layers.append(GCNConv(nfeat, nhid, bias=with_bias)) +# if with_bn: +# self.bns.append(nn.BatchNorm1d(nhid)) +# for i in range(nlayers-2): +# self.layers.append(GCNConv(nhid, nhid, bias=with_bias)) +# if with_bn: +# self.bns.append(nn.BatchNorm1d(nhid)) +# self.layers.append(GCNConv(nhid, nclass, bias=with_bias)) + +# self.dropout = dropout +# self.weight_decay = weight_decay +# self.lr = lr +# self.output = None +# self.best_model = None +# self.best_output = None +# self.with_bn = with_bn +# self.name = 'GCN' + +# def forward(self, x, edge_index, edge_weight=None): +# x, edge_index, edge_weight = self._ensure_contiguousness(x, edge_index, edge_weight) +# for ii, layer in enumerate(self.layers): +# if edge_weight is not None: +# adj = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=2 * x.shape[:1]).t() +# x = layer(x, adj) +# else: +# x = layer(x, edge_index) +# if ii != len(self.layers) - 1: +# if self.with_bn: +# x = self.bns[ii](x) +# x = F.relu(x) +# x = F.dropout(x, p=self.dropout, training=self.training) +# return F.log_softmax(x, dim=1) + +# def get_embed(self, x, edge_index, edge_weight=None): +# x, edge_index, edge_weight = self._ensure_contiguousness(x, edge_index, edge_weight) +# for ii, layer in enumerate(self.layers): +# if ii == len(self.layers) - 1: +# return x +# if edge_weight is not None: +# adj = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=2 * x.shape[:1]).t() +# x = layer(x, adj) +# else: +# x = layer(x, edge_index) +# if ii != len(self.layers) - 1: +# if self.with_bn: +# x = self.bns[ii](x) +# x = F.relu(x) +# return x + +# def initialize(self): +# for m in self.layers: +# m.reset_parameters() +# if self.with_bn: +# for bn in self.bns: +# bn.reset_parameters() + + +class BaseGNNGuard(PoisonDefender): + name = "BaseGNNGuard" + + def __init__(self, lr=0.1, train_iters=200, device='cpu'): + super().__init__() + self.model = None + self.lr = lr + self.device = device + self.train_iters = train_iters + + def defense(self, gen_dataset): + self.model = model_configs_zoo(gen_dataset, 'gcn_gcn_linearized') + default_config = ModelModificationConfig( + model_ver_ind=0, + ) + manager_config = ConfigPattern( + _config_class="ModelManagerConfig", + _config_kwargs={ + "mask_features": [], + "optimizer": { + "_config_class": "Config", + "_class_name": "Adam", + "_import_path": OPTIMIZERS_PARAMETERS_PATH, + "_class_import_info": ["torch.optim"], + "_config_kwargs": {"weight_decay": 5e-4}, + } + } + ) + gnn_model_manager_surrogate = FrameworkGNNModelManager( + gnn=self.model, + dataset_path=gen_dataset, + modification=default_config, + manager_config=manager_config, + ) + + gnn_model_manager_surrogate.train_model(gen_dataset=gen_dataset, steps=self.train_iters) + + self.pred_labels = gnn_model_manager_surrogate.run_model(gen_dataset=gen_dataset, mask='all', out='answers') + + +class GNNGuard(BaseGNNGuard): + name = 'GNNGuard' + + def __init__(self, lr=0.01, attention=False, drop=False, train_iters=200, device='cpu', with_bias=False, with_relu=False): + super().__init__(lr=lr, train_iters=train_iters, device=device) + assert device is not None, "Please specify 'device'!" + self.with_bias = with_bias + self.with_relu = with_relu + self.attention = attention + self.drop = drop + + def defense(self, gen_dataset): + super().defense(gen_dataset=gen_dataset) + + self.hidden_sizes = [16] + self.nfeat = gen_dataset.num_node_features + self.nclass = gen_dataset.num_classes + # self.attention = attention + # self.drop = drop + + self.wrap() + self.initialize() + + def wrap(self): + def att_coef(self, fea, edge_index, is_lil=False, i=0): + if is_lil == False: + edge_index = edge_index._indices() + else: + edge_index = edge_index.tocoo() + + n_node = fea.shape[0] + row, col = edge_index[0].cpu().data.numpy()[:], edge_index[1].cpu().data.numpy()[:] + + fea_copy = fea.cpu().data.numpy() + sim_matrix = cosine_similarity(X=fea_copy, Y=fea_copy) # try cosine similarity + sim = sim_matrix[row, col] + sim[sim<0.1] = 0 + + """build a attention matrix""" + att_dense = lil_matrix((n_node, n_node), dtype=np.float32) + att_dense[row, col] = sim + if att_dense[0, 0] == 1: + att_dense = att_dense - sp.diags(att_dense.diagonal(), offsets=0, format="lil") + # normalization, make the sum of each row is 1 + att_dense_norm = normalize(att_dense, axis=1, norm='l1') + + + """add learnable dropout, make character vector""" + if self.drop: + character = np.vstack((att_dense_norm[row, col].A1, + att_dense_norm[col, row].A1)) + character = torch.from_numpy(character.T) + drop_score = self.drop_learn_1(character) + drop_score = torch.sigmoid(drop_score) # do not use softmax since we only have one element + mm = torch.nn.Threshold(0.5, 0) + drop_score = mm(drop_score) + mm_2 = torch.nn.Threshold(-0.49, 1) + drop_score = mm_2(-drop_score) + drop_decision = drop_score.clone().requires_grad_() + # print('rate of left edges', drop_decision.sum().data/drop_decision.shape[0]) + drop_matrix = lil_matrix((n_node, n_node), dtype=np.float32) + drop_matrix[row, col] = drop_decision.cpu().data.numpy().squeeze(-1) + att_dense_norm = att_dense_norm.multiply(drop_matrix.tocsr()) # update, remove the 0 edges + + if att_dense_norm[0, 0] == 0: # add the weights of self-loop only add self-loop at the first layer + degree = (att_dense_norm != 0).sum(1).A1 + lam = 1 / (degree + 1) # degree +1 is to add itself + self_weight = sp.diags(np.array(lam), offsets=0, format="lil") + att = att_dense_norm + self_weight # add the self loop + else: + att = att_dense_norm + + row, col = att.nonzero() + att_adj = np.vstack((row, col)) + att_edge_weight = att[row, col] + att_edge_weight = np.exp(att_edge_weight) # exponent, kind of softmax + att_edge_weight = torch.tensor(np.array(att_edge_weight)[0], dtype=torch.float32)#.cuda() + att_adj = torch.tensor(att_adj, dtype=torch.int64)#.cuda() + + shape = (n_node, n_node) + new_adj = torch.sparse.FloatTensor(att_adj, att_edge_weight, shape) + return new_adj + + def forward(self, *args, **kwargs): + """we don't change the edge_index, just update the edge_weight; + some edge_weight are regarded as removed if it equals to zero""" + layer_ind = -1 + tensor_storage = {} + dim_cat = 0 + layer_emb_dict = {} + save_emb_flag = self._save_emb_flag + + x, edge_index, batch = self.arguments_read(*args, **kwargs) + feat = x + adj = edge_index.tocoo() + adj_memory = None + # print(list(self.__dict__['_modules'].items())) + for elem in list(self.__dict__['_modules'].items()): + layer_name, curr_layer_ind = elem[0].split('_') + curr_layer_ind = int(curr_layer_ind) + inp = torch.clone(x) + loc_flag = False + if curr_layer_ind != layer_ind: + if save_emb_flag: + loc_flag = True + zeroing_x_flag = False + for key, value in self.conn_dict.items(): + if key[0] == layer_ind and layer_ind not in tensor_storage: + tensor_storage[layer_ind] = torch.clone(x) + layer_ind = curr_layer_ind + x_copy = torch.clone(x) + connection_tensor = torch.Tensor() + for key, value in self.conn_dict.items(): + + if key[1] == curr_layer_ind: + if key[1] - key[0] == 1: + zeroing_x_flag = True + for con in self.conn_dict[key]: + if self.embedding_levels_by_layers[key[1]] == 'n' and \ + self.embedding_levels_by_layers[key[0]] == 'n': + connection_tensor = torch.cat((connection_tensor, + tensor_storage[key[0]]), 1) + dim_cat = 1 + elif self.embedding_levels_by_layers[key[1]] == 'g' and \ + self.embedding_levels_by_layers[key[0]] == 'g': + connection_tensor = torch.cat((connection_tensor, + tensor_storage[key[0]]), 0) + dim_cat = 0 + elif self.embedding_levels_by_layers[key[1]] == 'g' and \ + self.embedding_levels_by_layers[key[0]] == 'n': + con_pool = import_by_name(con['pool']['pool_type'], + ["torch_geometric.nn"]) + tensor_after_pool = con_pool(tensor_storage[key[0]], batch) + connection_tensor = torch.cat((connection_tensor, + tensor_after_pool), 1) + dim_cat = 1 + else: + raise Exception( + "Connection from layer type " + f"{self.embedding_levels_by_layers[curr_layer_ind - 1]} to" + f" layer type {self.embedding_levels_by_layers[curr_layer_ind]}" + "is not supported now") + + + if zeroing_x_flag: + x = connection_tensor + else: + x = torch.cat((x_copy, connection_tensor), dim_cat) + + + if self.attention: + if layer_name == 'GINConv': + if adj_memory is None: + adj = self.att_coef(x, adj, is_lil=False,i=layer_ind) + edge_index = adj._indices() + edge_weight = adj._values() + adj_memory = adj + elif adj_memory is not None: + adj = self.att_coef(x, adj_memory, is_lil=False, i=layer_ind) + edge_weight = self.gate * adj_memory._values() + (1 - self.gate) * adj._values() + adj_memory = adj + elif layer_name == 'GCNConv' or layer_name == 'GATConv': + if adj_memory is None: + adj = self.att_coef(x, adj, i=0) + edge_index = adj._indices() + edge_weight = adj._values() + adj_memory = adj + elif adj_memory is not None: + adj = self.att_coef(x, adj_memory, i=layer_ind).to_dense() + row, col = adj.nonzero()[:,0], adj.nonzero()[:,1] + edge_index = torch.stack((row, col), dim=0) + edge_weight = adj[row, col] + adj_memory = adj + else: + edge_index = adj._indices() + edge_weight = adj._values() + + + # QUE Kirill, maybe we should not off UserWarning + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning) + # mid = x + if layer_name in self.modules_info: + code_str = f"getattr(self, elem[0])" \ + f"({self.modules_info[layer_name][TECHNICAL_PARAMETER_KEY]['forward_parameters']}," \ + f" edge_weight=edge_weight)" + x = eval(f"{code_str}") + else: + x = getattr(self, elem[0])(x) + if loc_flag: + layer_emb_dict[layer_ind] = torch.clone(x) + + if save_emb_flag: + return layer_emb_dict + return x + + self.model.drop = self.drop + self.model.attention = self.attention + self.model.gate = Parameter(torch.rand(1)) + self.model.drop_learn_1 = nn.Linear(2, 1) + self.model.drop_learn_2 = nn.Linear(2, 1) + self.model.forward = types.MethodType(forward, self.model) + self.model.att_coef = types.MethodType(att_coef, self.model) + + def initialize(self): + self.model.drop_learn_1.reset_parameters() + self.model.drop_learn_2.reset_parameters() + # self.model.gate.reset_parameters() + +# if __name__ == "__main__": +# from deeprobust.graph.data import Dataset, Dpr2Pyg +# # from deeprobust.graph.defense import GCN +# data = Dataset(root='/tmp/', name='citeseer', setting='prognn') +# adj, features, labels = data.adj, data.features, data.labels +# idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test +# model = GCN(nfeat=features.shape[1], +# nhid=16, +# nclass=labels.max().item() + 1, +# dropout=0.5, device='cuda') +# model = model.to('cuda') +# pyg_data = Dpr2Pyg(data)[0] + +# # model.fit(features, adj, labels, idx_train, train_iters=200, verbose=True) +# # model.test(idx_test) + +# from utils import get_dataset +# pyg_data = get_dataset('citeseer', True, if_dpr=False)[0] + +# import ipdb +# ipdb.set_trace() + +# model.fit(pyg_data, verbose=True) # train with earlystopping +# model.test() +# print(model.predict()) \ No newline at end of file diff --git a/src/defense/GNNGuard/utils.py b/src/defense/GNNGuard/utils.py new file mode 100644 index 0000000..f7be00a --- /dev/null +++ b/src/defense/GNNGuard/utils.py @@ -0,0 +1,777 @@ +import numpy as np +import scipy.sparse as sp +import torch +from sklearn.model_selection import train_test_split +import torch.sparse as ts +import torch.nn.functional as F +import warnings + +def encode_onehot(labels): + """Convert label to onehot format. + + Parameters + ---------- + labels : numpy.array + node labels + + Returns + ------- + numpy.array + onehot labels + """ + eye = np.eye(labels.max() + 1) + onehot_mx = eye[labels] + return onehot_mx + +def tensor2onehot(labels): + """Convert label tensor to label onehot tensor. + + Parameters + ---------- + labels : torch.LongTensor + node labels + + Returns + ------- + torch.LongTensor + onehot labels tensor + + """ + + eye = torch.eye(labels.max() + 1) + onehot_mx = eye[labels] + return onehot_mx.to(labels.device) + +def preprocess(adj, features, labels, preprocess_adj=False, preprocess_feature=False, sparse=False, device='cpu'): + """Convert adj, features, labels from array or sparse matrix to + torch Tensor, and normalize the input data. + + Parameters + ---------- + adj : scipy.sparse.csr_matrix + the adjacency matrix. + features : scipy.sparse.csr_matrix + node features + labels : numpy.array + node labels + preprocess_adj : bool + whether to normalize the adjacency matrix + preprocess_feature : bool + whether to normalize the feature matrix + sparse : bool + whether to return sparse tensor + device : str + 'cpu' or 'cuda' + """ + + if preprocess_adj: + adj = normalize_adj(adj) + + if preprocess_feature: + features = normalize_feature(features) + + labels = torch.LongTensor(labels) + if sparse: + adj = sparse_mx_to_torch_sparse_tensor(adj) + features = sparse_mx_to_torch_sparse_tensor(features) + else: + if sp.issparse(features): + features = torch.FloatTensor(np.array(features.todense())) + else: + features = torch.FloatTensor(features) + adj = torch.FloatTensor(adj.todense()) + return adj.to(device), features.to(device), labels.to(device) + +def to_tensor(adj, features, labels=None, device='cpu'): + """Convert adj, features, labels from array or sparse matrix to + torch Tensor. + + Parameters + ---------- + adj : scipy.sparse.csr_matrix + the adjacency matrix. + features : scipy.sparse.csr_matrix + node features + labels : numpy.array + node labels + device : str + 'cpu' or 'cuda' + """ + if sp.issparse(adj): + adj = sparse_mx_to_torch_sparse_tensor(adj) + else: + adj = torch.FloatTensor(adj) + if sp.issparse(features): + features = sparse_mx_to_torch_sparse_tensor(features) + else: + features = torch.FloatTensor(np.array(features)) + + if labels is None: + return adj.to(device), features.to(device) + else: + labels = torch.LongTensor(labels) + return adj.to(device), features.to(device), labels.to(device) + +def normalize_feature(mx): + """Row-normalize sparse matrix or dense matrix + + Parameters + ---------- + mx : scipy.sparse.csr_matrix or numpy.array + matrix to be normalized + + Returns + ------- + scipy.sprase.lil_matrix + normalized matrix + """ + if type(mx) is not sp.lil.lil_matrix: + try: + mx = mx.tolil() + except AttributeError: + pass + rowsum = np.array(mx.sum(1)) + r_inv = np.power(rowsum, -1).flatten() + r_inv[np.isinf(r_inv)] = 0. + r_mat_inv = sp.diags(r_inv) + mx = r_mat_inv.dot(mx) + return mx + +def normalize_adj(mx): + """Normalize sparse adjacency matrix, + A' = (D + I)^-1/2 * ( A + I ) * (D + I)^-1/2 + Row-normalize sparse matrix + + Parameters + ---------- + mx : scipy.sparse.csr_matrix + matrix to be normalized + + Returns + ------- + scipy.sprase.lil_matrix + normalized matrix + """ + + # TODO: maybe using coo format would be better? + if type(mx) is not sp.lil.lil_matrix: + mx = mx.tolil() + if mx[0, 0] == 0 : + mx = mx + sp.eye(mx.shape[0]) + rowsum = np.array(mx.sum(1)) + r_inv = np.power(rowsum, -1/2).flatten() + r_inv[np.isinf(r_inv)] = 0. + r_mat_inv = sp.diags(r_inv) + mx = r_mat_inv.dot(mx) + mx = mx.dot(r_mat_inv) + return mx + +def normalize_sparse_tensor(adj, fill_value=1): + """Normalize sparse tensor. Need to import torch_scatter + """ + edge_index = adj._indices() + edge_weight = adj._values() + num_nodes= adj.size(0) + edge_index, edge_weight = add_self_loops( + edge_index, edge_weight, fill_value, num_nodes) + + row, col = edge_index + from torch_scatter import scatter_add + deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) + deg_inv_sqrt = deg.pow(-0.5) + deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 + + values = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] + + shape = adj.shape + return torch.sparse.FloatTensor(edge_index, values, shape) + +def add_self_loops(edge_index, edge_weight=None, fill_value=1, num_nodes=None): + # num_nodes = maybe_num_nodes(edge_index, num_nodes) + + loop_index = torch.arange(0, num_nodes, dtype=torch.long, + device=edge_index.device) + loop_index = loop_index.unsqueeze(0).repeat(2, 1) + + if edge_weight is not None: + assert edge_weight.numel() == edge_index.size(1) + loop_weight = edge_weight.new_full((num_nodes, ), fill_value) + edge_weight = torch.cat([edge_weight, loop_weight], dim=0) + + edge_index = torch.cat([edge_index, loop_index], dim=1) + + return edge_index, edge_weight + +def normalize_adj_tensor(adj, sparse=False): + """Normalize adjacency tensor matrix. + """ + device = adj.device + if sparse: + # warnings.warn('If you find the training process is too slow, you can uncomment line 207 in deeprobust/graph/utils.py. Note that you need to install torch_sparse') + # TODO if this is too slow, uncomment the following code, + # but you need to install torch_scatter + # return normalize_sparse_tensor(adj) + adj = to_scipy(adj) + mx = normalize_adj(adj) + return sparse_mx_to_torch_sparse_tensor(mx).to(device) + else: + mx = adj + torch.eye(adj.shape[0]).to(device) + rowsum = mx.sum(1) + r_inv = rowsum.pow(-1/2).flatten() + r_inv[torch.isinf(r_inv)] = 0. + r_mat_inv = torch.diag(r_inv) + mx = r_mat_inv @ mx + mx = mx @ r_mat_inv + return mx + +def degree_normalize_adj(mx): + """Row-normalize sparse matrix""" + mx = mx.tolil() + if mx[0, 0] == 0 : + mx = mx + sp.eye(mx.shape[0]) + rowsum = np.array(mx.sum(1)) + r_inv = np.power(rowsum, -1).flatten() + r_inv[np.isinf(r_inv)] = 0. + r_mat_inv = sp.diags(r_inv) + # mx = mx.dot(r_mat_inv) + mx = r_mat_inv.dot(mx) + return mx + +def degree_normalize_sparse_tensor(adj, fill_value=1): + """degree_normalize_sparse_tensor. + """ + edge_index = adj._indices() + edge_weight = adj._values() + num_nodes= adj.size(0) + + edge_index, edge_weight = add_self_loops( + edge_index, edge_weight, fill_value, num_nodes) + + row, col = edge_index + from torch_scatter import scatter_add + deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) + deg_inv_sqrt = deg.pow(-1) + deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 + + values = deg_inv_sqrt[row] * edge_weight + shape = adj.shape + return torch.sparse.FloatTensor(edge_index, values, shape) + +def degree_normalize_adj_tensor(adj, sparse=True): + """degree_normalize_adj_tensor. + """ + + device = adj.device + if sparse: + # return degree_normalize_sparse_tensor(adj) + adj = to_scipy(adj) + mx = degree_normalize_adj(adj) + return sparse_mx_to_torch_sparse_tensor(mx).to(device) + else: + mx = adj + torch.eye(adj.shape[0]).to(device) + rowsum = mx.sum(1) + r_inv = rowsum.pow(-1).flatten() + r_inv[torch.isinf(r_inv)] = 0. + r_mat_inv = torch.diag(r_inv) + mx = r_mat_inv @ mx + return mx + +def accuracy(output, labels): + """Return accuracy of output compared to labels. + + Parameters + ---------- + output : torch.Tensor + output from model + labels : torch.Tensor or numpy.array + node labels + + Returns + ------- + float + accuracy + """ + if not hasattr(labels, '__len__'): + labels = [labels] + if type(labels) is not torch.Tensor: + labels = torch.LongTensor(labels) + preds = output.max(1)[1].type_as(labels) + correct = preds.eq(labels).double() + correct = correct.sum() + return correct / len(labels) + +def loss_acc(output, labels, targets, avg_loss=True): + if type(labels) is not torch.Tensor: + labels = torch.LongTensor(labels) + preds = output.max(1)[1].type_as(labels) + correct = preds.eq(labels).double()[targets] + loss = F.nll_loss(output[targets], labels[targets], reduction='mean' if avg_loss else 'none') + + if avg_loss: + return loss, correct.sum() / len(targets) + return loss, correct + # correct = correct.sum() + # return loss, correct / len(labels) + +def get_perf(output, labels, mask, verbose=True): + """evalute performance for test masked data""" + loss = F.nll_loss(output[mask], labels[mask]) + acc = accuracy(output[mask], labels[mask]) + if verbose: + print("loss= {:.4f}".format(loss.item()), + "accuracy= {:.4f}".format(acc.item())) + return loss.item(), acc.item() + + +def classification_margin(output, true_label): + """Calculate classification margin for outputs. + `probs_true_label - probs_best_second_class` + + Parameters + ---------- + output: torch.Tensor + output vector (1 dimension) + true_label: int + true label for this node + + Returns + ------- + list + classification margin for this node + """ + + probs = torch.exp(output) + probs_true_label = probs[true_label].clone() + probs[true_label] = 0 + probs_best_second_class = probs[probs.argmax()] + return (probs_true_label - probs_best_second_class).item() + +def sparse_mx_to_torch_sparse_tensor(sparse_mx): + """Convert a scipy sparse matrix to a torch sparse tensor.""" + sparse_mx = sparse_mx.tocoo().astype(np.float32) + sparserow=torch.LongTensor(sparse_mx.row).unsqueeze(1) + sparsecol=torch.LongTensor(sparse_mx.col).unsqueeze(1) + sparseconcat=torch.cat((sparserow, sparsecol),1) + sparsedata=torch.FloatTensor(sparse_mx.data) + return torch.sparse.FloatTensor(sparseconcat.t(),sparsedata,torch.Size(sparse_mx.shape)) + + # slower version.... + # sparse_mx = sparse_mx.tocoo().astype(np.float32) + # indices = torch.from_numpy( + # np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) + # values = torch.from_numpy(sparse_mx.data) + # shape = torch.Size(sparse_mx.shape) + # return torch.sparse.FloatTensor(indices, values, shape) + + + +def to_scipy(tensor): + """Convert a dense/sparse tensor to scipy matrix""" + if is_sparse_tensor(tensor): + values = tensor._values() + indices = tensor._indices() + return sp.csr_matrix((values.cpu().numpy(), indices.cpu().numpy()), shape=tensor.shape) + else: + indices = tensor.nonzero().t() + values = tensor[indices[0], indices[1]] + return sp.csr_matrix((values.cpu().numpy(), indices.cpu().numpy()), shape=tensor.shape) + +def is_sparse_tensor(tensor): + """Check if a tensor is sparse tensor. + + Parameters + ---------- + tensor : torch.Tensor + given tensor + + Returns + ------- + bool + whether a tensor is sparse tensor + """ + # if hasattr(tensor, 'nnz'): + if tensor.layout == torch.sparse_coo: + return True + else: + return False + +def get_train_val_test(nnodes, val_size=0.1, test_size=0.8, stratify=None, seed=None): + """This setting follows nettack/mettack, where we split the nodes + into 10% training, 10% validation and 80% testing data + + Parameters + ---------- + nnodes : int + number of nodes in total + val_size : float + size of validation set + test_size : float + size of test set + stratify : + data is expected to split in a stratified fashion. So stratify should be labels. + seed : int or None + random seed + + Returns + ------- + idx_train : + node training indices + idx_val : + node validation indices + idx_test : + node test indices + """ + + assert stratify is not None, 'stratify cannot be None!' + + if seed is not None: + np.random.seed(seed) + + idx = np.arange(nnodes) + train_size = 1 - val_size - test_size + idx_train_and_val, idx_test = train_test_split(idx, + random_state=None, + train_size=train_size + val_size, + test_size=test_size, + stratify=stratify) + + if stratify is not None: + stratify = stratify[idx_train_and_val] + + idx_train, idx_val = train_test_split(idx_train_and_val, + random_state=None, + train_size=(train_size / (train_size + val_size)), + test_size=(val_size / (train_size + val_size)), + stratify=stratify) + + return idx_train, idx_val, idx_test + +def get_train_test(nnodes, test_size=0.8, stratify=None, seed=None): + """This function returns training and test set without validation. + It can be used for settings of different label rates. + + Parameters + ---------- + nnodes : int + number of nodes in total + test_size : float + size of test set + stratify : + data is expected to split in a stratified fashion. So stratify should be labels. + seed : int or None + random seed + + Returns + ------- + idx_train : + node training indices + idx_test : + node test indices + """ + assert stratify is not None, 'stratify cannot be None!' + + if seed is not None: + np.random.seed(seed) + + idx = np.arange(nnodes) + train_size = 1 - test_size + idx_train, idx_test = train_test_split(idx, random_state=None, + train_size=train_size, + test_size=test_size, + stratify=stratify) + + return idx_train, idx_test + +def get_train_val_test_gcn(labels, seed=None): + """This setting follows gcn, where we randomly sample 20 instances for each class + as training data, 500 instances as validation data, 1000 instances as test data. + Note here we are not using fixed splits. When random seed changes, the splits + will also change. + + Parameters + ---------- + labels : numpy.array + node labels + seed : int or None + random seed + + Returns + ------- + idx_train : + node training indices + idx_val : + node validation indices + idx_test : + node test indices + """ + if seed is not None: + np.random.seed(seed) + + idx = np.arange(len(labels)) + nclass = labels.max() + 1 + idx_train = [] + idx_unlabeled = [] + for i in range(nclass): + labels_i = idx[labels==i] + labels_i = np.random.permutation(labels_i) + idx_train = np.hstack((idx_train, labels_i[: 20])).astype(np.int) + idx_unlabeled = np.hstack((idx_unlabeled, labels_i[20: ])).astype(np.int) + + idx_unlabeled = np.random.permutation(idx_unlabeled) + idx_val = idx_unlabeled[: 500] + idx_test = idx_unlabeled[500: 1500] + return idx_train, idx_val, idx_test + +def get_train_test_labelrate(labels, label_rate): + """Get train test according to given label rate. + """ + nclass = labels.max() + 1 + train_size = int(round(len(labels) * label_rate / nclass)) + print("=== train_size = %s ===" % train_size) + idx_train, idx_val, idx_test = get_splits_each_class(labels, train_size=train_size) + return idx_train, idx_test + +def get_splits_each_class(labels, train_size): + """We randomly sample n instances for class, where n = train_size. + """ + idx = np.arange(len(labels)) + nclass = labels.max() + 1 + idx_train = [] + idx_val = [] + idx_test = [] + for i in range(nclass): + labels_i = idx[labels==i] + labels_i = np.random.permutation(labels_i) + idx_train = np.hstack((idx_train, labels_i[: train_size])).astype(np.int) + idx_val = np.hstack((idx_val, labels_i[train_size: 2*train_size])).astype(np.int) + idx_test = np.hstack((idx_test, labels_i[2*train_size: ])).astype(np.int) + + return np.random.permutation(idx_train), np.random.permutation(idx_val), \ + np.random.permutation(idx_test) + + +def unravel_index(index, array_shape): + rows = torch.div(index, array_shape[1], rounding_mode='trunc') + cols = index % array_shape[1] + return rows, cols + + +def get_degree_squence(adj): + try: + return adj.sum(0) + except: + return ts.sum(adj, dim=1).to_dense() + +def likelihood_ratio_filter(node_pairs, modified_adjacency, original_adjacency, d_min, threshold=0.004, undirected=True): + """ + Filter the input node pairs based on the likelihood ratio test proposed by Zügner et al. 2018, see + https://dl.acm.org/citation.cfm?id=3220078. In essence, for each node pair return 1 if adding/removing the edge + between the two nodes does not violate the unnoticeability constraint, and return 0 otherwise. Assumes unweighted + and undirected graphs. + """ + + N = int(modified_adjacency.shape[0]) + # original_degree_sequence = get_degree_squence(original_adjacency) + # current_degree_sequence = get_degree_squence(modified_adjacency) + original_degree_sequence = original_adjacency.sum(0) + current_degree_sequence = modified_adjacency.sum(0) + + concat_degree_sequence = torch.cat((current_degree_sequence, original_degree_sequence)) + + # Compute the log likelihood values of the original, modified, and combined degree sequences. + ll_orig, alpha_orig, n_orig, sum_log_degrees_original = degree_sequence_log_likelihood(original_degree_sequence, d_min) + ll_current, alpha_current, n_current, sum_log_degrees_current = degree_sequence_log_likelihood(current_degree_sequence, d_min) + + ll_comb, alpha_comb, n_comb, sum_log_degrees_combined = degree_sequence_log_likelihood(concat_degree_sequence, d_min) + + # Compute the log likelihood ratio + current_ratio = -2 * ll_comb + 2 * (ll_orig + ll_current) + + # Compute new log likelihood values that would arise if we add/remove the edges corresponding to each node pair. + new_lls, new_alphas, new_ns, new_sum_log_degrees = updated_log_likelihood_for_edge_changes(node_pairs, + modified_adjacency, d_min) + + # Combination of the original degree distribution with the distributions corresponding to each node pair. + n_combined = n_orig + new_ns + new_sum_log_degrees_combined = sum_log_degrees_original + new_sum_log_degrees + alpha_combined = compute_alpha(n_combined, new_sum_log_degrees_combined, d_min) + new_ll_combined = compute_log_likelihood(n_combined, alpha_combined, new_sum_log_degrees_combined, d_min) + new_ratios = -2 * new_ll_combined + 2 * (new_lls + ll_orig) + + # Allowed edges are only those for which the resulting likelihood ratio measure is < than the threshold + allowed_edges = new_ratios < threshold + + if allowed_edges.is_cuda: + filtered_edges = node_pairs[allowed_edges.cpu().numpy().astype(np.bool)] + else: + filtered_edges = node_pairs[allowed_edges.numpy().astype(np.bool)] + + allowed_mask = torch.zeros(modified_adjacency.shape) + allowed_mask[filtered_edges.T] = 1 + if undirected: + allowed_mask += allowed_mask.t() + return allowed_mask, current_ratio + + +def degree_sequence_log_likelihood(degree_sequence, d_min): + """ + Compute the (maximum) log likelihood of the Powerlaw distribution fit on a degree distribution. + """ + + # Determine which degrees are to be considered, i.e. >= d_min. + D_G = degree_sequence[(degree_sequence >= d_min.item())] + try: + sum_log_degrees = torch.log(D_G).sum() + except: + sum_log_degrees = np.log(D_G).sum() + n = len(D_G) + + alpha = compute_alpha(n, sum_log_degrees, d_min) + ll = compute_log_likelihood(n, alpha, sum_log_degrees, d_min) + return ll, alpha, n, sum_log_degrees + +def updated_log_likelihood_for_edge_changes(node_pairs, adjacency_matrix, d_min): + """ Adopted from https://github.com/danielzuegner/nettack + """ + # For each node pair find out whether there is an edge or not in the input adjacency matrix. + + edge_entries_before = adjacency_matrix[node_pairs.T] + degree_sequence = adjacency_matrix.sum(1) + D_G = degree_sequence[degree_sequence >= d_min.item()] + sum_log_degrees = torch.log(D_G).sum() + n = len(D_G) + deltas = -2 * edge_entries_before + 1 + d_edges_before = degree_sequence[node_pairs] + + d_edges_after = degree_sequence[node_pairs] + deltas[:, None] + + # Sum the log of the degrees after the potential changes which are >= d_min + sum_log_degrees_after, new_n = update_sum_log_degrees(sum_log_degrees, n, d_edges_before, d_edges_after, d_min) + # Updated estimates of the Powerlaw exponents + new_alpha = compute_alpha(new_n, sum_log_degrees_after, d_min) + # Updated log likelihood values for the Powerlaw distributions + new_ll = compute_log_likelihood(new_n, new_alpha, sum_log_degrees_after, d_min) + return new_ll, new_alpha, new_n, sum_log_degrees_after + + +def update_sum_log_degrees(sum_log_degrees_before, n_old, d_old, d_new, d_min): + # Find out whether the degrees before and after the change are above the threshold d_min. + old_in_range = d_old >= d_min + new_in_range = d_new >= d_min + d_old_in_range = d_old * old_in_range.float() + d_new_in_range = d_new * new_in_range.float() + + # Update the sum by subtracting the old values and then adding the updated logs of the degrees. + sum_log_degrees_after = sum_log_degrees_before - (torch.log(torch.clamp(d_old_in_range, min=1))).sum(1) \ + + (torch.log(torch.clamp(d_new_in_range, min=1))).sum(1) + + # Update the number of degrees >= d_min + + new_n = n_old - (old_in_range!=0).sum(1) + (new_in_range!=0).sum(1) + new_n = new_n.float() + return sum_log_degrees_after, new_n + +def compute_alpha(n, sum_log_degrees, d_min): + try: + alpha = 1 + n / (sum_log_degrees - n * torch.log(d_min - 0.5)) + except: + alpha = 1 + n / (sum_log_degrees - n * np.log(d_min - 0.5)) + return alpha + +def compute_log_likelihood(n, alpha, sum_log_degrees, d_min): + # Log likelihood under alpha + try: + ll = n * torch.log(alpha) + n * alpha * torch.log(d_min) + (alpha + 1) * sum_log_degrees + except: + ll = n * np.log(alpha) + n * alpha * np.log(d_min) + (alpha + 1) * sum_log_degrees + + return ll + +def ravel_multiple_indices(ixs, shape, reverse=False): + """ + "Flattens" multiple 2D input indices into indices on the flattened matrix, similar to np.ravel_multi_index. + Does the same as ravel_index but for multiple indices at once. + Parameters + ---------- + ixs: array of ints shape (n, 2) + The array of n indices that will be flattened. + + shape: list or tuple of ints of length 2 + The shape of the corresponding matrix. + + Returns + ------- + array of n ints between 0 and shape[0]*shape[1]-1 + The indices on the flattened matrix corresponding to the 2D input indices. + + """ + if reverse: + return ixs[:, 1] * shape[1] + ixs[:, 0] + + return ixs[:, 0] * shape[1] + ixs[:, 1] + +def visualize(your_var): + """visualize computation graph""" + from graphviz import Digraph + import torch + from torch.autograd import Variable + from torchviz import make_dot + make_dot(your_var).view() + +def reshape_mx(mx, shape): + indices = mx.nonzero() + return sp.csr_matrix((mx.data, (indices[0], indices[1])), shape=shape) + +def add_mask(data, dataset): + """data: ogb-arxiv pyg data format""" + # for arxiv + split_idx = dataset.get_idx_split() + train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"] + n = data.x.shape[0] + data.train_mask = index_to_mask(train_idx, n) + data.val_mask = index_to_mask(valid_idx, n) + data.test_mask = index_to_mask(test_idx, n) + data.y = data.y.squeeze() + # data.edge_index = to_undirected(data.edge_index, data.num_nodes) + +def index_to_mask(index, size): + mask = torch.zeros((size, ), dtype=torch.bool) + mask[index] = 1 + return mask + +def add_feature_noise(data, noise_ratio, seed): + np.random.seed(seed) + n, d = data.x.shape + # noise = torch.normal(mean=torch.zeros(int(noise_ratio*n), d), std=1) + noise = torch.FloatTensor(np.random.normal(0, 1, size=(int(noise_ratio*n), d))).to(data.x.device) + indices = np.arange(n) + indices = np.random.permutation(indices)[: int(noise_ratio*n)] + delta_feat = torch.zeros_like(data.x) + delta_feat[indices] = noise - data.x[indices] + data.x[indices] = noise + mask = np.zeros(n) + mask[indices] = 1 + mask = torch.tensor(mask).bool().to(data.x.device) + return delta_feat, mask + +def add_feature_noise_test(data, noise_ratio, seed): + np.random.seed(seed) + n, d = data.x.shape + indices = np.arange(n) + test_nodes = indices[data.test_mask.cpu()] + selected = np.random.permutation(test_nodes)[: int(noise_ratio*len(test_nodes))] + noise = torch.FloatTensor(np.random.normal(0, 1, size=(int(noise_ratio*len(test_nodes)), d))) + noise = noise.to(data.x.device) + + delta_feat = torch.zeros_like(data.x) + delta_feat[selected] = noise - data.x[selected] + data.x[selected] = noise + # mask = np.zeros(len(test_nodes)) + mask = np.zeros(n) + mask[selected] = 1 + mask = torch.tensor(mask).bool().to(data.x.device) + return delta_feat, mask + +# def check_path(file_path): +# if not osp.exists(file_path): +# os.system(f'mkdir -p {file_path}') \ No newline at end of file diff --git a/src/defense/RGCN/rgcn.py b/src/defense/RGCN/rgcn.py new file mode 100644 index 0000000..8ec8206 --- /dev/null +++ b/src/defense/RGCN/rgcn.py @@ -0,0 +1,324 @@ + +# import torch.nn.functional as F +# import math +# import torch +# from torch.nn.parameter import Parameter +# from torch.nn.modules.module import Module +# from torch.distributions.multivariate_normal import MultivariateNormal +# # from deeprobust.graph import utils +# import defense.RGCN.utils as utils +# import torch.optim as optim +# from copy import deepcopy + +# # TODO sparse implementation + +# class GGCL_F(Module): +# """Graph Gaussian Convolution Layer (GGCL) when the input is feature""" + +# def __init__(self, in_features, out_features, dropout=0.6): +# super(GGCL_F, self).__init__() +# self.in_features = in_features +# self.out_features = out_features +# self.dropout = dropout +# self.weight_miu = Parameter(torch.FloatTensor(in_features, out_features)) +# self.weight_sigma = Parameter(torch.FloatTensor(in_features, out_features)) +# self.reset_parameters() + +# def reset_parameters(self): +# torch.nn.init.xavier_uniform_(self.weight_miu) +# torch.nn.init.xavier_uniform_(self.weight_sigma) + +# def forward(self, features, adj_norm1, adj_norm2, gamma=1): +# features = F.dropout(features, self.dropout, training=self.training) +# self.miu = F.elu(torch.mm(features, self.weight_miu)) +# self.sigma = F.relu(torch.mm(features, self.weight_sigma)) +# # torch.mm(previous_sigma, self.weight_sigma) +# Att = torch.exp(-gamma * self.sigma) +# miu_out = adj_norm1 @ (self.miu * Att) +# sigma_out = adj_norm2 @ (self.sigma * Att * Att) +# return miu_out, sigma_out + +# class GGCL_D(Module): + +# """Graph Gaussian Convolution Layer (GGCL) when the input is distribution""" +# def __init__(self, in_features, out_features, dropout): +# super(GGCL_D, self).__init__() +# self.in_features = in_features +# self.out_features = out_features +# self.dropout = dropout +# self.weight_miu = Parameter(torch.FloatTensor(in_features, out_features)) +# self.weight_sigma = Parameter(torch.FloatTensor(in_features, out_features)) +# # self.register_parameter('bias', None) +# self.reset_parameters() + +# def reset_parameters(self): +# torch.nn.init.xavier_uniform_(self.weight_miu) +# torch.nn.init.xavier_uniform_(self.weight_sigma) + +# def forward(self, miu, sigma, adj_norm1, adj_norm2, gamma=1): +# miu = F.dropout(miu, self.dropout, training=self.training) +# sigma = F.dropout(sigma, self.dropout, training=self.training) +# miu = F.elu(miu @ self.weight_miu) +# sigma = F.relu(sigma @ self.weight_sigma) + +# Att = torch.exp(-gamma * sigma) +# mean_out = adj_norm1 @ (miu * Att) +# sigma_out = adj_norm2 @ (sigma * Att * Att) +# return mean_out, sigma_out + + +# class GaussianConvolution(Module): +# """[Deprecated] Alternative gaussion convolution layer. +# """ + +# def __init__(self, in_features, out_features): +# super(GaussianConvolution, self).__init__() +# self.in_features = in_features +# self.out_features = out_features +# self.weight_miu = Parameter(torch.FloatTensor(in_features, out_features)) +# self.weight_sigma = Parameter(torch.FloatTensor(in_features, out_features)) +# # self.sigma = Parameter(torch.FloatTensor(out_features)) +# # self.register_parameter('bias', None) +# self.reset_parameters() + +# def reset_parameters(self): +# # TODO +# torch.nn.init.xavier_uniform_(self.weight_miu) +# torch.nn.init.xavier_uniform_(self.weight_sigma) + +# def forward(self, previous_miu, previous_sigma, adj_norm1=None, adj_norm2=None, gamma=1): + +# if adj_norm1 is None and adj_norm2 is None: +# return torch.mm(previous_miu, self.weight_miu), \ +# torch.mm(previous_miu, self.weight_miu) +# # torch.mm(previous_sigma, self.weight_sigma) + +# Att = torch.exp(-gamma * previous_sigma) +# M = adj_norm1 @ (previous_miu * Att) @ self.weight_miu +# Sigma = adj_norm2 @ (previous_sigma * Att * Att) @ self.weight_sigma +# return M, Sigma + +# # M = torch.mm(torch.mm(adj, previous_miu * A), self.weight_miu) +# # Sigma = torch.mm(torch.mm(adj, previous_sigma * A * A), self.weight_sigma) + +# # TODO sparse implemention +# # support = torch.mm(input, self.weight) +# # output = torch.spmm(adj, support) +# # return output + self.bias + +# def __repr__(self): +# return self.__class__.__name__ + ' (' \ +# + str(self.in_features) + ' -> ' \ +# + str(self.out_features) + ')' + + +# class RGCN(Module): +# """Robust Graph Convolutional Networks Against Adversarial Attacks. KDD 2019. + +# Parameters +# ---------- +# nnodes : int +# number of nodes in the input grpah +# nfeat : int +# size of input feature dimension +# nhid : int +# number of hidden units +# nclass : int +# size of output dimension +# gamma : float +# hyper-parameter for RGCN. See more details in the paper. +# beta1 : float +# hyper-parameter for RGCN. See more details in the paper. +# beta2 : float +# hyper-parameter for RGCN. See more details in the paper. +# lr : float +# learning rate for GCN +# dropout : float +# dropout rate for GCN +# device: str +# 'cpu' or 'cuda'. + +# """ + +# def __init__(self, nnodes, nfeat, nhid, nclass, gamma=1.0, beta1=5e-4, beta2=5e-4, lr=0.01, dropout=0.6, device='cpu'): +# super(RGCN, self).__init__() + +# self.device = device +# # adj_norm = normalize(adj) +# # first turn original features to distribution +# self.lr = lr +# self.gamma = gamma +# self.beta1 = beta1 +# self.beta2 = beta2 +# self.nclass = nclass +# self.nhid = nhid // 2 +# # self.gc1 = GaussianConvolution(nfeat, nhid, dropout=dropout) +# # self.gc2 = GaussianConvolution(nhid, nclass, dropout) +# self.gc1 = GGCL_F(nfeat, nhid, dropout=dropout) +# self.gc2 = GGCL_D(nhid, nclass, dropout=dropout) + +# self.dropout = dropout +# # self.gaussian = MultivariateNormal(torch.zeros(self.nclass), torch.eye(self.nclass)) +# self.gaussian = MultivariateNormal(torch.zeros(nnodes, self.nclass), +# torch.diag_embed(torch.ones(nnodes, self.nclass))) +# self.adj_norm1, self.adj_norm2 = None, None +# self.features, self.labels = None, None + +# def forward(self): +# features = self.features +# miu, sigma = self.gc1(features, self.adj_norm1, self.adj_norm2, self.gamma) +# miu, sigma = self.gc2(miu, sigma, self.adj_norm1, self.adj_norm2, self.gamma) +# output = miu + self.gaussian.sample().to(self.device) * torch.sqrt(sigma + 1e-8) +# return F.log_softmax(output, dim=1) + +# def fit(self, features, adj, labels, idx_train, idx_val=None, train_iters=200, verbose=True, **kwargs): +# """Train RGCN. + +# Parameters +# ---------- +# features : +# node features +# adj : +# the adjacency matrix. The format could be torch.tensor or scipy matrix +# labels : +# node labels +# idx_train : +# node training indices +# idx_val : +# node validation indices. If not given (None), GCN training process will not adpot early stopping +# train_iters : int +# number of training epochs +# verbose : bool +# whether to show verbose logs + +# Examples +# -------- +# We can first load dataset and then train RGCN. + +# >>> from deeprobust.graph.data import PrePtbDataset, Dataset +# >>> from deeprobust.graph.defense import RGCN +# >>> # load clean graph data +# >>> data = Dataset(root='/tmp/', name='cora', seed=15) +# >>> adj, features, labels = data.adj, data.features, data.labels +# >>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test +# >>> # load perturbed graph data +# >>> perturbed_data = PrePtbDataset(root='/tmp/', name='cora') +# >>> perturbed_adj = perturbed_data.adj +# >>> # train defense model +# >>> model = RGCN(nnodes=perturbed_adj.shape[0], nfeat=features.shape[1], +# nclass=labels.max()+1, nhid=32, device='cpu') +# >>> model.fit(features, perturbed_adj, labels, idx_train, idx_val, +# train_iters=200, verbose=True) +# >>> model.test(idx_test) + +# """ + +# adj, features, labels = utils.to_tensor(adj.todense(), features.todense(), labels, device=self.device) + +# self.features, self.labels = features, labels +# self.adj_norm1 = self._normalize_adj(adj, power=-1/2) +# self.adj_norm2 = self._normalize_adj(adj, power=-1) +# print('=== training rgcn model ===') +# self._initialize() +# if idx_val is None: +# self._train_without_val(labels, idx_train, train_iters, verbose) +# else: +# self._train_with_val(labels, idx_train, idx_val, train_iters, verbose) + +# def _train_without_val(self, labels, idx_train, train_iters, verbose=True): +# optimizer = optim.Adam(self.parameters(), lr=self.lr) +# self.train() +# for i in range(train_iters): +# optimizer.zero_grad() +# output = self.forward() +# loss_train = self._loss(output[idx_train], labels[idx_train]) +# loss_train.backward() +# optimizer.step() +# if verbose and i % 10 == 0: +# print('Epoch {}, training loss: {}'.format(i, loss_train.item())) + +# self.eval() +# output = self.forward() +# self.output = output + +# def _train_with_val(self, labels, idx_train, idx_val, train_iters, verbose): +# optimizer = optim.Adam(self.parameters(), lr=self.lr) + +# best_loss_val = 100 +# best_acc_val = 0 + +# for i in range(train_iters): +# self.train() +# optimizer.zero_grad() +# output = self.forward() +# loss_train = self._loss(output[idx_train], labels[idx_train]) +# loss_train.backward() +# optimizer.step() +# if verbose and i % 10 == 0: +# print('Epoch {}, training loss: {}'.format(i, loss_train.item())) + +# self.eval() +# output = self.forward() +# loss_val = F.nll_loss(output[idx_val], labels[idx_val]) +# acc_val = utils.accuracy(output[idx_val], labels[idx_val]) + +# if best_loss_val > loss_val: +# best_loss_val = loss_val +# self.output = output + +# if acc_val > best_acc_val: +# best_acc_val = acc_val +# self.output = output + +# print('=== picking the best model according to the performance on validation ===') + + +# def test(self, idx_test): +# """Evaluate the peformance on test set +# """ +# self.eval() +# # output = self.forward() +# output = self.output +# loss_test = F.nll_loss(output[idx_test], self.labels[idx_test]) +# acc_test = utils.accuracy(output[idx_test], self.labels[idx_test]) +# print("Test set results:", +# "loss= {:.4f}".format(loss_test.item()), +# "accuracy= {:.4f}".format(acc_test.item())) +# return acc_test.item() + +# def predict(self): +# """ +# Returns +# ------- +# torch.FloatTensor +# output (log probabilities) of RGCN +# """ + +# self.eval() +# return self.forward() + +# def _loss(self, input, labels): +# loss = F.nll_loss(input, labels) +# miu1 = self.gc1.miu +# sigma1 = self.gc1.sigma +# kl_loss = 0.5 * (miu1.pow(2) + sigma1 - torch.log(1e-8 + sigma1)).mean(1) +# kl_loss = kl_loss.sum() +# norm2 = torch.norm(self.gc1.weight_miu, 2).pow(2) + \ +# torch.norm(self.gc1.weight_sigma, 2).pow(2) + +# # print(f'gcn_loss: {loss.item()}, kl_loss: {self.beta1 * kl_loss.item()}, norm2: {self.beta2 * norm2.item()}') +# return loss + self.beta1 * kl_loss + self.beta2 * norm2 + +# def _initialize(self): +# self.gc1.reset_parameters() +# self.gc2.reset_parameters() + +# def _normalize_adj(self, adj, power=-1/2): + +# """Row-normalize sparse matrix""" +# A = adj + torch.eye(len(adj)).to(self.device) +# D_power = (A.sum(1)).pow(power) +# D_power[torch.isinf(D_power)] = 0. +# D_power = torch.diag(D_power) +# return D_power @ A @ D_power + diff --git a/src/defense/RGCN/utils.py b/src/defense/RGCN/utils.py new file mode 100644 index 0000000..f7be00a --- /dev/null +++ b/src/defense/RGCN/utils.py @@ -0,0 +1,777 @@ +import numpy as np +import scipy.sparse as sp +import torch +from sklearn.model_selection import train_test_split +import torch.sparse as ts +import torch.nn.functional as F +import warnings + +def encode_onehot(labels): + """Convert label to onehot format. + + Parameters + ---------- + labels : numpy.array + node labels + + Returns + ------- + numpy.array + onehot labels + """ + eye = np.eye(labels.max() + 1) + onehot_mx = eye[labels] + return onehot_mx + +def tensor2onehot(labels): + """Convert label tensor to label onehot tensor. + + Parameters + ---------- + labels : torch.LongTensor + node labels + + Returns + ------- + torch.LongTensor + onehot labels tensor + + """ + + eye = torch.eye(labels.max() + 1) + onehot_mx = eye[labels] + return onehot_mx.to(labels.device) + +def preprocess(adj, features, labels, preprocess_adj=False, preprocess_feature=False, sparse=False, device='cpu'): + """Convert adj, features, labels from array or sparse matrix to + torch Tensor, and normalize the input data. + + Parameters + ---------- + adj : scipy.sparse.csr_matrix + the adjacency matrix. + features : scipy.sparse.csr_matrix + node features + labels : numpy.array + node labels + preprocess_adj : bool + whether to normalize the adjacency matrix + preprocess_feature : bool + whether to normalize the feature matrix + sparse : bool + whether to return sparse tensor + device : str + 'cpu' or 'cuda' + """ + + if preprocess_adj: + adj = normalize_adj(adj) + + if preprocess_feature: + features = normalize_feature(features) + + labels = torch.LongTensor(labels) + if sparse: + adj = sparse_mx_to_torch_sparse_tensor(adj) + features = sparse_mx_to_torch_sparse_tensor(features) + else: + if sp.issparse(features): + features = torch.FloatTensor(np.array(features.todense())) + else: + features = torch.FloatTensor(features) + adj = torch.FloatTensor(adj.todense()) + return adj.to(device), features.to(device), labels.to(device) + +def to_tensor(adj, features, labels=None, device='cpu'): + """Convert adj, features, labels from array or sparse matrix to + torch Tensor. + + Parameters + ---------- + adj : scipy.sparse.csr_matrix + the adjacency matrix. + features : scipy.sparse.csr_matrix + node features + labels : numpy.array + node labels + device : str + 'cpu' or 'cuda' + """ + if sp.issparse(adj): + adj = sparse_mx_to_torch_sparse_tensor(adj) + else: + adj = torch.FloatTensor(adj) + if sp.issparse(features): + features = sparse_mx_to_torch_sparse_tensor(features) + else: + features = torch.FloatTensor(np.array(features)) + + if labels is None: + return adj.to(device), features.to(device) + else: + labels = torch.LongTensor(labels) + return adj.to(device), features.to(device), labels.to(device) + +def normalize_feature(mx): + """Row-normalize sparse matrix or dense matrix + + Parameters + ---------- + mx : scipy.sparse.csr_matrix or numpy.array + matrix to be normalized + + Returns + ------- + scipy.sprase.lil_matrix + normalized matrix + """ + if type(mx) is not sp.lil.lil_matrix: + try: + mx = mx.tolil() + except AttributeError: + pass + rowsum = np.array(mx.sum(1)) + r_inv = np.power(rowsum, -1).flatten() + r_inv[np.isinf(r_inv)] = 0. + r_mat_inv = sp.diags(r_inv) + mx = r_mat_inv.dot(mx) + return mx + +def normalize_adj(mx): + """Normalize sparse adjacency matrix, + A' = (D + I)^-1/2 * ( A + I ) * (D + I)^-1/2 + Row-normalize sparse matrix + + Parameters + ---------- + mx : scipy.sparse.csr_matrix + matrix to be normalized + + Returns + ------- + scipy.sprase.lil_matrix + normalized matrix + """ + + # TODO: maybe using coo format would be better? + if type(mx) is not sp.lil.lil_matrix: + mx = mx.tolil() + if mx[0, 0] == 0 : + mx = mx + sp.eye(mx.shape[0]) + rowsum = np.array(mx.sum(1)) + r_inv = np.power(rowsum, -1/2).flatten() + r_inv[np.isinf(r_inv)] = 0. + r_mat_inv = sp.diags(r_inv) + mx = r_mat_inv.dot(mx) + mx = mx.dot(r_mat_inv) + return mx + +def normalize_sparse_tensor(adj, fill_value=1): + """Normalize sparse tensor. Need to import torch_scatter + """ + edge_index = adj._indices() + edge_weight = adj._values() + num_nodes= adj.size(0) + edge_index, edge_weight = add_self_loops( + edge_index, edge_weight, fill_value, num_nodes) + + row, col = edge_index + from torch_scatter import scatter_add + deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) + deg_inv_sqrt = deg.pow(-0.5) + deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 + + values = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] + + shape = adj.shape + return torch.sparse.FloatTensor(edge_index, values, shape) + +def add_self_loops(edge_index, edge_weight=None, fill_value=1, num_nodes=None): + # num_nodes = maybe_num_nodes(edge_index, num_nodes) + + loop_index = torch.arange(0, num_nodes, dtype=torch.long, + device=edge_index.device) + loop_index = loop_index.unsqueeze(0).repeat(2, 1) + + if edge_weight is not None: + assert edge_weight.numel() == edge_index.size(1) + loop_weight = edge_weight.new_full((num_nodes, ), fill_value) + edge_weight = torch.cat([edge_weight, loop_weight], dim=0) + + edge_index = torch.cat([edge_index, loop_index], dim=1) + + return edge_index, edge_weight + +def normalize_adj_tensor(adj, sparse=False): + """Normalize adjacency tensor matrix. + """ + device = adj.device + if sparse: + # warnings.warn('If you find the training process is too slow, you can uncomment line 207 in deeprobust/graph/utils.py. Note that you need to install torch_sparse') + # TODO if this is too slow, uncomment the following code, + # but you need to install torch_scatter + # return normalize_sparse_tensor(adj) + adj = to_scipy(adj) + mx = normalize_adj(adj) + return sparse_mx_to_torch_sparse_tensor(mx).to(device) + else: + mx = adj + torch.eye(adj.shape[0]).to(device) + rowsum = mx.sum(1) + r_inv = rowsum.pow(-1/2).flatten() + r_inv[torch.isinf(r_inv)] = 0. + r_mat_inv = torch.diag(r_inv) + mx = r_mat_inv @ mx + mx = mx @ r_mat_inv + return mx + +def degree_normalize_adj(mx): + """Row-normalize sparse matrix""" + mx = mx.tolil() + if mx[0, 0] == 0 : + mx = mx + sp.eye(mx.shape[0]) + rowsum = np.array(mx.sum(1)) + r_inv = np.power(rowsum, -1).flatten() + r_inv[np.isinf(r_inv)] = 0. + r_mat_inv = sp.diags(r_inv) + # mx = mx.dot(r_mat_inv) + mx = r_mat_inv.dot(mx) + return mx + +def degree_normalize_sparse_tensor(adj, fill_value=1): + """degree_normalize_sparse_tensor. + """ + edge_index = adj._indices() + edge_weight = adj._values() + num_nodes= adj.size(0) + + edge_index, edge_weight = add_self_loops( + edge_index, edge_weight, fill_value, num_nodes) + + row, col = edge_index + from torch_scatter import scatter_add + deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) + deg_inv_sqrt = deg.pow(-1) + deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 + + values = deg_inv_sqrt[row] * edge_weight + shape = adj.shape + return torch.sparse.FloatTensor(edge_index, values, shape) + +def degree_normalize_adj_tensor(adj, sparse=True): + """degree_normalize_adj_tensor. + """ + + device = adj.device + if sparse: + # return degree_normalize_sparse_tensor(adj) + adj = to_scipy(adj) + mx = degree_normalize_adj(adj) + return sparse_mx_to_torch_sparse_tensor(mx).to(device) + else: + mx = adj + torch.eye(adj.shape[0]).to(device) + rowsum = mx.sum(1) + r_inv = rowsum.pow(-1).flatten() + r_inv[torch.isinf(r_inv)] = 0. + r_mat_inv = torch.diag(r_inv) + mx = r_mat_inv @ mx + return mx + +def accuracy(output, labels): + """Return accuracy of output compared to labels. + + Parameters + ---------- + output : torch.Tensor + output from model + labels : torch.Tensor or numpy.array + node labels + + Returns + ------- + float + accuracy + """ + if not hasattr(labels, '__len__'): + labels = [labels] + if type(labels) is not torch.Tensor: + labels = torch.LongTensor(labels) + preds = output.max(1)[1].type_as(labels) + correct = preds.eq(labels).double() + correct = correct.sum() + return correct / len(labels) + +def loss_acc(output, labels, targets, avg_loss=True): + if type(labels) is not torch.Tensor: + labels = torch.LongTensor(labels) + preds = output.max(1)[1].type_as(labels) + correct = preds.eq(labels).double()[targets] + loss = F.nll_loss(output[targets], labels[targets], reduction='mean' if avg_loss else 'none') + + if avg_loss: + return loss, correct.sum() / len(targets) + return loss, correct + # correct = correct.sum() + # return loss, correct / len(labels) + +def get_perf(output, labels, mask, verbose=True): + """evalute performance for test masked data""" + loss = F.nll_loss(output[mask], labels[mask]) + acc = accuracy(output[mask], labels[mask]) + if verbose: + print("loss= {:.4f}".format(loss.item()), + "accuracy= {:.4f}".format(acc.item())) + return loss.item(), acc.item() + + +def classification_margin(output, true_label): + """Calculate classification margin for outputs. + `probs_true_label - probs_best_second_class` + + Parameters + ---------- + output: torch.Tensor + output vector (1 dimension) + true_label: int + true label for this node + + Returns + ------- + list + classification margin for this node + """ + + probs = torch.exp(output) + probs_true_label = probs[true_label].clone() + probs[true_label] = 0 + probs_best_second_class = probs[probs.argmax()] + return (probs_true_label - probs_best_second_class).item() + +def sparse_mx_to_torch_sparse_tensor(sparse_mx): + """Convert a scipy sparse matrix to a torch sparse tensor.""" + sparse_mx = sparse_mx.tocoo().astype(np.float32) + sparserow=torch.LongTensor(sparse_mx.row).unsqueeze(1) + sparsecol=torch.LongTensor(sparse_mx.col).unsqueeze(1) + sparseconcat=torch.cat((sparserow, sparsecol),1) + sparsedata=torch.FloatTensor(sparse_mx.data) + return torch.sparse.FloatTensor(sparseconcat.t(),sparsedata,torch.Size(sparse_mx.shape)) + + # slower version.... + # sparse_mx = sparse_mx.tocoo().astype(np.float32) + # indices = torch.from_numpy( + # np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) + # values = torch.from_numpy(sparse_mx.data) + # shape = torch.Size(sparse_mx.shape) + # return torch.sparse.FloatTensor(indices, values, shape) + + + +def to_scipy(tensor): + """Convert a dense/sparse tensor to scipy matrix""" + if is_sparse_tensor(tensor): + values = tensor._values() + indices = tensor._indices() + return sp.csr_matrix((values.cpu().numpy(), indices.cpu().numpy()), shape=tensor.shape) + else: + indices = tensor.nonzero().t() + values = tensor[indices[0], indices[1]] + return sp.csr_matrix((values.cpu().numpy(), indices.cpu().numpy()), shape=tensor.shape) + +def is_sparse_tensor(tensor): + """Check if a tensor is sparse tensor. + + Parameters + ---------- + tensor : torch.Tensor + given tensor + + Returns + ------- + bool + whether a tensor is sparse tensor + """ + # if hasattr(tensor, 'nnz'): + if tensor.layout == torch.sparse_coo: + return True + else: + return False + +def get_train_val_test(nnodes, val_size=0.1, test_size=0.8, stratify=None, seed=None): + """This setting follows nettack/mettack, where we split the nodes + into 10% training, 10% validation and 80% testing data + + Parameters + ---------- + nnodes : int + number of nodes in total + val_size : float + size of validation set + test_size : float + size of test set + stratify : + data is expected to split in a stratified fashion. So stratify should be labels. + seed : int or None + random seed + + Returns + ------- + idx_train : + node training indices + idx_val : + node validation indices + idx_test : + node test indices + """ + + assert stratify is not None, 'stratify cannot be None!' + + if seed is not None: + np.random.seed(seed) + + idx = np.arange(nnodes) + train_size = 1 - val_size - test_size + idx_train_and_val, idx_test = train_test_split(idx, + random_state=None, + train_size=train_size + val_size, + test_size=test_size, + stratify=stratify) + + if stratify is not None: + stratify = stratify[idx_train_and_val] + + idx_train, idx_val = train_test_split(idx_train_and_val, + random_state=None, + train_size=(train_size / (train_size + val_size)), + test_size=(val_size / (train_size + val_size)), + stratify=stratify) + + return idx_train, idx_val, idx_test + +def get_train_test(nnodes, test_size=0.8, stratify=None, seed=None): + """This function returns training and test set without validation. + It can be used for settings of different label rates. + + Parameters + ---------- + nnodes : int + number of nodes in total + test_size : float + size of test set + stratify : + data is expected to split in a stratified fashion. So stratify should be labels. + seed : int or None + random seed + + Returns + ------- + idx_train : + node training indices + idx_test : + node test indices + """ + assert stratify is not None, 'stratify cannot be None!' + + if seed is not None: + np.random.seed(seed) + + idx = np.arange(nnodes) + train_size = 1 - test_size + idx_train, idx_test = train_test_split(idx, random_state=None, + train_size=train_size, + test_size=test_size, + stratify=stratify) + + return idx_train, idx_test + +def get_train_val_test_gcn(labels, seed=None): + """This setting follows gcn, where we randomly sample 20 instances for each class + as training data, 500 instances as validation data, 1000 instances as test data. + Note here we are not using fixed splits. When random seed changes, the splits + will also change. + + Parameters + ---------- + labels : numpy.array + node labels + seed : int or None + random seed + + Returns + ------- + idx_train : + node training indices + idx_val : + node validation indices + idx_test : + node test indices + """ + if seed is not None: + np.random.seed(seed) + + idx = np.arange(len(labels)) + nclass = labels.max() + 1 + idx_train = [] + idx_unlabeled = [] + for i in range(nclass): + labels_i = idx[labels==i] + labels_i = np.random.permutation(labels_i) + idx_train = np.hstack((idx_train, labels_i[: 20])).astype(np.int) + idx_unlabeled = np.hstack((idx_unlabeled, labels_i[20: ])).astype(np.int) + + idx_unlabeled = np.random.permutation(idx_unlabeled) + idx_val = idx_unlabeled[: 500] + idx_test = idx_unlabeled[500: 1500] + return idx_train, idx_val, idx_test + +def get_train_test_labelrate(labels, label_rate): + """Get train test according to given label rate. + """ + nclass = labels.max() + 1 + train_size = int(round(len(labels) * label_rate / nclass)) + print("=== train_size = %s ===" % train_size) + idx_train, idx_val, idx_test = get_splits_each_class(labels, train_size=train_size) + return idx_train, idx_test + +def get_splits_each_class(labels, train_size): + """We randomly sample n instances for class, where n = train_size. + """ + idx = np.arange(len(labels)) + nclass = labels.max() + 1 + idx_train = [] + idx_val = [] + idx_test = [] + for i in range(nclass): + labels_i = idx[labels==i] + labels_i = np.random.permutation(labels_i) + idx_train = np.hstack((idx_train, labels_i[: train_size])).astype(np.int) + idx_val = np.hstack((idx_val, labels_i[train_size: 2*train_size])).astype(np.int) + idx_test = np.hstack((idx_test, labels_i[2*train_size: ])).astype(np.int) + + return np.random.permutation(idx_train), np.random.permutation(idx_val), \ + np.random.permutation(idx_test) + + +def unravel_index(index, array_shape): + rows = torch.div(index, array_shape[1], rounding_mode='trunc') + cols = index % array_shape[1] + return rows, cols + + +def get_degree_squence(adj): + try: + return adj.sum(0) + except: + return ts.sum(adj, dim=1).to_dense() + +def likelihood_ratio_filter(node_pairs, modified_adjacency, original_adjacency, d_min, threshold=0.004, undirected=True): + """ + Filter the input node pairs based on the likelihood ratio test proposed by Zügner et al. 2018, see + https://dl.acm.org/citation.cfm?id=3220078. In essence, for each node pair return 1 if adding/removing the edge + between the two nodes does not violate the unnoticeability constraint, and return 0 otherwise. Assumes unweighted + and undirected graphs. + """ + + N = int(modified_adjacency.shape[0]) + # original_degree_sequence = get_degree_squence(original_adjacency) + # current_degree_sequence = get_degree_squence(modified_adjacency) + original_degree_sequence = original_adjacency.sum(0) + current_degree_sequence = modified_adjacency.sum(0) + + concat_degree_sequence = torch.cat((current_degree_sequence, original_degree_sequence)) + + # Compute the log likelihood values of the original, modified, and combined degree sequences. + ll_orig, alpha_orig, n_orig, sum_log_degrees_original = degree_sequence_log_likelihood(original_degree_sequence, d_min) + ll_current, alpha_current, n_current, sum_log_degrees_current = degree_sequence_log_likelihood(current_degree_sequence, d_min) + + ll_comb, alpha_comb, n_comb, sum_log_degrees_combined = degree_sequence_log_likelihood(concat_degree_sequence, d_min) + + # Compute the log likelihood ratio + current_ratio = -2 * ll_comb + 2 * (ll_orig + ll_current) + + # Compute new log likelihood values that would arise if we add/remove the edges corresponding to each node pair. + new_lls, new_alphas, new_ns, new_sum_log_degrees = updated_log_likelihood_for_edge_changes(node_pairs, + modified_adjacency, d_min) + + # Combination of the original degree distribution with the distributions corresponding to each node pair. + n_combined = n_orig + new_ns + new_sum_log_degrees_combined = sum_log_degrees_original + new_sum_log_degrees + alpha_combined = compute_alpha(n_combined, new_sum_log_degrees_combined, d_min) + new_ll_combined = compute_log_likelihood(n_combined, alpha_combined, new_sum_log_degrees_combined, d_min) + new_ratios = -2 * new_ll_combined + 2 * (new_lls + ll_orig) + + # Allowed edges are only those for which the resulting likelihood ratio measure is < than the threshold + allowed_edges = new_ratios < threshold + + if allowed_edges.is_cuda: + filtered_edges = node_pairs[allowed_edges.cpu().numpy().astype(np.bool)] + else: + filtered_edges = node_pairs[allowed_edges.numpy().astype(np.bool)] + + allowed_mask = torch.zeros(modified_adjacency.shape) + allowed_mask[filtered_edges.T] = 1 + if undirected: + allowed_mask += allowed_mask.t() + return allowed_mask, current_ratio + + +def degree_sequence_log_likelihood(degree_sequence, d_min): + """ + Compute the (maximum) log likelihood of the Powerlaw distribution fit on a degree distribution. + """ + + # Determine which degrees are to be considered, i.e. >= d_min. + D_G = degree_sequence[(degree_sequence >= d_min.item())] + try: + sum_log_degrees = torch.log(D_G).sum() + except: + sum_log_degrees = np.log(D_G).sum() + n = len(D_G) + + alpha = compute_alpha(n, sum_log_degrees, d_min) + ll = compute_log_likelihood(n, alpha, sum_log_degrees, d_min) + return ll, alpha, n, sum_log_degrees + +def updated_log_likelihood_for_edge_changes(node_pairs, adjacency_matrix, d_min): + """ Adopted from https://github.com/danielzuegner/nettack + """ + # For each node pair find out whether there is an edge or not in the input adjacency matrix. + + edge_entries_before = adjacency_matrix[node_pairs.T] + degree_sequence = adjacency_matrix.sum(1) + D_G = degree_sequence[degree_sequence >= d_min.item()] + sum_log_degrees = torch.log(D_G).sum() + n = len(D_G) + deltas = -2 * edge_entries_before + 1 + d_edges_before = degree_sequence[node_pairs] + + d_edges_after = degree_sequence[node_pairs] + deltas[:, None] + + # Sum the log of the degrees after the potential changes which are >= d_min + sum_log_degrees_after, new_n = update_sum_log_degrees(sum_log_degrees, n, d_edges_before, d_edges_after, d_min) + # Updated estimates of the Powerlaw exponents + new_alpha = compute_alpha(new_n, sum_log_degrees_after, d_min) + # Updated log likelihood values for the Powerlaw distributions + new_ll = compute_log_likelihood(new_n, new_alpha, sum_log_degrees_after, d_min) + return new_ll, new_alpha, new_n, sum_log_degrees_after + + +def update_sum_log_degrees(sum_log_degrees_before, n_old, d_old, d_new, d_min): + # Find out whether the degrees before and after the change are above the threshold d_min. + old_in_range = d_old >= d_min + new_in_range = d_new >= d_min + d_old_in_range = d_old * old_in_range.float() + d_new_in_range = d_new * new_in_range.float() + + # Update the sum by subtracting the old values and then adding the updated logs of the degrees. + sum_log_degrees_after = sum_log_degrees_before - (torch.log(torch.clamp(d_old_in_range, min=1))).sum(1) \ + + (torch.log(torch.clamp(d_new_in_range, min=1))).sum(1) + + # Update the number of degrees >= d_min + + new_n = n_old - (old_in_range!=0).sum(1) + (new_in_range!=0).sum(1) + new_n = new_n.float() + return sum_log_degrees_after, new_n + +def compute_alpha(n, sum_log_degrees, d_min): + try: + alpha = 1 + n / (sum_log_degrees - n * torch.log(d_min - 0.5)) + except: + alpha = 1 + n / (sum_log_degrees - n * np.log(d_min - 0.5)) + return alpha + +def compute_log_likelihood(n, alpha, sum_log_degrees, d_min): + # Log likelihood under alpha + try: + ll = n * torch.log(alpha) + n * alpha * torch.log(d_min) + (alpha + 1) * sum_log_degrees + except: + ll = n * np.log(alpha) + n * alpha * np.log(d_min) + (alpha + 1) * sum_log_degrees + + return ll + +def ravel_multiple_indices(ixs, shape, reverse=False): + """ + "Flattens" multiple 2D input indices into indices on the flattened matrix, similar to np.ravel_multi_index. + Does the same as ravel_index but for multiple indices at once. + Parameters + ---------- + ixs: array of ints shape (n, 2) + The array of n indices that will be flattened. + + shape: list or tuple of ints of length 2 + The shape of the corresponding matrix. + + Returns + ------- + array of n ints between 0 and shape[0]*shape[1]-1 + The indices on the flattened matrix corresponding to the 2D input indices. + + """ + if reverse: + return ixs[:, 1] * shape[1] + ixs[:, 0] + + return ixs[:, 0] * shape[1] + ixs[:, 1] + +def visualize(your_var): + """visualize computation graph""" + from graphviz import Digraph + import torch + from torch.autograd import Variable + from torchviz import make_dot + make_dot(your_var).view() + +def reshape_mx(mx, shape): + indices = mx.nonzero() + return sp.csr_matrix((mx.data, (indices[0], indices[1])), shape=shape) + +def add_mask(data, dataset): + """data: ogb-arxiv pyg data format""" + # for arxiv + split_idx = dataset.get_idx_split() + train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"] + n = data.x.shape[0] + data.train_mask = index_to_mask(train_idx, n) + data.val_mask = index_to_mask(valid_idx, n) + data.test_mask = index_to_mask(test_idx, n) + data.y = data.y.squeeze() + # data.edge_index = to_undirected(data.edge_index, data.num_nodes) + +def index_to_mask(index, size): + mask = torch.zeros((size, ), dtype=torch.bool) + mask[index] = 1 + return mask + +def add_feature_noise(data, noise_ratio, seed): + np.random.seed(seed) + n, d = data.x.shape + # noise = torch.normal(mean=torch.zeros(int(noise_ratio*n), d), std=1) + noise = torch.FloatTensor(np.random.normal(0, 1, size=(int(noise_ratio*n), d))).to(data.x.device) + indices = np.arange(n) + indices = np.random.permutation(indices)[: int(noise_ratio*n)] + delta_feat = torch.zeros_like(data.x) + delta_feat[indices] = noise - data.x[indices] + data.x[indices] = noise + mask = np.zeros(n) + mask[indices] = 1 + mask = torch.tensor(mask).bool().to(data.x.device) + return delta_feat, mask + +def add_feature_noise_test(data, noise_ratio, seed): + np.random.seed(seed) + n, d = data.x.shape + indices = np.arange(n) + test_nodes = indices[data.test_mask.cpu()] + selected = np.random.permutation(test_nodes)[: int(noise_ratio*len(test_nodes))] + noise = torch.FloatTensor(np.random.normal(0, 1, size=(int(noise_ratio*len(test_nodes)), d))) + noise = noise.to(data.x.device) + + delta_feat = torch.zeros_like(data.x) + delta_feat[selected] = noise - data.x[selected] + data.x[selected] = noise + # mask = np.zeros(len(test_nodes)) + mask = np.zeros(n) + mask[selected] = 1 + mask = torch.tensor(mask).bool().to(data.x.device) + return delta_feat, mask + +# def check_path(file_path): +# if not osp.exists(file_path): +# os.system(f'mkdir -p {file_path}') \ No newline at end of file From 02dd26905866c8c9ac182873f7a6576dada9f09c Mon Sep 17 00:00:00 2001 From: abhhfcgjk Date: Sun, 13 Oct 2024 18:53:41 +0300 Subject: [PATCH 02/13] get all subclasses for PoisonDefender --- src/models_builder/gnn_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/models_builder/gnn_models.py b/src/models_builder/gnn_models.py index a93aad7..2aec451 100644 --- a/src/models_builder/gnn_models.py +++ b/src/models_builder/gnn_models.py @@ -435,7 +435,7 @@ def set_poison_defender(self, poison_defense_config=None, poison_defense_name: s self.poison_defense_name = poison_defense_name poison_defense_kwargs = getattr(self.poison_defense_config, CONFIG_OBJ).to_dict() - name_klass = {e.name: e for e in PoisonDefender.__subclasses__()} + name_klass = {e.name: e for e in all_subclasses(PoisonDefender)} klass = name_klass[self.poison_defense_name] self.poison_defender = klass( # device=self.device, From d906e6d4105e09bd2da950bb10981cf84bcc1270 Mon Sep 17 00:00:00 2001 From: abhhfcgjk Date: Sun, 13 Oct 2024 18:54:33 +0300 Subject: [PATCH 03/13] poison defense tests --- experiments/attack_defense_test.py | 97 ++++++++++++++++++++++++++---- tests/attacks_test.py | 4 ++ tests/defense_test.py | 84 ++++++++++++++++++++++++++ tests/explainers_test.py | 4 ++ 4 files changed, 178 insertions(+), 11 deletions(-) create mode 100644 tests/defense_test.py diff --git a/experiments/attack_defense_test.py b/experiments/attack_defense_test.py index ce772bc..27864b1 100644 --- a/experiments/attack_defense_test.py +++ b/experiments/attack_defense_test.py @@ -2,6 +2,10 @@ import warnings +import sys +import os +sys.path.append(f"{os.getcwd()}/src") + from torch import device from src.aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH, EVASION_ATTACK_PARAMETERS_PATH, \ @@ -12,6 +16,7 @@ from src.models_builder.models_zoo import model_configs_zoo + def test_attack_defense(): # my_device = device('cuda' if is_available() else 'cpu') my_device = device('cpu') @@ -111,22 +116,16 @@ def test_attack_defense(): } ) - # poison_defense_config = ConfigPattern( - # _class_name="BadRandomPoisonDefender", - # _import_path=POISON_DEFENSE_PARAMETERS_PATH, - # _config_class="PoisonDefenseConfig", - # _config_kwargs={ - # "n_edges_percent": 0.1, - # } - # ) poison_defense_config = ConfigPattern( - _class_name="EmptyPoisonDefender", + _class_name="GNNGuard", _import_path=POISON_DEFENSE_PARAMETERS_PATH, _config_class="PoisonDefenseConfig", _config_kwargs={ + "n_edges_percent": 0.1, } ) + evasion_attack_config = ConfigPattern( _class_name="FGSM", _import_path=EVASION_ATTACK_PARAMETERS_PATH, @@ -145,7 +144,7 @@ def test_attack_defense(): ) gnn_model_manager.set_poison_attacker(poison_attack_config=poison_attack_config) - # gnn_model_manager.set_poison_defender(poison_defense_config=poison_defense_config) + gnn_model_manager.set_poison_defender(poison_defense_config=poison_defense_config) # gnn_model_manager.set_evasion_attacker(evasion_attack_config=evasion_attack_config) # gnn_model_manager.set_evasion_defender(evasion_defense_config=evasion_defense_config) @@ -344,7 +343,83 @@ def test_nettack_evasion(): print(f"info_after_evasion_attack: {info_after_evasion_attack}") +def test_gnnguard(): + # from attacks.poison_attacks_collection.metattack import meta_gradient_attack + from defense.GNNGuard import gnnguard + + my_device = device('cpu') + full_name = ("single-graph", "Planetoid", 'Cora') + + dataset, data, results_dataset_path = DatasetManager.get_by_full_name( + full_name=full_name, + dataset_ver_ind=0 + ) + gnn = model_configs_zoo(dataset=dataset, model_name='gcn_gcn') + manager_config = ConfigPattern( + _config_class="ModelManagerConfig", + _config_kwargs={ + "mask_features": [], + "optimizer": { + # "_config_class": "Config", + "_class_name": "Adam", + # "_import_path": OPTIMIZERS_PARAMETERS_PATH, + # "_class_import_info": ["torch.optim"], + "_config_kwargs": {}, + } + } + ) + steps_epochs = 200 + gnn_model_manager = FrameworkGNNModelManager( + gnn=gnn, + dataset_path=results_dataset_path, + manager_config=manager_config, + modification=ModelModificationConfig(model_ver_ind=0, epochs=steps_epochs) + ) + save_model_flag = False + gnn_model_manager.gnn.to(my_device) + data = data.to(my_device) + + poison_defense_config = ConfigPattern( + _class_name="GNNGuard", + _import_path=POISON_DEFENSE_PARAMETERS_PATH, + _config_class="PoisonDefenseConfig", + _config_kwargs={ + # "num_nodes": dataset.dataset.x.shape[0] + } + ) + from defense.poison_defense import PoisonDefender + from src.aux.utils import all_subclasses + print([e.name for e in all_subclasses(PoisonDefender)]) + gnn_model_manager.set_poison_defender(poison_defense_config=poison_defense_config) + + warnings.warn("Start training") + dataset.train_test_split(percent_train_class=0.1) + + try: + raise FileNotFoundError() + # gnn_model_manager.load_model_executor() + except FileNotFoundError: + gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0 + train_test_split_path = gnn_model_manager.train_model(gen_dataset=dataset, steps=steps_epochs, + save_model_flag=save_model_flag, + metrics=[Metric("F1", mask='train', average=None)]) + + if train_test_split_path is not None: + dataset.save_train_test_mask(train_test_split_path) + train_mask, val_mask, test_mask, train_test_sizes = torch.load(train_test_split_path / 'train_test_split')[ + :] + dataset.train_mask, dataset.val_mask, dataset.test_mask = train_mask, val_mask, test_mask + data.percent_train_class, data.percent_test_class = train_test_sizes + + warnings.warn("Training was successful") + + metric_loc = gnn_model_manager.evaluate_model( + gen_dataset=dataset, metrics=[Metric("F1", mask='test', average='macro'), + Metric("Accuracy", mask='test')]) + print(metric_loc) + if __name__ == '__main__': #test_attack_defense() torch.manual_seed(5000) - test_meta() + test_gnnguard() + # test_attack_defense() diff --git a/tests/attacks_test.py b/tests/attacks_test.py index acee652..252f63a 100644 --- a/tests/attacks_test.py +++ b/tests/attacks_test.py @@ -1,3 +1,7 @@ +import sys +import os +sys.path.append(f'/home/igor/Documents/graphs/GNN-AID/src') + import unittest import torch diff --git a/tests/defense_test.py b/tests/defense_test.py new file mode 100644 index 0000000..796befc --- /dev/null +++ b/tests/defense_test.py @@ -0,0 +1,84 @@ +import unittest + +import torch + +from base.datasets_processing import DatasetManager +from models_builder.gnn_models import FrameworkGNNModelManager, Metric +from aux.configs import ModelManagerConfig, ModelModificationConfig, DatasetConfig, DatasetVarConfig, ConfigPattern +from models_builder.models_zoo import model_configs_zoo + +from aux.utils import POISON_DEFENSE_PARAMETERS_PATH, \ + EVASION_DEFENSE_PARAMETERS_PATH, OPTIMIZERS_PARAMETERS_PATH + +class DefenseTest(unittest.TestCase): + def setUp(self): + print('setup') + + from defense.poison_defense.GNNGuard_defense import GNNGuard_defense + # Init datasets + # Single-Graph - Example + self.dataset_sg_example, _, results_dataset_path_sg_example = DatasetManager.get_by_full_name( + full_name=("single-graph", "custom", "example",), + features={'attr': {'a': 'as_is'}}, + labeling='binary', + dataset_ver_ind=0 + ) + + self.gen_dataset_sg_example = DatasetManager.get_by_config( + DatasetConfig( + domain="single-graph", + group="custom", + graph="example"), + DatasetVarConfig(features={'attr': {'a': 'as_is'}}, + labeling='binary', + dataset_ver_ind=0) + ) + self.gen_dataset_sg_example.train_test_split(percent_train_class=0.6, percent_test_class=0.4) + self.results_dataset_path_sg_example = self.gen_dataset_sg_example.results_dir + + self.default_config = ModelModificationConfig( + model_ver_ind=0, + ) + + self.manager_config = ConfigPattern( + _config_class="ModelManagerConfig", + _config_kwargs={ + "mask_features": [], + "optimizer": { + "_config_class": "Config", + "_class_name": "Adam", + "_import_path": OPTIMIZERS_PARAMETERS_PATH, + "_class_import_info": ["torch.optim"], + "_config_kwargs": {"weight_decay": 5e-4}, + } + } + ) + + def test_gnnguard(self): + poison_defense_config = ConfigPattern( + _class_name="GNNGuard", + _import_path=POISON_DEFENSE_PARAMETERS_PATH, + _config_class="PoisonDefenseConfig", + _config_kwargs={ + # "num_nodes": self.gen_dataset_sg_example.dataset.x.shape[0] # is there more fancy way? + } + ) + + gat_gat_sg_example = model_configs_zoo(dataset=self.gen_dataset_sg_example, model_name='gat_gat') + + gnn_model_manager_sg_example = FrameworkGNNModelManager( + gnn=gat_gat_sg_example, + dataset_path=self.results_dataset_path_sg_example, + modification=self.default_config, + manager_config=self.manager_config, + ) + + gnn_model_manager_sg_example.set_poison_defender(poison_defense_config=poison_defense_config) + + gnn_model_manager_sg_example.train_model(gen_dataset=self.gen_dataset_sg_example, steps=100, metrics=[Metric("Accuracy", mask='test')]) + metric_loc = gnn_model_manager_sg_example.evaluate_model(gen_dataset=self.gen_dataset_sg_example, metrics=[Metric("F1", mask='test', average='macro')]) + print(metric_loc) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/explainers_test.py b/tests/explainers_test.py index ae3fe4b..45640e0 100644 --- a/tests/explainers_test.py +++ b/tests/explainers_test.py @@ -8,6 +8,10 @@ import signal from time import time +import sys +import os +sys.path.append(f'/home/igor/Documents/graphs/GNN-AID/src') + from aux import utils from aux.utils import EXPLAINERS_INIT_PARAMETERS_PATH, EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH, \ EXPLAINERS_GLOBAL_RUN_PARAMETERS_PATH From 6f423e224f05e7edd780533b92f1dc2543d19414 Mon Sep 17 00:00:00 2001 From: abhhfcgjk Date: Sun, 13 Oct 2024 18:55:17 +0300 Subject: [PATCH 04/13] defense configs --- metainfo/poison_defense_parameters.json | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/metainfo/poison_defense_parameters.json b/metainfo/poison_defense_parameters.json index e9fcb92..7a12f90 100644 --- a/metainfo/poison_defense_parameters.json +++ b/metainfo/poison_defense_parameters.json @@ -3,7 +3,11 @@ }, "BadRandomPoisonDefender": { "n_edges_percent": ["n_edges_percent", "float", 0.1, {"min": 0.0001, "step": 0.01}, "?"] + }, + "GNNGuard": { + "lr": ["lr", "float", 0.01, {"min": 0.0001, "step": 0.005}, "?"], + "attention": ["attention", "bool", true, {}, "?"], + "drop": ["drop", "bool", true, {}, "?"] } - } From 6acb051cedba8848472d0e543de8225677bbe1f6 Mon Sep 17 00:00:00 2001 From: abhhfcgjk Date: Mon, 14 Oct 2024 18:17:45 +0300 Subject: [PATCH 05/13] adversarial training added --- experiments/attack_defense_test.py | 80 +++++++++++++++++++++++- metainfo/evasion_defense_parameters.json | 3 + src/defense/evasion_defense.py | 75 ++++++++++++++++++++++ 3 files changed, 155 insertions(+), 3 deletions(-) diff --git a/experiments/attack_defense_test.py b/experiments/attack_defense_test.py index 27864b1..5a3c062 100644 --- a/experiments/attack_defense_test.py +++ b/experiments/attack_defense_test.py @@ -378,7 +378,7 @@ def test_gnnguard(): save_model_flag = False gnn_model_manager.gnn.to(my_device) data = data.to(my_device) - + print(type(data)) poison_defense_config = ConfigPattern( _class_name="GNNGuard", _import_path=POISON_DEFENSE_PARAMETERS_PATH, @@ -418,8 +418,82 @@ def test_gnnguard(): Metric("Accuracy", mask='test')]) print(metric_loc) +def test_adv_training(): + from defense.evasion_defense import AdvTraining + + my_device = device('cpu') + full_name = ("single-graph", "Planetoid", 'Cora') + + dataset, data, results_dataset_path = DatasetManager.get_by_full_name( + full_name=full_name, + dataset_ver_ind=0 + ) + gnn = model_configs_zoo(dataset=dataset, model_name='gcn_gcn') + manager_config = ConfigPattern( + _config_class="ModelManagerConfig", + _config_kwargs={ + "mask_features": [], + "optimizer": { + # "_config_class": "Config", + "_class_name": "Adam", + # "_import_path": OPTIMIZERS_PARAMETERS_PATH, + # "_class_import_info": ["torch.optim"], + "_config_kwargs": {}, + } + } + ) + steps_epochs = 200 + gnn_model_manager = FrameworkGNNModelManager( + gnn=gnn, + dataset_path=results_dataset_path, + manager_config=manager_config, + modification=ModelModificationConfig(model_ver_ind=0, epochs=steps_epochs) + ) + save_model_flag = False + gnn_model_manager.gnn.to(my_device) + data = data.to(my_device) + + evasion_defense_config = ConfigPattern( + _class_name="AdvTraining", + _import_path=EVASION_DEFENSE_PARAMETERS_PATH, + _config_class="EvasionDefenseConfig", + _config_kwargs={ + # "num_nodes": dataset.dataset.x.shape[0] + } + ) + from defense.evasion_defense import EvasionDefender + from src.aux.utils import all_subclasses + print([e.name for e in all_subclasses(EvasionDefender)]) + gnn_model_manager.set_evasion_defender(evasion_defense_config=evasion_defense_config) + + warnings.warn("Start training") + dataset.train_test_split(percent_train_class=0.1) + + try: + raise FileNotFoundError() + # gnn_model_manager.load_model_executor() + except FileNotFoundError: + gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0 + train_test_split_path = gnn_model_manager.train_model(gen_dataset=dataset, steps=steps_epochs, + save_model_flag=save_model_flag, + metrics=[Metric("F1", mask='train', average=None)]) + + if train_test_split_path is not None: + dataset.save_train_test_mask(train_test_split_path) + train_mask, val_mask, test_mask, train_test_sizes = torch.load(train_test_split_path / 'train_test_split')[ + :] + dataset.train_mask, dataset.val_mask, dataset.test_mask = train_mask, val_mask, test_mask + data.percent_train_class, data.percent_test_class = train_test_sizes + + warnings.warn("Training was successful") + + metric_loc = gnn_model_manager.evaluate_model( + gen_dataset=dataset, metrics=[Metric("F1", mask='test', average='macro'), + Metric("Accuracy", mask='test')]) + print(metric_loc) + if __name__ == '__main__': #test_attack_defense() torch.manual_seed(5000) - test_gnnguard() - # test_attack_defense() + test_adv_training() + # test_gnnguard() diff --git a/metainfo/evasion_defense_parameters.json b/metainfo/evasion_defense_parameters.json index 3551390..62b30d2 100644 --- a/metainfo/evasion_defense_parameters.json +++ b/metainfo/evasion_defense_parameters.json @@ -6,6 +6,9 @@ }, "QuantizationDefender": { "qbit": ["qbit", "int", 8, {"min": 1, "step": 1}, "?"] + }, + "AdvTraining": { + "epsilon": ["epsilon", "float", 0.1, {"min": 0.001, "step": 0.005}] } } diff --git a/src/defense/evasion_defense.py b/src/defense/evasion_defense.py index 3f30d15..b06f371 100644 --- a/src/defense/evasion_defense.py +++ b/src/defense/evasion_defense.py @@ -1,7 +1,11 @@ import torch from defense.defense_base import Defender +from src.aux.utils import import_by_name +from attacks.evasion_attacks import FGSMAttacker +from torch_geometric import data +import copy class EvasionDefender(Defender): def __init__(self, **kwargs): @@ -52,3 +56,74 @@ def __init__(self, qbit=8): def pre_batch(self, **kwargs): pass + + +class AdvTraining(EvasionDefender): + name = "AdvTraining" + + def __init__(self, epsilon=0.1, attack_type="FGSM", device='cpu'): + super().__init__() + assert device is not None, "Please specify 'device'!" + # if attack_type=="NettackEvasion": + # self.attacker = evasion_attacks.NettackEvasionAttacker() + # elif attack_type=="FGSM": + # self.attacker = evasion_attacks.FGSMAttacker(epsilon=epsilon) + # self.attacker = import_by_name(attack_type, ['attacks.evasion_attacks'])() + self.attacker = FGSMAttacker(epsilon=epsilon) + + def pre_batch(self, model_manager, batch): + super().pre_batch(model_manager=model_manager, batch=batch) + # print(batch) + gen_data = data.Data() + gen_data.data = copy.deepcopy(batch) + # print(gen_data) + # print(batch.batch) + attacked_batch = self.attacker.attack(model_manager, gen_data, batch.train_mask).data + new_batch = self.merge_batches(batch, attacked_batch) + # print(attacked_batch.x.mean() - batch.x.mean()) + + # print(batch) + batch = new_batch + # print(batch) + + def post_batch(self, model_manager, batch, loss) -> dict: + super().post_batch(model_manager=model_manager, batch=batch, loss=loss) + + @staticmethod + def merge_batches(batch1, batch2): + merged_x = torch.cat([batch1.x, batch2.x], dim=0) + + adj_edge_index = batch2.edge_index + batch1.x.size(0) + merged_edge_index = torch.cat([batch1.edge_index, adj_edge_index], dim=1) + + merged_y = torch.cat([batch1.y, batch2.y], dim=0) + + merged_train_mask = torch.cat([batch1.train_mask, batch2.train_mask], dim=0) + merged_val_mask = torch.cat([batch1.val_mask, batch2.val_mask], dim=0) + merged_test_mask = torch.cat([batch1.test_mask, batch2.test_mask], dim=0) + + merged_n_id = torch.cat([batch1.n_id, batch2.n_id], dim=0) + merged_e_id = torch.cat([batch1.e_id, batch2.e_id], dim=0) + + merged_num_sampled_nodes = batch1.num_sampled_nodes + batch2.num_sampled_nodes + merged_num_sampled_edges = batch1.num_sampled_edges + batch2.num_sampled_edges + merged_input_id = torch.cat([batch1.input_id, batch2.input_id], dim=0) + merged_batch_size = batch1.batch_size + batch2.batch_size + + merged_batch = None + if batch1.batch and batch2.batch: + merged_batch = torch.cat([batch1.batch, batch2.batch + batch1.batch.max() + 1], dim=0) + + return data.Data(x=merged_x, + edge_index=merged_edge_index, + y=merged_y, + train_mask=merged_train_mask, + val_mask=merged_val_mask, + test_mask=merged_test_mask, + n_id=merged_n_id, + e_id=merged_e_id, + num_sampled_nodes=merged_num_sampled_nodes, + num_sampled_edges=merged_num_sampled_edges, + input_id=merged_input_id, + batch_size=merged_batch_size, + batch=merged_batch) \ No newline at end of file From 11d8bb7aa0502a2d4e08973e581893d9a37435e7 Mon Sep 17 00:00:00 2001 From: abhhfcgjk Date: Mon, 14 Oct 2024 19:17:42 +0300 Subject: [PATCH 06/13] adv training fix --- src/defense/evasion_defense.py | 33 +++++++++++++++------------------ 1 file changed, 15 insertions(+), 18 deletions(-) diff --git a/src/defense/evasion_defense.py b/src/defense/evasion_defense.py index b06f371..765dd78 100644 --- a/src/defense/evasion_defense.py +++ b/src/defense/evasion_defense.py @@ -64,30 +64,27 @@ class AdvTraining(EvasionDefender): def __init__(self, epsilon=0.1, attack_type="FGSM", device='cpu'): super().__init__() assert device is not None, "Please specify 'device'!" - # if attack_type=="NettackEvasion": - # self.attacker = evasion_attacks.NettackEvasionAttacker() - # elif attack_type=="FGSM": - # self.attacker = evasion_attacks.FGSMAttacker(epsilon=epsilon) - # self.attacker = import_by_name(attack_type, ['attacks.evasion_attacks'])() - self.attacker = FGSMAttacker(epsilon=epsilon) + self.epsilon = epsilon def pre_batch(self, model_manager, batch): super().pre_batch(model_manager=model_manager, batch=batch) - # print(batch) - gen_data = data.Data() - gen_data.data = copy.deepcopy(batch) - # print(gen_data) - # print(batch.batch) - attacked_batch = self.attacker.attack(model_manager, gen_data, batch.train_mask).data - new_batch = self.merge_batches(batch, attacked_batch) - # print(attacked_batch.x.mean() - batch.x.mean()) - - # print(batch) - batch = new_batch - # print(batch) + batch.x.requires_grad = True + outputs = model_manager.gnn(batch.x, batch.edge_index) + loss_loc = model_manager.loss_function(outputs, batch.y) + gradients = torch.autograd.grad(outputs=loss_loc, inputs=batch.x, + grad_outputs=torch.ones_like(loss_loc), + create_graph=True, retain_graph=True, only_inputs=True)[0] + perturbed_data_x = batch.x + self.epsilon*gradients.sign() + perturbed_data_x = torch.clamp(perturbed_data_x, 0, 1) + self.attacked_batch = copy.deepcopy(batch) + self.attacked_batch.x = perturbed_data_x + def post_batch(self, model_manager, batch, loss) -> dict: super().post_batch(model_manager=model_manager, batch=batch, loss=loss) + outputs = model_manager.gnn(self.attacked_batch.x, self.attacked_batch.edge_index) + loss_loc = model_manager.loss_function(outputs, batch.y) + return {"loss": loss + loss_loc} @staticmethod def merge_batches(batch1, batch2): From a5592060d80cba815f868b90e46742c5aedb33f8 Mon Sep 17 00:00:00 2001 From: abhhfcgjk Date: Tue, 15 Oct 2024 18:04:25 +0300 Subject: [PATCH 07/13] gnnguard update --- src/defense/GNNGuard/gnnguard.py | 100 ------------------------------- 1 file changed, 100 deletions(-) diff --git a/src/defense/GNNGuard/gnnguard.py b/src/defense/GNNGuard/gnnguard.py index 337853f..6a284bf 100644 --- a/src/defense/GNNGuard/gnnguard.py +++ b/src/defense/GNNGuard/gnnguard.py @@ -32,80 +32,6 @@ from src.aux.configs import ModelConfig -# class BaseGCN(BaseModel): - -# def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, -# with_bn=False, weight_decay=5e-4, with_bias=True, device=None): - -# super(BaseGCN, self).__init__() - -# assert device is not None, "Please specify 'device'!" -# self.device = device - -# self.layers = nn.ModuleList([]) -# if with_bn: -# self.bns = nn.ModuleList() - -# if nlayers == 1: -# self.layers.append(GCNConv(nfeat, nclass, bias=with_bias)) -# else: -# self.layers.append(GCNConv(nfeat, nhid, bias=with_bias)) -# if with_bn: -# self.bns.append(nn.BatchNorm1d(nhid)) -# for i in range(nlayers-2): -# self.layers.append(GCNConv(nhid, nhid, bias=with_bias)) -# if with_bn: -# self.bns.append(nn.BatchNorm1d(nhid)) -# self.layers.append(GCNConv(nhid, nclass, bias=with_bias)) - -# self.dropout = dropout -# self.weight_decay = weight_decay -# self.lr = lr -# self.output = None -# self.best_model = None -# self.best_output = None -# self.with_bn = with_bn -# self.name = 'GCN' - -# def forward(self, x, edge_index, edge_weight=None): -# x, edge_index, edge_weight = self._ensure_contiguousness(x, edge_index, edge_weight) -# for ii, layer in enumerate(self.layers): -# if edge_weight is not None: -# adj = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=2 * x.shape[:1]).t() -# x = layer(x, adj) -# else: -# x = layer(x, edge_index) -# if ii != len(self.layers) - 1: -# if self.with_bn: -# x = self.bns[ii](x) -# x = F.relu(x) -# x = F.dropout(x, p=self.dropout, training=self.training) -# return F.log_softmax(x, dim=1) - -# def get_embed(self, x, edge_index, edge_weight=None): -# x, edge_index, edge_weight = self._ensure_contiguousness(x, edge_index, edge_weight) -# for ii, layer in enumerate(self.layers): -# if ii == len(self.layers) - 1: -# return x -# if edge_weight is not None: -# adj = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=2 * x.shape[:1]).t() -# x = layer(x, adj) -# else: -# x = layer(x, edge_index) -# if ii != len(self.layers) - 1: -# if self.with_bn: -# x = self.bns[ii](x) -# x = F.relu(x) -# return x - -# def initialize(self): -# for m in self.layers: -# m.reset_parameters() -# if self.with_bn: -# for bn in self.bns: -# bn.reset_parameters() - - class BaseGNNGuard(PoisonDefender): name = "BaseGNNGuard" @@ -354,29 +280,3 @@ def initialize(self): self.model.drop_learn_1.reset_parameters() self.model.drop_learn_2.reset_parameters() # self.model.gate.reset_parameters() - -# if __name__ == "__main__": -# from deeprobust.graph.data import Dataset, Dpr2Pyg -# # from deeprobust.graph.defense import GCN -# data = Dataset(root='/tmp/', name='citeseer', setting='prognn') -# adj, features, labels = data.adj, data.features, data.labels -# idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test -# model = GCN(nfeat=features.shape[1], -# nhid=16, -# nclass=labels.max().item() + 1, -# dropout=0.5, device='cuda') -# model = model.to('cuda') -# pyg_data = Dpr2Pyg(data)[0] - -# # model.fit(features, adj, labels, idx_train, train_iters=200, verbose=True) -# # model.test(idx_test) - -# from utils import get_dataset -# pyg_data = get_dataset('citeseer', True, if_dpr=False)[0] - -# import ipdb -# ipdb.set_trace() - -# model.fit(pyg_data, verbose=True) # train with earlystopping -# model.test() -# print(model.predict()) \ No newline at end of file From fd39eb6137584a797ee7474fd4aa6cf7ababffb8 Mon Sep 17 00:00:00 2001 From: abhhfcgjk Date: Wed, 16 Oct 2024 20:29:44 +0300 Subject: [PATCH 08/13] adv training update --- experiments/attack_defense_test.py | 66 ++++++++++++++--- src/defense/evasion_defense.py | 114 ++++++++++++++++------------- 2 files changed, 117 insertions(+), 63 deletions(-) diff --git a/experiments/attack_defense_test.py b/experiments/attack_defense_test.py index 5a3c062..e2fcc19 100644 --- a/experiments/attack_defense_test.py +++ b/experiments/attack_defense_test.py @@ -18,6 +18,8 @@ def test_attack_defense(): + from attacks.QAttack import qattack + # from attacks.poison_attacks_collection.metattack import meta_gradient_attack # my_device = device('cuda' if is_available() else 'cpu') my_device = device('cpu') @@ -107,15 +109,33 @@ def test_attack_defense(): gnn_model_manager.gnn.to(my_device) data = data.to(my_device) + # poison_attack_config = ConfigPattern( + # _class_name="RandomPoisonAttack", + # _import_path=POISON_ATTACK_PARAMETERS_PATH, + # _config_class="PoisonAttackConfig", + # _config_kwargs={ + # "n_edges_percent": 0.1, + # } + # ) + poison_attack_config = ConfigPattern( - _class_name="RandomPoisonAttack", + _class_name="MetaAttackFull", _import_path=POISON_ATTACK_PARAMETERS_PATH, _config_class="PoisonAttackConfig", _config_kwargs={ - "n_edges_percent": 0.1, + "num_nodes": dataset.dataset.x.shape[0] } ) + # poison_attack_config = ConfigPattern( + # _class_name="RandomPoisonAttack", + # _import_path=POISON_ATTACK_PARAMETERS_PATH, + # _config_class="PoisonAttackConfig", + # _config_kwargs={ + # "n_edges_percent": 0.1, + # } + # ) + poison_defense_config = ConfigPattern( _class_name="GNNGuard", _import_path=POISON_DEFENSE_PARAMETERS_PATH, @@ -127,25 +147,47 @@ def test_attack_defense(): evasion_attack_config = ConfigPattern( - _class_name="FGSM", + _class_name="QAttack", _import_path=EVASION_ATTACK_PARAMETERS_PATH, _config_class="EvasionAttackConfig", _config_kwargs={ - "epsilon": 0.01 * 1, + "population_size": 50, + "individual_size": 30, + "generations": 50, + "prob_cross": 0.5, + "prob_mutate": 0.02 } ) + # evasion_attack_config = ConfigPattern( + # _class_name="FGSM", + # _import_path=EVASION_ATTACK_PARAMETERS_PATH, + # _config_class="EvasionAttackConfig", + # _config_kwargs={ + # "epsilon": 0.01 * 1, + # } + # ) + + # evasion_defense_config = ConfigPattern( + # _class_name="GradientRegularizationDefender", + # _import_path=EVASION_DEFENSE_PARAMETERS_PATH, + # _config_class="EvasionDefenseConfig", + # _config_kwargs={ + # "regularization_strength": 0.1 * 10 + # } + # ) evasion_defense_config = ConfigPattern( - _class_name="GradientRegularizationDefender", + _class_name="AdvTraining", _import_path=EVASION_DEFENSE_PARAMETERS_PATH, _config_class="EvasionDefenseConfig", _config_kwargs={ - "regularization_strength": 0.1 * 10 + "attack_name": None, + "attack_config": evasion_attack_config # evasion_attack_config } ) - gnn_model_manager.set_poison_attacker(poison_attack_config=poison_attack_config) - gnn_model_manager.set_poison_defender(poison_defense_config=poison_defense_config) - # gnn_model_manager.set_evasion_attacker(evasion_attack_config=evasion_attack_config) + # gnn_model_manager.set_poison_attacker(poison_attack_config=poison_attack_config) + # gnn_model_manager.set_poison_defender(poison_defense_config=poison_defense_config) + gnn_model_manager.set_evasion_attacker(evasion_attack_config=evasion_attack_config) # gnn_model_manager.set_evasion_defender(evasion_defense_config=evasion_defense_config) warnings.warn("Start training") @@ -493,7 +535,9 @@ def test_adv_training(): print(metric_loc) if __name__ == '__main__': - #test_attack_defense() + import random + random.seed(10) + test_attack_defense() torch.manual_seed(5000) - test_adv_training() + # test_adv_training() # test_gnnguard() diff --git a/src/defense/evasion_defense.py b/src/defense/evasion_defense.py index 765dd78..84d9656 100644 --- a/src/defense/evasion_defense.py +++ b/src/defense/evasion_defense.py @@ -2,7 +2,11 @@ from defense.defense_base import Defender from src.aux.utils import import_by_name +from src.aux.configs import ModelModificationConfig, ConfigPattern +from src.aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH, EVASION_ATTACK_PARAMETERS_PATH, \ + EVASION_DEFENSE_PARAMETERS_PATH from attacks.evasion_attacks import FGSMAttacker +from attacks.QAttack import qattack from torch_geometric import data import copy @@ -58,69 +62,75 @@ def pre_batch(self, **kwargs): pass +class DataWrap: + def __init__(self, batch) -> None: + self.data = batch + self.dataset = self + class AdvTraining(EvasionDefender): name = "AdvTraining" - def __init__(self, epsilon=0.1, attack_type="FGSM", device='cpu'): + def __init__(self, attack_name=None, attack_config=None, attack_type=None, device='cpu'): super().__init__() assert device is not None, "Please specify 'device'!" - self.epsilon = epsilon + if not attack_config: + # build default config + assert attack_name is not None + if attack_type == "POISON": + self.attack_type = "POISON" + PARAM_PATH = POISON_ATTACK_PARAMETERS_PATH + else: + self.attack_type = "EVASION" + PARAM_PATH = EVASION_ATTACK_PARAMETERS_PATH + attack_config = ConfigPattern( + _class_name=attack_name, + _import_path=PARAM_PATH, + _config_class="EvasionAttackConfig", + _config_kwargs={} + ) + self.attack_config = attack_config + if self.attack_config._class_name == "FGSM": + self.attack_type = "EVASION" + # get attack params + self.epsilon = self.attack_config._config_kwargs.epsilon + # set attacker + self.attacker = FGSMAttacker(self.epsilon) + elif self.attack_config._class_name == "QAttack": + self.attack_type = "EVASION" + # get attack params + self.population_size = self.attack_config._config_kwargs["population_size"] + self.individual_size = self.attack_config._config_kwargs["individual_size"] + self.generations = self.attack_config._config_kwargs["generations"] + self.prob_cross = self.attack_config._config_kwargs["prob_cross"] + self.prob_mutate = self.attack_config._config_kwargs["prob_mutate"] + # set attacker + self.attacker = qattack.QAttacker(self.population_size, self.individual_size, + self.generations, self.prob_cross, + self.prob_mutate) + elif self.attack_config._class_name == "MetaAttackFull": + # from attacks.poison_attacks_collection.metattack import meta_gradient_attack + # self.attack_type = "POISON" + # self.num_nodes = self.attack_config._config_kwargs["num_nodes"] + # self.attacker = meta_gradient_attack.MetaAttackFull(num_nodes=self.num_nodes) + pass + else: + raise KeyError(f"There is no {self.attack_config._class_name} class") def pre_batch(self, model_manager, batch): super().pre_batch(model_manager=model_manager, batch=batch) - batch.x.requires_grad = True - outputs = model_manager.gnn(batch.x, batch.edge_index) - loss_loc = model_manager.loss_function(outputs, batch.y) - gradients = torch.autograd.grad(outputs=loss_loc, inputs=batch.x, - grad_outputs=torch.ones_like(loss_loc), - create_graph=True, retain_graph=True, only_inputs=True)[0] - perturbed_data_x = batch.x + self.epsilon*gradients.sign() - perturbed_data_x = torch.clamp(perturbed_data_x, 0, 1) - self.attacked_batch = copy.deepcopy(batch) - self.attacked_batch.x = perturbed_data_x + self.perturbed_gen_dataset = data.Data() + self.perturbed_gen_dataset.data = copy.deepcopy(batch) + self.perturbed_gen_dataset.dataset = self.perturbed_gen_dataset.data + self.perturbed_gen_dataset.dataset.data = self.perturbed_gen_dataset.data + if self.attack_type == "EVASION": + self.perturbed_gen_dataset = self.attacker.attack(model_manager=model_manager, + gen_dataset=self.perturbed_gen_dataset, + mask_tensor=self.perturbed_gen_dataset.data.train_mask) def post_batch(self, model_manager, batch, loss) -> dict: super().post_batch(model_manager=model_manager, batch=batch, loss=loss) - outputs = model_manager.gnn(self.attacked_batch.x, self.attacked_batch.edge_index) + # Output on perturbed data + outputs = model_manager.gnn(self.perturbed_gen_dataset.data.x, self.perturbed_gen_dataset.data.edge_index) loss_loc = model_manager.loss_function(outputs, batch.y) return {"loss": loss + loss_loc} - - @staticmethod - def merge_batches(batch1, batch2): - merged_x = torch.cat([batch1.x, batch2.x], dim=0) - - adj_edge_index = batch2.edge_index + batch1.x.size(0) - merged_edge_index = torch.cat([batch1.edge_index, adj_edge_index], dim=1) - - merged_y = torch.cat([batch1.y, batch2.y], dim=0) - - merged_train_mask = torch.cat([batch1.train_mask, batch2.train_mask], dim=0) - merged_val_mask = torch.cat([batch1.val_mask, batch2.val_mask], dim=0) - merged_test_mask = torch.cat([batch1.test_mask, batch2.test_mask], dim=0) - - merged_n_id = torch.cat([batch1.n_id, batch2.n_id], dim=0) - merged_e_id = torch.cat([batch1.e_id, batch2.e_id], dim=0) - - merged_num_sampled_nodes = batch1.num_sampled_nodes + batch2.num_sampled_nodes - merged_num_sampled_edges = batch1.num_sampled_edges + batch2.num_sampled_edges - merged_input_id = torch.cat([batch1.input_id, batch2.input_id], dim=0) - merged_batch_size = batch1.batch_size + batch2.batch_size - - merged_batch = None - if batch1.batch and batch2.batch: - merged_batch = torch.cat([batch1.batch, batch2.batch + batch1.batch.max() + 1], dim=0) - - return data.Data(x=merged_x, - edge_index=merged_edge_index, - y=merged_y, - train_mask=merged_train_mask, - val_mask=merged_val_mask, - test_mask=merged_test_mask, - n_id=merged_n_id, - e_id=merged_e_id, - num_sampled_nodes=merged_num_sampled_nodes, - num_sampled_edges=merged_num_sampled_edges, - input_id=merged_input_id, - batch_size=merged_batch_size, - batch=merged_batch) \ No newline at end of file From d066119df4ae317bdcf5ad7660a09354fa12cd11 Mon Sep 17 00:00:00 2001 From: abhhfcgjk Date: Wed, 16 Oct 2024 20:41:48 +0300 Subject: [PATCH 09/13] adv training defense config --- metainfo/evasion_attack_parameters.json | 10 +- metainfo/evasion_defense_parameters.json | 2 +- src/attacks/QAttack/qattack.py | 243 +++++++++++++++++++++++ src/attacks/QAttack/utils.py | 46 +++++ 4 files changed, 298 insertions(+), 3 deletions(-) create mode 100644 src/attacks/QAttack/qattack.py create mode 100644 src/attacks/QAttack/utils.py diff --git a/metainfo/evasion_attack_parameters.json b/metainfo/evasion_attack_parameters.json index 3d54f1b..11a479d 100644 --- a/metainfo/evasion_attack_parameters.json +++ b/metainfo/evasion_attack_parameters.json @@ -11,6 +11,12 @@ "perturb_structure": ["perturb_structure", "bool", true, {}, "Indicates whether the structure can be changed"], "direct": ["direct", "bool", true, {}, "Indicates whether to directly modify edges/features of the node attacked or only those of influencers"], "n_influencers": ["n_influencers", "int", 0, {"min": 0, "step": 1}, "Number of influencing nodes. Will be ignored if direct is True"] - } + }, + "QAttack": { + "population_size": ["Population size", "int", 50, {"min": 1, "step": 1}, "Number of genes in population"], + "individual_size": ["Individual size", "int", 30, {"min": 1, "step": 1}, "Number of rewiring operations within one gene"], + "generations" : ["Generations", "int", 50, {"min": 0, "step": 1}, "Number of generations for genetic algorithm"], + "prob_cross": ["Probability for crossover", "float", 0.5, {"min": 0, "max": 1, "step": 0.01}, "Probability of crossover between two genes"], + "prob_mutate": ["Probability for mutation", "float", 0.02, {"min": 0, "max": 1, "step": 0.01}, "Probability of gene mutation"] } - +} \ No newline at end of file diff --git a/metainfo/evasion_defense_parameters.json b/metainfo/evasion_defense_parameters.json index 62b30d2..9b218b6 100644 --- a/metainfo/evasion_defense_parameters.json +++ b/metainfo/evasion_defense_parameters.json @@ -8,7 +8,7 @@ "qbit": ["qbit", "int", 8, {"min": 1, "step": 1}, "?"] }, "AdvTraining": { - "epsilon": ["epsilon", "float", 0.1, {"min": 0.001, "step": 0.005}] + "attack_name": ["attack_name", "str", "FGSM", {}, "?"] } } diff --git a/src/attacks/QAttack/qattack.py b/src/attacks/QAttack/qattack.py new file mode 100644 index 0000000..5e9ab3f --- /dev/null +++ b/src/attacks/QAttack/qattack.py @@ -0,0 +1,243 @@ +import copy +import math +import numpy as np +import random + +from tqdm import tqdm +from attacks.evasion_attacks import EvasionAttacker +from attacks.QAttack.utils import get_adj_list, from_adj_list, adj_list_oriented_to_non_oriented + +class QAttacker(EvasionAttacker): + name = "QAttack" + + def __init__(self, population_size, individual_size, generations, prob_cross, prob_mutate, **kwargs): + super().__init__(**kwargs) + self.population_size = population_size + self.individual_size = individual_size + self.generations = generations + self.prob_cross = prob_cross + self.prob_mutate = prob_mutate + + def init(self, gen_dataset): + """ + Init first population: + gen_dataset - graph-dataset + population_size - size of population + individual_size - amount of rewiring actions in one gene/individual + """ + self.population = [] + + self.adj_list = get_adj_list(gen_dataset) + + for i in tqdm(range(self.population_size), desc='Init first population:'): + non_isolated_nodes = set(gen_dataset.dataset.edge_index[0].tolist()).union( + set(gen_dataset.dataset.edge_index[1].tolist())) + selected_nodes = np.random.choice(list(self.adj_list.keys()), size=self.individual_size, replace=False) + gene = {} + for n in selected_nodes: + connected_nodes = set(self.adj_list[n]) + connected_nodes.add(n) + addition_nodes = non_isolated_nodes.difference(connected_nodes) + gene[n] = {'add': np.random.choice(list(addition_nodes), size=1), + 'del': np.random.choice(list(self.adj_list[n]), size=1)} + self.population.append(gene) + + def fitness(self, model, gen_dataset): + """ + Calculate fitness function with node classification + """ + + fit_scores = [] + for i in range(self.population_size): + # Get rewired dataset + dataset = copy.deepcopy(gen_dataset.dataset) + rewiring = self.population[i] + adj_list = get_adj_list(dataset) + for n in rewiring.keys(): + adj_list[n] = list(set(adj_list[n]).union({int(rewiring[n]['add'])}).difference({int(rewiring[n]['del'])})) + dataset.edge_index = from_adj_list(adj_list) + + # Get labels from black-box + labels = model.gnn.get_answer(dataset.x, dataset.edge_index) + labeled_nodes = dict(enumerate(labels.tolist())) + # labeled_nodes = {n: labels.tolist()[n-1] for n in adj_list.keys()} # FIXME check order for labels and node id consistency + + # Calculate modularity + Q = self.modularity(adj_list, labeled_nodes) + fit_scores.append(1 / math.exp(Q)) + return fit_scores + + def fitness_individual(self, model, gen_dataset, gene): + dataset = copy.deepcopy(gen_dataset.dataset) + rewiring = gene + adj_list = get_adj_list(dataset) + for n in rewiring.keys(): + adj_list[n] = list(set(adj_list[n]).union(set(rewiring[n]['add'])).difference(set(rewiring[n]['del']))) + dataset.edge_index = from_adj_list(adj_list) + + # Get labels from black-box + labels = model.gnn.get_answer(dataset.x, dataset.edge_index) + labeled_nodes = dict(enumerate(labels.tolist())) + # labeled_nodes = {n: labels.tolist()[n-1] for n in adj_list.keys()} # FIXME check order for labels and node id consistency + + # Calculate modularity + Q = self.modularity(adj_list, labeled_nodes) + return 1 / math.exp(Q) + + @staticmethod + def modularity(adj_list, labeled_nodes): + """ + Calculation of graph modularity with specified node partition on communities + """ + # TODO implement oriented-modularity + + inc = dict([]) + deg = dict([]) + + links = 0 + non_oriented_adj_list = adj_list_oriented_to_non_oriented(adj_list) + for k, v in non_oriented_adj_list.items(): + links += len(v) + if links == 0: + raise ValueError("A graph without link has an undefined modularity") + links //= 2 + + for node, edges in non_oriented_adj_list.items(): + com = labeled_nodes[node] + deg[com] = deg.get(com, 0.) + len(non_oriented_adj_list[node]) + for neighbor in edges: + edge_weight = 1 # TODO weighted graph to be implemented + if labeled_nodes[neighbor] == com: + if neighbor == node: + inc[com] = inc.get(com, 0.) + float(edge_weight) + else: + inc[com] = inc.get(com, 0.) + float(edge_weight) / 2. + + res = 0. + for com in set(labeled_nodes.values()): + res += (inc.get(com, 0.) / links) - \ + (deg.get(com, 0.) / (2. * links)) ** 2 + return res + + def selection(self, model_manager, gen_dataset): + fit_scores = self.fitness(model_manager, gen_dataset) + probs = [i / sum(fit_scores) for i in fit_scores] + selected_population = copy.deepcopy(self.population) + for i in range(self.population_size): + selected_population[i] = copy.deepcopy(self.population[np.random.choice( + self.population_size, 1, False, probs)[0]]) + self.population = selected_population + + def crossover(self): + for i in range(0, self.population_size // 2, 2): + parent_1 = self.population[i] + parent_2 = self.population[i + 1] + crossover_prob = np.random.random() + if crossover_prob <= self.prob_cross: + self.population[i * 2], self.population[i * 2 + 1] = self.gene_crossover(parent_1, parent_2) + else: + self.population[i * 2], self.population[i * 2 + 1] = (copy.deepcopy(self.population[i * 2]), + copy.deepcopy(self.population[i * 2 + 1])) + + def gene_crossover(self, parent_1, parent_2): + parent_1_set = set(parent_1.keys()) + parent_2_set = set(parent_2.keys()) + + parent_1_unique = parent_1_set.difference(parent_2_set) + parent_2_unique = parent_2_set.difference(parent_1_set) + + parent_1_cross = list(parent_1_unique) + parent_2_cross = list(parent_2_unique) + + assert len(parent_1_cross) == len(parent_2_cross) + if len(parent_1_cross) == 0: + return parent_1, parent_2 + n = np.random.randint(1, len(parent_1_cross) + 1) + parent_1_cross = random.sample(parent_1_cross, n) + parent_2_cross = random.sample(parent_2_cross, n) + + parent_1_set.difference_update(parent_1_cross) + parent_2_set.difference_update(parent_2_cross) + + parent_1_set.update(parent_2_cross) + parent_2_set.update(parent_1_cross) + + child_1 = {} + child_2 = {} + for n in parent_1_set: + if n in parent_1.keys(): + child_1[n] = parent_1[n] + else: + child_1[n] = parent_2[n] + for n in parent_2_set: + if n in parent_2.keys(): + child_2[n] = parent_2[n] + else: + child_2[n] = parent_1[n] + + return child_1,child_2 + + def mutation(self, gen_dataset): + for i in range(self.population_size): + keys = self.population[i].keys() + for n in list(keys): + mutation_prob = np.random.random() + if mutation_prob <= self.prob_mutate: + mut_type = np.random.randint(3) + dataset = copy.deepcopy(gen_dataset.dataset) + rewiring = self.population[i] + adj_list = get_adj_list(dataset) + for n in rewiring.keys(): + adj_list[n] = list( + set(adj_list[n]).union(set([int(rewiring[n]['add'])])).difference(set([int(rewiring[n]['del'])]))) + dataset.edge_index = from_adj_list(adj_list) + non_isolated_nodes = set(gen_dataset.dataset.edge_index[0].tolist()).union( + set(gen_dataset.dataset.edge_index[1].tolist())) + non_drain_nodes = set(gen_dataset.dataset.edge_index[0].tolist()) + if mut_type == 0: + # add mutation + connected_nodes = set(self.adj_list[n]) + connected_nodes.add(n) + addition_nodes = non_isolated_nodes.difference(connected_nodes) + self.population[i][n]['add'] = np.random.choice(list(addition_nodes), 1) + elif mut_type == 1: + # del mutation + self.population[i][n]['del'] = np.random.choice(list(adj_list[n]), 1) + else: + selected_nodes = set(self.population[i].keys()) + #non_selected_nodes = non_isolated_nodes.difference(selected_nodes) + non_drain_nodes = non_drain_nodes.difference(selected_nodes) + new_node = np.random.choice(list(non_drain_nodes), size=1, replace=False)[0] + self.population[i].pop(n) + addition_nodes = non_isolated_nodes.difference(set(self.adj_list[new_node])) + self.population[i][new_node] = {} + self.population[i][new_node]['add'] = np.random.choice(list(addition_nodes), 1) + self.population[i][new_node]['del'] = np.random.choice(list(adj_list[new_node]), 1) + + def elitism(self, model, gen_dataset): + fit_scores = list(enumerate(self.fitness(model, gen_dataset))) + fit_scores = sorted(fit_scores, key=lambda x: x[1]) + sort_order = [x[0] for x in fit_scores] + self.population = [self.population[i] for i in sort_order] + elitism_size = int(0.1 * self.population_size) + self.population[:elitism_size] = self.population[-elitism_size:] + return self.population[-1] + + + def attack(self, model_manager, gen_dataset, mask_tensor): + self.init(gen_dataset) + + for i in tqdm(range(self.generations), desc='Attack iterations:', position=0, leave=True): + self.selection(model_manager, gen_dataset) + self.crossover() + self.mutation(gen_dataset) + best_offspring = self.elitism(model_manager, gen_dataset) + + rewiring = best_offspring + adj_list = get_adj_list(gen_dataset) + for n in rewiring.keys(): + adj_list[n] = list( + set(adj_list[n]).union(set([int(rewiring[n]['add'])])).difference(set([int(rewiring[n]['del'])]))) + + gen_dataset.dataset.data.edge_index = from_adj_list(adj_list) + return gen_dataset \ No newline at end of file diff --git a/src/attacks/QAttack/utils.py b/src/attacks/QAttack/utils.py new file mode 100644 index 0000000..31a8b35 --- /dev/null +++ b/src/attacks/QAttack/utils.py @@ -0,0 +1,46 @@ +import copy + +import torch + +def get_adj_list(gen_dataset): + """ + Get adjacency list from gen_dataset of GNN-AID format + """ + if hasattr(gen_dataset, 'dataset'): + gen_dataset = gen_dataset.dataset + adj_list = {} + for u, v in zip(gen_dataset.edge_index[0].tolist(), gen_dataset.edge_index[1].tolist()): + if u in adj_list.keys(): + adj_list[u].append(v) + else: + adj_list[u] = [v] + + return adj_list + +def from_adj_list(adj_list): + """ + Get edge_index in COO-format from adjacency list + """ + in_nodes = [] + out_nodes = [] + for n, edges in adj_list.items(): + for e in edges: + in_nodes.append(n) + out_nodes.append(e) + return torch.tensor([in_nodes, out_nodes], dtype=torch.int) + +def adj_list_oriented_to_non_oriented(adj_list): + non_oriented_adj_list = {} + for node, neighs in adj_list.items(): + if node not in non_oriented_adj_list.keys(): + non_oriented_adj_list[node] = copy.deepcopy(adj_list[node]) + else: + non_oriented_adj_list[node] += copy.deepcopy(adj_list[node]) + for in_node in adj_list[node]: + if in_node not in non_oriented_adj_list.keys(): + non_oriented_adj_list[in_node] = [node] + else: + non_oriented_adj_list[in_node].append(node) + for k in non_oriented_adj_list.keys(): + non_oriented_adj_list[k] = list(set(non_oriented_adj_list[k])) + return non_oriented_adj_list \ No newline at end of file From 94eb91f3e8dc85ff1dc8321963a6a1c8535dc137 Mon Sep 17 00:00:00 2001 From: abhhfcgjk Date: Wed, 16 Oct 2024 20:49:19 +0300 Subject: [PATCH 10/13] rm GNNGuard --- src/defense/GNNGuard/gnnguard.py | 282 ----------- src/defense/GNNGuard/utils.py | 777 ------------------------------- 2 files changed, 1059 deletions(-) delete mode 100644 src/defense/GNNGuard/gnnguard.py delete mode 100644 src/defense/GNNGuard/utils.py diff --git a/src/defense/GNNGuard/gnnguard.py b/src/defense/GNNGuard/gnnguard.py deleted file mode 100644 index 6a284bf..0000000 --- a/src/defense/GNNGuard/gnnguard.py +++ /dev/null @@ -1,282 +0,0 @@ -import torch.nn as nn -import torch.nn.functional as F -import math -import torch -from torch.nn.parameter import Parameter -from torch.nn.modules.module import Module -from torch_geometric.nn import GCNConv - -# from defense.GNNGuard.base_model import BaseModel -from defense.poison_defense import PoisonDefender - -from models_builder.gnn_models import FrameworkGNNModelManager -from models_builder.gnn_constructor import FrameworkGNNConstructor -from models_builder.models_zoo import model_configs_zoo -from aux.configs import ModelManagerConfig, ModelModificationConfig, DatasetConfig, DatasetVarConfig, ConfigPattern -from aux.utils import import_by_name, CUSTOM_LAYERS_INFO_PATH, MODULES_PARAMETERS_PATH, hash_data_sha256, \ - TECHNICAL_PARAMETER_KEY, IMPORT_INFO_KEY, OPTIMIZERS_PARAMETERS_PATH - - -import warnings -import types -# from torch_sparse import coalesce, SparseTensor, matmul - -from sklearn.metrics.pairwise import cosine_similarity -from sklearn.preprocessing import normalize - -from scipy.sparse import lil_matrix -import scipy.sparse as sp - -import numpy as np - -from src.aux.configs import ModelConfig - - -class BaseGNNGuard(PoisonDefender): - name = "BaseGNNGuard" - - def __init__(self, lr=0.1, train_iters=200, device='cpu'): - super().__init__() - self.model = None - self.lr = lr - self.device = device - self.train_iters = train_iters - - def defense(self, gen_dataset): - self.model = model_configs_zoo(gen_dataset, 'gcn_gcn_linearized') - default_config = ModelModificationConfig( - model_ver_ind=0, - ) - manager_config = ConfigPattern( - _config_class="ModelManagerConfig", - _config_kwargs={ - "mask_features": [], - "optimizer": { - "_config_class": "Config", - "_class_name": "Adam", - "_import_path": OPTIMIZERS_PARAMETERS_PATH, - "_class_import_info": ["torch.optim"], - "_config_kwargs": {"weight_decay": 5e-4}, - } - } - ) - gnn_model_manager_surrogate = FrameworkGNNModelManager( - gnn=self.model, - dataset_path=gen_dataset, - modification=default_config, - manager_config=manager_config, - ) - - gnn_model_manager_surrogate.train_model(gen_dataset=gen_dataset, steps=self.train_iters) - - self.pred_labels = gnn_model_manager_surrogate.run_model(gen_dataset=gen_dataset, mask='all', out='answers') - - -class GNNGuard(BaseGNNGuard): - name = 'GNNGuard' - - def __init__(self, lr=0.01, attention=False, drop=False, train_iters=200, device='cpu', with_bias=False, with_relu=False): - super().__init__(lr=lr, train_iters=train_iters, device=device) - assert device is not None, "Please specify 'device'!" - self.with_bias = with_bias - self.with_relu = with_relu - self.attention = attention - self.drop = drop - - def defense(self, gen_dataset): - super().defense(gen_dataset=gen_dataset) - - self.hidden_sizes = [16] - self.nfeat = gen_dataset.num_node_features - self.nclass = gen_dataset.num_classes - # self.attention = attention - # self.drop = drop - - self.wrap() - self.initialize() - - def wrap(self): - def att_coef(self, fea, edge_index, is_lil=False, i=0): - if is_lil == False: - edge_index = edge_index._indices() - else: - edge_index = edge_index.tocoo() - - n_node = fea.shape[0] - row, col = edge_index[0].cpu().data.numpy()[:], edge_index[1].cpu().data.numpy()[:] - - fea_copy = fea.cpu().data.numpy() - sim_matrix = cosine_similarity(X=fea_copy, Y=fea_copy) # try cosine similarity - sim = sim_matrix[row, col] - sim[sim<0.1] = 0 - - """build a attention matrix""" - att_dense = lil_matrix((n_node, n_node), dtype=np.float32) - att_dense[row, col] = sim - if att_dense[0, 0] == 1: - att_dense = att_dense - sp.diags(att_dense.diagonal(), offsets=0, format="lil") - # normalization, make the sum of each row is 1 - att_dense_norm = normalize(att_dense, axis=1, norm='l1') - - - """add learnable dropout, make character vector""" - if self.drop: - character = np.vstack((att_dense_norm[row, col].A1, - att_dense_norm[col, row].A1)) - character = torch.from_numpy(character.T) - drop_score = self.drop_learn_1(character) - drop_score = torch.sigmoid(drop_score) # do not use softmax since we only have one element - mm = torch.nn.Threshold(0.5, 0) - drop_score = mm(drop_score) - mm_2 = torch.nn.Threshold(-0.49, 1) - drop_score = mm_2(-drop_score) - drop_decision = drop_score.clone().requires_grad_() - # print('rate of left edges', drop_decision.sum().data/drop_decision.shape[0]) - drop_matrix = lil_matrix((n_node, n_node), dtype=np.float32) - drop_matrix[row, col] = drop_decision.cpu().data.numpy().squeeze(-1) - att_dense_norm = att_dense_norm.multiply(drop_matrix.tocsr()) # update, remove the 0 edges - - if att_dense_norm[0, 0] == 0: # add the weights of self-loop only add self-loop at the first layer - degree = (att_dense_norm != 0).sum(1).A1 - lam = 1 / (degree + 1) # degree +1 is to add itself - self_weight = sp.diags(np.array(lam), offsets=0, format="lil") - att = att_dense_norm + self_weight # add the self loop - else: - att = att_dense_norm - - row, col = att.nonzero() - att_adj = np.vstack((row, col)) - att_edge_weight = att[row, col] - att_edge_weight = np.exp(att_edge_weight) # exponent, kind of softmax - att_edge_weight = torch.tensor(np.array(att_edge_weight)[0], dtype=torch.float32)#.cuda() - att_adj = torch.tensor(att_adj, dtype=torch.int64)#.cuda() - - shape = (n_node, n_node) - new_adj = torch.sparse.FloatTensor(att_adj, att_edge_weight, shape) - return new_adj - - def forward(self, *args, **kwargs): - """we don't change the edge_index, just update the edge_weight; - some edge_weight are regarded as removed if it equals to zero""" - layer_ind = -1 - tensor_storage = {} - dim_cat = 0 - layer_emb_dict = {} - save_emb_flag = self._save_emb_flag - - x, edge_index, batch = self.arguments_read(*args, **kwargs) - feat = x - adj = edge_index.tocoo() - adj_memory = None - # print(list(self.__dict__['_modules'].items())) - for elem in list(self.__dict__['_modules'].items()): - layer_name, curr_layer_ind = elem[0].split('_') - curr_layer_ind = int(curr_layer_ind) - inp = torch.clone(x) - loc_flag = False - if curr_layer_ind != layer_ind: - if save_emb_flag: - loc_flag = True - zeroing_x_flag = False - for key, value in self.conn_dict.items(): - if key[0] == layer_ind and layer_ind not in tensor_storage: - tensor_storage[layer_ind] = torch.clone(x) - layer_ind = curr_layer_ind - x_copy = torch.clone(x) - connection_tensor = torch.Tensor() - for key, value in self.conn_dict.items(): - - if key[1] == curr_layer_ind: - if key[1] - key[0] == 1: - zeroing_x_flag = True - for con in self.conn_dict[key]: - if self.embedding_levels_by_layers[key[1]] == 'n' and \ - self.embedding_levels_by_layers[key[0]] == 'n': - connection_tensor = torch.cat((connection_tensor, - tensor_storage[key[0]]), 1) - dim_cat = 1 - elif self.embedding_levels_by_layers[key[1]] == 'g' and \ - self.embedding_levels_by_layers[key[0]] == 'g': - connection_tensor = torch.cat((connection_tensor, - tensor_storage[key[0]]), 0) - dim_cat = 0 - elif self.embedding_levels_by_layers[key[1]] == 'g' and \ - self.embedding_levels_by_layers[key[0]] == 'n': - con_pool = import_by_name(con['pool']['pool_type'], - ["torch_geometric.nn"]) - tensor_after_pool = con_pool(tensor_storage[key[0]], batch) - connection_tensor = torch.cat((connection_tensor, - tensor_after_pool), 1) - dim_cat = 1 - else: - raise Exception( - "Connection from layer type " - f"{self.embedding_levels_by_layers[curr_layer_ind - 1]} to" - f" layer type {self.embedding_levels_by_layers[curr_layer_ind]}" - "is not supported now") - - - if zeroing_x_flag: - x = connection_tensor - else: - x = torch.cat((x_copy, connection_tensor), dim_cat) - - - if self.attention: - if layer_name == 'GINConv': - if adj_memory is None: - adj = self.att_coef(x, adj, is_lil=False,i=layer_ind) - edge_index = adj._indices() - edge_weight = adj._values() - adj_memory = adj - elif adj_memory is not None: - adj = self.att_coef(x, adj_memory, is_lil=False, i=layer_ind) - edge_weight = self.gate * adj_memory._values() + (1 - self.gate) * adj._values() - adj_memory = adj - elif layer_name == 'GCNConv' or layer_name == 'GATConv': - if adj_memory is None: - adj = self.att_coef(x, adj, i=0) - edge_index = adj._indices() - edge_weight = adj._values() - adj_memory = adj - elif adj_memory is not None: - adj = self.att_coef(x, adj_memory, i=layer_ind).to_dense() - row, col = adj.nonzero()[:,0], adj.nonzero()[:,1] - edge_index = torch.stack((row, col), dim=0) - edge_weight = adj[row, col] - adj_memory = adj - else: - edge_index = adj._indices() - edge_weight = adj._values() - - - # QUE Kirill, maybe we should not off UserWarning - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning) - # mid = x - if layer_name in self.modules_info: - code_str = f"getattr(self, elem[0])" \ - f"({self.modules_info[layer_name][TECHNICAL_PARAMETER_KEY]['forward_parameters']}," \ - f" edge_weight=edge_weight)" - x = eval(f"{code_str}") - else: - x = getattr(self, elem[0])(x) - if loc_flag: - layer_emb_dict[layer_ind] = torch.clone(x) - - if save_emb_flag: - return layer_emb_dict - return x - - self.model.drop = self.drop - self.model.attention = self.attention - self.model.gate = Parameter(torch.rand(1)) - self.model.drop_learn_1 = nn.Linear(2, 1) - self.model.drop_learn_2 = nn.Linear(2, 1) - self.model.forward = types.MethodType(forward, self.model) - self.model.att_coef = types.MethodType(att_coef, self.model) - - def initialize(self): - self.model.drop_learn_1.reset_parameters() - self.model.drop_learn_2.reset_parameters() - # self.model.gate.reset_parameters() diff --git a/src/defense/GNNGuard/utils.py b/src/defense/GNNGuard/utils.py deleted file mode 100644 index f7be00a..0000000 --- a/src/defense/GNNGuard/utils.py +++ /dev/null @@ -1,777 +0,0 @@ -import numpy as np -import scipy.sparse as sp -import torch -from sklearn.model_selection import train_test_split -import torch.sparse as ts -import torch.nn.functional as F -import warnings - -def encode_onehot(labels): - """Convert label to onehot format. - - Parameters - ---------- - labels : numpy.array - node labels - - Returns - ------- - numpy.array - onehot labels - """ - eye = np.eye(labels.max() + 1) - onehot_mx = eye[labels] - return onehot_mx - -def tensor2onehot(labels): - """Convert label tensor to label onehot tensor. - - Parameters - ---------- - labels : torch.LongTensor - node labels - - Returns - ------- - torch.LongTensor - onehot labels tensor - - """ - - eye = torch.eye(labels.max() + 1) - onehot_mx = eye[labels] - return onehot_mx.to(labels.device) - -def preprocess(adj, features, labels, preprocess_adj=False, preprocess_feature=False, sparse=False, device='cpu'): - """Convert adj, features, labels from array or sparse matrix to - torch Tensor, and normalize the input data. - - Parameters - ---------- - adj : scipy.sparse.csr_matrix - the adjacency matrix. - features : scipy.sparse.csr_matrix - node features - labels : numpy.array - node labels - preprocess_adj : bool - whether to normalize the adjacency matrix - preprocess_feature : bool - whether to normalize the feature matrix - sparse : bool - whether to return sparse tensor - device : str - 'cpu' or 'cuda' - """ - - if preprocess_adj: - adj = normalize_adj(adj) - - if preprocess_feature: - features = normalize_feature(features) - - labels = torch.LongTensor(labels) - if sparse: - adj = sparse_mx_to_torch_sparse_tensor(adj) - features = sparse_mx_to_torch_sparse_tensor(features) - else: - if sp.issparse(features): - features = torch.FloatTensor(np.array(features.todense())) - else: - features = torch.FloatTensor(features) - adj = torch.FloatTensor(adj.todense()) - return adj.to(device), features.to(device), labels.to(device) - -def to_tensor(adj, features, labels=None, device='cpu'): - """Convert adj, features, labels from array or sparse matrix to - torch Tensor. - - Parameters - ---------- - adj : scipy.sparse.csr_matrix - the adjacency matrix. - features : scipy.sparse.csr_matrix - node features - labels : numpy.array - node labels - device : str - 'cpu' or 'cuda' - """ - if sp.issparse(adj): - adj = sparse_mx_to_torch_sparse_tensor(adj) - else: - adj = torch.FloatTensor(adj) - if sp.issparse(features): - features = sparse_mx_to_torch_sparse_tensor(features) - else: - features = torch.FloatTensor(np.array(features)) - - if labels is None: - return adj.to(device), features.to(device) - else: - labels = torch.LongTensor(labels) - return adj.to(device), features.to(device), labels.to(device) - -def normalize_feature(mx): - """Row-normalize sparse matrix or dense matrix - - Parameters - ---------- - mx : scipy.sparse.csr_matrix or numpy.array - matrix to be normalized - - Returns - ------- - scipy.sprase.lil_matrix - normalized matrix - """ - if type(mx) is not sp.lil.lil_matrix: - try: - mx = mx.tolil() - except AttributeError: - pass - rowsum = np.array(mx.sum(1)) - r_inv = np.power(rowsum, -1).flatten() - r_inv[np.isinf(r_inv)] = 0. - r_mat_inv = sp.diags(r_inv) - mx = r_mat_inv.dot(mx) - return mx - -def normalize_adj(mx): - """Normalize sparse adjacency matrix, - A' = (D + I)^-1/2 * ( A + I ) * (D + I)^-1/2 - Row-normalize sparse matrix - - Parameters - ---------- - mx : scipy.sparse.csr_matrix - matrix to be normalized - - Returns - ------- - scipy.sprase.lil_matrix - normalized matrix - """ - - # TODO: maybe using coo format would be better? - if type(mx) is not sp.lil.lil_matrix: - mx = mx.tolil() - if mx[0, 0] == 0 : - mx = mx + sp.eye(mx.shape[0]) - rowsum = np.array(mx.sum(1)) - r_inv = np.power(rowsum, -1/2).flatten() - r_inv[np.isinf(r_inv)] = 0. - r_mat_inv = sp.diags(r_inv) - mx = r_mat_inv.dot(mx) - mx = mx.dot(r_mat_inv) - return mx - -def normalize_sparse_tensor(adj, fill_value=1): - """Normalize sparse tensor. Need to import torch_scatter - """ - edge_index = adj._indices() - edge_weight = adj._values() - num_nodes= adj.size(0) - edge_index, edge_weight = add_self_loops( - edge_index, edge_weight, fill_value, num_nodes) - - row, col = edge_index - from torch_scatter import scatter_add - deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) - deg_inv_sqrt = deg.pow(-0.5) - deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 - - values = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] - - shape = adj.shape - return torch.sparse.FloatTensor(edge_index, values, shape) - -def add_self_loops(edge_index, edge_weight=None, fill_value=1, num_nodes=None): - # num_nodes = maybe_num_nodes(edge_index, num_nodes) - - loop_index = torch.arange(0, num_nodes, dtype=torch.long, - device=edge_index.device) - loop_index = loop_index.unsqueeze(0).repeat(2, 1) - - if edge_weight is not None: - assert edge_weight.numel() == edge_index.size(1) - loop_weight = edge_weight.new_full((num_nodes, ), fill_value) - edge_weight = torch.cat([edge_weight, loop_weight], dim=0) - - edge_index = torch.cat([edge_index, loop_index], dim=1) - - return edge_index, edge_weight - -def normalize_adj_tensor(adj, sparse=False): - """Normalize adjacency tensor matrix. - """ - device = adj.device - if sparse: - # warnings.warn('If you find the training process is too slow, you can uncomment line 207 in deeprobust/graph/utils.py. Note that you need to install torch_sparse') - # TODO if this is too slow, uncomment the following code, - # but you need to install torch_scatter - # return normalize_sparse_tensor(adj) - adj = to_scipy(adj) - mx = normalize_adj(adj) - return sparse_mx_to_torch_sparse_tensor(mx).to(device) - else: - mx = adj + torch.eye(adj.shape[0]).to(device) - rowsum = mx.sum(1) - r_inv = rowsum.pow(-1/2).flatten() - r_inv[torch.isinf(r_inv)] = 0. - r_mat_inv = torch.diag(r_inv) - mx = r_mat_inv @ mx - mx = mx @ r_mat_inv - return mx - -def degree_normalize_adj(mx): - """Row-normalize sparse matrix""" - mx = mx.tolil() - if mx[0, 0] == 0 : - mx = mx + sp.eye(mx.shape[0]) - rowsum = np.array(mx.sum(1)) - r_inv = np.power(rowsum, -1).flatten() - r_inv[np.isinf(r_inv)] = 0. - r_mat_inv = sp.diags(r_inv) - # mx = mx.dot(r_mat_inv) - mx = r_mat_inv.dot(mx) - return mx - -def degree_normalize_sparse_tensor(adj, fill_value=1): - """degree_normalize_sparse_tensor. - """ - edge_index = adj._indices() - edge_weight = adj._values() - num_nodes= adj.size(0) - - edge_index, edge_weight = add_self_loops( - edge_index, edge_weight, fill_value, num_nodes) - - row, col = edge_index - from torch_scatter import scatter_add - deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) - deg_inv_sqrt = deg.pow(-1) - deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 - - values = deg_inv_sqrt[row] * edge_weight - shape = adj.shape - return torch.sparse.FloatTensor(edge_index, values, shape) - -def degree_normalize_adj_tensor(adj, sparse=True): - """degree_normalize_adj_tensor. - """ - - device = adj.device - if sparse: - # return degree_normalize_sparse_tensor(adj) - adj = to_scipy(adj) - mx = degree_normalize_adj(adj) - return sparse_mx_to_torch_sparse_tensor(mx).to(device) - else: - mx = adj + torch.eye(adj.shape[0]).to(device) - rowsum = mx.sum(1) - r_inv = rowsum.pow(-1).flatten() - r_inv[torch.isinf(r_inv)] = 0. - r_mat_inv = torch.diag(r_inv) - mx = r_mat_inv @ mx - return mx - -def accuracy(output, labels): - """Return accuracy of output compared to labels. - - Parameters - ---------- - output : torch.Tensor - output from model - labels : torch.Tensor or numpy.array - node labels - - Returns - ------- - float - accuracy - """ - if not hasattr(labels, '__len__'): - labels = [labels] - if type(labels) is not torch.Tensor: - labels = torch.LongTensor(labels) - preds = output.max(1)[1].type_as(labels) - correct = preds.eq(labels).double() - correct = correct.sum() - return correct / len(labels) - -def loss_acc(output, labels, targets, avg_loss=True): - if type(labels) is not torch.Tensor: - labels = torch.LongTensor(labels) - preds = output.max(1)[1].type_as(labels) - correct = preds.eq(labels).double()[targets] - loss = F.nll_loss(output[targets], labels[targets], reduction='mean' if avg_loss else 'none') - - if avg_loss: - return loss, correct.sum() / len(targets) - return loss, correct - # correct = correct.sum() - # return loss, correct / len(labels) - -def get_perf(output, labels, mask, verbose=True): - """evalute performance for test masked data""" - loss = F.nll_loss(output[mask], labels[mask]) - acc = accuracy(output[mask], labels[mask]) - if verbose: - print("loss= {:.4f}".format(loss.item()), - "accuracy= {:.4f}".format(acc.item())) - return loss.item(), acc.item() - - -def classification_margin(output, true_label): - """Calculate classification margin for outputs. - `probs_true_label - probs_best_second_class` - - Parameters - ---------- - output: torch.Tensor - output vector (1 dimension) - true_label: int - true label for this node - - Returns - ------- - list - classification margin for this node - """ - - probs = torch.exp(output) - probs_true_label = probs[true_label].clone() - probs[true_label] = 0 - probs_best_second_class = probs[probs.argmax()] - return (probs_true_label - probs_best_second_class).item() - -def sparse_mx_to_torch_sparse_tensor(sparse_mx): - """Convert a scipy sparse matrix to a torch sparse tensor.""" - sparse_mx = sparse_mx.tocoo().astype(np.float32) - sparserow=torch.LongTensor(sparse_mx.row).unsqueeze(1) - sparsecol=torch.LongTensor(sparse_mx.col).unsqueeze(1) - sparseconcat=torch.cat((sparserow, sparsecol),1) - sparsedata=torch.FloatTensor(sparse_mx.data) - return torch.sparse.FloatTensor(sparseconcat.t(),sparsedata,torch.Size(sparse_mx.shape)) - - # slower version.... - # sparse_mx = sparse_mx.tocoo().astype(np.float32) - # indices = torch.from_numpy( - # np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) - # values = torch.from_numpy(sparse_mx.data) - # shape = torch.Size(sparse_mx.shape) - # return torch.sparse.FloatTensor(indices, values, shape) - - - -def to_scipy(tensor): - """Convert a dense/sparse tensor to scipy matrix""" - if is_sparse_tensor(tensor): - values = tensor._values() - indices = tensor._indices() - return sp.csr_matrix((values.cpu().numpy(), indices.cpu().numpy()), shape=tensor.shape) - else: - indices = tensor.nonzero().t() - values = tensor[indices[0], indices[1]] - return sp.csr_matrix((values.cpu().numpy(), indices.cpu().numpy()), shape=tensor.shape) - -def is_sparse_tensor(tensor): - """Check if a tensor is sparse tensor. - - Parameters - ---------- - tensor : torch.Tensor - given tensor - - Returns - ------- - bool - whether a tensor is sparse tensor - """ - # if hasattr(tensor, 'nnz'): - if tensor.layout == torch.sparse_coo: - return True - else: - return False - -def get_train_val_test(nnodes, val_size=0.1, test_size=0.8, stratify=None, seed=None): - """This setting follows nettack/mettack, where we split the nodes - into 10% training, 10% validation and 80% testing data - - Parameters - ---------- - nnodes : int - number of nodes in total - val_size : float - size of validation set - test_size : float - size of test set - stratify : - data is expected to split in a stratified fashion. So stratify should be labels. - seed : int or None - random seed - - Returns - ------- - idx_train : - node training indices - idx_val : - node validation indices - idx_test : - node test indices - """ - - assert stratify is not None, 'stratify cannot be None!' - - if seed is not None: - np.random.seed(seed) - - idx = np.arange(nnodes) - train_size = 1 - val_size - test_size - idx_train_and_val, idx_test = train_test_split(idx, - random_state=None, - train_size=train_size + val_size, - test_size=test_size, - stratify=stratify) - - if stratify is not None: - stratify = stratify[idx_train_and_val] - - idx_train, idx_val = train_test_split(idx_train_and_val, - random_state=None, - train_size=(train_size / (train_size + val_size)), - test_size=(val_size / (train_size + val_size)), - stratify=stratify) - - return idx_train, idx_val, idx_test - -def get_train_test(nnodes, test_size=0.8, stratify=None, seed=None): - """This function returns training and test set without validation. - It can be used for settings of different label rates. - - Parameters - ---------- - nnodes : int - number of nodes in total - test_size : float - size of test set - stratify : - data is expected to split in a stratified fashion. So stratify should be labels. - seed : int or None - random seed - - Returns - ------- - idx_train : - node training indices - idx_test : - node test indices - """ - assert stratify is not None, 'stratify cannot be None!' - - if seed is not None: - np.random.seed(seed) - - idx = np.arange(nnodes) - train_size = 1 - test_size - idx_train, idx_test = train_test_split(idx, random_state=None, - train_size=train_size, - test_size=test_size, - stratify=stratify) - - return idx_train, idx_test - -def get_train_val_test_gcn(labels, seed=None): - """This setting follows gcn, where we randomly sample 20 instances for each class - as training data, 500 instances as validation data, 1000 instances as test data. - Note here we are not using fixed splits. When random seed changes, the splits - will also change. - - Parameters - ---------- - labels : numpy.array - node labels - seed : int or None - random seed - - Returns - ------- - idx_train : - node training indices - idx_val : - node validation indices - idx_test : - node test indices - """ - if seed is not None: - np.random.seed(seed) - - idx = np.arange(len(labels)) - nclass = labels.max() + 1 - idx_train = [] - idx_unlabeled = [] - for i in range(nclass): - labels_i = idx[labels==i] - labels_i = np.random.permutation(labels_i) - idx_train = np.hstack((idx_train, labels_i[: 20])).astype(np.int) - idx_unlabeled = np.hstack((idx_unlabeled, labels_i[20: ])).astype(np.int) - - idx_unlabeled = np.random.permutation(idx_unlabeled) - idx_val = idx_unlabeled[: 500] - idx_test = idx_unlabeled[500: 1500] - return idx_train, idx_val, idx_test - -def get_train_test_labelrate(labels, label_rate): - """Get train test according to given label rate. - """ - nclass = labels.max() + 1 - train_size = int(round(len(labels) * label_rate / nclass)) - print("=== train_size = %s ===" % train_size) - idx_train, idx_val, idx_test = get_splits_each_class(labels, train_size=train_size) - return idx_train, idx_test - -def get_splits_each_class(labels, train_size): - """We randomly sample n instances for class, where n = train_size. - """ - idx = np.arange(len(labels)) - nclass = labels.max() + 1 - idx_train = [] - idx_val = [] - idx_test = [] - for i in range(nclass): - labels_i = idx[labels==i] - labels_i = np.random.permutation(labels_i) - idx_train = np.hstack((idx_train, labels_i[: train_size])).astype(np.int) - idx_val = np.hstack((idx_val, labels_i[train_size: 2*train_size])).astype(np.int) - idx_test = np.hstack((idx_test, labels_i[2*train_size: ])).astype(np.int) - - return np.random.permutation(idx_train), np.random.permutation(idx_val), \ - np.random.permutation(idx_test) - - -def unravel_index(index, array_shape): - rows = torch.div(index, array_shape[1], rounding_mode='trunc') - cols = index % array_shape[1] - return rows, cols - - -def get_degree_squence(adj): - try: - return adj.sum(0) - except: - return ts.sum(adj, dim=1).to_dense() - -def likelihood_ratio_filter(node_pairs, modified_adjacency, original_adjacency, d_min, threshold=0.004, undirected=True): - """ - Filter the input node pairs based on the likelihood ratio test proposed by Zügner et al. 2018, see - https://dl.acm.org/citation.cfm?id=3220078. In essence, for each node pair return 1 if adding/removing the edge - between the two nodes does not violate the unnoticeability constraint, and return 0 otherwise. Assumes unweighted - and undirected graphs. - """ - - N = int(modified_adjacency.shape[0]) - # original_degree_sequence = get_degree_squence(original_adjacency) - # current_degree_sequence = get_degree_squence(modified_adjacency) - original_degree_sequence = original_adjacency.sum(0) - current_degree_sequence = modified_adjacency.sum(0) - - concat_degree_sequence = torch.cat((current_degree_sequence, original_degree_sequence)) - - # Compute the log likelihood values of the original, modified, and combined degree sequences. - ll_orig, alpha_orig, n_orig, sum_log_degrees_original = degree_sequence_log_likelihood(original_degree_sequence, d_min) - ll_current, alpha_current, n_current, sum_log_degrees_current = degree_sequence_log_likelihood(current_degree_sequence, d_min) - - ll_comb, alpha_comb, n_comb, sum_log_degrees_combined = degree_sequence_log_likelihood(concat_degree_sequence, d_min) - - # Compute the log likelihood ratio - current_ratio = -2 * ll_comb + 2 * (ll_orig + ll_current) - - # Compute new log likelihood values that would arise if we add/remove the edges corresponding to each node pair. - new_lls, new_alphas, new_ns, new_sum_log_degrees = updated_log_likelihood_for_edge_changes(node_pairs, - modified_adjacency, d_min) - - # Combination of the original degree distribution with the distributions corresponding to each node pair. - n_combined = n_orig + new_ns - new_sum_log_degrees_combined = sum_log_degrees_original + new_sum_log_degrees - alpha_combined = compute_alpha(n_combined, new_sum_log_degrees_combined, d_min) - new_ll_combined = compute_log_likelihood(n_combined, alpha_combined, new_sum_log_degrees_combined, d_min) - new_ratios = -2 * new_ll_combined + 2 * (new_lls + ll_orig) - - # Allowed edges are only those for which the resulting likelihood ratio measure is < than the threshold - allowed_edges = new_ratios < threshold - - if allowed_edges.is_cuda: - filtered_edges = node_pairs[allowed_edges.cpu().numpy().astype(np.bool)] - else: - filtered_edges = node_pairs[allowed_edges.numpy().astype(np.bool)] - - allowed_mask = torch.zeros(modified_adjacency.shape) - allowed_mask[filtered_edges.T] = 1 - if undirected: - allowed_mask += allowed_mask.t() - return allowed_mask, current_ratio - - -def degree_sequence_log_likelihood(degree_sequence, d_min): - """ - Compute the (maximum) log likelihood of the Powerlaw distribution fit on a degree distribution. - """ - - # Determine which degrees are to be considered, i.e. >= d_min. - D_G = degree_sequence[(degree_sequence >= d_min.item())] - try: - sum_log_degrees = torch.log(D_G).sum() - except: - sum_log_degrees = np.log(D_G).sum() - n = len(D_G) - - alpha = compute_alpha(n, sum_log_degrees, d_min) - ll = compute_log_likelihood(n, alpha, sum_log_degrees, d_min) - return ll, alpha, n, sum_log_degrees - -def updated_log_likelihood_for_edge_changes(node_pairs, adjacency_matrix, d_min): - """ Adopted from https://github.com/danielzuegner/nettack - """ - # For each node pair find out whether there is an edge or not in the input adjacency matrix. - - edge_entries_before = adjacency_matrix[node_pairs.T] - degree_sequence = adjacency_matrix.sum(1) - D_G = degree_sequence[degree_sequence >= d_min.item()] - sum_log_degrees = torch.log(D_G).sum() - n = len(D_G) - deltas = -2 * edge_entries_before + 1 - d_edges_before = degree_sequence[node_pairs] - - d_edges_after = degree_sequence[node_pairs] + deltas[:, None] - - # Sum the log of the degrees after the potential changes which are >= d_min - sum_log_degrees_after, new_n = update_sum_log_degrees(sum_log_degrees, n, d_edges_before, d_edges_after, d_min) - # Updated estimates of the Powerlaw exponents - new_alpha = compute_alpha(new_n, sum_log_degrees_after, d_min) - # Updated log likelihood values for the Powerlaw distributions - new_ll = compute_log_likelihood(new_n, new_alpha, sum_log_degrees_after, d_min) - return new_ll, new_alpha, new_n, sum_log_degrees_after - - -def update_sum_log_degrees(sum_log_degrees_before, n_old, d_old, d_new, d_min): - # Find out whether the degrees before and after the change are above the threshold d_min. - old_in_range = d_old >= d_min - new_in_range = d_new >= d_min - d_old_in_range = d_old * old_in_range.float() - d_new_in_range = d_new * new_in_range.float() - - # Update the sum by subtracting the old values and then adding the updated logs of the degrees. - sum_log_degrees_after = sum_log_degrees_before - (torch.log(torch.clamp(d_old_in_range, min=1))).sum(1) \ - + (torch.log(torch.clamp(d_new_in_range, min=1))).sum(1) - - # Update the number of degrees >= d_min - - new_n = n_old - (old_in_range!=0).sum(1) + (new_in_range!=0).sum(1) - new_n = new_n.float() - return sum_log_degrees_after, new_n - -def compute_alpha(n, sum_log_degrees, d_min): - try: - alpha = 1 + n / (sum_log_degrees - n * torch.log(d_min - 0.5)) - except: - alpha = 1 + n / (sum_log_degrees - n * np.log(d_min - 0.5)) - return alpha - -def compute_log_likelihood(n, alpha, sum_log_degrees, d_min): - # Log likelihood under alpha - try: - ll = n * torch.log(alpha) + n * alpha * torch.log(d_min) + (alpha + 1) * sum_log_degrees - except: - ll = n * np.log(alpha) + n * alpha * np.log(d_min) + (alpha + 1) * sum_log_degrees - - return ll - -def ravel_multiple_indices(ixs, shape, reverse=False): - """ - "Flattens" multiple 2D input indices into indices on the flattened matrix, similar to np.ravel_multi_index. - Does the same as ravel_index but for multiple indices at once. - Parameters - ---------- - ixs: array of ints shape (n, 2) - The array of n indices that will be flattened. - - shape: list or tuple of ints of length 2 - The shape of the corresponding matrix. - - Returns - ------- - array of n ints between 0 and shape[0]*shape[1]-1 - The indices on the flattened matrix corresponding to the 2D input indices. - - """ - if reverse: - return ixs[:, 1] * shape[1] + ixs[:, 0] - - return ixs[:, 0] * shape[1] + ixs[:, 1] - -def visualize(your_var): - """visualize computation graph""" - from graphviz import Digraph - import torch - from torch.autograd import Variable - from torchviz import make_dot - make_dot(your_var).view() - -def reshape_mx(mx, shape): - indices = mx.nonzero() - return sp.csr_matrix((mx.data, (indices[0], indices[1])), shape=shape) - -def add_mask(data, dataset): - """data: ogb-arxiv pyg data format""" - # for arxiv - split_idx = dataset.get_idx_split() - train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"] - n = data.x.shape[0] - data.train_mask = index_to_mask(train_idx, n) - data.val_mask = index_to_mask(valid_idx, n) - data.test_mask = index_to_mask(test_idx, n) - data.y = data.y.squeeze() - # data.edge_index = to_undirected(data.edge_index, data.num_nodes) - -def index_to_mask(index, size): - mask = torch.zeros((size, ), dtype=torch.bool) - mask[index] = 1 - return mask - -def add_feature_noise(data, noise_ratio, seed): - np.random.seed(seed) - n, d = data.x.shape - # noise = torch.normal(mean=torch.zeros(int(noise_ratio*n), d), std=1) - noise = torch.FloatTensor(np.random.normal(0, 1, size=(int(noise_ratio*n), d))).to(data.x.device) - indices = np.arange(n) - indices = np.random.permutation(indices)[: int(noise_ratio*n)] - delta_feat = torch.zeros_like(data.x) - delta_feat[indices] = noise - data.x[indices] - data.x[indices] = noise - mask = np.zeros(n) - mask[indices] = 1 - mask = torch.tensor(mask).bool().to(data.x.device) - return delta_feat, mask - -def add_feature_noise_test(data, noise_ratio, seed): - np.random.seed(seed) - n, d = data.x.shape - indices = np.arange(n) - test_nodes = indices[data.test_mask.cpu()] - selected = np.random.permutation(test_nodes)[: int(noise_ratio*len(test_nodes))] - noise = torch.FloatTensor(np.random.normal(0, 1, size=(int(noise_ratio*len(test_nodes)), d))) - noise = noise.to(data.x.device) - - delta_feat = torch.zeros_like(data.x) - delta_feat[selected] = noise - data.x[selected] - data.x[selected] = noise - # mask = np.zeros(len(test_nodes)) - mask = np.zeros(n) - mask[selected] = 1 - mask = torch.tensor(mask).bool().to(data.x.device) - return delta_feat, mask - -# def check_path(file_path): -# if not osp.exists(file_path): -# os.system(f'mkdir -p {file_path}') \ No newline at end of file From 9367fa50627c40ea55f6ecd79255fee45cf6785d Mon Sep 17 00:00:00 2001 From: abhhfcgjk Date: Wed, 16 Oct 2024 20:59:21 +0300 Subject: [PATCH 11/13] rm rgcn --- src/defense/RGCN/rgcn.py | 324 ---------------- src/defense/RGCN/utils.py | 777 -------------------------------------- 2 files changed, 1101 deletions(-) delete mode 100644 src/defense/RGCN/rgcn.py delete mode 100644 src/defense/RGCN/utils.py diff --git a/src/defense/RGCN/rgcn.py b/src/defense/RGCN/rgcn.py deleted file mode 100644 index 8ec8206..0000000 --- a/src/defense/RGCN/rgcn.py +++ /dev/null @@ -1,324 +0,0 @@ - -# import torch.nn.functional as F -# import math -# import torch -# from torch.nn.parameter import Parameter -# from torch.nn.modules.module import Module -# from torch.distributions.multivariate_normal import MultivariateNormal -# # from deeprobust.graph import utils -# import defense.RGCN.utils as utils -# import torch.optim as optim -# from copy import deepcopy - -# # TODO sparse implementation - -# class GGCL_F(Module): -# """Graph Gaussian Convolution Layer (GGCL) when the input is feature""" - -# def __init__(self, in_features, out_features, dropout=0.6): -# super(GGCL_F, self).__init__() -# self.in_features = in_features -# self.out_features = out_features -# self.dropout = dropout -# self.weight_miu = Parameter(torch.FloatTensor(in_features, out_features)) -# self.weight_sigma = Parameter(torch.FloatTensor(in_features, out_features)) -# self.reset_parameters() - -# def reset_parameters(self): -# torch.nn.init.xavier_uniform_(self.weight_miu) -# torch.nn.init.xavier_uniform_(self.weight_sigma) - -# def forward(self, features, adj_norm1, adj_norm2, gamma=1): -# features = F.dropout(features, self.dropout, training=self.training) -# self.miu = F.elu(torch.mm(features, self.weight_miu)) -# self.sigma = F.relu(torch.mm(features, self.weight_sigma)) -# # torch.mm(previous_sigma, self.weight_sigma) -# Att = torch.exp(-gamma * self.sigma) -# miu_out = adj_norm1 @ (self.miu * Att) -# sigma_out = adj_norm2 @ (self.sigma * Att * Att) -# return miu_out, sigma_out - -# class GGCL_D(Module): - -# """Graph Gaussian Convolution Layer (GGCL) when the input is distribution""" -# def __init__(self, in_features, out_features, dropout): -# super(GGCL_D, self).__init__() -# self.in_features = in_features -# self.out_features = out_features -# self.dropout = dropout -# self.weight_miu = Parameter(torch.FloatTensor(in_features, out_features)) -# self.weight_sigma = Parameter(torch.FloatTensor(in_features, out_features)) -# # self.register_parameter('bias', None) -# self.reset_parameters() - -# def reset_parameters(self): -# torch.nn.init.xavier_uniform_(self.weight_miu) -# torch.nn.init.xavier_uniform_(self.weight_sigma) - -# def forward(self, miu, sigma, adj_norm1, adj_norm2, gamma=1): -# miu = F.dropout(miu, self.dropout, training=self.training) -# sigma = F.dropout(sigma, self.dropout, training=self.training) -# miu = F.elu(miu @ self.weight_miu) -# sigma = F.relu(sigma @ self.weight_sigma) - -# Att = torch.exp(-gamma * sigma) -# mean_out = adj_norm1 @ (miu * Att) -# sigma_out = adj_norm2 @ (sigma * Att * Att) -# return mean_out, sigma_out - - -# class GaussianConvolution(Module): -# """[Deprecated] Alternative gaussion convolution layer. -# """ - -# def __init__(self, in_features, out_features): -# super(GaussianConvolution, self).__init__() -# self.in_features = in_features -# self.out_features = out_features -# self.weight_miu = Parameter(torch.FloatTensor(in_features, out_features)) -# self.weight_sigma = Parameter(torch.FloatTensor(in_features, out_features)) -# # self.sigma = Parameter(torch.FloatTensor(out_features)) -# # self.register_parameter('bias', None) -# self.reset_parameters() - -# def reset_parameters(self): -# # TODO -# torch.nn.init.xavier_uniform_(self.weight_miu) -# torch.nn.init.xavier_uniform_(self.weight_sigma) - -# def forward(self, previous_miu, previous_sigma, adj_norm1=None, adj_norm2=None, gamma=1): - -# if adj_norm1 is None and adj_norm2 is None: -# return torch.mm(previous_miu, self.weight_miu), \ -# torch.mm(previous_miu, self.weight_miu) -# # torch.mm(previous_sigma, self.weight_sigma) - -# Att = torch.exp(-gamma * previous_sigma) -# M = adj_norm1 @ (previous_miu * Att) @ self.weight_miu -# Sigma = adj_norm2 @ (previous_sigma * Att * Att) @ self.weight_sigma -# return M, Sigma - -# # M = torch.mm(torch.mm(adj, previous_miu * A), self.weight_miu) -# # Sigma = torch.mm(torch.mm(adj, previous_sigma * A * A), self.weight_sigma) - -# # TODO sparse implemention -# # support = torch.mm(input, self.weight) -# # output = torch.spmm(adj, support) -# # return output + self.bias - -# def __repr__(self): -# return self.__class__.__name__ + ' (' \ -# + str(self.in_features) + ' -> ' \ -# + str(self.out_features) + ')' - - -# class RGCN(Module): -# """Robust Graph Convolutional Networks Against Adversarial Attacks. KDD 2019. - -# Parameters -# ---------- -# nnodes : int -# number of nodes in the input grpah -# nfeat : int -# size of input feature dimension -# nhid : int -# number of hidden units -# nclass : int -# size of output dimension -# gamma : float -# hyper-parameter for RGCN. See more details in the paper. -# beta1 : float -# hyper-parameter for RGCN. See more details in the paper. -# beta2 : float -# hyper-parameter for RGCN. See more details in the paper. -# lr : float -# learning rate for GCN -# dropout : float -# dropout rate for GCN -# device: str -# 'cpu' or 'cuda'. - -# """ - -# def __init__(self, nnodes, nfeat, nhid, nclass, gamma=1.0, beta1=5e-4, beta2=5e-4, lr=0.01, dropout=0.6, device='cpu'): -# super(RGCN, self).__init__() - -# self.device = device -# # adj_norm = normalize(adj) -# # first turn original features to distribution -# self.lr = lr -# self.gamma = gamma -# self.beta1 = beta1 -# self.beta2 = beta2 -# self.nclass = nclass -# self.nhid = nhid // 2 -# # self.gc1 = GaussianConvolution(nfeat, nhid, dropout=dropout) -# # self.gc2 = GaussianConvolution(nhid, nclass, dropout) -# self.gc1 = GGCL_F(nfeat, nhid, dropout=dropout) -# self.gc2 = GGCL_D(nhid, nclass, dropout=dropout) - -# self.dropout = dropout -# # self.gaussian = MultivariateNormal(torch.zeros(self.nclass), torch.eye(self.nclass)) -# self.gaussian = MultivariateNormal(torch.zeros(nnodes, self.nclass), -# torch.diag_embed(torch.ones(nnodes, self.nclass))) -# self.adj_norm1, self.adj_norm2 = None, None -# self.features, self.labels = None, None - -# def forward(self): -# features = self.features -# miu, sigma = self.gc1(features, self.adj_norm1, self.adj_norm2, self.gamma) -# miu, sigma = self.gc2(miu, sigma, self.adj_norm1, self.adj_norm2, self.gamma) -# output = miu + self.gaussian.sample().to(self.device) * torch.sqrt(sigma + 1e-8) -# return F.log_softmax(output, dim=1) - -# def fit(self, features, adj, labels, idx_train, idx_val=None, train_iters=200, verbose=True, **kwargs): -# """Train RGCN. - -# Parameters -# ---------- -# features : -# node features -# adj : -# the adjacency matrix. The format could be torch.tensor or scipy matrix -# labels : -# node labels -# idx_train : -# node training indices -# idx_val : -# node validation indices. If not given (None), GCN training process will not adpot early stopping -# train_iters : int -# number of training epochs -# verbose : bool -# whether to show verbose logs - -# Examples -# -------- -# We can first load dataset and then train RGCN. - -# >>> from deeprobust.graph.data import PrePtbDataset, Dataset -# >>> from deeprobust.graph.defense import RGCN -# >>> # load clean graph data -# >>> data = Dataset(root='/tmp/', name='cora', seed=15) -# >>> adj, features, labels = data.adj, data.features, data.labels -# >>> idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test -# >>> # load perturbed graph data -# >>> perturbed_data = PrePtbDataset(root='/tmp/', name='cora') -# >>> perturbed_adj = perturbed_data.adj -# >>> # train defense model -# >>> model = RGCN(nnodes=perturbed_adj.shape[0], nfeat=features.shape[1], -# nclass=labels.max()+1, nhid=32, device='cpu') -# >>> model.fit(features, perturbed_adj, labels, idx_train, idx_val, -# train_iters=200, verbose=True) -# >>> model.test(idx_test) - -# """ - -# adj, features, labels = utils.to_tensor(adj.todense(), features.todense(), labels, device=self.device) - -# self.features, self.labels = features, labels -# self.adj_norm1 = self._normalize_adj(adj, power=-1/2) -# self.adj_norm2 = self._normalize_adj(adj, power=-1) -# print('=== training rgcn model ===') -# self._initialize() -# if idx_val is None: -# self._train_without_val(labels, idx_train, train_iters, verbose) -# else: -# self._train_with_val(labels, idx_train, idx_val, train_iters, verbose) - -# def _train_without_val(self, labels, idx_train, train_iters, verbose=True): -# optimizer = optim.Adam(self.parameters(), lr=self.lr) -# self.train() -# for i in range(train_iters): -# optimizer.zero_grad() -# output = self.forward() -# loss_train = self._loss(output[idx_train], labels[idx_train]) -# loss_train.backward() -# optimizer.step() -# if verbose and i % 10 == 0: -# print('Epoch {}, training loss: {}'.format(i, loss_train.item())) - -# self.eval() -# output = self.forward() -# self.output = output - -# def _train_with_val(self, labels, idx_train, idx_val, train_iters, verbose): -# optimizer = optim.Adam(self.parameters(), lr=self.lr) - -# best_loss_val = 100 -# best_acc_val = 0 - -# for i in range(train_iters): -# self.train() -# optimizer.zero_grad() -# output = self.forward() -# loss_train = self._loss(output[idx_train], labels[idx_train]) -# loss_train.backward() -# optimizer.step() -# if verbose and i % 10 == 0: -# print('Epoch {}, training loss: {}'.format(i, loss_train.item())) - -# self.eval() -# output = self.forward() -# loss_val = F.nll_loss(output[idx_val], labels[idx_val]) -# acc_val = utils.accuracy(output[idx_val], labels[idx_val]) - -# if best_loss_val > loss_val: -# best_loss_val = loss_val -# self.output = output - -# if acc_val > best_acc_val: -# best_acc_val = acc_val -# self.output = output - -# print('=== picking the best model according to the performance on validation ===') - - -# def test(self, idx_test): -# """Evaluate the peformance on test set -# """ -# self.eval() -# # output = self.forward() -# output = self.output -# loss_test = F.nll_loss(output[idx_test], self.labels[idx_test]) -# acc_test = utils.accuracy(output[idx_test], self.labels[idx_test]) -# print("Test set results:", -# "loss= {:.4f}".format(loss_test.item()), -# "accuracy= {:.4f}".format(acc_test.item())) -# return acc_test.item() - -# def predict(self): -# """ -# Returns -# ------- -# torch.FloatTensor -# output (log probabilities) of RGCN -# """ - -# self.eval() -# return self.forward() - -# def _loss(self, input, labels): -# loss = F.nll_loss(input, labels) -# miu1 = self.gc1.miu -# sigma1 = self.gc1.sigma -# kl_loss = 0.5 * (miu1.pow(2) + sigma1 - torch.log(1e-8 + sigma1)).mean(1) -# kl_loss = kl_loss.sum() -# norm2 = torch.norm(self.gc1.weight_miu, 2).pow(2) + \ -# torch.norm(self.gc1.weight_sigma, 2).pow(2) - -# # print(f'gcn_loss: {loss.item()}, kl_loss: {self.beta1 * kl_loss.item()}, norm2: {self.beta2 * norm2.item()}') -# return loss + self.beta1 * kl_loss + self.beta2 * norm2 - -# def _initialize(self): -# self.gc1.reset_parameters() -# self.gc2.reset_parameters() - -# def _normalize_adj(self, adj, power=-1/2): - -# """Row-normalize sparse matrix""" -# A = adj + torch.eye(len(adj)).to(self.device) -# D_power = (A.sum(1)).pow(power) -# D_power[torch.isinf(D_power)] = 0. -# D_power = torch.diag(D_power) -# return D_power @ A @ D_power - diff --git a/src/defense/RGCN/utils.py b/src/defense/RGCN/utils.py deleted file mode 100644 index f7be00a..0000000 --- a/src/defense/RGCN/utils.py +++ /dev/null @@ -1,777 +0,0 @@ -import numpy as np -import scipy.sparse as sp -import torch -from sklearn.model_selection import train_test_split -import torch.sparse as ts -import torch.nn.functional as F -import warnings - -def encode_onehot(labels): - """Convert label to onehot format. - - Parameters - ---------- - labels : numpy.array - node labels - - Returns - ------- - numpy.array - onehot labels - """ - eye = np.eye(labels.max() + 1) - onehot_mx = eye[labels] - return onehot_mx - -def tensor2onehot(labels): - """Convert label tensor to label onehot tensor. - - Parameters - ---------- - labels : torch.LongTensor - node labels - - Returns - ------- - torch.LongTensor - onehot labels tensor - - """ - - eye = torch.eye(labels.max() + 1) - onehot_mx = eye[labels] - return onehot_mx.to(labels.device) - -def preprocess(adj, features, labels, preprocess_adj=False, preprocess_feature=False, sparse=False, device='cpu'): - """Convert adj, features, labels from array or sparse matrix to - torch Tensor, and normalize the input data. - - Parameters - ---------- - adj : scipy.sparse.csr_matrix - the adjacency matrix. - features : scipy.sparse.csr_matrix - node features - labels : numpy.array - node labels - preprocess_adj : bool - whether to normalize the adjacency matrix - preprocess_feature : bool - whether to normalize the feature matrix - sparse : bool - whether to return sparse tensor - device : str - 'cpu' or 'cuda' - """ - - if preprocess_adj: - adj = normalize_adj(adj) - - if preprocess_feature: - features = normalize_feature(features) - - labels = torch.LongTensor(labels) - if sparse: - adj = sparse_mx_to_torch_sparse_tensor(adj) - features = sparse_mx_to_torch_sparse_tensor(features) - else: - if sp.issparse(features): - features = torch.FloatTensor(np.array(features.todense())) - else: - features = torch.FloatTensor(features) - adj = torch.FloatTensor(adj.todense()) - return adj.to(device), features.to(device), labels.to(device) - -def to_tensor(adj, features, labels=None, device='cpu'): - """Convert adj, features, labels from array or sparse matrix to - torch Tensor. - - Parameters - ---------- - adj : scipy.sparse.csr_matrix - the adjacency matrix. - features : scipy.sparse.csr_matrix - node features - labels : numpy.array - node labels - device : str - 'cpu' or 'cuda' - """ - if sp.issparse(adj): - adj = sparse_mx_to_torch_sparse_tensor(adj) - else: - adj = torch.FloatTensor(adj) - if sp.issparse(features): - features = sparse_mx_to_torch_sparse_tensor(features) - else: - features = torch.FloatTensor(np.array(features)) - - if labels is None: - return adj.to(device), features.to(device) - else: - labels = torch.LongTensor(labels) - return adj.to(device), features.to(device), labels.to(device) - -def normalize_feature(mx): - """Row-normalize sparse matrix or dense matrix - - Parameters - ---------- - mx : scipy.sparse.csr_matrix or numpy.array - matrix to be normalized - - Returns - ------- - scipy.sprase.lil_matrix - normalized matrix - """ - if type(mx) is not sp.lil.lil_matrix: - try: - mx = mx.tolil() - except AttributeError: - pass - rowsum = np.array(mx.sum(1)) - r_inv = np.power(rowsum, -1).flatten() - r_inv[np.isinf(r_inv)] = 0. - r_mat_inv = sp.diags(r_inv) - mx = r_mat_inv.dot(mx) - return mx - -def normalize_adj(mx): - """Normalize sparse adjacency matrix, - A' = (D + I)^-1/2 * ( A + I ) * (D + I)^-1/2 - Row-normalize sparse matrix - - Parameters - ---------- - mx : scipy.sparse.csr_matrix - matrix to be normalized - - Returns - ------- - scipy.sprase.lil_matrix - normalized matrix - """ - - # TODO: maybe using coo format would be better? - if type(mx) is not sp.lil.lil_matrix: - mx = mx.tolil() - if mx[0, 0] == 0 : - mx = mx + sp.eye(mx.shape[0]) - rowsum = np.array(mx.sum(1)) - r_inv = np.power(rowsum, -1/2).flatten() - r_inv[np.isinf(r_inv)] = 0. - r_mat_inv = sp.diags(r_inv) - mx = r_mat_inv.dot(mx) - mx = mx.dot(r_mat_inv) - return mx - -def normalize_sparse_tensor(adj, fill_value=1): - """Normalize sparse tensor. Need to import torch_scatter - """ - edge_index = adj._indices() - edge_weight = adj._values() - num_nodes= adj.size(0) - edge_index, edge_weight = add_self_loops( - edge_index, edge_weight, fill_value, num_nodes) - - row, col = edge_index - from torch_scatter import scatter_add - deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) - deg_inv_sqrt = deg.pow(-0.5) - deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 - - values = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] - - shape = adj.shape - return torch.sparse.FloatTensor(edge_index, values, shape) - -def add_self_loops(edge_index, edge_weight=None, fill_value=1, num_nodes=None): - # num_nodes = maybe_num_nodes(edge_index, num_nodes) - - loop_index = torch.arange(0, num_nodes, dtype=torch.long, - device=edge_index.device) - loop_index = loop_index.unsqueeze(0).repeat(2, 1) - - if edge_weight is not None: - assert edge_weight.numel() == edge_index.size(1) - loop_weight = edge_weight.new_full((num_nodes, ), fill_value) - edge_weight = torch.cat([edge_weight, loop_weight], dim=0) - - edge_index = torch.cat([edge_index, loop_index], dim=1) - - return edge_index, edge_weight - -def normalize_adj_tensor(adj, sparse=False): - """Normalize adjacency tensor matrix. - """ - device = adj.device - if sparse: - # warnings.warn('If you find the training process is too slow, you can uncomment line 207 in deeprobust/graph/utils.py. Note that you need to install torch_sparse') - # TODO if this is too slow, uncomment the following code, - # but you need to install torch_scatter - # return normalize_sparse_tensor(adj) - adj = to_scipy(adj) - mx = normalize_adj(adj) - return sparse_mx_to_torch_sparse_tensor(mx).to(device) - else: - mx = adj + torch.eye(adj.shape[0]).to(device) - rowsum = mx.sum(1) - r_inv = rowsum.pow(-1/2).flatten() - r_inv[torch.isinf(r_inv)] = 0. - r_mat_inv = torch.diag(r_inv) - mx = r_mat_inv @ mx - mx = mx @ r_mat_inv - return mx - -def degree_normalize_adj(mx): - """Row-normalize sparse matrix""" - mx = mx.tolil() - if mx[0, 0] == 0 : - mx = mx + sp.eye(mx.shape[0]) - rowsum = np.array(mx.sum(1)) - r_inv = np.power(rowsum, -1).flatten() - r_inv[np.isinf(r_inv)] = 0. - r_mat_inv = sp.diags(r_inv) - # mx = mx.dot(r_mat_inv) - mx = r_mat_inv.dot(mx) - return mx - -def degree_normalize_sparse_tensor(adj, fill_value=1): - """degree_normalize_sparse_tensor. - """ - edge_index = adj._indices() - edge_weight = adj._values() - num_nodes= adj.size(0) - - edge_index, edge_weight = add_self_loops( - edge_index, edge_weight, fill_value, num_nodes) - - row, col = edge_index - from torch_scatter import scatter_add - deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) - deg_inv_sqrt = deg.pow(-1) - deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 - - values = deg_inv_sqrt[row] * edge_weight - shape = adj.shape - return torch.sparse.FloatTensor(edge_index, values, shape) - -def degree_normalize_adj_tensor(adj, sparse=True): - """degree_normalize_adj_tensor. - """ - - device = adj.device - if sparse: - # return degree_normalize_sparse_tensor(adj) - adj = to_scipy(adj) - mx = degree_normalize_adj(adj) - return sparse_mx_to_torch_sparse_tensor(mx).to(device) - else: - mx = adj + torch.eye(adj.shape[0]).to(device) - rowsum = mx.sum(1) - r_inv = rowsum.pow(-1).flatten() - r_inv[torch.isinf(r_inv)] = 0. - r_mat_inv = torch.diag(r_inv) - mx = r_mat_inv @ mx - return mx - -def accuracy(output, labels): - """Return accuracy of output compared to labels. - - Parameters - ---------- - output : torch.Tensor - output from model - labels : torch.Tensor or numpy.array - node labels - - Returns - ------- - float - accuracy - """ - if not hasattr(labels, '__len__'): - labels = [labels] - if type(labels) is not torch.Tensor: - labels = torch.LongTensor(labels) - preds = output.max(1)[1].type_as(labels) - correct = preds.eq(labels).double() - correct = correct.sum() - return correct / len(labels) - -def loss_acc(output, labels, targets, avg_loss=True): - if type(labels) is not torch.Tensor: - labels = torch.LongTensor(labels) - preds = output.max(1)[1].type_as(labels) - correct = preds.eq(labels).double()[targets] - loss = F.nll_loss(output[targets], labels[targets], reduction='mean' if avg_loss else 'none') - - if avg_loss: - return loss, correct.sum() / len(targets) - return loss, correct - # correct = correct.sum() - # return loss, correct / len(labels) - -def get_perf(output, labels, mask, verbose=True): - """evalute performance for test masked data""" - loss = F.nll_loss(output[mask], labels[mask]) - acc = accuracy(output[mask], labels[mask]) - if verbose: - print("loss= {:.4f}".format(loss.item()), - "accuracy= {:.4f}".format(acc.item())) - return loss.item(), acc.item() - - -def classification_margin(output, true_label): - """Calculate classification margin for outputs. - `probs_true_label - probs_best_second_class` - - Parameters - ---------- - output: torch.Tensor - output vector (1 dimension) - true_label: int - true label for this node - - Returns - ------- - list - classification margin for this node - """ - - probs = torch.exp(output) - probs_true_label = probs[true_label].clone() - probs[true_label] = 0 - probs_best_second_class = probs[probs.argmax()] - return (probs_true_label - probs_best_second_class).item() - -def sparse_mx_to_torch_sparse_tensor(sparse_mx): - """Convert a scipy sparse matrix to a torch sparse tensor.""" - sparse_mx = sparse_mx.tocoo().astype(np.float32) - sparserow=torch.LongTensor(sparse_mx.row).unsqueeze(1) - sparsecol=torch.LongTensor(sparse_mx.col).unsqueeze(1) - sparseconcat=torch.cat((sparserow, sparsecol),1) - sparsedata=torch.FloatTensor(sparse_mx.data) - return torch.sparse.FloatTensor(sparseconcat.t(),sparsedata,torch.Size(sparse_mx.shape)) - - # slower version.... - # sparse_mx = sparse_mx.tocoo().astype(np.float32) - # indices = torch.from_numpy( - # np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) - # values = torch.from_numpy(sparse_mx.data) - # shape = torch.Size(sparse_mx.shape) - # return torch.sparse.FloatTensor(indices, values, shape) - - - -def to_scipy(tensor): - """Convert a dense/sparse tensor to scipy matrix""" - if is_sparse_tensor(tensor): - values = tensor._values() - indices = tensor._indices() - return sp.csr_matrix((values.cpu().numpy(), indices.cpu().numpy()), shape=tensor.shape) - else: - indices = tensor.nonzero().t() - values = tensor[indices[0], indices[1]] - return sp.csr_matrix((values.cpu().numpy(), indices.cpu().numpy()), shape=tensor.shape) - -def is_sparse_tensor(tensor): - """Check if a tensor is sparse tensor. - - Parameters - ---------- - tensor : torch.Tensor - given tensor - - Returns - ------- - bool - whether a tensor is sparse tensor - """ - # if hasattr(tensor, 'nnz'): - if tensor.layout == torch.sparse_coo: - return True - else: - return False - -def get_train_val_test(nnodes, val_size=0.1, test_size=0.8, stratify=None, seed=None): - """This setting follows nettack/mettack, where we split the nodes - into 10% training, 10% validation and 80% testing data - - Parameters - ---------- - nnodes : int - number of nodes in total - val_size : float - size of validation set - test_size : float - size of test set - stratify : - data is expected to split in a stratified fashion. So stratify should be labels. - seed : int or None - random seed - - Returns - ------- - idx_train : - node training indices - idx_val : - node validation indices - idx_test : - node test indices - """ - - assert stratify is not None, 'stratify cannot be None!' - - if seed is not None: - np.random.seed(seed) - - idx = np.arange(nnodes) - train_size = 1 - val_size - test_size - idx_train_and_val, idx_test = train_test_split(idx, - random_state=None, - train_size=train_size + val_size, - test_size=test_size, - stratify=stratify) - - if stratify is not None: - stratify = stratify[idx_train_and_val] - - idx_train, idx_val = train_test_split(idx_train_and_val, - random_state=None, - train_size=(train_size / (train_size + val_size)), - test_size=(val_size / (train_size + val_size)), - stratify=stratify) - - return idx_train, idx_val, idx_test - -def get_train_test(nnodes, test_size=0.8, stratify=None, seed=None): - """This function returns training and test set without validation. - It can be used for settings of different label rates. - - Parameters - ---------- - nnodes : int - number of nodes in total - test_size : float - size of test set - stratify : - data is expected to split in a stratified fashion. So stratify should be labels. - seed : int or None - random seed - - Returns - ------- - idx_train : - node training indices - idx_test : - node test indices - """ - assert stratify is not None, 'stratify cannot be None!' - - if seed is not None: - np.random.seed(seed) - - idx = np.arange(nnodes) - train_size = 1 - test_size - idx_train, idx_test = train_test_split(idx, random_state=None, - train_size=train_size, - test_size=test_size, - stratify=stratify) - - return idx_train, idx_test - -def get_train_val_test_gcn(labels, seed=None): - """This setting follows gcn, where we randomly sample 20 instances for each class - as training data, 500 instances as validation data, 1000 instances as test data. - Note here we are not using fixed splits. When random seed changes, the splits - will also change. - - Parameters - ---------- - labels : numpy.array - node labels - seed : int or None - random seed - - Returns - ------- - idx_train : - node training indices - idx_val : - node validation indices - idx_test : - node test indices - """ - if seed is not None: - np.random.seed(seed) - - idx = np.arange(len(labels)) - nclass = labels.max() + 1 - idx_train = [] - idx_unlabeled = [] - for i in range(nclass): - labels_i = idx[labels==i] - labels_i = np.random.permutation(labels_i) - idx_train = np.hstack((idx_train, labels_i[: 20])).astype(np.int) - idx_unlabeled = np.hstack((idx_unlabeled, labels_i[20: ])).astype(np.int) - - idx_unlabeled = np.random.permutation(idx_unlabeled) - idx_val = idx_unlabeled[: 500] - idx_test = idx_unlabeled[500: 1500] - return idx_train, idx_val, idx_test - -def get_train_test_labelrate(labels, label_rate): - """Get train test according to given label rate. - """ - nclass = labels.max() + 1 - train_size = int(round(len(labels) * label_rate / nclass)) - print("=== train_size = %s ===" % train_size) - idx_train, idx_val, idx_test = get_splits_each_class(labels, train_size=train_size) - return idx_train, idx_test - -def get_splits_each_class(labels, train_size): - """We randomly sample n instances for class, where n = train_size. - """ - idx = np.arange(len(labels)) - nclass = labels.max() + 1 - idx_train = [] - idx_val = [] - idx_test = [] - for i in range(nclass): - labels_i = idx[labels==i] - labels_i = np.random.permutation(labels_i) - idx_train = np.hstack((idx_train, labels_i[: train_size])).astype(np.int) - idx_val = np.hstack((idx_val, labels_i[train_size: 2*train_size])).astype(np.int) - idx_test = np.hstack((idx_test, labels_i[2*train_size: ])).astype(np.int) - - return np.random.permutation(idx_train), np.random.permutation(idx_val), \ - np.random.permutation(idx_test) - - -def unravel_index(index, array_shape): - rows = torch.div(index, array_shape[1], rounding_mode='trunc') - cols = index % array_shape[1] - return rows, cols - - -def get_degree_squence(adj): - try: - return adj.sum(0) - except: - return ts.sum(adj, dim=1).to_dense() - -def likelihood_ratio_filter(node_pairs, modified_adjacency, original_adjacency, d_min, threshold=0.004, undirected=True): - """ - Filter the input node pairs based on the likelihood ratio test proposed by Zügner et al. 2018, see - https://dl.acm.org/citation.cfm?id=3220078. In essence, for each node pair return 1 if adding/removing the edge - between the two nodes does not violate the unnoticeability constraint, and return 0 otherwise. Assumes unweighted - and undirected graphs. - """ - - N = int(modified_adjacency.shape[0]) - # original_degree_sequence = get_degree_squence(original_adjacency) - # current_degree_sequence = get_degree_squence(modified_adjacency) - original_degree_sequence = original_adjacency.sum(0) - current_degree_sequence = modified_adjacency.sum(0) - - concat_degree_sequence = torch.cat((current_degree_sequence, original_degree_sequence)) - - # Compute the log likelihood values of the original, modified, and combined degree sequences. - ll_orig, alpha_orig, n_orig, sum_log_degrees_original = degree_sequence_log_likelihood(original_degree_sequence, d_min) - ll_current, alpha_current, n_current, sum_log_degrees_current = degree_sequence_log_likelihood(current_degree_sequence, d_min) - - ll_comb, alpha_comb, n_comb, sum_log_degrees_combined = degree_sequence_log_likelihood(concat_degree_sequence, d_min) - - # Compute the log likelihood ratio - current_ratio = -2 * ll_comb + 2 * (ll_orig + ll_current) - - # Compute new log likelihood values that would arise if we add/remove the edges corresponding to each node pair. - new_lls, new_alphas, new_ns, new_sum_log_degrees = updated_log_likelihood_for_edge_changes(node_pairs, - modified_adjacency, d_min) - - # Combination of the original degree distribution with the distributions corresponding to each node pair. - n_combined = n_orig + new_ns - new_sum_log_degrees_combined = sum_log_degrees_original + new_sum_log_degrees - alpha_combined = compute_alpha(n_combined, new_sum_log_degrees_combined, d_min) - new_ll_combined = compute_log_likelihood(n_combined, alpha_combined, new_sum_log_degrees_combined, d_min) - new_ratios = -2 * new_ll_combined + 2 * (new_lls + ll_orig) - - # Allowed edges are only those for which the resulting likelihood ratio measure is < than the threshold - allowed_edges = new_ratios < threshold - - if allowed_edges.is_cuda: - filtered_edges = node_pairs[allowed_edges.cpu().numpy().astype(np.bool)] - else: - filtered_edges = node_pairs[allowed_edges.numpy().astype(np.bool)] - - allowed_mask = torch.zeros(modified_adjacency.shape) - allowed_mask[filtered_edges.T] = 1 - if undirected: - allowed_mask += allowed_mask.t() - return allowed_mask, current_ratio - - -def degree_sequence_log_likelihood(degree_sequence, d_min): - """ - Compute the (maximum) log likelihood of the Powerlaw distribution fit on a degree distribution. - """ - - # Determine which degrees are to be considered, i.e. >= d_min. - D_G = degree_sequence[(degree_sequence >= d_min.item())] - try: - sum_log_degrees = torch.log(D_G).sum() - except: - sum_log_degrees = np.log(D_G).sum() - n = len(D_G) - - alpha = compute_alpha(n, sum_log_degrees, d_min) - ll = compute_log_likelihood(n, alpha, sum_log_degrees, d_min) - return ll, alpha, n, sum_log_degrees - -def updated_log_likelihood_for_edge_changes(node_pairs, adjacency_matrix, d_min): - """ Adopted from https://github.com/danielzuegner/nettack - """ - # For each node pair find out whether there is an edge or not in the input adjacency matrix. - - edge_entries_before = adjacency_matrix[node_pairs.T] - degree_sequence = adjacency_matrix.sum(1) - D_G = degree_sequence[degree_sequence >= d_min.item()] - sum_log_degrees = torch.log(D_G).sum() - n = len(D_G) - deltas = -2 * edge_entries_before + 1 - d_edges_before = degree_sequence[node_pairs] - - d_edges_after = degree_sequence[node_pairs] + deltas[:, None] - - # Sum the log of the degrees after the potential changes which are >= d_min - sum_log_degrees_after, new_n = update_sum_log_degrees(sum_log_degrees, n, d_edges_before, d_edges_after, d_min) - # Updated estimates of the Powerlaw exponents - new_alpha = compute_alpha(new_n, sum_log_degrees_after, d_min) - # Updated log likelihood values for the Powerlaw distributions - new_ll = compute_log_likelihood(new_n, new_alpha, sum_log_degrees_after, d_min) - return new_ll, new_alpha, new_n, sum_log_degrees_after - - -def update_sum_log_degrees(sum_log_degrees_before, n_old, d_old, d_new, d_min): - # Find out whether the degrees before and after the change are above the threshold d_min. - old_in_range = d_old >= d_min - new_in_range = d_new >= d_min - d_old_in_range = d_old * old_in_range.float() - d_new_in_range = d_new * new_in_range.float() - - # Update the sum by subtracting the old values and then adding the updated logs of the degrees. - sum_log_degrees_after = sum_log_degrees_before - (torch.log(torch.clamp(d_old_in_range, min=1))).sum(1) \ - + (torch.log(torch.clamp(d_new_in_range, min=1))).sum(1) - - # Update the number of degrees >= d_min - - new_n = n_old - (old_in_range!=0).sum(1) + (new_in_range!=0).sum(1) - new_n = new_n.float() - return sum_log_degrees_after, new_n - -def compute_alpha(n, sum_log_degrees, d_min): - try: - alpha = 1 + n / (sum_log_degrees - n * torch.log(d_min - 0.5)) - except: - alpha = 1 + n / (sum_log_degrees - n * np.log(d_min - 0.5)) - return alpha - -def compute_log_likelihood(n, alpha, sum_log_degrees, d_min): - # Log likelihood under alpha - try: - ll = n * torch.log(alpha) + n * alpha * torch.log(d_min) + (alpha + 1) * sum_log_degrees - except: - ll = n * np.log(alpha) + n * alpha * np.log(d_min) + (alpha + 1) * sum_log_degrees - - return ll - -def ravel_multiple_indices(ixs, shape, reverse=False): - """ - "Flattens" multiple 2D input indices into indices on the flattened matrix, similar to np.ravel_multi_index. - Does the same as ravel_index but for multiple indices at once. - Parameters - ---------- - ixs: array of ints shape (n, 2) - The array of n indices that will be flattened. - - shape: list or tuple of ints of length 2 - The shape of the corresponding matrix. - - Returns - ------- - array of n ints between 0 and shape[0]*shape[1]-1 - The indices on the flattened matrix corresponding to the 2D input indices. - - """ - if reverse: - return ixs[:, 1] * shape[1] + ixs[:, 0] - - return ixs[:, 0] * shape[1] + ixs[:, 1] - -def visualize(your_var): - """visualize computation graph""" - from graphviz import Digraph - import torch - from torch.autograd import Variable - from torchviz import make_dot - make_dot(your_var).view() - -def reshape_mx(mx, shape): - indices = mx.nonzero() - return sp.csr_matrix((mx.data, (indices[0], indices[1])), shape=shape) - -def add_mask(data, dataset): - """data: ogb-arxiv pyg data format""" - # for arxiv - split_idx = dataset.get_idx_split() - train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"] - n = data.x.shape[0] - data.train_mask = index_to_mask(train_idx, n) - data.val_mask = index_to_mask(valid_idx, n) - data.test_mask = index_to_mask(test_idx, n) - data.y = data.y.squeeze() - # data.edge_index = to_undirected(data.edge_index, data.num_nodes) - -def index_to_mask(index, size): - mask = torch.zeros((size, ), dtype=torch.bool) - mask[index] = 1 - return mask - -def add_feature_noise(data, noise_ratio, seed): - np.random.seed(seed) - n, d = data.x.shape - # noise = torch.normal(mean=torch.zeros(int(noise_ratio*n), d), std=1) - noise = torch.FloatTensor(np.random.normal(0, 1, size=(int(noise_ratio*n), d))).to(data.x.device) - indices = np.arange(n) - indices = np.random.permutation(indices)[: int(noise_ratio*n)] - delta_feat = torch.zeros_like(data.x) - delta_feat[indices] = noise - data.x[indices] - data.x[indices] = noise - mask = np.zeros(n) - mask[indices] = 1 - mask = torch.tensor(mask).bool().to(data.x.device) - return delta_feat, mask - -def add_feature_noise_test(data, noise_ratio, seed): - np.random.seed(seed) - n, d = data.x.shape - indices = np.arange(n) - test_nodes = indices[data.test_mask.cpu()] - selected = np.random.permutation(test_nodes)[: int(noise_ratio*len(test_nodes))] - noise = torch.FloatTensor(np.random.normal(0, 1, size=(int(noise_ratio*len(test_nodes)), d))) - noise = noise.to(data.x.device) - - delta_feat = torch.zeros_like(data.x) - delta_feat[selected] = noise - data.x[selected] - data.x[selected] = noise - # mask = np.zeros(len(test_nodes)) - mask = np.zeros(n) - mask[selected] = 1 - mask = torch.tensor(mask).bool().to(data.x.device) - return delta_feat, mask - -# def check_path(file_path): -# if not osp.exists(file_path): -# os.system(f'mkdir -p {file_path}') \ No newline at end of file From eb14351fc863fd8b01ef8ebb063bb84f2d4385de Mon Sep 17 00:00:00 2001 From: abhhfcgjk Date: Wed, 16 Oct 2024 21:00:27 +0300 Subject: [PATCH 12/13] attack_defense clear --- experiments/attack_defense_test.py | 81 ++---------------------------- 1 file changed, 3 insertions(+), 78 deletions(-) diff --git a/experiments/attack_defense_test.py b/experiments/attack_defense_test.py index e2fcc19..aed544b 100644 --- a/experiments/attack_defense_test.py +++ b/experiments/attack_defense_test.py @@ -2,9 +2,9 @@ import warnings -import sys -import os -sys.path.append(f"{os.getcwd()}/src") +# import sys +# import os +# sys.path.append(f"{os.getcwd()}/src") from torch import device @@ -385,81 +385,6 @@ def test_nettack_evasion(): print(f"info_after_evasion_attack: {info_after_evasion_attack}") -def test_gnnguard(): - # from attacks.poison_attacks_collection.metattack import meta_gradient_attack - from defense.GNNGuard import gnnguard - - my_device = device('cpu') - full_name = ("single-graph", "Planetoid", 'Cora') - - dataset, data, results_dataset_path = DatasetManager.get_by_full_name( - full_name=full_name, - dataset_ver_ind=0 - ) - gnn = model_configs_zoo(dataset=dataset, model_name='gcn_gcn') - manager_config = ConfigPattern( - _config_class="ModelManagerConfig", - _config_kwargs={ - "mask_features": [], - "optimizer": { - # "_config_class": "Config", - "_class_name": "Adam", - # "_import_path": OPTIMIZERS_PARAMETERS_PATH, - # "_class_import_info": ["torch.optim"], - "_config_kwargs": {}, - } - } - ) - steps_epochs = 200 - gnn_model_manager = FrameworkGNNModelManager( - gnn=gnn, - dataset_path=results_dataset_path, - manager_config=manager_config, - modification=ModelModificationConfig(model_ver_ind=0, epochs=steps_epochs) - ) - save_model_flag = False - gnn_model_manager.gnn.to(my_device) - data = data.to(my_device) - print(type(data)) - poison_defense_config = ConfigPattern( - _class_name="GNNGuard", - _import_path=POISON_DEFENSE_PARAMETERS_PATH, - _config_class="PoisonDefenseConfig", - _config_kwargs={ - # "num_nodes": dataset.dataset.x.shape[0] - } - ) - from defense.poison_defense import PoisonDefender - from src.aux.utils import all_subclasses - print([e.name for e in all_subclasses(PoisonDefender)]) - gnn_model_manager.set_poison_defender(poison_defense_config=poison_defense_config) - - warnings.warn("Start training") - dataset.train_test_split(percent_train_class=0.1) - - try: - raise FileNotFoundError() - # gnn_model_manager.load_model_executor() - except FileNotFoundError: - gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0 - train_test_split_path = gnn_model_manager.train_model(gen_dataset=dataset, steps=steps_epochs, - save_model_flag=save_model_flag, - metrics=[Metric("F1", mask='train', average=None)]) - - if train_test_split_path is not None: - dataset.save_train_test_mask(train_test_split_path) - train_mask, val_mask, test_mask, train_test_sizes = torch.load(train_test_split_path / 'train_test_split')[ - :] - dataset.train_mask, dataset.val_mask, dataset.test_mask = train_mask, val_mask, test_mask - data.percent_train_class, data.percent_test_class = train_test_sizes - - warnings.warn("Training was successful") - - metric_loc = gnn_model_manager.evaluate_model( - gen_dataset=dataset, metrics=[Metric("F1", mask='test', average='macro'), - Metric("Accuracy", mask='test')]) - print(metric_loc) - def test_adv_training(): from defense.evasion_defense import AdvTraining From 981e05c885541b5728d9d1e53e355f0c7f799c10 Mon Sep 17 00:00:00 2001 From: abhhfcgjk Date: Thu, 17 Oct 2024 11:13:41 +0300 Subject: [PATCH 13/13] review --- experiments/attack_defense_test.py | 11 +++-------- tests/attacks_test.py | 4 ---- tests/explainers_test.py | 3 --- 3 files changed, 3 insertions(+), 15 deletions(-) diff --git a/experiments/attack_defense_test.py b/experiments/attack_defense_test.py index aed544b..11e8355 100644 --- a/experiments/attack_defense_test.py +++ b/experiments/attack_defense_test.py @@ -2,9 +2,6 @@ import warnings -# import sys -# import os -# sys.path.append(f"{os.getcwd()}/src") from torch import device @@ -14,14 +11,12 @@ from src.aux.configs import ModelModificationConfig, ConfigPattern from src.base.datasets_processing import DatasetManager from src.models_builder.models_zoo import model_configs_zoo - +from attacks.QAttack import qattack def test_attack_defense(): - from attacks.QAttack import qattack - # from attacks.poison_attacks_collection.metattack import meta_gradient_attack - # my_device = device('cuda' if is_available() else 'cpu') - my_device = device('cpu') + + my_device = device('cuda' if torch.cuda.is_available() else 'cpu') full_name = None diff --git a/tests/attacks_test.py b/tests/attacks_test.py index 252f63a..acee652 100644 --- a/tests/attacks_test.py +++ b/tests/attacks_test.py @@ -1,7 +1,3 @@ -import sys -import os -sys.path.append(f'/home/igor/Documents/graphs/GNN-AID/src') - import unittest import torch diff --git a/tests/explainers_test.py b/tests/explainers_test.py index 45640e0..15ddd78 100644 --- a/tests/explainers_test.py +++ b/tests/explainers_test.py @@ -8,9 +8,6 @@ import signal from time import time -import sys -import os -sys.path.append(f'/home/igor/Documents/graphs/GNN-AID/src') from aux import utils from aux.utils import EXPLAINERS_INIT_PARAMETERS_PATH, EXPLAINERS_LOCAL_RUN_PARAMETERS_PATH, \