diff --git a/experiments/attack_defense_test.py b/experiments/attack_defense_test.py index cbcb6e5..ce772bc 100644 --- a/experiments/attack_defense_test.py +++ b/experiments/attack_defense_test.py @@ -174,6 +174,75 @@ def test_attack_defense(): gen_dataset=dataset, metrics=[Metric("F1", mask='test', average='macro')]) print(metric_loc) +def test_meta(): + from attacks.poison_attacks_collection.metattack import meta_gradient_attack + my_device = device('cpu') + full_name = ("single-graph", "Planetoid", 'Cora') + + dataset, data, results_dataset_path = DatasetManager.get_by_full_name( + full_name=full_name, + dataset_ver_ind=0 + ) + gnn = model_configs_zoo(dataset=dataset, model_name='gcn_gcn') + manager_config = ConfigPattern( + _config_class="ModelManagerConfig", + _config_kwargs={ + "mask_features": [], + "optimizer": { + # "_config_class": "Config", + "_class_name": "Adam", + # "_import_path": OPTIMIZERS_PARAMETERS_PATH, + # "_class_import_info": ["torch.optim"], + "_config_kwargs": {}, + } + } + ) + steps_epochs = 200 + gnn_model_manager = FrameworkGNNModelManager( + gnn=gnn, + dataset_path=results_dataset_path, + manager_config=manager_config, + modification=ModelModificationConfig(model_ver_ind=0, epochs=steps_epochs) + ) + save_model_flag = False + gnn_model_manager.gnn.to(my_device) + data = data.to(my_device) + + poison_attack_config = ConfigPattern( + _class_name="MetaAttackApprox", + _import_path=POISON_ATTACK_PARAMETERS_PATH, + _config_class="PoisonAttackConfig", + _config_kwargs={ + "num_nodes": dataset.dataset.x.shape[0] + } + ) + gnn_model_manager.set_poison_attacker(poison_attack_config=poison_attack_config) + + warnings.warn("Start training") + dataset.train_test_split(percent_train_class=0.1) + + try: + raise FileNotFoundError() + # gnn_model_manager.load_model_executor() + except FileNotFoundError: + gnn_model_manager.epochs = gnn_model_manager.modification.epochs = 0 + train_test_split_path = gnn_model_manager.train_model(gen_dataset=dataset, steps=steps_epochs, + save_model_flag=save_model_flag, + metrics=[Metric("F1", mask='train', average=None)]) + + if train_test_split_path is not None: + dataset.save_train_test_mask(train_test_split_path) + train_mask, val_mask, test_mask, train_test_sizes = torch.load(train_test_split_path / 'train_test_split')[ + :] + dataset.train_mask, dataset.val_mask, dataset.test_mask = train_mask, val_mask, test_mask + data.percent_train_class, data.percent_test_class = train_test_sizes + + warnings.warn("Training was successful") + + metric_loc = gnn_model_manager.evaluate_model( + gen_dataset=dataset, metrics=[Metric("F1", mask='test', average='macro'), + Metric("Accuracy", mask='test')]) + print(metric_loc) def test_nettack_evasion(): my_device = device('cpu') @@ -276,8 +345,6 @@ def test_nettack_evasion(): if __name__ == '__main__': - # test_attack_defense() - test_nettack_evasion() - - - + #test_attack_defense() + torch.manual_seed(5000) + test_meta() diff --git a/experiments/metattack_exp.py b/experiments/metattack_exp.py new file mode 100644 index 0000000..e69de29 diff --git a/experiments/test_dataset_api/processed/pre_filter.pt b/experiments/test_dataset_api/processed/pre_filter.pt deleted file mode 100644 index 5d9470f..0000000 Binary files a/experiments/test_dataset_api/processed/pre_filter.pt and /dev/null differ diff --git a/experiments/test_dataset_api/processed/pre_transform.pt b/experiments/test_dataset_api/processed/pre_transform.pt deleted file mode 100644 index d133f9a..0000000 Binary files a/experiments/test_dataset_api/processed/pre_transform.pt and /dev/null differ diff --git a/experiments/test_dataset_converted_dgl/processed/data.pt b/experiments/test_dataset_converted_dgl/processed/data.pt deleted file mode 100644 index a526886..0000000 Binary files a/experiments/test_dataset_converted_dgl/processed/data.pt and /dev/null differ diff --git a/experiments/test_dataset_converted_dgl/processed/pre_filter.pt b/experiments/test_dataset_converted_dgl/processed/pre_filter.pt deleted file mode 100644 index 5d9470f..0000000 Binary files a/experiments/test_dataset_converted_dgl/processed/pre_filter.pt and /dev/null differ diff --git a/experiments/test_dataset_converted_dgl/processed/pre_transform.pt b/experiments/test_dataset_converted_dgl/processed/pre_transform.pt deleted file mode 100644 index d133f9a..0000000 Binary files a/experiments/test_dataset_converted_dgl/processed/pre_transform.pt and /dev/null differ diff --git a/metainfo/poison_attack_parameters.json b/metainfo/poison_attack_parameters.json index 980cabd..4f5da26 100644 --- a/metainfo/poison_attack_parameters.json +++ b/metainfo/poison_attack_parameters.json @@ -3,7 +3,18 @@ }, "RandomPoisonAttack": { "n_edges_percent": ["n_edges_percent", "float", 0.1, {"min": 0.0001, "step": 0.01}, "?"] + }, + "MetaAttackFull":{ + "lambda_": ["Lambda", "float", 0.5, {"min": 0, "max": 1, "step": 0.05}, "lambda coef - paper"], + "train_iters": ["Train iters (surrogate)", "int", 200, {"min": 0, "step": 1}, "Trainig iterations for surrogate model"], + "attack_structure": ["Attack structure", "bool", true, {}, "whether change graph structure with attack or not"], + "attack_features": ["Attack features", "bool", false, {}, "whether change node features with attack or not"] + }, + "MetaAttackApprox":{ + "lambda_": ["Lambda", "float", 0.5, {"min": 0, "max": 1, "step": 0.05}, "lambda coef - paper"], + "train_iters": ["Train iters (surrogate)", "int", 200, {"min": 0, "step": 1}, "Trainig iterations for surrogate model"], + "attack_structure": ["Attack structure", "bool", true, {}, "whether change graph structure with attack or not"], + "attack_features": ["Attack features", "bool", false, {}, "whether change node features with attack or not"] } - } diff --git a/src/attacks/poison_attacks.py b/src/attacks/poison_attacks.py index f8dc546..0d85927 100644 --- a/src/attacks/poison_attacks.py +++ b/src/attacks/poison_attacks.py @@ -1,8 +1,11 @@ import numpy as np +import importlib import torch from attacks.attack_base import Attacker +from pathlib import Path +POISON_ATTACKS_DIR = Path(__file__).parent.resolve() / 'poison_attacks_collection' class PoisonAttacker(Attacker): def __init__(self, **kwargs): @@ -42,3 +45,20 @@ def attack(self, gen_dataset): def attack_diff(self): return self.attack_diff + +class EmptyPoisonAttacker(PoisonAttacker): + name = "EmptyPoisonAttacker" + + def attack(self, **kwargs): + pass + +# for attack_name in POISON_ATTACKS_DIR.rglob("*_attack.py"): +# try: +# importlib.import_module(str(attack_name)) +# except ImportError: +# print(f"Couldn't import Attack: {attack_name}") + +# import attacks.poison_attacks_collection.metattack.meta_gradient_attack + +# # TODO this is not best practice to import this thing here this way +# from attacks.poison_attacks_collection.metattack.meta_gradient_attack import BaseMeta diff --git a/src/attacks/poison_attacks_collection/metattack/meta_gradient_attack.py b/src/attacks/poison_attacks_collection/metattack/meta_gradient_attack.py new file mode 100644 index 0000000..96bd378 --- /dev/null +++ b/src/attacks/poison_attacks_collection/metattack/meta_gradient_attack.py @@ -0,0 +1,556 @@ +import math +import torch +import numpy as np +import scipy.sparse as sp +import attacks.poison_attacks_collection.metattack.utils as utils +from torch.nn import functional as F +from torch.nn.parameter import Parameter +from torch import optim +from tqdm import tqdm +from models_builder.gnn_models import FrameworkGNNModelManager +from models_builder.models_zoo import model_configs_zoo +from aux.configs import ModelManagerConfig, ModelModificationConfig, DatasetConfig, DatasetVarConfig, ConfigPattern +from aux.utils import OPTIMIZERS_PARAMETERS_PATH +from torch_geometric.utils import to_dense_adj, to_torch_csr_tensor, to_torch_coo_tensor, dense_to_sparse, to_edge_index + +from attacks.poison_attacks import PoisonAttacker + + +class BaseMeta(PoisonAttacker): + name = "BaseMeta" + + """ + Super class for Metattack on GNNs + Parameters + ---------- + model: + surrogate model that will be attacked directly + num_nodes : int + number of nodes in the input graph + train_iters : int + number of initial training iterations for surrogate model + attack_iters: int + number of training iterations for surrogate model for meta-gradient calc + lambda_ : float + lambda_ is used to weight the two objectives in Eq. (10) in the paper. + lr: float + learning rate for surrogate meta-training + feature_shape : tuple + shape of the input node features + attack_structure : bool + whether to attack graph structure + attack_features : bool + whether to attack node features + undirected : bool + whether the graph is undirected + device: str + 'cpu' or 'cuda' + """ + + def __init__(self, num_nodes=None, feature_shape=None, lambda_=0.5, train_iters=200, attack_iters=100, lr=0.1, + attack_structure=True, attack_features=False, undirected=False, device='cpu'): + super().__init__() + self.model = None + self.num_nodes = num_nodes + self.feature_shape = feature_shape + self.lambda_ = lambda_ + self.train_iters = train_iters + self.attack_iters = attack_iters + self.lr = lr + self.device = device + + self.attack_structure = attack_structure + self.attack_features = attack_features + assert attack_features or attack_structure, 'attack_features or attack_structure cannot be both False' + + self.modified_adj = None + self.modified_features = None + + if attack_structure: + self.undirected = undirected + assert num_nodes is not None, 'Num_nodes should be given' + self.adj_changes = Parameter(torch.FloatTensor(num_nodes, num_nodes)) + self.adj_changes.data.fill_(0) + + if attack_features: + assert feature_shape is not None, 'Feature_shape should be given' + self.feature_changes = Parameter(torch.FloatTensor(feature_shape)) + self.feature_changes.data.fill_(0) + + def attack(self, gen_dataset): + # TODO model choice by user to be implemented. + # note: kind of sophisticated task + + # Initial surrogate model training + + # from torch_geometric.data import Data + # Data().get('adj_t') + self.model = model_configs_zoo(gen_dataset, 'gcn_gcn_linearized') + default_config = ModelModificationConfig( + model_ver_ind=0, + ) + manager_config = ConfigPattern( + _config_class="ModelManagerConfig", + _config_kwargs={ + "mask_features": [], + "optimizer": { + "_config_class": "Config", + "_class_name": "Adam", + "_import_path": OPTIMIZERS_PARAMETERS_PATH, + "_class_import_info": ["torch.optim"], + "_config_kwargs": {"weight_decay": 5e-4}, + } + } + ) + gnn_model_manager_surrogate = FrameworkGNNModelManager( + gnn=self.model, + dataset_path=gen_dataset, + modification=default_config, + manager_config=manager_config, + ) + + gnn_model_manager_surrogate.train_model(gen_dataset=gen_dataset, steps=self.train_iters) + + self.pred_labels = gnn_model_manager_surrogate.run_model(gen_dataset=gen_dataset, mask='all', out='answers') + + + def get_modified_adj(self, ori_adj): + adj_changes_square = self.adj_changes - torch.diag(torch.diag(self.adj_changes, 0)) + if self.undirected: + adj_changes_square = adj_changes_square + torch.transpose(adj_changes_square, 1, 0) + adj_changes_square = torch.clamp(adj_changes_square, -1, 1) + modified_adj = adj_changes_square + ori_adj + return modified_adj + + def get_modified_features(self, ori_features): + return ori_features + self.feature_changes + + def filter_potential_singletons(self, modified_adj): + """ + Computes a mask for entries potentially leading to singleton nodes, i.e. one of the two nodes corresponding to + the entry have degree 1 and there is an edge between the two nodes. + """ + + degrees = modified_adj.sum(0) + degree_one = (degrees == 1) + resh = degree_one.repeat(modified_adj.shape[0], 1).float() + l_and = resh * modified_adj + if self.undirected: + l_and = l_and + l_and.t() + flat_mask = 1 - l_and + return flat_mask + + def self_training_label(self, labels, idx_train): + # Predict the labels of the unlabeled nodes to use them for self-training. + output = self.pred_labels + # labels_self_training = output.argmax(1) + labels_self_training = self.pred_labels.long().clone().detach() + labels_self_training[idx_train] = labels[idx_train] + return labels_self_training + + + def log_likelihood_constraint(self, modified_adj, ori_adj, ll_cutoff): + """ + Computes a mask for entries that, if the edge corresponding to the entry is added/removed, would lead to the + log likelihood constraint to be violated. + + Note that different data type (float, double) can effect the final results. + """ + t_d_min = torch.tensor(2.0).to(self.device) + if self.undirected: + t_possible_edges = np.array(np.triu(np.ones((self.num_nodes, self.num_nodes)), k=1).nonzero()).T + else: + t_possible_edges = np.array((np.ones((self.num_nodes, self.num_nodes)) - np.eye(self.num_nodes)).nonzero()).T + allowed_mask, current_ratio = utils.likelihood_ratio_filter(t_possible_edges, + modified_adj, + ori_adj, t_d_min, + ll_cutoff, undirected=self.undirected) + return allowed_mask, current_ratio + + def get_adj_score(self, adj_grad, modified_adj, ori_adj, ll_constraint, ll_cutoff): + adj_meta_grad = adj_grad * (-2 * modified_adj + 1) + # Make sure that the minimum entry is 0. + adj_meta_grad = adj_meta_grad - adj_meta_grad.min() + # Filter self-loops + adj_meta_grad = adj_meta_grad - torch.diag(torch.diag(adj_meta_grad, 0)) + # # Set entries to 0 that could lead to singleton nodes. + singleton_mask = self.filter_potential_singletons(modified_adj) + adj_meta_grad = adj_meta_grad * singleton_mask + + if ll_constraint: + allowed_mask, self.ll_ratio = self.log_likelihood_constraint(modified_adj, ori_adj, ll_cutoff) + allowed_mask = allowed_mask.to(self.device) + adj_meta_grad = adj_meta_grad * allowed_mask + return adj_meta_grad + + def get_feature_score(self, feature_grad, modified_features): + feature_meta_grad = feature_grad * (-2 * modified_features + 1) + feature_meta_grad -= feature_meta_grad.min() + return feature_meta_grad + + # def train_surrogate(self, gen_dataset, initialize=True): + # if initialize: + # pass + + def reset_parameters(self): + pass + +class MetaAttackFull(BaseMeta): + """ + Attack GNNs with meta gradients + """ + name = "MetaAttackFull" + + def __init__(self, num_nodes=None, feature_shape=None, lambda_=0.5, train_iters=200, attack_iters=100, lr=0.1, + momentum=0.9, attack_structure=True, attack_features=False, undirected=False, device='cpu', + with_bias=False, with_relu=False): + super().__init__(num_nodes=num_nodes, feature_shape=feature_shape, lambda_=lambda_, train_iters=train_iters, + attack_iters=attack_iters, lr=lr, attack_features=attack_features, + attack_structure=attack_structure, undirected=undirected, device=device) + self.with_bias = with_bias + self.with_relu = with_relu + + self.weights = [] + self.biases = [] + self.w_velocities = [] + self.b_velocities = [] + self.momentum = momentum + + def attack(self, gen_dataset, attack_budget=10, ll_constraint=True, ll_cutoff=0.004): + super().attack(gen_dataset=gen_dataset) + + self.hidden_sizes = [16] # FIXME get from model architecture + self.nfeat = gen_dataset.num_node_features + self.nclass = gen_dataset.num_classes + + previous_size = self.nfeat + for ix, nhid in enumerate(self.hidden_sizes): + weight = Parameter(torch.FloatTensor(previous_size, nhid).to(self.device)) + w_velocity = torch.zeros(weight.shape).to(self.device) + self.weights.append(weight) + self.w_velocities.append(w_velocity) + + if self.with_bias: + bias = Parameter(torch.FloatTensor(nhid).to(self.device)) + b_velocity = torch.zeros(bias.shape).to(self.device) + self.biases.append(bias) + self.b_velocities.append(b_velocity) + + previous_size = nhid + + output_weight = Parameter(torch.FloatTensor(previous_size, self.nclass).to(self.device)) + output_w_velocity = torch.zeros(output_weight.shape).to(self.device) + self.weights.append(output_weight) + self.w_velocities.append(output_w_velocity) + + if self.with_bias: + output_bias = Parameter(torch.FloatTensor(self.nclass).to(self.device)) + output_b_velocity = torch.zeros(output_bias.shape).to(self.device) + self.biases.append(output_bias) + self.b_velocities.append(output_b_velocity) + + self._initialize() + + ori_features = gen_dataset.dataset.data.x + ori_adj = gen_dataset.dataset.data.edge_index + labels = gen_dataset.dataset.data.y + idx_train = gen_dataset.train_mask + idx_unlabeled = gen_dataset.test_mask + + self.sparse_features = sp.issparse(ori_features) + ori_adj, ori_features, labels = utils.to_tensor(ori_adj, ori_features, labels, device=self.device) + + labels_self_training = self.self_training_label(labels, idx_train) + modified_adj = ori_adj + modified_features = ori_features + + for i in tqdm(range(attack_budget), desc="Perturbing graph"): + if self.attack_structure: + modified_adj = self.get_modified_adj(ori_adj) + + if self.attack_features: + modified_features = ori_features + self.feature_changes + + adj_norm = utils.normalize_adj_tensor(modified_adj) + self.inner_train(modified_features, adj_norm, idx_train, idx_unlabeled, labels) + + adj_grad, feature_grad = self.get_meta_grad(modified_features, adj_norm, idx_train, idx_unlabeled, labels, labels_self_training) + + adj_meta_score = torch.tensor(0.0).to(self.device) + feature_meta_score = torch.tensor(0.0).to(self.device) + if self.attack_structure: + adj_meta_score = self.get_adj_score(adj_grad, modified_adj, ori_adj, ll_constraint, ll_cutoff) + if self.attack_features: + feature_meta_score = self.get_feature_score(feature_grad, modified_features) + + if adj_meta_score.max() >= feature_meta_score.max(): + adj_meta_argmax = torch.argmax(adj_meta_score) + row_idx, col_idx = utils.unravel_index(adj_meta_argmax, ori_adj.shape) + self.adj_changes.data[row_idx][col_idx] += (-2 * modified_adj[row_idx][col_idx] + 1) + if self.undirected: + self.adj_changes.data[col_idx][row_idx] += (-2 * modified_adj[row_idx][col_idx] + 1) + else: + feature_meta_argmax = torch.argmax(feature_meta_score) + row_idx, col_idx = utils.unravel_index(feature_meta_argmax, ori_features.shape) + self.feature_changes.data[row_idx][col_idx] += (-2 * modified_features[row_idx][col_idx] + 1) + + if self.attack_structure: + self.modified_adj = self.get_modified_adj(ori_adj).detach() + if self.attack_features: + self.modified_features = self.get_modified_features(ori_features).detach() + + gen_dataset.dataset.data.edge_index = dense_to_sparse(self.modified_adj.int())[0] + print("TEST") + + def _initialize(self): + for w, v in zip(self.weights, self.w_velocities): + stdv = 1. / math.sqrt(w.size(1)) + w.data.uniform_(-stdv, stdv) + v.data.fill_(0) + + if self.with_bias: + for b, v in zip(self.biases, self.b_velocities): + stdv = 1. / math.sqrt(w.size(1)) + b.data.uniform_(-stdv, stdv) + v.data.fill_(0) + + def inner_train(self, features, adj_norm, idx_train, idx_unlabeled, labels): + self._initialize() + + for ix in range(len(self.hidden_sizes) + 1): + self.weights[ix] = self.weights[ix].detach() + self.weights[ix].requires_grad = True + self.w_velocities[ix] = self.w_velocities[ix].detach() + self.w_velocities[ix].requires_grad = True + + if self.with_bias: + self.biases[ix] = self.biases[ix].detach() + self.biases[ix].requires_grad = True + self.b_velocities[ix] = self.b_velocities[ix].detach() + self.b_velocities[ix].requires_grad = True + + for j in range(self.attack_iters): + hidden = features + for ix, w in enumerate(self.weights): + b = self.biases[ix] if self.with_bias else 0 + if self.sparse_features: + hidden = adj_norm @ torch.spmm(hidden, w) + b + else: + hidden = adj_norm @ hidden @ w + b + + if self.with_relu and ix != len(self.weights) - 1: + hidden = F.relu(hidden) + + output = F.log_softmax(hidden, dim=1) + loss_labeled = F.nll_loss(output[idx_train], labels[idx_train]) + + weight_grads = torch.autograd.grad(loss_labeled, self.weights, create_graph=True) + self.w_velocities = [self.momentum * v + g for v, g in zip(self.w_velocities, weight_grads)] + if self.with_bias: + bias_grads = torch.autograd.grad(loss_labeled, self.biases, create_graph=True) + self.b_velocities = [self.momentum * v + g for v, g in zip(self.b_velocities, bias_grads)] + + self.weights = [w - self.lr * v for w, v in zip(self.weights, self.w_velocities)] + if self.with_bias: + self.biases = [b - self.lr * v for b, v in zip(self.biases, self.b_velocities)] + + def get_meta_grad(self, features, adj_norm, idx_train, idx_unlabeled, labels, labels_self_training): + + hidden = features + for ix, w in enumerate(self.weights): + b = self.biases[ix] if self.with_bias else 0 + if self.sparse_features: + hidden = adj_norm @ torch.spmm(hidden, w) + b + else: + hidden = adj_norm @ hidden @ w + b + if self.with_relu and ix != len(self.weights) - 1: + hidden = F.relu(hidden) + + output = F.log_softmax(hidden, dim=1) + + loss_labeled = F.nll_loss(output[idx_train], labels[idx_train]) + loss_unlabeled = F.nll_loss(output[idx_unlabeled], labels_self_training[idx_unlabeled]) + loss_test_val = F.nll_loss(output[idx_unlabeled], labels[idx_unlabeled]) + + if self.lambda_ == 1: + attack_loss = loss_labeled + elif self.lambda_ == 0: + attack_loss = loss_unlabeled + else: + attack_loss = self.lambda_ * loss_labeled + (1 - self.lambda_) * loss_unlabeled + + print('GCN loss on unlabled data: {}'.format(loss_test_val.item())) + print('GCN acc on unlabled data: {}'.format(utils.accuracy(output[idx_unlabeled], labels[idx_unlabeled]).item())) + print('attack loss: {}'.format(attack_loss.item())) + + adj_grad, feature_grad = None, None + if self.attack_structure: + adj_grad = torch.autograd.grad(attack_loss, self.adj_changes, retain_graph=True)[0] + if self.attack_features: + feature_grad = torch.autograd.grad(attack_loss, self.feature_changes, retain_graph=True)[0] + return adj_grad, feature_grad + + +class MetaAttackApprox(BaseMeta): + """ + Attack GNNs with approximate meta gradients + """ + name = "MetaAttackApprox" + + def __init__(self, num_nodes=None, feature_shape=None, attack_structure=True, attack_features=False, + undirected=False, device='cpu', with_bias=False, lambda_=0.5, train_iters=200, attack_iters=10, + lr=0.01, with_relu=False): + super().__init__(num_nodes=num_nodes, feature_shape=feature_shape, lambda_=lambda_, train_iters=train_iters, + attack_iters=attack_iters, lr=lr, attack_features=attack_features, + attack_structure=attack_structure, undirected=undirected, device=device) + + self.lr = lr + self.train_iters = train_iters + self.attack_iters = attack_iters + self.adj_meta_grad = None + self.features_meta_grad = None + if self.attack_structure: + self.adj_grad_sum = torch.zeros(num_nodes, num_nodes).to(device) + if self.attack_features: + self.feature_grad_sum = torch.zeros(feature_shape).to(device) + + self.with_bias = with_bias + self.with_relu = with_relu + + self.weights = [] + self.biases = [] + + def attack(self, gen_dataset, attack_budget=500, ll_constraint=True, ll_cutoff=0.004): + super().attack(gen_dataset=gen_dataset) + + self.hidden_sizes = [16] # FIXME get from model architecture + self.nfeat = gen_dataset.num_node_features + self.nclass = gen_dataset.num_classes + + previous_size = self.nfeat + for ix, nhid in enumerate(self.hidden_sizes): + weight = Parameter(torch.FloatTensor(previous_size, nhid).to(self.device)) + bias = Parameter(torch.FloatTensor(previous_size, nhid).to(self.device)) + previous_size = nhid + + self.weights.append(weight) + self.biases.append(bias) + + output_weight = Parameter(torch.FloatTensor(previous_size, self.nclass).to(self.device)) + output_bias = Parameter(torch.FloatTensor(self.nclass).to(self.device)) + self.weights.append(output_weight) + self.biases.append(output_bias) + + self.optimizer = optim.Adam(self.weights + self.biases, lr=self.lr) # , weight_decay=5e-4) + self._initialize() + + ori_features = gen_dataset.dataset.data.x + ori_adj = gen_dataset.dataset.data.edge_index + labels = gen_dataset.dataset.data.y + idx_train = gen_dataset.train_mask + idx_unlabeled = gen_dataset.test_mask + + ori_adj, ori_features, labels = utils.to_tensor(ori_adj, ori_features, labels, device=self.device) + labels_self_training = self.self_training_label(labels, idx_train) + self.sparse_features = sp.issparse(ori_features) + modified_adj = ori_adj + modified_features = ori_features + + for i in tqdm(range(attack_budget), desc="Perturbing graph"): + self._initialize() + + if self.attack_structure: + modified_adj = self.get_modified_adj(ori_adj) + self.adj_grad_sum.data.fill_(0) + if self.attack_features: + modified_features = ori_features + self.feature_changes + self.feature_grad_sum.data.fill_(0) + + self.inner_train(modified_features, modified_adj, idx_train, idx_unlabeled, labels, labels_self_training) + + adj_meta_score = torch.tensor(0.0).to(self.device) + feature_meta_score = torch.tensor(0.0).to(self.device) + + if self.attack_structure: + adj_meta_score = self.get_adj_score(self.adj_grad_sum, modified_adj, ori_adj, ll_constraint, ll_cutoff) + if self.attack_features: + feature_meta_score = self.get_feature_score(self.feature_grad_sum, modified_features) + + if adj_meta_score.max() >= feature_meta_score.max(): + adj_meta_argmax = torch.argmax(adj_meta_score) + row_idx, col_idx = utils.unravel_index(adj_meta_argmax, ori_adj.shape) + self.adj_changes.data[row_idx][col_idx] += (-2 * modified_adj[row_idx][col_idx] + 1) + if self.undirected: + self.adj_changes.data[col_idx][row_idx] += (-2 * modified_adj[row_idx][col_idx] + 1) + else: + feature_meta_argmax = torch.argmax(feature_meta_score) + row_idx, col_idx = utils.unravel_index(feature_meta_argmax, ori_features.shape) + self.feature_changes.data[row_idx][col_idx] += (-2 * modified_features[row_idx][col_idx] + 1) + + if self.attack_structure: + self.modified_adj = self.get_modified_adj(ori_adj).detach() + if self.attack_features: + self.modified_features = self.get_modified_features(ori_features).detach() + + gen_dataset.dataset.data.edge_index = dense_to_sparse(self.modified_adj.int())[0] + print("TEST") + + def _initialize(self): + for w, b in zip(self.weights, self.biases): + # w.data.fill_(1) + # b.data.fill_(1) + stdv = 1. / math.sqrt(w.size(1)) + w.data.uniform_(-stdv, stdv) + b.data.uniform_(-stdv, stdv) + + self.optimizer = optim.Adam(self.weights + self.biases, lr=self.lr) + + def inner_train(self, features, modified_adj, idx_train, idx_unlabeled, labels, labels_self_training): + adj_norm = utils.normalize_adj_tensor(modified_adj) + for j in range(self.attack_iters): + # hidden = features + # for w, b in zip(self.weights, self.biases): + # if self.sparse_features: + # hidden = adj_norm @ torch.spmm(hidden, w) + b + # else: + # hidden = adj_norm @ hidden @ w + b + # if self.with_relu: + # hidden = F.relu(hidden) + + hidden = features + for ix, w in enumerate(self.weights): + b = self.biases[ix] if self.with_bias else 0 + if self.sparse_features: + hidden = adj_norm @ torch.spmm(hidden, w) + b + else: + hidden = adj_norm @ hidden @ w + b + if self.with_relu: + hidden = F.relu(hidden) + + output = F.log_softmax(hidden, dim=1) + loss_labeled = F.nll_loss(output[idx_train], labels[idx_train]) + loss_unlabeled = F.nll_loss(output[idx_unlabeled], labels_self_training[idx_unlabeled]) + + if self.lambda_ == 1: + attack_loss = loss_labeled + elif self.lambda_ == 0: + attack_loss = loss_unlabeled + else: + attack_loss = self.lambda_ * loss_labeled + (1 - self.lambda_) * loss_unlabeled + + self.optimizer.zero_grad() + loss_labeled.backward(retain_graph=True) + + if self.attack_structure: + self.adj_changes.grad.zero_() + self.adj_grad_sum += torch.autograd.grad(attack_loss, self.adj_changes, retain_graph=True)[0] + if self.attack_features: + self.feature_changes.grad.zero_() + self.feature_grad_sum += torch.autograd.grad(attack_loss, self.feature_changes, retain_graph=True)[0] + + self.optimizer.step() + + + loss_test_val = F.nll_loss(output[idx_unlabeled], labels[idx_unlabeled]) + print('GCN loss on unlabled data: {}'.format(loss_test_val.item())) + print('GCN acc on unlabled data: {}'.format(utils.accuracy(output[idx_unlabeled], labels[idx_unlabeled]).item())) \ No newline at end of file diff --git a/src/attacks/poison_attacks_collection/metattack/utils.py b/src/attacks/poison_attacks_collection/metattack/utils.py new file mode 100644 index 0000000..201f301 --- /dev/null +++ b/src/attacks/poison_attacks_collection/metattack/utils.py @@ -0,0 +1,786 @@ +import numpy as np +import scipy.sparse as sp +import torch +from sklearn.model_selection import train_test_split +import torch.sparse as ts +import torch.nn.functional as F +import warnings + +from torch_geometric.utils import to_dense_adj, to_torch_csr_tensor, to_torch_coo_tensor + +def encode_onehot(labels): + """Convert label to onehot format. + + Parameters + ---------- + labels : numpy.array + node labels + + Returns + ------- + numpy.array + onehot labels + """ + eye = np.eye(labels.max() + 1) + onehot_mx = eye[labels] + return onehot_mx + +def tensor2onehot(labels): + """Convert label tensor to label onehot tensor. + + Parameters + ---------- + labels : torch.LongTensor + node labels + + Returns + ------- + torch.LongTensor + onehot labels tensor + + """ + + eye = torch.eye(labels.max() + 1) + onehot_mx = eye[labels] + return onehot_mx.to(labels.device) + +def preprocess(adj, features, labels, preprocess_adj=False, preprocess_feature=False, sparse=False, device='cpu'): + """Convert adj, features, labels from array or sparse matrix to + torch Tensor, and normalize the input data. + + Parameters + ---------- + adj : scipy.sparse.csr_matrix + the adjacency matrix. + features : scipy.sparse.csr_matrix + node features + labels : numpy.array + node labels + preprocess_adj : bool + whether to normalize the adjacency matrix + preprocess_feature : bool + whether to normalize the feature matrix + sparse : bool + whether to return sparse tensor + device : str + 'cpu' or 'cuda' + """ + + if preprocess_adj: + adj = normalize_adj(adj) + + if preprocess_feature: + features = normalize_feature(features) + + labels = torch.LongTensor(labels) + if sparse: + adj = sparse_mx_to_torch_sparse_tensor(adj) + features = sparse_mx_to_torch_sparse_tensor(features) + else: + if sp.issparse(features): + features = torch.FloatTensor(np.array(features.todense())) + else: + features = torch.FloatTensor(features) + adj = torch.FloatTensor(adj.todense()) + return adj.to(device), features.to(device), labels.to(device) + +def to_tensor(adj, features, labels=None, device='cpu'): + """Convert adj, features, labels from array or sparse matrix to + torch Tensor. + + Parameters + ---------- + adj : torch.Tensor + the adjacency matrix. + features : scipy.sparse.csr_matrix + node features + labels : numpy.array + node labels + device : str + 'cpu' or 'cuda' + """ + # if sp.issparse(adj): + # adj = sparse_mx_to_torch_sparse_tensor(adj) + # else: + # adj = torch.FloatTensor(adj) + + # adj = to_dense_adj(adj).squeeze() + # adj = to_torch_csr_tensor(adj) + # adj = to_torch_coo_tensor(adj) + adj = to_dense_adj(adj).squeeze() + if sp.issparse(features): + features = sparse_mx_to_torch_sparse_tensor(features) + else: + features = torch.FloatTensor(np.array(features)) + + if labels is None: + return adj.to(device), features.to(device) + else: + labels = torch.LongTensor(labels) + return adj.to(device), features.to(device), labels.to(device) + +def normalize_feature(mx): + """Row-normalize sparse matrix or dense matrix + + Parameters + ---------- + mx : scipy.sparse.csr_matrix or numpy.array + matrix to be normalized + + Returns + ------- + scipy.sprase.lil_matrix + normalized matrix + """ + if type(mx) is not sp.lil.lil_matrix: + try: + mx = mx.tolil() + except AttributeError: + pass + rowsum = np.array(mx.sum(1)) + r_inv = np.power(rowsum, -1).flatten() + r_inv[np.isinf(r_inv)] = 0. + r_mat_inv = sp.diags(r_inv) + mx = r_mat_inv.dot(mx) + return mx + +def normalize_adj(mx): + """Normalize sparse adjacency matrix, + A' = (D + I)^-1/2 * ( A + I ) * (D + I)^-1/2 + Row-normalize sparse matrix + + Parameters + ---------- + mx : scipy.sparse.csr_matrix + matrix to be normalized + + Returns + ------- + scipy.sprase.lil_matrix + normalized matrix + """ + + # TODO: maybe using coo format would be better? + if type(mx) is not sp.lil.lil_matrix: + mx = mx.tolil() + if mx[0, 0] == 0 : + mx = mx + sp.eye(mx.shape[0]) + rowsum = np.array(mx.sum(1)) + r_inv = np.power(rowsum, -1/2).flatten() + r_inv[np.isinf(r_inv)] = 0. + r_mat_inv = sp.diags(r_inv) + mx = r_mat_inv.dot(mx) + mx = mx.dot(r_mat_inv) + return mx + +def normalize_sparse_tensor(adj, fill_value=1): + """Normalize sparse tensor. Need to import torch_scatter + """ + edge_index = adj._indices() + edge_weight = adj._values() + num_nodes= adj.size(0) + edge_index, edge_weight = add_self_loops( + edge_index, edge_weight, fill_value, num_nodes) + + row, col = edge_index + from torch_scatter import scatter_add + deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) + deg_inv_sqrt = deg.pow(-0.5) + deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 + + values = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] + + shape = adj.shape + return torch.sparse.FloatTensor(edge_index, values, shape) + +def add_self_loops(edge_index, edge_weight=None, fill_value=1, num_nodes=None): + # num_nodes = maybe_num_nodes(edge_index, num_nodes) + + loop_index = torch.arange(0, num_nodes, dtype=torch.long, + device=edge_index.device) + loop_index = loop_index.unsqueeze(0).repeat(2, 1) + + if edge_weight is not None: + assert edge_weight.numel() == edge_index.size(1) + loop_weight = edge_weight.new_full((num_nodes, ), fill_value) + edge_weight = torch.cat([edge_weight, loop_weight], dim=0) + + edge_index = torch.cat([edge_index, loop_index], dim=1) + + return edge_index, edge_weight + +def normalize_adj_tensor(adj, sparse=False): + """Normalize adjacency tensor matrix. + """ + device = adj.device + if sparse: + # warnings.warn('If you find the training process is too slow, you can uncomment line 207 in deeprobust/graph/utils.py. Note that you need to install torch_sparse') + # TODO if this is too slow, uncomment the following code, + # but you need to install torch_scatter + # return normalize_sparse_tensor(adj) + adj = to_scipy(adj) + mx = normalize_adj(adj) + return sparse_mx_to_torch_sparse_tensor(mx).to(device) + else: + mx = adj + torch.eye(adj.shape[0]).to(device) + rowsum = mx.sum(1) + r_inv = rowsum.pow(-1/2).flatten() + r_inv[torch.isinf(r_inv)] = 0. + r_mat_inv = torch.diag(r_inv) + mx = r_mat_inv @ mx + mx = mx @ r_mat_inv + return mx + +def degree_normalize_adj(mx): + """Row-normalize sparse matrix""" + mx = mx.tolil() + if mx[0, 0] == 0 : + mx = mx + sp.eye(mx.shape[0]) + rowsum = np.array(mx.sum(1)) + r_inv = np.power(rowsum, -1).flatten() + r_inv[np.isinf(r_inv)] = 0. + r_mat_inv = sp.diags(r_inv) + # mx = mx.dot(r_mat_inv) + mx = r_mat_inv.dot(mx) + return mx + +def degree_normalize_sparse_tensor(adj, fill_value=1): + """degree_normalize_sparse_tensor. + """ + edge_index = adj._indices() + edge_weight = adj._values() + num_nodes= adj.size(0) + + edge_index, edge_weight = add_self_loops( + edge_index, edge_weight, fill_value, num_nodes) + + row, col = edge_index + from torch_scatter import scatter_add + deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes) + deg_inv_sqrt = deg.pow(-1) + deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 + + values = deg_inv_sqrt[row] * edge_weight + shape = adj.shape + return torch.sparse.FloatTensor(edge_index, values, shape) + +def degree_normalize_adj_tensor(adj, sparse=True): + """degree_normalize_adj_tensor. + """ + + device = adj.device + if sparse: + # return degree_normalize_sparse_tensor(adj) + adj = to_scipy(adj) + mx = degree_normalize_adj(adj) + return sparse_mx_to_torch_sparse_tensor(mx).to(device) + else: + mx = adj + torch.eye(adj.shape[0]).to(device) + rowsum = mx.sum(1) + r_inv = rowsum.pow(-1).flatten() + r_inv[torch.isinf(r_inv)] = 0. + r_mat_inv = torch.diag(r_inv) + mx = r_mat_inv @ mx + return mx + +def accuracy(output, labels): + """Return accuracy of output compared to labels. + + Parameters + ---------- + output : torch.Tensor + output from model + labels : torch.Tensor or numpy.array + node labels + + Returns + ------- + float + accuracy + """ + if not hasattr(labels, '__len__'): + labels = [labels] + if type(labels) is not torch.Tensor: + labels = torch.LongTensor(labels) + preds = output.max(1)[1].type_as(labels) + correct = preds.eq(labels).double() + correct = correct.sum() + return correct / len(labels) + +def loss_acc(output, labels, targets, avg_loss=True): + if type(labels) is not torch.Tensor: + labels = torch.LongTensor(labels) + preds = output.max(1)[1].type_as(labels) + correct = preds.eq(labels).double()[targets] + loss = F.nll_loss(output[targets], labels[targets], reduction='mean' if avg_loss else 'none') + + if avg_loss: + return loss, correct.sum() / len(targets) + return loss, correct + # correct = correct.sum() + # return loss, correct / len(labels) + +def get_perf(output, labels, mask, verbose=True): + """evalute performance for test masked data""" + loss = F.nll_loss(output[mask], labels[mask]) + acc = accuracy(output[mask], labels[mask]) + if verbose: + print("loss= {:.4f}".format(loss.item()), + "accuracy= {:.4f}".format(acc.item())) + return loss.item(), acc.item() + + +def classification_margin(output, true_label): + """Calculate classification margin for outputs. + `probs_true_label - probs_best_second_class` + + Parameters + ---------- + output: torch.Tensor + output vector (1 dimension) + true_label: int + true label for this node + + Returns + ------- + list + classification margin for this node + """ + + probs = torch.exp(output) + probs_true_label = probs[true_label].clone() + probs[true_label] = 0 + probs_best_second_class = probs[probs.argmax()] + return (probs_true_label - probs_best_second_class).item() + +def sparse_mx_to_torch_sparse_tensor(sparse_mx): + """Convert a scipy sparse matrix to a torch sparse tensor.""" + sparse_mx = sparse_mx.tocoo().astype(np.float32) + sparserow=torch.LongTensor(sparse_mx.row).unsqueeze(1) + sparsecol=torch.LongTensor(sparse_mx.col).unsqueeze(1) + sparseconcat=torch.cat((sparserow, sparsecol),1) + sparsedata=torch.FloatTensor(sparse_mx.data) + return torch.sparse.FloatTensor(sparseconcat.t(),sparsedata,torch.Size(sparse_mx.shape)) + + # slower version.... + # sparse_mx = sparse_mx.tocoo().astype(np.float32) + # indices = torch.from_numpy( + # np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)) + # values = torch.from_numpy(sparse_mx.data) + # shape = torch.Size(sparse_mx.shape) + # return torch.sparse.FloatTensor(indices, values, shape) + + + +def to_scipy(tensor): + """Convert a dense/sparse tensor to scipy matrix""" + if is_sparse_tensor(tensor): + values = tensor._values() + indices = tensor._indices() + return sp.csr_matrix((values.cpu().numpy(), indices.cpu().numpy()), shape=tensor.shape) + else: + indices = tensor.nonzero().t() + values = tensor[indices[0], indices[1]] + return sp.csr_matrix((values.cpu().numpy(), indices.cpu().numpy()), shape=tensor.shape) + +def is_sparse_tensor(tensor): + """Check if a tensor is sparse tensor. + + Parameters + ---------- + tensor : torch.Tensor + given tensor + + Returns + ------- + bool + whether a tensor is sparse tensor + """ + # if hasattr(tensor, 'nnz'): + if tensor.layout == torch.sparse_coo: + return True + else: + return False + +def get_train_val_test(nnodes, val_size=0.1, test_size=0.8, stratify=None, seed=None): + """This setting follows nettack/mettack, where we split the nodes + into 10% training, 10% validation and 80% testing data + + Parameters + ---------- + nnodes : int + number of nodes in total + val_size : float + size of validation set + test_size : float + size of test set + stratify : + data is expected to split in a stratified fashion. So stratify should be labels. + seed : int or None + random seed + + Returns + ------- + idx_train : + node training indices + idx_val : + node validation indices + idx_test : + node test indices + """ + + assert stratify is not None, 'stratify cannot be None!' + + if seed is not None: + np.random.seed(seed) + + idx = np.arange(nnodes) + train_size = 1 - val_size - test_size + idx_train_and_val, idx_test = train_test_split(idx, + random_state=None, + train_size=train_size + val_size, + test_size=test_size, + stratify=stratify) + + if stratify is not None: + stratify = stratify[idx_train_and_val] + + idx_train, idx_val = train_test_split(idx_train_and_val, + random_state=None, + train_size=(train_size / (train_size + val_size)), + test_size=(val_size / (train_size + val_size)), + stratify=stratify) + + return idx_train, idx_val, idx_test + +def get_train_test(nnodes, test_size=0.8, stratify=None, seed=None): + """This function returns training and test set without validation. + It can be used for settings of different label rates. + + Parameters + ---------- + nnodes : int + number of nodes in total + test_size : float + size of test set + stratify : + data is expected to split in a stratified fashion. So stratify should be labels. + seed : int or None + random seed + + Returns + ------- + idx_train : + node training indices + idx_test : + node test indices + """ + assert stratify is not None, 'stratify cannot be None!' + + if seed is not None: + np.random.seed(seed) + + idx = np.arange(nnodes) + train_size = 1 - test_size + idx_train, idx_test = train_test_split(idx, random_state=None, + train_size=train_size, + test_size=test_size, + stratify=stratify) + + return idx_train, idx_test + +def get_train_val_test_gcn(labels, seed=None): + """This setting follows gcn, where we randomly sample 20 instances for each class + as training data, 500 instances as validation data, 1000 instances as test data. + Note here we are not using fixed splits. When random seed changes, the splits + will also change. + + Parameters + ---------- + labels : numpy.array + node labels + seed : int or None + random seed + + Returns + ------- + idx_train : + node training indices + idx_val : + node validation indices + idx_test : + node test indices + """ + if seed is not None: + np.random.seed(seed) + + idx = np.arange(len(labels)) + nclass = labels.max() + 1 + idx_train = [] + idx_unlabeled = [] + for i in range(nclass): + labels_i = idx[labels==i] + labels_i = np.random.permutation(labels_i) + idx_train = np.hstack((idx_train, labels_i[: 20])).astype(np.int) + idx_unlabeled = np.hstack((idx_unlabeled, labels_i[20: ])).astype(np.int) + + idx_unlabeled = np.random.permutation(idx_unlabeled) + idx_val = idx_unlabeled[: 500] + idx_test = idx_unlabeled[500: 1500] + return idx_train, idx_val, idx_test + +def get_train_test_labelrate(labels, label_rate): + """Get train test according to given label rate. + """ + nclass = labels.max() + 1 + train_size = int(round(len(labels) * label_rate / nclass)) + print("=== train_size = %s ===" % train_size) + idx_train, idx_val, idx_test = get_splits_each_class(labels, train_size=train_size) + return idx_train, idx_test + +def get_splits_each_class(labels, train_size): + """We randomly sample n instances for class, where n = train_size. + """ + idx = np.arange(len(labels)) + nclass = labels.max() + 1 + idx_train = [] + idx_val = [] + idx_test = [] + for i in range(nclass): + labels_i = idx[labels==i] + labels_i = np.random.permutation(labels_i) + idx_train = np.hstack((idx_train, labels_i[: train_size])).astype(np.int) + idx_val = np.hstack((idx_val, labels_i[train_size: 2*train_size])).astype(np.int) + idx_test = np.hstack((idx_test, labels_i[2*train_size: ])).astype(np.int) + + return np.random.permutation(idx_train), np.random.permutation(idx_val), \ + np.random.permutation(idx_test) + + +def unravel_index(index, array_shape): + rows = torch.div(index, array_shape[1], rounding_mode='trunc') + cols = index % array_shape[1] + return rows, cols + + +def get_degree_squence(adj): + try: + return adj.sum(0) + except: + return ts.sum(adj, dim=1).to_dense() + +def likelihood_ratio_filter(node_pairs, modified_adjacency, original_adjacency, d_min, threshold=0.004, undirected=True): + """ + Filter the input node pairs based on the likelihood ratio test proposed by Zügner et al. 2018, see + https://dl.acm.org/citation.cfm?id=3220078. In essence, for each node pair return 1 if adding/removing the edge + between the two nodes does not violate the unnoticeability constraint, and return 0 otherwise. Assumes unweighted + and undirected graphs. + """ + + N = int(modified_adjacency.shape[0]) + # original_degree_sequence = get_degree_squence(original_adjacency) + # current_degree_sequence = get_degree_squence(modified_adjacency) + original_degree_sequence = original_adjacency.sum(0) + current_degree_sequence = modified_adjacency.sum(0) + + concat_degree_sequence = torch.cat((current_degree_sequence, original_degree_sequence)) + + # Compute the log likelihood values of the original, modified, and combined degree sequences. + ll_orig, alpha_orig, n_orig, sum_log_degrees_original = degree_sequence_log_likelihood(original_degree_sequence, d_min) + ll_current, alpha_current, n_current, sum_log_degrees_current = degree_sequence_log_likelihood(current_degree_sequence, d_min) + + ll_comb, alpha_comb, n_comb, sum_log_degrees_combined = degree_sequence_log_likelihood(concat_degree_sequence, d_min) + + # Compute the log likelihood ratio + current_ratio = -2 * ll_comb + 2 * (ll_orig + ll_current) + + # Compute new log likelihood values that would arise if we add/remove the edges corresponding to each node pair. + new_lls, new_alphas, new_ns, new_sum_log_degrees = updated_log_likelihood_for_edge_changes(node_pairs, + modified_adjacency, d_min) + + # Combination of the original degree distribution with the distributions corresponding to each node pair. + n_combined = n_orig + new_ns + new_sum_log_degrees_combined = sum_log_degrees_original + new_sum_log_degrees + alpha_combined = compute_alpha(n_combined, new_sum_log_degrees_combined, d_min) + new_ll_combined = compute_log_likelihood(n_combined, alpha_combined, new_sum_log_degrees_combined, d_min) + new_ratios = -2 * new_ll_combined + 2 * (new_lls + ll_orig) + + # Allowed edges are only those for which the resulting likelihood ratio measure is < than the threshold + allowed_edges = new_ratios < threshold + + if allowed_edges.is_cuda: + filtered_edges = node_pairs[allowed_edges.cpu().numpy().astype(bool)] + else: + filtered_edges = node_pairs[allowed_edges.numpy().astype(bool)] + + allowed_mask = torch.zeros(modified_adjacency.shape) + allowed_mask[filtered_edges.T] = 1 + if undirected: + allowed_mask += allowed_mask.t() + return allowed_mask, current_ratio + + +def degree_sequence_log_likelihood(degree_sequence, d_min): + """ + Compute the (maximum) log likelihood of the Powerlaw distribution fit on a degree distribution. + """ + + # Determine which degrees are to be considered, i.e. >= d_min. + D_G = degree_sequence[(degree_sequence >= d_min.item())] + try: + sum_log_degrees = torch.log(D_G).sum() + except: + sum_log_degrees = np.log(D_G).sum() + n = len(D_G) + + alpha = compute_alpha(n, sum_log_degrees, d_min) + ll = compute_log_likelihood(n, alpha, sum_log_degrees, d_min) + return ll, alpha, n, sum_log_degrees + +def updated_log_likelihood_for_edge_changes(node_pairs, adjacency_matrix, d_min): + """ Adopted from https://github.com/danielzuegner/nettack + """ + # For each node pair find out whether there is an edge or not in the input adjacency matrix. + + #edge_entries_before = adjacency_matrix[node_pairs.T] + edge_entries_before = adjacency_matrix[node_pairs.T[0,:], node_pairs.T[1,:]] + degree_sequence = adjacency_matrix.sum(1) + D_G = degree_sequence[degree_sequence >= d_min.item()] + sum_log_degrees = torch.log(D_G).sum() + n = len(D_G) + deltas = -2 * edge_entries_before + 1 + d_edges_before = degree_sequence[node_pairs] + + d_edges_after = degree_sequence[node_pairs] + deltas[:, None] + + # Sum the log of the degrees after the potential changes which are >= d_min + sum_log_degrees_after, new_n = update_sum_log_degrees(sum_log_degrees, n, d_edges_before, d_edges_after, d_min) + # Updated estimates of the Powerlaw exponents + new_alpha = compute_alpha(new_n, sum_log_degrees_after, d_min) + # Updated log likelihood values for the Powerlaw distributions + new_ll = compute_log_likelihood(new_n, new_alpha, sum_log_degrees_after, d_min) + return new_ll, new_alpha, new_n, sum_log_degrees_after + + +def update_sum_log_degrees(sum_log_degrees_before, n_old, d_old, d_new, d_min): + # Find out whether the degrees before and after the change are above the threshold d_min. + old_in_range = d_old >= d_min + new_in_range = d_new >= d_min + d_old_in_range = d_old * old_in_range.float() + d_new_in_range = d_new * new_in_range.float() + + # Update the sum by subtracting the old values and then adding the updated logs of the degrees. + sum_log_degrees_after = sum_log_degrees_before - (torch.log(torch.clamp(d_old_in_range, min=1))).sum(1) \ + + (torch.log(torch.clamp(d_new_in_range, min=1))).sum(1) + + # Update the number of degrees >= d_min + + new_n = n_old - (old_in_range!=0).sum(1) + (new_in_range!=0).sum(1) + new_n = new_n.float() + return sum_log_degrees_after, new_n + +def compute_alpha(n, sum_log_degrees, d_min): + try: + alpha = 1 + n / (sum_log_degrees - n * torch.log(d_min - 0.5)) + except: + alpha = 1 + n / (sum_log_degrees - n * np.log(d_min - 0.5)) + return alpha + +def compute_log_likelihood(n, alpha, sum_log_degrees, d_min): + # Log likelihood under alpha + try: + ll = n * torch.log(alpha) + n * alpha * torch.log(d_min) + (alpha + 1) * sum_log_degrees + except: + ll = n * np.log(alpha) + n * alpha * np.log(d_min) + (alpha + 1) * sum_log_degrees + + return ll + +def ravel_multiple_indices(ixs, shape, reverse=False): + """ + "Flattens" multiple 2D input indices into indices on the flattened matrix, similar to np.ravel_multi_index. + Does the same as ravel_index but for multiple indices at once. + Parameters + ---------- + ixs: array of ints shape (n, 2) + The array of n indices that will be flattened. + + shape: list or tuple of ints of length 2 + The shape of the corresponding matrix. + + Returns + ------- + array of n ints between 0 and shape[0]*shape[1]-1 + The indices on the flattened matrix corresponding to the 2D input indices. + + """ + if reverse: + return ixs[:, 1] * shape[1] + ixs[:, 0] + + return ixs[:, 0] * shape[1] + ixs[:, 1] + +# def visualize(your_var): +# """visualize computation graph""" +# from graphviz import Digraph +# import torch +# from torch.autograd import Variable +# from torchviz import make_dot +# make_dot(your_var).view() + +def reshape_mx(mx, shape): + indices = mx.nonzero() + return sp.csr_matrix((mx.data, (indices[0], indices[1])), shape=shape) + +def add_mask(data, dataset): + """data: ogb-arxiv pyg data format""" + # for arxiv + split_idx = dataset.get_idx_split() + train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"] + n = data.x.shape[0] + data.train_mask = index_to_mask(train_idx, n) + data.val_mask = index_to_mask(valid_idx, n) + data.test_mask = index_to_mask(test_idx, n) + data.y = data.y.squeeze() + # data.edge_index = to_undirected(data.edge_index, data.num_nodes) + +def index_to_mask(index, size): + mask = torch.zeros((size, ), dtype=torch.bool) + mask[index] = 1 + return mask + +def add_feature_noise(data, noise_ratio, seed): + np.random.seed(seed) + n, d = data.x.shape + # noise = torch.normal(mean=torch.zeros(int(noise_ratio*n), d), std=1) + noise = torch.FloatTensor(np.random.normal(0, 1, size=(int(noise_ratio*n), d))).to(data.x.device) + indices = np.arange(n) + indices = np.random.permutation(indices)[: int(noise_ratio*n)] + delta_feat = torch.zeros_like(data.x) + delta_feat[indices] = noise - data.x[indices] + data.x[indices] = noise + mask = np.zeros(n) + mask[indices] = 1 + mask = torch.tensor(mask).bool().to(data.x.device) + return delta_feat, mask + +def add_feature_noise_test(data, noise_ratio, seed): + np.random.seed(seed) + n, d = data.x.shape + indices = np.arange(n) + test_nodes = indices[data.test_mask.cpu()] + selected = np.random.permutation(test_nodes)[: int(noise_ratio*len(test_nodes))] + noise = torch.FloatTensor(np.random.normal(0, 1, size=(int(noise_ratio*len(test_nodes)), d))) + noise = noise.to(data.x.device) + + delta_feat = torch.zeros_like(data.x) + delta_feat[selected] = noise - data.x[selected] + data.x[selected] = noise + # mask = np.zeros(len(test_nodes)) + mask = np.zeros(n) + mask[selected] = 1 + mask = torch.tensor(mask).bool().to(data.x.device) + return delta_feat, mask + +# def check_path(file_path): +# if not osp.exists(file_path): +# os.system(f'mkdir -p {file_path}') + diff --git a/src/aux/utils.py b/src/aux/utils.py index 5532522..2acd503 100644 --- a/src/aux/utils.py +++ b/src/aux/utils.py @@ -142,3 +142,7 @@ def setting_class_default_parameters(class_name: str, class_kwargs: dict, defaul return class_kwargs_for_save, class_kwargs_for_init + +def all_subclasses(cls): + return set(cls.__subclasses__()).union( + [s for c in cls.__subclasses__() for s in all_subclasses(c)]) diff --git a/src/models_builder/gnn_models.py b/src/models_builder/gnn_models.py index 975aa5e..3f3316b 100644 --- a/src/models_builder/gnn_models.py +++ b/src/models_builder/gnn_models.py @@ -15,7 +15,7 @@ from aux.configs import ModelManagerConfig, ModelModificationConfig, ModelConfig, CONFIG_CLASS_NAME from aux.data_info import UserCodeInfo -from aux.utils import import_by_name, FRAMEWORK_PARAMETERS_PATH, model_managers_info_by_names_list, hash_data_sha256, \ +from aux.utils import import_by_name, all_subclasses, FRAMEWORK_PARAMETERS_PATH, model_managers_info_by_names_list, hash_data_sha256, \ TECHNICAL_PARAMETER_KEY, IMPORT_INFO_KEY, OPTIMIZERS_PARAMETERS_PATH, FUNCTIONS_PARAMETERS_PATH from aux.declaration import Declare from explainers.explainer import ProgressBar @@ -321,11 +321,13 @@ def set_poison_attacker(self, poison_attack_config=None, poison_attack_name: str elif poison_attack_name != self.poison_attack_config._class_name: raise Exception(f"poison_attack_name and self.poison_attack_config._class_name should be equal, " f"but now poison_attack_name is {poison_attack_name}, " - f"self.poison_attack_config._class_name is {self.poison_attack_config._class_name}") + f"self.poisontorch.optim_attack_config._class_name is {self.poison_attack_config._class_name}") self.poison_attack_name = poison_attack_name poison_attack_kwargs = getattr(self.poison_attack_config, CONFIG_OBJ).to_dict() - name_klass = {e.name: e for e in PoisonAttacker.__subclasses__()} + # name_klass = {e.name: e for e in PoisonAttacker.__subclasses__()} + name_klass = {e.name: e for e in all_subclasses(PoisonAttacker)} + klass = name_klass[self.poison_attack_name] self.poison_attacker = klass( # device=self.device, diff --git a/src/models_builder/models_zoo.py b/src/models_builder/models_zoo.py index 07dad7f..2c639c0 100644 --- a/src/models_builder/models_zoo.py +++ b/src/models_builder/models_zoo.py @@ -311,6 +311,52 @@ def model_configs_zoo(dataset, model_name): ) ) + gcn_gcn_linearized = FrameworkGNNConstructor( + model_config=ModelConfig( + structure=ModelStructureConfig( + [ + { + 'label': 'n', + 'layer': { + 'layer_name': 'GCNConv', + 'layer_kwargs': { + 'in_channels': dataset.num_node_features, + 'out_channels': 16, + 'bias': True, + }, + }, + # 'activation': { + # 'activation_name': 'ReLU', + # 'activation_kwargs': None, + # }, + 'dropout': { + 'dropout_name': 'Dropout', + 'dropout_kwargs': { + 'p': 0.5, + } + } + }, + + { + 'label': 'n', + 'layer': { + 'layer_name': 'GCNConv', + 'layer_kwargs': { + 'in_channels': 16, + 'out_channels': dataset.num_classes, + 'bias': True, + }, + }, + 'activation': { + 'activation_name': 'LogSoftmax', + 'activation_kwargs': None, + }, + }, + ] + ) + ) + ) + gcn = FrameworkGNNConstructor( model_config=ModelConfig( structure=ModelStructureConfig( diff --git a/tests/attacks_test.py b/tests/attacks_test.py new file mode 100644 index 0000000..acee652 --- /dev/null +++ b/tests/attacks_test.py @@ -0,0 +1,113 @@ +import unittest + +import torch + +from base.datasets_processing import DatasetManager +from models_builder.gnn_models import FrameworkGNNModelManager, Metric +from aux.configs import ModelManagerConfig, ModelModificationConfig, DatasetConfig, DatasetVarConfig, ConfigPattern +from models_builder.models_zoo import model_configs_zoo + +from aux.utils import POISON_ATTACK_PARAMETERS_PATH, POISON_DEFENSE_PARAMETERS_PATH, EVASION_ATTACK_PARAMETERS_PATH, \ + EVASION_DEFENSE_PARAMETERS_PATH, OPTIMIZERS_PARAMETERS_PATH + +class AttacksTest(unittest.TestCase): + def setUp(self): + print('setup') + from attacks.poison_attacks_collection.metattack import meta_gradient_attack + + # Init datasets + # Single-Graph - Example + self.dataset_sg_example, _, results_dataset_path_sg_example = DatasetManager.get_by_full_name( + full_name=("single-graph", "custom", "example",), + features={'attr': {'a': 'as_is'}}, + labeling='binary', + dataset_ver_ind=0 + ) + + self.gen_dataset_sg_example = DatasetManager.get_by_config( + DatasetConfig( + domain="single-graph", + group="custom", + graph="example"), + DatasetVarConfig(features={'attr': {'a': 'as_is'}}, + labeling='binary', + dataset_ver_ind=0) + ) + self.gen_dataset_sg_example.train_test_split(percent_train_class=0.6, percent_test_class=0.4) + self.results_dataset_path_sg_example = self.gen_dataset_sg_example.results_dir + + self.default_config = ModelModificationConfig( + model_ver_ind=0, + ) + + self.manager_config = ConfigPattern( + _config_class="ModelManagerConfig", + _config_kwargs={ + "mask_features": [], + "optimizer": { + "_config_class": "Config", + "_class_name": "Adam", + "_import_path": OPTIMIZERS_PARAMETERS_PATH, + "_class_import_info": ["torch.optim"], + "_config_kwargs": {"weight_decay": 5e-4}, + } + } + ) + + def test_metattack_full(self): + poison_attack_config = ConfigPattern( + _class_name="MetaAttackFull", + _import_path=POISON_ATTACK_PARAMETERS_PATH, + _config_class="PoisonAttackConfig", + _config_kwargs={ + "num_nodes": self.gen_dataset_sg_example.dataset.x.shape[0] # is there more fancy way? + } + ) + + gat_gat_sg_example = model_configs_zoo(dataset=self.gen_dataset_sg_example, model_name='gat_gat') + + gnn_model_manager_sg_example = FrameworkGNNModelManager( + gnn=gat_gat_sg_example, + dataset_path=self.results_dataset_path_sg_example, + modification=self.default_config, + manager_config=self.manager_config, + ) + + gnn_model_manager_sg_example.set_poison_attacker(poison_attack_config=poison_attack_config) + + gnn_model_manager_sg_example.train_model(gen_dataset=self.gen_dataset_sg_example, steps=100, metrics=[Metric("Accuracy", mask='test')]) + metric_loc = gnn_model_manager_sg_example.evaluate_model(gen_dataset=self.gen_dataset_sg_example, metrics=[Metric("F1", mask='test', average='macro')]) + print(metric_loc) + + def test_metattack_approx(self): + torch.manual_seed(100) # DEBUG + + poison_attack_config = ConfigPattern( + _class_name="MetaAttackApprox", + _import_path=POISON_ATTACK_PARAMETERS_PATH, + _config_class="PoisonAttackConfig", + _config_kwargs={ + "num_nodes": self.gen_dataset_sg_example.dataset.x.shape[0] # is there more fancy way? + } + ) + + gat_gat_sg_example = model_configs_zoo(dataset=self.gen_dataset_sg_example, model_name='gat_gat') + + gnn_model_manager_sg_example = FrameworkGNNModelManager( + gnn=gat_gat_sg_example, + dataset_path=self.results_dataset_path_sg_example, + modification=self.default_config, + manager_config=self.manager_config, + ) + + # gnn_model_manager_sg_example.set_poison_attacker(poison_attack_config=poison_attack_config) + + gnn_model_manager_sg_example.train_model(gen_dataset=self.gen_dataset_sg_example, steps=100, metrics=[Metric("Accuracy", mask='test')]) + metric_loc = gnn_model_manager_sg_example.evaluate_model(gen_dataset=self.gen_dataset_sg_example, + metrics=[Metric("F1", mask='test', average='macro'), + Metric("Accuracy", mask='test')]) + print(metric_loc) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file