Skip to content

Commit

Permalink
add PRBCD and AirGNN
Browse files Browse the repository at this point in the history
  • Loading branch information
ChandlerBang committed Feb 26, 2023
1 parent 1c0ef07 commit 122588e
Show file tree
Hide file tree
Showing 16 changed files with 1,623 additions and 7 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,8 @@ If our work could help your research, please cite:
```

# Changelog
* [02/2023] DeepRobust 0.2.7 Released. Please try `pip install deeprobust==0.2.7`! We have added a scalable attack [PRBCD, NeurIPS'21](https://arxiv.org/abs/2110.14038) to graph package. We can now use PRBCD to attack large-scale graphs such as ogb-arxiv (see example in [test_prbcd.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_prbcd.py))!
* [02/2023] Add a robust model [AirGNN, NeurIPS'21](https://proceedings.neurips.cc/paper/2021/file/50abc3e730e36b387ca8e02c26dc0a22-Paper.pdf) to graph package. Try `python examples/graph/test_airgnn.py`! See details in [test_airgnn.py](https://github.com/DSE-MSU/DeepRobust/blob/master/examples/graph/test_airgnn.py)
* [11/2022] DeepRobust 0.2.6 Released. Please try `pip install deeprobust==0.2.6`! We have more updates coming. Please stay tuned!
* [11/2021] A subpackage that includes popular black box attacks in image domain is relased. Find it here. [Link](https://github.com/I-am-Bot/Black-Box-Attacks)
* [11/2021] DeepRobust 0.2.4 Released. Please try `pip install deeprobust==0.2.4`!
Expand Down
10 changes: 5 additions & 5 deletions deeprobust/graph/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ def __init__(self, root, name, setting='nettack', seed=None, require_mask=False)
self.seed = seed
# self.url = 'https://raw.githubusercontent.com/danielzuegner/nettack/master/data/%s.npz' % self.name
self.url = 'https://raw.githubusercontent.com/danielzuegner/gnn-meta-attack/master/data/%s.npz' % self.name
if platform.system() == 'Windows':
root = root
else:

if platform.system() == 'Windows':
root = root
else:
self.root = osp.expanduser(osp.normpath(root))

self.data_folder = osp.join(root, self.name)
self.data_filename = self.data_folder + '.npz'
self.require_mask = require_mask
Expand Down
15 changes: 15 additions & 0 deletions deeprobust/graph/defense_pyg/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
try:
from .gcn import GCN
from .gat import GAT
from .appnp import APPNP
from .sage import SAGE
from .gpr import GPRGNN
from .airgnn import AirGNN
except ImportError as e:
print(e)
warnings.warn("Please install pytorch geometric if you " +
"would like to use the datasets from pytorch " +
"geometric. See details in https://pytorch-geom" +
"etric.readthedocs.io/en/latest/notes/installation.html")

__all__ = ["GCN", "GAT", "APPNP", "SAGE", "GPRGNN", "AirGNN"]
186 changes: 186 additions & 0 deletions deeprobust/graph/defense_pyg/airgnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn.conv.gcn_conv import gcn_norm
from torch_geometric.nn.conv import MessagePassing
from typing import Optional, Tuple
from torch_geometric.typing import Adj, OptTensor
from torch import Tensor
from torch_sparse import SparseTensor, matmul
from .base_model import BaseModel
import torch.nn as nn

class AirGNN(BaseModel):

def __init__(self, nfeat, nhid, nclass, nlayers=2, K=2, dropout=0.5, lr=0.01,
with_bn=False, weight_decay=5e-4, with_bias=True, device=None, args=None):

super(AirGNN, self).__init__()
assert device is not None, "Please specify 'device'!"
self.device = device

self.lins = nn.ModuleList([])
self.lins.append(Linear(nfeat, nhid))
if with_bn:
self.bns = nn.ModuleList([])
self.bns.append(nn.BatchNorm1d(nhid))
for i in range(nlayers-2):
self.lins.append(Linear(nhid, nhid))
if with_bn:
self.bns.append(nn.BatchNorm1d(nhid))
self.lins.append(Linear(nhid, nclass))

self.prop = AdaptiveMessagePassing(K=K, alpha=args.alpha, mode=args.model, args=args)
print(self.prop)

self.dropout = dropout
self.weight_decay = weight_decay
self.lr = lr
self.name = args.model
self.with_bn = with_bn

def initialize(self):
self.reset_parameters()

def reset_parameters(self):
for lin in self.lins:
lin.reset_parameters()
if self.with_bn:
for bn in self.bns:
bn.reset_parameters()
self.prop.reset_parameters()

def forward(self, x, edge_index, edge_weight=None):
x, edge_index, edge_weight = self._ensure_contiguousness(x, edge_index, edge_weight)
edge_index = SparseTensor.from_edge_index(edge_index, edge_weight,
sparse_sizes=2 * x.shape[:1]).t()
for ii, lin in enumerate(self.lins[:-1]):
x = F.dropout(x, p=self.dropout, training=self.training)
x = lin(x)
if self.with_bn:
x = self.bns[ii](x)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lins[-1](x)
x = self.prop(x, edge_index)
return F.log_softmax(x, dim=1)

def get_embed(self, x, edge_index, edge_weight=None):
x, edge_index, edge_weight = self._ensure_contiguousness(x, edge_index, edge_weight)
edge_index = SparseTensor.from_edge_index(edge_index, edge_weight,
sparse_sizes=2 * x.shape[:1]).t()
for ii, lin in enumerate(self.lins[:-1]):
x = lin(x)
if self.with_bn:
x = self.bns[ii](x)
x = F.relu(x)
x = self.prop(x, edge_index)
return x


class AdaptiveMessagePassing(MessagePassing):
_cached_edge_index: Optional[Tuple[Tensor, Tensor]]
_cached_adj_t: Optional[SparseTensor]

def __init__(self,
K: int,
alpha: float,
dropout: float = 0.,
cached: bool = False,
add_self_loops: bool = True,
normalize: bool = True,
mode: str = None,
node_num: int = None,
args=None,
**kwargs):

super(AdaptiveMessagePassing, self).__init__(aggr='add', **kwargs)
self.K = K
self.alpha = alpha
self.mode = mode
self.dropout = dropout
self.cached = cached
self.add_self_loops = add_self_loops
self.normalize = normalize
self._cached_edge_index = None
self.node_num = node_num
self.args = args
self._cached_adj_t = None

def reset_parameters(self):
self._cached_edge_index = None
self._cached_adj_t = None

def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None, mode=None) -> Tensor:
if self.normalize:
if isinstance(edge_index, Tensor):
raise ValueError('Only support SparseTensor now')

elif isinstance(edge_index, SparseTensor):
cache = self._cached_adj_t
if cache is None:
edge_index = gcn_norm( # yapf: disable
edge_index, edge_weight, x.size(self.node_dim), False,
add_self_loops=self.add_self_loops, dtype=x.dtype)
if self.cached:
self._cached_adj_t = edge_index
else:
edge_index = cache

if mode == None: mode = self.mode

if self.K <= 0:
return x
hh = x

if mode == 'MLP':
return x

elif mode == 'APPNP':
x = self.appnp_forward(x=x, hh=hh, edge_index=edge_index, K=self.K, alpha=self.alpha)

elif mode in ['AirGNN']:
x = self.amp_forward(x=x, hh=hh, edge_index=edge_index, K=self.K)
else:
raise ValueError('wrong propagate mode')
return x

def appnp_forward(self, x, hh, edge_index, K, alpha):
for k in range(K):
x = self.propagate(edge_index, x=x, edge_weight=None, size=None)
x = x * (1 - alpha)
x += alpha * hh
return x

def amp_forward(self, x, hh, K, edge_index):
lambda_amp = self.args.lambda_amp
gamma = 1 / (2 * (1 - lambda_amp)) ## or simply gamma = 1

for k in range(K):
y = x - gamma * 2 * (1 - lambda_amp) * self.compute_LX(x=x, edge_index=edge_index) # Equation (9)
x = hh + self.proximal_L21(x=y - hh, lambda_=gamma * lambda_amp) # Equation (11) and (12)
return x

def proximal_L21(self, x: Tensor, lambda_):
row_norm = torch.norm(x, p=2, dim=1)
score = torch.clamp(row_norm - lambda_, min=0)
index = torch.where(row_norm > 0) # Deal with the case when the row_norm is 0
score[index] = score[index] / row_norm[index] # score is the adaptive score in Equation (14)
return score.unsqueeze(1) * x

def compute_LX(self, x, edge_index, edge_weight=None):
x = x - self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None)
return x

def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor:
return edge_weight.view(-1, 1) * x_j

def message_and_aggregate(self, adj_t: SparseTensor, x: Tensor) -> Tensor:
return matmul(adj_t, x, reduce=self.aggr)

def __repr__(self):
return '{}(K={}, alpha={}, mode={}, dropout={}, lambda_amp={})'.format(self.__class__.__name__, self.K,
self.alpha, self.mode, self.dropout,
self.args.lambda_amp)


79 changes: 79 additions & 0 deletions deeprobust/graph/defense_pyg/appnp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import torch.nn as nn
import torch.nn.functional as F
import math
import torch
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Module
from torch_geometric.nn import APPNP as APPNPConv
from torch.nn import Linear
from .base_model import BaseModel


class APPNP(BaseModel):

def __init__(self, nfeat, nhid, nclass, K=10, alpha=0.1, dropout=0.5, lr=0.01,
with_bn=False, weight_decay=5e-4, with_bias=True, device=None):

super(APPNP, self).__init__()

assert device is not None, "Please specify 'device'!"
self.device = device


self.lin1 = Linear(nfeat, nhid)
if with_bn:
self.bn1 = nn.BatchNorm1d(nhid)
self.bn2 = nn.BatchNorm1d(nclass)

self.lin2 = Linear(nhid, nclass)
self.prop1 = APPNPConv(K, alpha)

self.dropout = dropout
self.weight_decay = weight_decay
self.lr = lr
self.output = None
self.best_model = None
self.best_output = None
self.name = 'APPNP'
self.with_bn = with_bn

def forward(self, x, edge_index, edge_weight=None):
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lin1(x)
if self.with_bn:
x = self.bn1(x)
x = F.relu(x)
x = F.dropout(x, p=self.dropout, training=self.training)
x = self.lin2(x)
if self.with_bn:
x = self.bn2(x)
x = self.prop1(x, edge_index, edge_weight)
return F.log_softmax(x, dim=1)


def initialize(self):
self.lin1.reset_parameters()
self.lin2.reset_parameters()
if self.with_bn:
self.bn1.reset_parameters()
self.bn2.reset_parameters()


if __name__ == "__main__":
from deeprobust.graph.data import Dataset, Dpr2Pyg
data = Dataset(root='/tmp/', name='cora', setting='gcn')
adj, features, labels = data.adj, data.features, data.labels
idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test
model = GCN(nfeat=features.shape[1],
nhid=16,
nclass=labels.max().item() + 1,
dropout=0.5, device='cuda')
model = model.to('cuda')
pyg_data = Dpr2Pyg(data)[0]

import ipdb
ipdb.set_trace()

model.fit(pyg_data, verbose=True) # train with earlystopping
model.test()
print(model.predict())
Loading

0 comments on commit 122588e

Please sign in to comment.