diff --git a/README.md b/README.md index e682ab7..0d7f0fc 100644 --- a/README.md +++ b/README.md @@ -62,6 +62,8 @@ If our work could help your research, please cite: ``` # Changelog +* [02/2023] DeepRobust 0.2.7 Released. Please try `pip install deeprobust==0.2.7`! We have added a scalable attack [PRBCD, NeurIPS'21](https://arxiv.org/abs/2110.14038) to graph package. We can now use PRBCD to attack large-scale graphs such as ogb-arxiv (see example in [test_prbcd.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_prbcd.py))! +* [02/2023] Add a robust model [AirGNN, NeurIPS'21](https://proceedings.neurips.cc/paper/2021/file/50abc3e730e36b387ca8e02c26dc0a22-Paper.pdf) to graph package. Try `python examples/graph/test_airgnn.py`! See details in [test_airgnn.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_airgnn.py) * [11/2022] DeepRobust 0.2.6 Released. Please try `pip install deeprobust==0.2.6`! We have more updates coming. Please stay tuned! * [11/2021] A subpackage that includes popular black box attacks in image domain is relased. Find it here. [Link](https://github.com/I-am-Bot/Black-Box-Attacks) * [11/2021] DeepRobust 0.2.4 Released. Please try `pip install deeprobust==0.2.4`! diff --git a/deeprobust/graph/data/dataset.py b/deeprobust/graph/data/dataset.py index 2323300..fb7ef2e 100644 --- a/deeprobust/graph/data/dataset.py +++ b/deeprobust/graph/data/dataset.py @@ -62,12 +62,12 @@ def __init__(self, root, name, setting='nettack', seed=None, require_mask=False) self.seed = seed # self.url = 'https://raw.githubusercontent.com/danielzuegner/nettack/master/data/%s.npz' % self.name self.url = 'https://raw.githubusercontent.com/danielzuegner/gnn-meta-attack/master/data/%s.npz' % self.name - - if platform.system() == 'Windows': - root = root - else: + + if platform.system() == 'Windows': + root = root + else: self.root = osp.expanduser(osp.normpath(root)) - + self.data_folder = osp.join(root, self.name) self.data_filename = self.data_folder + '.npz' self.require_mask = require_mask diff --git a/deeprobust/graph/defense_pyg/__init__.py b/deeprobust/graph/defense_pyg/__init__.py new file mode 100644 index 0000000..b70646f --- /dev/null +++ b/deeprobust/graph/defense_pyg/__init__.py @@ -0,0 +1,15 @@ +try: + from .gcn import GCN + from .gat import GAT + from .appnp import APPNP + from .sage import SAGE + from .gpr import GPRGNN + from .airgnn import AirGNN +except ImportError as e: + print(e) + warnings.warn("Please install pytorch geometric if you " + + "would like to use the datasets from pytorch " + + "geometric. See details in https://pytorch-geom" + + "etric.readthedocs.io/en/latest/notes/installation.html") + +__all__ = ["GCN", "GAT", "APPNP", "SAGE", "GPRGNN", "AirGNN"] diff --git a/deeprobust/graph/defense_pyg/airgnn.py b/deeprobust/graph/defense_pyg/airgnn.py new file mode 100644 index 0000000..0bc4e0f --- /dev/null +++ b/deeprobust/graph/defense_pyg/airgnn.py @@ -0,0 +1,186 @@ +import torch +import torch.nn.functional as F +from torch.nn import Linear +from torch_geometric.nn.conv.gcn_conv import gcn_norm +from torch_geometric.nn.conv import MessagePassing +from typing import Optional, Tuple +from torch_geometric.typing import Adj, OptTensor +from torch import Tensor +from torch_sparse import SparseTensor, matmul +from .base_model import BaseModel +import torch.nn as nn + +class AirGNN(BaseModel): + + def __init__(self, nfeat, nhid, nclass, nlayers=2, K=2, dropout=0.5, lr=0.01, + with_bn=False, weight_decay=5e-4, with_bias=True, device=None, args=None): + + super(AirGNN, self).__init__() + assert device is not None, "Please specify 'device'!" + self.device = device + + self.lins = nn.ModuleList([]) + self.lins.append(Linear(nfeat, nhid)) + if with_bn: + self.bns = nn.ModuleList([]) + self.bns.append(nn.BatchNorm1d(nhid)) + for i in range(nlayers-2): + self.lins.append(Linear(nhid, nhid)) + if with_bn: + self.bns.append(nn.BatchNorm1d(nhid)) + self.lins.append(Linear(nhid, nclass)) + + self.prop = AdaptiveMessagePassing(K=K, alpha=args.alpha, mode=args.model, args=args) + print(self.prop) + + self.dropout = dropout + self.weight_decay = weight_decay + self.lr = lr + self.name = args.model + self.with_bn = with_bn + + def initialize(self): + self.reset_parameters() + + def reset_parameters(self): + for lin in self.lins: + lin.reset_parameters() + if self.with_bn: + for bn in self.bns: + bn.reset_parameters() + self.prop.reset_parameters() + + def forward(self, x, edge_index, edge_weight=None): + x, edge_index, edge_weight = self._ensure_contiguousness(x, edge_index, edge_weight) + edge_index = SparseTensor.from_edge_index(edge_index, edge_weight, + sparse_sizes=2 * x.shape[:1]).t() + for ii, lin in enumerate(self.lins[:-1]): + x = F.dropout(x, p=self.dropout, training=self.training) + x = lin(x) + if self.with_bn: + x = self.bns[ii](x) + x = F.relu(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.lins[-1](x) + x = self.prop(x, edge_index) + return F.log_softmax(x, dim=1) + + def get_embed(self, x, edge_index, edge_weight=None): + x, edge_index, edge_weight = self._ensure_contiguousness(x, edge_index, edge_weight) + edge_index = SparseTensor.from_edge_index(edge_index, edge_weight, + sparse_sizes=2 * x.shape[:1]).t() + for ii, lin in enumerate(self.lins[:-1]): + x = lin(x) + if self.with_bn: + x = self.bns[ii](x) + x = F.relu(x) + x = self.prop(x, edge_index) + return x + + +class AdaptiveMessagePassing(MessagePassing): + _cached_edge_index: Optional[Tuple[Tensor, Tensor]] + _cached_adj_t: Optional[SparseTensor] + + def __init__(self, + K: int, + alpha: float, + dropout: float = 0., + cached: bool = False, + add_self_loops: bool = True, + normalize: bool = True, + mode: str = None, + node_num: int = None, + args=None, + **kwargs): + + super(AdaptiveMessagePassing, self).__init__(aggr='add', **kwargs) + self.K = K + self.alpha = alpha + self.mode = mode + self.dropout = dropout + self.cached = cached + self.add_self_loops = add_self_loops + self.normalize = normalize + self._cached_edge_index = None + self.node_num = node_num + self.args = args + self._cached_adj_t = None + + def reset_parameters(self): + self._cached_edge_index = None + self._cached_adj_t = None + + def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None, mode=None) -> Tensor: + if self.normalize: + if isinstance(edge_index, Tensor): + raise ValueError('Only support SparseTensor now') + + elif isinstance(edge_index, SparseTensor): + cache = self._cached_adj_t + if cache is None: + edge_index = gcn_norm( # yapf: disable + edge_index, edge_weight, x.size(self.node_dim), False, + add_self_loops=self.add_self_loops, dtype=x.dtype) + if self.cached: + self._cached_adj_t = edge_index + else: + edge_index = cache + + if mode == None: mode = self.mode + + if self.K <= 0: + return x + hh = x + + if mode == 'MLP': + return x + + elif mode == 'APPNP': + x = self.appnp_forward(x=x, hh=hh, edge_index=edge_index, K=self.K, alpha=self.alpha) + + elif mode in ['AirGNN']: + x = self.amp_forward(x=x, hh=hh, edge_index=edge_index, K=self.K) + else: + raise ValueError('wrong propagate mode') + return x + + def appnp_forward(self, x, hh, edge_index, K, alpha): + for k in range(K): + x = self.propagate(edge_index, x=x, edge_weight=None, size=None) + x = x * (1 - alpha) + x += alpha * hh + return x + + def amp_forward(self, x, hh, K, edge_index): + lambda_amp = self.args.lambda_amp + gamma = 1 / (2 * (1 - lambda_amp)) ## or simply gamma = 1 + + for k in range(K): + y = x - gamma * 2 * (1 - lambda_amp) * self.compute_LX(x=x, edge_index=edge_index) # Equation (9) + x = hh + self.proximal_L21(x=y - hh, lambda_=gamma * lambda_amp) # Equation (11) and (12) + return x + + def proximal_L21(self, x: Tensor, lambda_): + row_norm = torch.norm(x, p=2, dim=1) + score = torch.clamp(row_norm - lambda_, min=0) + index = torch.where(row_norm > 0) # Deal with the case when the row_norm is 0 + score[index] = score[index] / row_norm[index] # score is the adaptive score in Equation (14) + return score.unsqueeze(1) * x + + def compute_LX(self, x, edge_index, edge_weight=None): + x = x - self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None) + return x + + def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: + return edge_weight.view(-1, 1) * x_j + + def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor: + return matmul(adj_t, x, reduce=self.aggr) + + def __repr__(self): + return '{}(K={}, alpha={}, mode={}, dropout={}, lambda_amp={})'.format(self.__class__.__name__, self.K, + self.alpha, self.mode, self.dropout, + self.args.lambda_amp) + + diff --git a/deeprobust/graph/defense_pyg/appnp.py b/deeprobust/graph/defense_pyg/appnp.py new file mode 100644 index 0000000..569ba1d --- /dev/null +++ b/deeprobust/graph/defense_pyg/appnp.py @@ -0,0 +1,79 @@ +import torch.nn as nn +import torch.nn.functional as F +import math +import torch +from torch.nn.parameter import Parameter +from torch.nn.modules.module import Module +from torch_geometric.nn import APPNP as APPNPConv +from torch.nn import Linear +from .base_model import BaseModel + + +class APPNP(BaseModel): + + def __init__(self, nfeat, nhid, nclass, K=10, alpha=0.1, dropout=0.5, lr=0.01, + with_bn=False, weight_decay=5e-4, with_bias=True, device=None): + + super(APPNP, self).__init__() + + assert device is not None, "Please specify 'device'!" + self.device = device + + + self.lin1 = Linear(nfeat, nhid) + if with_bn: + self.bn1 = nn.BatchNorm1d(nhid) + self.bn2 = nn.BatchNorm1d(nclass) + + self.lin2 = Linear(nhid, nclass) + self.prop1 = APPNPConv(K, alpha) + + self.dropout = dropout + self.weight_decay = weight_decay + self.lr = lr + self.output = None + self.best_model = None + self.best_output = None + self.name = 'APPNP' + self.with_bn = with_bn + + def forward(self, x, edge_index, edge_weight=None): + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.lin1(x) + if self.with_bn: + x = self.bn1(x) + x = F.relu(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.lin2(x) + if self.with_bn: + x = self.bn2(x) + x = self.prop1(x, edge_index, edge_weight) + return F.log_softmax(x, dim=1) + + + def initialize(self): + self.lin1.reset_parameters() + self.lin2.reset_parameters() + if self.with_bn: + self.bn1.reset_parameters() + self.bn2.reset_parameters() + + +if __name__ == "__main__": + from deeprobust.graph.data import Dataset, Dpr2Pyg + data = Dataset(root='/tmp/', name='cora', setting='gcn') + adj, features, labels = data.adj, data.features, data.labels + idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test + model = GCN(nfeat=features.shape[1], + nhid=16, + nclass=labels.max().item() + 1, + dropout=0.5, device='cuda') + model = model.to('cuda') + pyg_data = Dpr2Pyg(data)[0] + + import ipdb + ipdb.set_trace() + + model.fit(pyg_data, verbose=True) # train with earlystopping + model.test() + print(model.predict()) diff --git a/deeprobust/graph/defense_pyg/base_model.py b/deeprobust/graph/defense_pyg/base_model.py new file mode 100644 index 0000000..2c567d9 --- /dev/null +++ b/deeprobust/graph/defense_pyg/base_model.py @@ -0,0 +1,206 @@ +import torch.optim as optim +import torch.nn as nn +import torch.nn.functional as F +from copy import deepcopy +from deeprobust.graph import utils +import torch + + +class BaseModel(nn.Module): + def __init__(self): + super(BaseModel, self).__init__() + pass + + def fit(self, pyg_data, train_iters=1000, initialize=True, verbose=False, patience=100, **kwargs): + if initialize: + self.initialize() + + # self.data = pyg_data[0].to(self.device) + self.data = pyg_data.to(self.device) + # By default, it is trained with early stopping on validation + self.train_with_early_stopping(train_iters, patience, verbose) + + def finetune(self, edge_index, edge_weight, feat=None, train_iters=10, verbose=True): + if verbose: + print(f'=== finetuning {self.name} model ===') + optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) + labels = self.data.y + if feat is None: + x = self.data.x + else: + x = feat + train_mask, val_mask = self.data.train_mask, self.data.val_mask + best_loss_val = 100 + best_acc_val = 0 + for i in range(train_iters): + self.train() + optimizer.zero_grad() + output = self.forward(x, edge_index, edge_weight) + loss_train = F.nll_loss(output[train_mask], labels[train_mask]) + loss_train.backward() + optimizer.step() + + if verbose and i % 50 == 0: + print('Epoch {}, training loss: {}'.format(i, loss_train.item())) + + self.eval() + with torch.no_grad(): + output = self.forward(x, edge_index) + loss_val = F.nll_loss(output[val_mask], labels[val_mask]) + acc_val = utils.accuracy(output[val_mask], labels[val_mask]) + + # if best_loss_val > loss_val: + # best_loss_val = loss_val + # best_output = output + # weights = deepcopy(self.state_dict()) + + if best_acc_val < acc_val: + best_acc_val = acc_val + best_output = output + weights = deepcopy(self.state_dict()) + + print('best_acc_val:', best_acc_val.item()) + self.load_state_dict(weights) + return best_output + + + def _fit_with_val(self, pyg_data, train_iters=1000, initialize=True, verbose=False, **kwargs): + if initialize: + self.initialize() + + # self.data = pyg_data[0].to(self.device) + self.data = pyg_data.to(self.device) + if verbose: + print(f'=== training {self.name} model ===') + optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) + + labels = self.data.y + train_mask, val_mask = self.data.train_mask, self.data.val_mask + + x, edge_index = self.data.x, self.data.edge_index + for i in range(train_iters): + self.train() + optimizer.zero_grad() + output = self.forward(x, edge_index) + loss_train = F.nll_loss(output[train_mask+val_mask], labels[train_mask+val_mask]) + loss_train.backward() + optimizer.step() + + if verbose and i % 50 == 0: + print('Epoch {}, training loss: {}'.format(i, loss_train.item())) + + def fit_with_val(self, pyg_data, train_iters=1000, initialize=True, patience=100, verbose=False, **kwargs): + if initialize: + self.initialize() + + self.data = pyg_data.to(self.device) + self.data.train_mask = self.data.train_mask + self.data.val1_mask + self.data.val_mask = self.data.val2_mask + self.train_with_early_stopping(train_iters, patience, verbose) + + def train_with_early_stopping(self, train_iters, patience, verbose): + """early stopping based on the validation loss + """ + if verbose: + print(f'=== training {self.name} model ===') + optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay) + + labels = self.data.y + train_mask, val_mask = self.data.train_mask, self.data.val_mask + + early_stopping = patience + best_loss_val = 100 + best_acc_val = 0 + best_epoch = 0 + + x, edge_index = self.data.x, self.data.edge_index + for i in range(train_iters): + self.train() + optimizer.zero_grad() + + output = self.forward(x, edge_index) + + loss_train = F.nll_loss(output[train_mask], labels[train_mask]) + loss_train.backward() + optimizer.step() + + if verbose and i % 50 == 0: + print('Epoch {}, training loss: {}'.format(i, loss_train.item())) + + self.eval() + output = self.forward(x, edge_index) + loss_val = F.nll_loss(output[val_mask], labels[val_mask]) + acc_val = utils.accuracy(output[val_mask], labels[val_mask]) + # print(acc) + + # if best_loss_val > loss_val: + # best_loss_val = loss_val + # self.output = output + # weights = deepcopy(self.state_dict()) + # patience = early_stopping + # best_epoch = i + # else: + # patience -= 1 + + if best_acc_val < acc_val: + best_acc_val = acc_val + self.output = output + weights = deepcopy(self.state_dict()) + patience = early_stopping + best_epoch = i + else: + patience -= 1 + + if i > early_stopping and patience <= 0: + break + + if verbose: + # print('=== early stopping at {0}, loss_val = {1} ==='.format(best_epoch, best_loss_val) ) + print('=== early stopping at {0}, acc_val = {1} ==='.format(best_epoch, best_acc_val) ) + self.load_state_dict(weights) + + def test(self): + """Evaluate model performance on test set. + Parameters + ---------- + idx_test : + node testing indices + """ + self.eval() + test_mask = self.data.test_mask + labels = self.data.y + output = self.forward(self.data.x, self.data.edge_index) + # output = self.output + loss_test = F.nll_loss(output[test_mask], labels[test_mask]) + acc_test = utils.accuracy(output[test_mask], labels[test_mask]) + print("Test set results:", + "loss= {:.4f}".format(loss_test.item()), + "accuracy= {:.4f}".format(acc_test.item())) + return acc_test.item() + + def predict(self, x=None, edge_index=None, edge_weight=None): + """ + Returns + ------- + torch.FloatTensor + output (log probabilities) + """ + self.eval() + if x is None or edge_index is None: + x, edge_index = self.data.x, self.data.edge_index + return self.forward(x, edge_index, edge_weight) + + def _ensure_contiguousness(self, + x, + edge_idx, + edge_weight): + if not x.is_sparse: + x = x.contiguous() + if hasattr(edge_idx, 'contiguous'): + edge_idx = edge_idx.contiguous() + if edge_weight is not None: + edge_weight = edge_weight.contiguous() + return x, edge_idx, edge_weight + + + diff --git a/deeprobust/graph/defense_pyg/gat.py b/deeprobust/graph/defense_pyg/gat.py new file mode 100644 index 0000000..bf98d54 --- /dev/null +++ b/deeprobust/graph/defense_pyg/gat.py @@ -0,0 +1,99 @@ +import torch.nn as nn +import torch.nn.functional as F +import math +import torch +from torch.nn.parameter import Parameter +from torch.nn.modules.module import Module +from torch_geometric.nn import GATConv +from .base_model import BaseModel + + +class GAT(BaseModel): + + def __init__(self, nfeat, nhid, nclass, heads=8, output_heads=1, dropout=0.5, lr=0.01, + nlayers=2, with_bn=False, weight_decay=5e-4, with_bias=True, device=None): + + super(GAT, self).__init__() + + assert device is not None, "Please specify 'device'!" + self.device = device + + self.convs = nn.ModuleList([]) + if with_bn: + self.bns = nn.ModuleList([]) + self.bns.append(nn.BatchNorm1d(nhid*heads)) + + self.convs.append(GATConv( + nfeat, + nhid, + heads=heads, + dropout=dropout, + bias=with_bias)) + + for i in range(nlayers-2): + self.convs.append(GATConv(nhid*heads, + nhid, heads=heads, dropout=dropout, bias=with_bias)) + if with_bn: + self.bns.append(nn.BatchNorm1d(nhid*heads)) + + self.convs.append(GATConv( + nhid * heads, + nclass, + heads=output_heads, + concat=False, + dropout=dropout, + bias=with_bias)) + + self.dropout = dropout + self.weight_decay = weight_decay + self.lr = lr + self.output = None + self.best_model = None + self.best_output = None + self.name = 'GAT' + self.with_bn = with_bn + + def forward(self, x, edge_index, edge_weight=None): + for ii, conv in enumerate(self.convs[:-1]): + x = F.dropout(x, p=self.dropout, training=self.training) + x = conv(x, edge_index, edge_weight) + if self.with_bn: + x = self.bns[ii](x) + x = F.elu(x) + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.convs[-1](x, edge_index, edge_weight) + return F.log_softmax(x, dim=1) + + def get_embed(self, x, edge_index, edge_weight=None): + for ii, conv in enumerate(self.convs[:-1]): + x = F.dropout(x, p=self.dropout, training=self.training) + x = conv(x, edge_index, edge_weight) + if self.with_bn: + x = self.bns[ii](x) + x = F.elu(x) + return x + + def initialize(self): + for conv in self.convs: + conv.reset_parameters() + if self.with_bn: + for bn in self.bns: + bn.reset_parameters() + + + +if __name__ == "__main__": + from deeprobust.graph.data import Dataset, Dpr2Pyg + # from deeprobust.graph.defense import GAT + data = Dataset(root='/tmp/', name='cora') + adj, features, labels = data.adj, data.features, data.labels + idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test + gat = GAT(nfeat=features.shape[1], + nhid=8, heads=8, + nclass=labels.max().item() + 1, + dropout=0.5, device='cpu') + gat = gat.to('cpu') + pyg_data = Dpr2Pyg(data) + gat.fit(pyg_data, verbose=True) # train with earlystopping + gat.test() + print(gat.predict()) diff --git a/deeprobust/graph/defense_pyg/gcn.py b/deeprobust/graph/defense_pyg/gcn.py new file mode 100644 index 0000000..7ea6347 --- /dev/null +++ b/deeprobust/graph/defense_pyg/gcn.py @@ -0,0 +1,110 @@ +import torch.nn as nn +import torch.nn.functional as F +import math +import torch +from torch.nn.parameter import Parameter +from torch.nn.modules.module import Module +from torch_geometric.nn import GCNConv +from .base_model import BaseModel +from torch_sparse import coalesce, SparseTensor, matmul + + +class GCN(BaseModel): + + def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, + with_bn=False, weight_decay=5e-4, with_bias=True, device=None): + + super(GCN, self).__init__() + + assert device is not None, "Please specify 'device'!" + self.device = device + + self.layers = nn.ModuleList([]) + if with_bn: + self.bns = nn.ModuleList() + + if nlayers == 1: + self.layers.append(GCNConv(nfeat, nclass, bias=with_bias)) + else: + self.layers.append(GCNConv(nfeat, nhid, bias=with_bias)) + if with_bn: + self.bns.append(nn.BatchNorm1d(nhid)) + for i in range(nlayers-2): + self.layers.append(GCNConv(nhid, nhid, bias=with_bias)) + if with_bn: + self.bns.append(nn.BatchNorm1d(nhid)) + self.layers.append(GCNConv(nhid, nclass, bias=with_bias)) + + self.dropout = dropout + self.weight_decay = weight_decay + self.lr = lr + self.output = None + self.best_model = None + self.best_output = None + self.with_bn = with_bn + self.name = 'GCN' + + def forward(self, x, edge_index, edge_weight=None): + x, edge_index, edge_weight = self._ensure_contiguousness(x, edge_index, edge_weight) + for ii, layer in enumerate(self.layers): + if edge_weight is not None: + adj = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=2 * x.shape[:1]).t() + x = layer(x, adj) + else: + x = layer(x, edge_index) + if ii != len(self.layers) - 1: + if self.with_bn: + x = self.bns[ii](x) + x = F.relu(x) + x = F.dropout(x, p=self.dropout, training=self.training) + return F.log_softmax(x, dim=1) + + def get_embed(self, x, edge_index, edge_weight=None): + x, edge_index, edge_weight = self._ensure_contiguousness(x, edge_index, edge_weight) + for ii, layer in enumerate(self.layers): + if ii == len(self.layers) - 1: + return x + if edge_weight is not None: + adj = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=2 * x.shape[:1]).t() + x = layer(x, adj) + else: + x = layer(x, edge_index) + if ii != len(self.layers) - 1: + if self.with_bn: + x = self.bns[ii](x) + x = F.relu(x) + return x + + def initialize(self): + for m in self.layers: + m.reset_parameters() + if self.with_bn: + for bn in self.bns: + bn.reset_parameters() + + +if __name__ == "__main__": + from deeprobust.graph.data import Dataset, Dpr2Pyg + # from deeprobust.graph.defense import GCN + data = Dataset(root='/tmp/', name='citeseer', setting='prognn') + adj, features, labels = data.adj, data.features, data.labels + idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test + model = GCN(nfeat=features.shape[1], + nhid=16, + nclass=labels.max().item() + 1, + dropout=0.5, device='cuda') + model = model.to('cuda') + pyg_data = Dpr2Pyg(data)[0] + + # model.fit(features, adj, labels, idx_train, train_iters=200, verbose=True) + # model.test(idx_test) + + from utils import get_dataset + pyg_data = get_dataset('citeseer', True, if_dpr=False)[0] + + import ipdb + ipdb.set_trace() + + model.fit(pyg_data, verbose=True) # train with earlystopping + model.test() + print(model.predict()) diff --git a/deeprobust/graph/defense_pyg/gpr.py b/deeprobust/graph/defense_pyg/gpr.py new file mode 100644 index 0000000..80920d4 --- /dev/null +++ b/deeprobust/graph/defense_pyg/gpr.py @@ -0,0 +1,135 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_sparse import SparseTensor, matmul +from torch_geometric.nn import GCNConv, SAGEConv, GATConv, APPNP, MessagePassing +from torch_geometric.nn.conv.gcn_conv import gcn_norm +import scipy.sparse +import numpy as np +from .base_model import BaseModel + + +class GPRGNN(BaseModel): + """GPRGNN, from original repo https://github.com/jianhao2016/GPRGNN""" + + def __init__(self, in_channels, hidden_channels, out_channels, Init='PPR', dprate=.5, dropout=.5, + lr=0.01, weight_decay=0, device='cpu', + K=10, alpha=.1, Gamma=None, ppnp='GPR_prop'): + super(GPRGNN, self).__init__() + self.lin1 = nn.Linear(in_channels, hidden_channels) + self.lin2 = nn.Linear(hidden_channels, out_channels) + + if ppnp == 'PPNP': + self.prop1 = APPNP(K, alpha) + elif ppnp == 'GPR_prop': + self.prop1 = GPR_prop(K, alpha, Init, Gamma) + + self.Init = Init + self.dprate = dprate + self.dropout = dropout + self.name = "GPR" + self.weight_decay = weight_decay + self.lr = lr + self.device=device + + def initialize(self): + self.reset_parameters() + + def reset_parameters(self): + self.lin1.reset_parameters() + self.lin2.reset_parameters() + self.prop1.reset_parameters() + + def forward(self, x, edge_index, edge_weight=None): + + x = F.dropout(x, p=self.dropout, training=self.training) + x = F.relu(self.lin1(x)) + x = F.dropout(x, p=self.dropout, training=self.training) + x = self.lin2(x) + + if edge_weight is not None: + adj = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=2 * x.shape[:1]) + if self.dprate == 0.0: + x = self.prop1(x, adj) + else: + x = F.dropout(x, p=self.dprate, training=self.training) + x = self.prop1(x, adj) + else: + if self.dprate == 0.0: + x = self.prop1(x, edge_index, edge_weight) + else: + x = F.dropout(x, p=self.dprate, training=self.training) + x = self.prop1(x, edge_index, edge_weight) + + return F.log_softmax(x, dim=1) + + +class GPR_prop(MessagePassing): + ''' + GPRGNN, from original repo https://github.com/jianhao2016/GPRGNN + propagation class for GPR_GNN + ''' + + def __init__(self, K, alpha, Init, Gamma=None, bias=True, **kwargs): + super(GPR_prop, self).__init__(aggr='add', **kwargs) + self.K = K + self.Init = Init + self.alpha = alpha + + assert Init in ['SGC', 'PPR', 'NPPR', 'Random', 'WS'] + if Init == 'SGC': + # SGC-like + TEMP = 0.0*np.ones(K+1) + TEMP[alpha] = 1.0 + elif Init == 'PPR': + # PPR-like + TEMP = alpha*(1-alpha)**np.arange(K+1) + TEMP[-1] = (1-alpha)**K + elif Init == 'NPPR': + # Negative PPR + TEMP = (alpha)**np.arange(K+1) + TEMP = TEMP/np.sum(np.abs(TEMP)) + elif Init == 'Random': + # Random + bound = np.sqrt(3/(K+1)) + TEMP = np.random.uniform(-bound, bound, K+1) + TEMP = TEMP/np.sum(np.abs(TEMP)) + elif Init == 'WS': + # Specify Gamma + TEMP = Gamma + + self.temp = nn.Parameter(torch.tensor(TEMP)) + + def reset_parameters(self): + nn.init.zeros_(self.temp) + for k in range(self.K+1): + self.temp.data[k] = self.alpha*(1-self.alpha)**k + self.temp.data[-1] = (1-self.alpha)**self.K + + def forward(self, x, edge_index, edge_weight=None): + if isinstance(edge_index, torch.Tensor): + edge_index, norm = gcn_norm( + edge_index, edge_weight, num_nodes=x.size(0), dtype=x.dtype) + elif isinstance(edge_index, SparseTensor): + edge_index = gcn_norm( + edge_index, edge_weight, num_nodes=x.size(0), dtype=x.dtype) + norm = None + + hidden = x*(self.temp[0]) + for k in range(self.K): + x = self.propagate(edge_index, x=x, norm=norm) + gamma = self.temp[k+1] + hidden = hidden + gamma*x + return hidden + + def message(self, x_j, norm): + return norm.view(-1, 1) * x_j + + def message_and_aggregate(self, adj_t, x): + return matmul(adj_t, x, reduce=self.aggr) + + def __repr__(self): + return '{}(K={}, temp={})'.format(self.__class__.__name__, self.K, + self.temp) + + diff --git a/deeprobust/graph/defense_pyg/sage.py b/deeprobust/graph/defense_pyg/sage.py new file mode 100644 index 0000000..fb9c24b --- /dev/null +++ b/deeprobust/graph/defense_pyg/sage.py @@ -0,0 +1,154 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_sparse import SparseTensor, matmul +# from torch_geometric.nn import SAGEConv, GATConv, APPNP, MessagePassing +from torch_geometric.nn.conv.gcn_conv import gcn_norm +import scipy.sparse +import numpy as np +from .base_model import BaseModel + + +class SAGE(BaseModel): + def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2, + dropout=0.5, lr=0.01, weight_decay=0, device='cpu', with_bn=True): + super(SAGE, self).__init__() + + self.convs = nn.ModuleList() + self.convs.append( + SAGEConv(in_channels, hidden_channels)) + + self.bns = nn.ModuleList() + self.bns.append(nn.BatchNorm1d(hidden_channels)) + for _ in range(num_layers - 2): + self.convs.append( + SAGEConv(hidden_channels, hidden_channels)) + self.bns.append(nn.BatchNorm1d(hidden_channels)) + + self.convs.append( + SAGEConv(hidden_channels, out_channels)) + + self.weight_decay = weight_decay + self.lr = lr + self.dropout = dropout + self.activation = F.relu + self.with_bn = with_bn + self.device = device + self.name = "SAGE" + + def initialize(self): + self.reset_parameters() + + def reset_parameters(self): + for conv in self.convs: + conv.reset_parameters() + for bn in self.bns: + bn.reset_parameters() + + + def forward(self, x, edge_index, edge_weight=None): + if edge_weight is not None: + adj = SparseTensor.from_edge_index(edge_index, edge_weight, sparse_sizes=2 * x.shape[:1]).t() + + for i, conv in enumerate(self.convs[:-1]): + if edge_weight is not None: + x = conv(x, adj) + else: + x = conv(x, edge_index, edge_weight) + if self.with_bn: + x = self.bns[i](x) + x = self.activation(x) + x = F.dropout(x, p=self.dropout, training=self.training) + if edge_weight is not None: + x = self.convs[-1](x, adj) + else: + x = self.convs[-1](x, edge_index, edge_weight) + return F.log_softmax(x, dim=1) + + + +from typing import Union, Tuple +from torch_geometric.typing import OptPairTensor, Adj, Size + +from torch import Tensor +from torch.nn import Linear +import torch.nn.functional as F +from torch_sparse import SparseTensor, matmul +from torch_geometric.nn.conv import MessagePassing + + +class SAGEConv(MessagePassing): + r"""The GraphSAGE operator from the `"Inductive Representation Learning on + Large Graphs" `_ paper + + .. math:: + \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W_2} \cdot + \mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{x}_j + + Args: + in_channels (int or tuple): Size of each input sample. A tuple + corresponds to the sizes of source and target dimensionalities. + out_channels (int): Size of each output sample. + normalize (bool, optional): If set to :obj:`True`, output features + will be :math:`\ell_2`-normalized, *i.e.*, + :math:`\frac{\mathbf{x}^{\prime}_i} + {\| \mathbf{x}^{\prime}_i \|_2}`. + (default: :obj:`False`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`torch_geometric.nn.conv.MessagePassing`. + """ + def __init__(self, in_channels: Union[int, Tuple[int, int]], + out_channels: int, normalize: bool = False, + bias: bool = True, **kwargs): # yapf: disable + kwargs.setdefault('aggr', 'mean') + super(SAGEConv, self).__init__(**kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.normalize = normalize + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + + self.lin_l = Linear(in_channels[0], out_channels, bias=bias) + self.lin_r = Linear(in_channels[1], out_channels, bias=False) + + self.reset_parameters() + + def reset_parameters(self): + self.lin_l.reset_parameters() + self.lin_r.reset_parameters() + + def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, + size: Size = None) -> Tensor: + """""" + if isinstance(x, Tensor): + x: OptPairTensor = (x, x) + + # propagate_type: (x: OptPairTensor) + out = self.propagate(edge_index, x=x, size=size) + out = self.lin_l(out) + + x_r = x[1] + if x_r is not None: + out += self.lin_r(x_r) + + if self.normalize: + out = F.normalize(out, p=2., dim=-1) + + return out + + def message(self, x_j: Tensor) -> Tensor: + return x_j + + def message_and_aggregate(self, adj_t: SparseTensor, + x: OptPairTensor) -> Tensor: + # Deleted the following line to make propagation differentiable + # adj_t = adj_t.set_value(None, layout=None) + return matmul(adj_t, x[0], reduce=self.aggr) + + def __repr__(self): + return '{}({}, {})'.format(self.__class__.__name__, self.in_channels, + self.out_channels) diff --git a/deeprobust/graph/global_attack/__init__.py b/deeprobust/graph/global_attack/__init__.py index 6a57a08..fb809c1 100644 --- a/deeprobust/graph/global_attack/__init__.py +++ b/deeprobust/graph/global_attack/__init__.py @@ -6,4 +6,13 @@ from .node_embedding_attack import NodeEmbeddingAttack, OtherNodeEmbeddingAttack from .nipa import NIPA -__all__ = ['BaseAttack', 'DICE', 'MetaApprox', 'Metattack', 'Random', 'MinMax', 'PGDAttack', 'NIPA', 'NodeEmbeddingAttack', 'OtherNodeEmbeddingAttack'] +try: + from .prbcd import PRBCD +except ImportError as e: + print(e) + warnings.warn("Please install pytorch geometric if you " + + "would like to use the datasets from pytorch " + + "geometric. See details in https://pytorch-geom" + + "etric.readthedocs.io/en/latest/notes/installation.html") + +__all__ = ['BaseAttack', 'DICE', 'MetaApprox', 'Metattack', 'Random', 'MinMax', 'PGDAttack', 'NIPA', 'NodeEmbeddingAttack', 'OtherNodeEmbeddingAttack', 'PRBCD'] diff --git a/deeprobust/graph/global_attack/prbcd.py b/deeprobust/graph/global_attack/prbcd.py new file mode 100644 index 0000000..90fcf87 --- /dev/null +++ b/deeprobust/graph/global_attack/prbcd.py @@ -0,0 +1,434 @@ +""" +Robustness of Graph Neural Networks at Scale. NeurIPS 2021. + +Modified from https://github.com/sigeisler/robustness_of_gnns_at_scale/blob/main/rgnn_at_scale/attacks/prbcd.py +""" +import numpy as np +from deeprobust.graph.defense_pyg import GCN +import torch.nn.functional as F +import torch +import deeprobust.graph.utils as utils +from torch.nn.parameter import Parameter +from tqdm import tqdm +import torch_sparse +from torch_sparse import coalesce +import math +from torch_geometric.utils import to_scipy_sparse_matrix, from_scipy_sparse_matrix + + +class PRBCD: + + def __init__(self, data, make_undirected=True, + eps=1e-7, search_space_size=10_000_000, + max_final_samples=20, + fine_tune_epochs=100, + epochs=400, lr_adj=0.1, + with_early_stopping=True, + do_synchronize=True, + device='cuda' + ): + self.device = device + self.data = data + self.model = self.pretrain_model() + output = self.model.predict() + labels = data.y.to(self.device) + self.get_perf(output, labels, data.test_mask) + + for param in self.model.parameters(): + param.requires_grad = False + + nnodes = data.x.shape[0] + d = data.x.shape[1] + + self.n, self.d = nnodes, nnodes + self.make_undirected = make_undirected + self.max_final_samples = max_final_samples + self.search_space_size = search_space_size + self.eps = eps + self.lr_adj = lr_adj + + self.modified_edge_index: torch.Tensor = None + self.perturbed_edge_weight: torch.Tensor = None + if self.make_undirected: + self.n_possible_edges = self.n * (self.n - 1) // 2 + else: + self.n_possible_edges = self.n ** 2 # We filter self-loops later + + # lr_factor = 0.1 + # self.lr_factor = lr_factor * max(math.log2(self.n_possible_edges / self.search_space_size), 1.) + self.epochs = epochs + self.epochs_resampling = epochs - fine_tune_epochs # TODO + + self.with_early_stopping = with_early_stopping + self.do_synchronize = do_synchronize + + def pretrain_model(self): + data = self.data + device = self.device + feat, labels = data.x, data.y + nclass = max(labels).item()+1 + + model = GCN(nfeat=feat.shape[1], nhid=256, dropout=0, + nlayers=3, with_bn=True, weight_decay=5e-4, nclass=nclass, + device=device).to(device) + print(model) + + model.fit(data, train_iters=1000, patience=200, verbose=True) + model.eval() + model.data = data.to(self.device) + output = model.predict() + labels = labels.to(device) + print(f"{model.name} Test set results:", self.get_perf(output, labels, data.test_mask, verbose=0)[1]) + self.clean_node_mask = (output.argmax(1) == labels) + return model + + + def sample_random_block(self, n_perturbations): + for _ in range(self.max_final_samples): + self.current_search_space = torch.randint( + self.n_possible_edges, (self.search_space_size,), device=self.device) + self.current_search_space = torch.unique(self.current_search_space, sorted=True) + if self.make_undirected: + self.modified_edge_index = linear_to_triu_idx(self.n, self.current_search_space) + else: + self.modified_edge_index = linear_to_full_idx(self.n, self.current_search_space) + is_not_self_loop = self.modified_edge_index[0] != self.modified_edge_index[1] + self.current_search_space = self.current_search_space[is_not_self_loop] + self.modified_edge_index = self.modified_edge_index[:, is_not_self_loop] + + self.perturbed_edge_weight = torch.full_like( + self.current_search_space, self.eps, dtype=torch.float32, requires_grad=True + ) + if self.current_search_space.size(0) >= n_perturbations: + return + raise RuntimeError('Sampling random block was not successfull. Please decrease `n_perturbations`.') + + @torch.no_grad() + def sample_final_edges(self, n_perturbations): + best_loss = -float('Inf') + perturbed_edge_weight = self.perturbed_edge_weight.detach() + perturbed_edge_weight[perturbed_edge_weight <= self.eps] = 0 + + _, feat, labels = self.edge_index, self.data.x, self.data.y + for i in range(self.max_final_samples): + if best_loss == float('Inf') or best_loss == -float('Inf'): + # In first iteration employ top k heuristic instead of sampling + sampled_edges = torch.zeros_like(perturbed_edge_weight) + sampled_edges[torch.topk(perturbed_edge_weight, n_perturbations).indices] = 1 + else: + sampled_edges = torch.bernoulli(perturbed_edge_weight).float() + + if sampled_edges.sum() > n_perturbations: + n_samples = sampled_edges.sum() + print(f'{i}-th sampling: too many samples {n_samples}') + continue + self.perturbed_edge_weight = sampled_edges + + edge_index, edge_weight = self.get_modified_adj() + with torch.no_grad(): + output = self.model.forward(feat, edge_index, edge_weight) + loss = F.nll_loss(output[self.data.val_mask], labels[self.data.val_mask]).item() + + if best_loss < loss: + best_loss = loss + print('best_loss:', best_loss) + best_edges = self.perturbed_edge_weight.clone().cpu() + + # Recover best sample + self.perturbed_edge_weight.data.copy_(best_edges.to(self.device)) + + edge_index, edge_weight = self.get_modified_adj() + edge_mask = edge_weight == 1 + + allowed_perturbations = 2 * n_perturbations if self.make_undirected else n_perturbations + edges_after_attack = edge_mask.sum() + clean_edges = self.edge_index.shape[1] + assert (edges_after_attack >= clean_edges - allowed_perturbations + and edges_after_attack <= clean_edges + allowed_perturbations), \ + f'{edges_after_attack} out of range with {clean_edges} clean edges and {n_perturbations} pertutbations' + return edge_index[:, edge_mask], edge_weight[edge_mask] + + def resample_random_block(self, n_perturbations: int): + self.keep_heuristic = 'WeightOnly' + if self.keep_heuristic == 'WeightOnly': + sorted_idx = torch.argsort(self.perturbed_edge_weight) + idx_keep = (self.perturbed_edge_weight <= self.eps).sum().long() + # Keep at most half of the block (i.e. resample low weights) + if idx_keep < sorted_idx.size(0) // 2: + idx_keep = sorted_idx.size(0) // 2 + else: + raise NotImplementedError('Only keep_heuristic=`WeightOnly` supported') + + sorted_idx = sorted_idx[idx_keep:] + self.current_search_space = self.current_search_space[sorted_idx] + self.modified_edge_index = self.modified_edge_index[:, sorted_idx] + self.perturbed_edge_weight = self.perturbed_edge_weight[sorted_idx] + + # Sample until enough edges were drawn + for i in range(self.max_final_samples): + n_edges_resample = self.search_space_size - self.current_search_space.size(0) + lin_index = torch.randint(self.n_possible_edges, (n_edges_resample,), device=self.device) + + self.current_search_space, unique_idx = torch.unique( + torch.cat((self.current_search_space, lin_index)), + sorted=True, + return_inverse=True + ) + + if self.make_undirected: + self.modified_edge_index = linear_to_triu_idx(self.n, self.current_search_space) + else: + self.modified_edge_index = linear_to_full_idx(self.n, self.current_search_space) + + # Merge existing weights with new edge weights + perturbed_edge_weight_old = self.perturbed_edge_weight.clone() + self.perturbed_edge_weight = torch.full_like(self.current_search_space, self.eps, dtype=torch.float32) + self.perturbed_edge_weight[ + unique_idx[:perturbed_edge_weight_old.size(0)] + ] = perturbed_edge_weight_old # unique_idx: the indices for the old edges + + if not self.make_undirected: + is_not_self_loop = self.modified_edge_index[0] != self.modified_edge_index[1] + self.current_search_space = self.current_search_space[is_not_self_loop] + self.modified_edge_index = self.modified_edge_index[:, is_not_self_loop] + self.perturbed_edge_weight = self.perturbed_edge_weight[is_not_self_loop] + + if self.current_search_space.size(0) > n_perturbations: + return + raise RuntimeError('Sampling random block was not successfull. Please decrease `n_perturbations`.') + + + def project(self, n_perturbations, values, eps, inplace=False): + if not inplace: + values = values.clone() + + if torch.clamp(values, 0, 1).sum() > n_perturbations: + left = (values - 1).min() + right = values.max() + miu = bisection(values, left, right, n_perturbations) + values.data.copy_(torch.clamp( + values - miu, min=eps, max=1 - eps + )) + else: + values.data.copy_(torch.clamp( + values, min=eps, max=1 - eps + )) + return values + + def get_modified_adj(self): + if self.make_undirected: + modified_edge_index, modified_edge_weight = to_symmetric( + self.modified_edge_index, self.perturbed_edge_weight, self.n + ) + else: + modified_edge_index, modified_edge_weight = self.modified_edge_index, self.perturbed_edge_weight + edge_index = torch.cat((self.edge_index.to(self.device), modified_edge_index), dim=-1) + edge_weight = torch.cat((self.edge_weight.to(self.device), modified_edge_weight)) + + edge_index, edge_weight = torch_sparse.coalesce(edge_index, edge_weight, m=self.n, n=self.n, op='sum') + + # Allow removal of edges + edge_weight[edge_weight > 1] = 2 - edge_weight[edge_weight > 1] + return edge_index, edge_weight + + def update_edge_weights(self, n_perturbations, epoch, gradient): + self.optimizer_adj.zero_grad() + self.perturbed_edge_weight.grad = -gradient + self.optimizer_adj.step() + self.perturbed_edge_weight.data[self.perturbed_edge_weight < self.eps] = self.eps + + def _update_edge_weights(self, n_perturbations, epoch, gradient): + lr_factor = n_perturbations / self.n / 2 * self.lr_factor + lr = lr_factor / np.sqrt(max(0, epoch - self.epochs_resampling) + 1) + self.perturbed_edge_weight.data.add_(lr * gradient) + self.perturbed_edge_weight.data[self.perturbed_edge_weight < self.eps] = self.eps + return None + + def attack(self, edge_index=None, edge_weight=None, ptb_rate=0.1): + data = self.data + epochs, lr_adj = self.epochs, self.lr_adj + model = self.model + model.eval() # should set to eval + + self.edge_index, feat, labels = data.edge_index, data.x, data.y + with torch.no_grad(): + output = model.forward(feat, self.edge_index) + pred = output.argmax(1) + gt_labels = labels + labels = labels.clone() # to avoid shallow copy + labels[~data.train_mask] = pred[~data.train_mask] + + if edge_index is not None: + self.edge_index = edge_index + + self.edge_weight = torch.ones(self.edge_index.shape[1]).to(self.device) + + n_perturbations = int(ptb_rate * self.edge_index.shape[1] //2) + print('n_perturbations:', n_perturbations) + self.sample_random_block(n_perturbations) + + self.perturbed_edge_weight.requires_grad = True + self.optimizer_adj = torch.optim.Adam([self.perturbed_edge_weight], lr=lr_adj) + best_loss_val = -float('Inf') + for it in tqdm(range(epochs)): + self.perturbed_edge_weight.requires_grad = True + edge_index, edge_weight = self.get_modified_adj() + if torch.cuda.is_available() and self.do_synchronize: + torch.cuda.empty_cache() + torch.cuda.synchronize() + output = model.forward(feat, edge_index, edge_weight) + loss = self.loss_attack(output, labels, type='tanhMargin') + gradient = grad_with_checkpoint(loss, self.perturbed_edge_weight)[0] + + if torch.cuda.is_available() and self.do_synchronize: + torch.cuda.empty_cache() + torch.cuda.synchronize() + if it % 10 == 0: + print(f'Epoch {it}: {loss}') + + with torch.no_grad(): + self.update_edge_weights(n_perturbations, it, gradient) + self.perturbed_edge_weight = self.project( + n_perturbations, self.perturbed_edge_weight, self.eps) + + del edge_index, edge_weight #, logits + + if it < self.epochs_resampling - 1: + self.resample_random_block(n_perturbations) + + edge_index, edge_weight = self.get_modified_adj() + output = model.predict(feat, edge_index, edge_weight) + loss_val = F.nll_loss(output[data.val_mask], labels[data.val_mask]) + + self.perturbed_edge_weight.requires_grad = True + self.optimizer_adj = torch.optim.Adam([self.perturbed_edge_weight], lr=lr_adj) + + # Sample final discrete graph + edge_index, edge_weight = self.sample_final_edges(n_perturbations) + output = model.predict(feat, edge_index, edge_weight) + print('Test:') + self.get_perf(output, gt_labels, data.test_mask) + print('Validatoin:') + self.get_perf(output, gt_labels, data.val_mask) + return edge_index, edge_weight + + def loss_attack(self, logits, labels, type='CE'): + self.loss_type = type + if self.loss_type == 'tanhMargin': + sorted = logits.argsort(-1) + best_non_target_class = sorted[sorted != labels[:, None]].reshape(logits.size(0), -1)[:, -1] + margin = ( + logits[np.arange(logits.size(0)), labels] + - logits[np.arange(logits.size(0)), best_non_target_class] + ) + loss = torch.tanh(-margin).mean() + elif self.loss_type == 'MCE': + not_flipped = logits.argmax(-1) == labels + loss = F.cross_entropy(logits[not_flipped], labels[not_flipped]) + elif self.loss_type == 'NCE': + sorted = logits.argsort(-1) + best_non_target_class = sorted[sorted != labels[:, None]].reshape(logits.size(0), -1)[:, -1] + loss = -F.cross_entropy(logits, best_non_target_class) + else: + loss = F.cross_entropy(logits, labels) + return loss + + def get_perf(self, output, labels, mask, verbose=True): + loss = F.nll_loss(output[mask], labels[mask]) + acc = utils.accuracy(output[mask], labels[mask]) + if verbose: + print("loss= {:.4f}".format(loss.item()), + "accuracy= {:.4f}".format(acc.item())) + return loss.item(), acc.item() + +@torch.jit.script +def softmax_entropy(x: torch.Tensor) -> torch.Tensor: + """Entropy of softmax distribution from **logits**.""" + return -(x.softmax(1) * x.log_softmax(1)).sum(1) + +@torch.jit.script +def entropy(x: torch.Tensor) -> torch.Tensor: + """Entropy of softmax distribution from **log_softmax**.""" + return -(torch.exp(x) * x).sum(1) + +def to_symmetric(edge_index, edge_weight, n, op='mean'): + symmetric_edge_index = torch.cat( + (edge_index, edge_index.flip(0)), dim=-1 + ) + + symmetric_edge_weight = edge_weight.repeat(2) + + symmetric_edge_index, symmetric_edge_weight = coalesce( + symmetric_edge_index, + symmetric_edge_weight, + m=n, + n=n, + op=op + ) + return symmetric_edge_index, symmetric_edge_weight + +def linear_to_full_idx(n: int, lin_idx: torch.Tensor) -> torch.Tensor: + row_idx = lin_idx // n + col_idx = lin_idx % n + return torch.stack((row_idx, col_idx)) + +def linear_to_triu_idx(n: int, lin_idx: torch.Tensor) -> torch.Tensor: + row_idx = ( + n + - 2 + - torch.floor(torch.sqrt(-8 * lin_idx.double() + 4 * n * (n - 1) - 7) / 2.0 - 0.5) + ).long() + col_idx = ( + lin_idx + + row_idx + + 1 - n * (n - 1) // 2 + + (n - row_idx) * ((n - row_idx) - 1) // 2 + ) + return torch.stack((row_idx, col_idx)) + +def grad_with_checkpoint(outputs, inputs): + inputs = (inputs,) if isinstance(inputs, torch.Tensor) else tuple(inputs) + for input in inputs: + if not input.is_leaf: + input.retain_grad() + torch.autograd.backward(outputs) + + grad_outputs = [] + for input in inputs: + grad_outputs.append(input.grad.clone()) + input.grad.zero_() + return grad_outputs + +def bisection(edge_weights, a, b, n_perturbations, epsilon=1e-5, iter_max=1e5): + def func(x): + return torch.clamp(edge_weights - x, 0, 1).sum() - n_perturbations + + miu = a + for i in range(int(iter_max)): + miu = (a + b) / 2 + # Check if middle point is root + if (func(miu) == 0.0): + break + # Decide the side to repeat the steps + if (func(miu) * func(a) < 0): + b = miu + else: + a = miu + if ((b - a) <= epsilon): + break + return miu + + +if __name__ == "__main__": + from ogb.nodeproppred import PygNodePropPredDataset + from torch_geometric.utils import to_undirected + import torch_geometric.transforms as T + dataset = PygNodePropPredDataset(name='ogbn-arxiv') + dataset.transform = T.NormalizeFeatures() + data = dataset[0] + if not hasattr(data, 'train_mask'): + utils.add_mask(data, dataset) + data.edge_index = to_undirected(data.edge_index, data.num_nodes) + agent = PRBCD(data) + edge_index, edge_weight = agent.attack() + diff --git a/deeprobust/graph/utils.py b/deeprobust/graph/utils.py index 66eb6bc..1c4d25b 100644 --- a/deeprobust/graph/utils.py +++ b/deeprobust/graph/utils.py @@ -313,6 +313,16 @@ def loss_acc(output, labels, targets, avg_loss=True): # correct = correct.sum() # return loss, correct / len(labels) +def get_perf(output, labels, mask, verbose=True): + """evalute performance for test masked data""" + loss = F.nll_loss(output[mask], labels[mask]) + acc = accuracy(output[mask], labels[mask]) + if verbose: + print("loss= {:.4f}".format(loss.item()), + "accuracy= {:.4f}".format(acc.item())) + return loss.item(), acc.item() + + def classification_margin(output, true_label): """Calculate classification margin for outputs. `probs_true_label - probs_best_second_class` @@ -712,6 +722,56 @@ def reshape_mx(mx, shape): indices = mx.nonzero() return sp.csr_matrix((mx.data, (indices[0], indices[1])), shape=shape) +def add_mask(data, dataset): + """data: ogb-arxiv pyg data format""" + # for arxiv + split_idx = dataset.get_idx_split() + train_idx, valid_idx, test_idx = split_idx["train"], split_idx["valid"], split_idx["test"] + n = data.x.shape[0] + data.train_mask = index_to_mask(train_idx, n) + data.val_mask = index_to_mask(valid_idx, n) + data.test_mask = index_to_mask(test_idx, n) + data.y = data.y.squeeze() + # data.edge_index = to_undirected(data.edge_index, data.num_nodes) + +def index_to_mask(index, size): + mask = torch.zeros((size, ), dtype=torch.bool) + mask[index] = 1 + return mask + +def add_feature_noise(data, noise_ratio, seed): + np.random.seed(seed) + n, d = data.x.shape + # noise = torch.normal(mean=torch.zeros(int(noise_ratio*n), d), std=1) + noise = torch.FloatTensor(np.random.normal(0, 1, size=(int(noise_ratio*n), d))).to(data.x.device) + indices = np.arange(n) + indices = np.random.permutation(indices)[: int(noise_ratio*n)] + delta_feat = torch.zeros_like(data.x) + delta_feat[indices] = noise - data.x[indices] + data.x[indices] = noise + mask = np.zeros(n) + mask[indices] = 1 + mask = torch.tensor(mask).bool().to(data.x.device) + return delta_feat, mask + +def add_feature_noise_test(data, noise_ratio, seed): + np.random.seed(seed) + n, d = data.x.shape + indices = np.arange(n) + test_nodes = indices[data.test_mask.cpu()] + selected = np.random.permutation(test_nodes)[: int(noise_ratio*len(test_nodes))] + noise = torch.FloatTensor(np.random.normal(0, 1, size=(int(noise_ratio*len(test_nodes)), d))) + noise = noise.to(data.x.device) + + delta_feat = torch.zeros_like(data.x) + delta_feat[selected] = noise - data.x[selected] + data.x[selected] = noise + # mask = np.zeros(len(test_nodes)) + mask = np.zeros(n) + mask[selected] = 1 + mask = torch.tensor(mask).bool().to(data.x.device) + return delta_feat, mask + # def check_path(file_path): # if not osp.exists(file_path): # os.system(f'mkdir -p {file_path}') diff --git a/examples/graph/test_airgnn.py b/examples/graph/test_airgnn.py new file mode 100644 index 0000000..2f83dac --- /dev/null +++ b/examples/graph/test_airgnn.py @@ -0,0 +1,103 @@ +""""test different models on noise features""" +import argparse +import numpy as np +from torch_geometric.datasets import Planetoid +import torch_geometric.transforms as T +from deeprobust.graph.defense_pyg import AirGNN, GCN, APPNP, GAT, SAGE, GPRGNN +import torch +import random +import os.path as osp +from deeprobust.graph.utils import add_feature_noise, add_feature_noise_test, get_perf +import torch.nn.functional as F + +parser = argparse.ArgumentParser() +parser.add_argument('--gpu_id', type=int, default=0, help='gpu id') +parser.add_argument('--dataset', type=str, default='cora') +parser.add_argument('--epochs', type=int, default=10) +parser.add_argument('--lr', type=float, default=0.01) +parser.add_argument('--hidden', type=int, default=64) +parser.add_argument('--weight_decay', type=float, default=5e-4) +parser.add_argument('--with_bn', type=int, default=0) +parser.add_argument('--seed', type=int, default=0, help='Random seed.') +parser.add_argument('--nlayers', type=int, default=2) +parser.add_argument('--model', type=str, default='AirGNN') +parser.add_argument('--debug', type=float, default=0) +parser.add_argument('--dropout', type=float, default=0.5) +parser.add_argument('--noise_feature', type=float, default=0.3) +parser.add_argument('--lambda_', type=float, default=0) +args = parser.parse_args() + +torch.cuda.set_device(args.gpu_id) + +print('===========') + +# random seed setting +random.seed(args.seed) +np.random.seed(args.seed) +torch.manual_seed(args.seed) +torch.cuda.manual_seed(args.seed) + +def get_dataset(name, normalize_features=True, transform=None, if_dpr=True): + path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', name) + if name in ['cora', 'citeseer', 'pubmed']: + dataset = Planetoid(path, name) + else: + raise NotImplementedError + dataset.transform = T.NormalizeFeatures() + return dataset + +dataset = get_dataset(args.dataset) +data = dataset[0] + +def pretrain_model(): + feat, labels = data.x, data.y + nclass = max(labels).item()+1 + if args.model == "AirGNN": + args.dropout=0.2; args.lambda_amp=0.5; args.alpha=0.1 + model = AirGNN(nfeat=feat.shape[1], nhid=args.hidden, dropout=args.dropout, with_bn=args.with_bn, + K=10, weight_decay=args.weight_decay, args=args, nlayers=args.nlayers, + nclass=max(labels).item()+1, device=device).to(device) + elif args.model == "GCN": + model = GCN(nfeat=feat.shape[1], nhid=args.hidden, dropout=args.dropout, + nlayers=args.nlayers, with_bn=args.with_bn, + weight_decay=args.weight_decay, nclass=nclass, + device=device).to(device) + elif args.model == "GAT": + args.dropout = 0.5; args.hidden = 8 + model = GAT(nfeat=feat.shape[1], nhid=args.hidden, heads=8, lr=0.005, nlayers=args.nlayers, + nclass=nclass, with_bn=args.with_bn, weight_decay=args.weight_decay, + dropout=args.dropout, device=device).to(device) + elif args.model == "SAGE": + model = SAGE(feat.shape[1], 32, max(labels).item()+1, num_layers=5, + dropout=0.0, lr=0.01, weight_decay=0, device=device).to(device) + elif args.model == "GPR": + model = GPRGNN(feat.shape[1], 32, max(labels).item()+1, dropout=0.0, + lr=0.01, weight_decay=0, device=device).to(device) + else: + raise NotImplementedError + + print(model) + model.fit(data, train_iters=1000, patience=1000, verbose=True) + + model.eval() + model.data = data.to(device) + output = model.predict() + labels = labels.to(device) + print("Test set results:", get_perf(output, labels, data.test_mask, verbose=0)[1]) + return model + +device = 'cuda' +model = pretrain_model() + +if args.noise_feature > 0: + feat_noise, noisy_nodes = add_feature_noise_test(data, + args.noise_feature, args.seed) + +output = model.predict() +labels = data.y.to(device) +print("After noise, test set results:", get_perf(output, labels, data.test_mask, verbose=0)[1]) +print('Validation:', get_perf(output, labels, data.val_mask, verbose=0)[1]) +print('Abnomral test nodes:', get_perf(output, labels, noisy_nodes, verbose=0)[1]) +print('Normal test nodes:', get_perf(output, labels, data.test_mask & (~noisy_nodes), verbose=0)[1]) + + diff --git a/examples/graph/test_prbcd.py b/examples/graph/test_prbcd.py new file mode 100644 index 0000000..893b895 --- /dev/null +++ b/examples/graph/test_prbcd.py @@ -0,0 +1,24 @@ +from ogb.nodeproppred import PygNodePropPredDataset +from torch_geometric.utils import to_undirected +import torch_geometric.transforms as T +import argparse +import torch +import deeprobust.graph.utils as utils +from deeprobust.graph.global_attack import PRBCD + +parser = argparse.ArgumentParser() +parser.add_argument('--ptb_rate', type=float, default=0.1, help='perturbation rate.') +args = parser.parse_args() + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +dataset = PygNodePropPredDataset(name='ogbn-arxiv') +dataset.transform = T.NormalizeFeatures() +data = dataset[0] +if not hasattr(data, 'train_mask'): + utils.add_mask(data, dataset) + +data.edge_index = to_undirected(data.edge_index, data.num_nodes) +agent = PRBCD(data, device=device) +edge_index, edge_weight = agent.attack(ptb_rate=args.ptb_rate) + + diff --git a/setup.py b/setup.py index d032274..2d0255d 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ long_description = fh.read() setup(name = "deeprobust", - version = "0.2.3", + version = "0.2.7", author='MSU-DSE', maintainer='MSU-DSE', description = "A PyTorch library for adversarial robustness learning for image and graph data.",