Skip to content

Commit

Permalink
Change to PyG loss layer
Browse files Browse the repository at this point in the history
  • Loading branch information
nyLiao committed Aug 12, 2024
1 parent 61d067d commit 1a3f09c
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 11 deletions.
2 changes: 1 addition & 1 deletion benchmark/trainer/fullbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(self,
super(TrnFullbatch, self).__init__(model, data, args, **kwargs)
metric = metric_loader(args).to(self.device)
self.evaluator = {k: metric.clone(postfix='_'+k) for k in self.splits}
self.criterion = nn.CrossEntropyLoss()
self.criterion = nn.BCELoss() if self.num_classes == 1 else nn.NLLLoss()

self.mask: dict = None
self.flag_test_deg = args.test_deg if hasattr(args, 'test_deg') else False
Expand Down
15 changes: 11 additions & 4 deletions benchmark/trainer/load_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from argparse import Namespace
from torchmetrics import MetricCollection
from torchmetrics.classification import (
MulticlassAccuracy, MultilabelAccuracy,
MulticlassF1Score, MultilabelF1Score,
MulticlassAUROC, MultilabelAUROC,
MulticlassAveragePrecision, MultilabelAveragePrecision,
MulticlassAccuracy, MultilabelAccuracy, BinaryAccuracy,
MulticlassF1Score, MultilabelF1Score, BinaryF1Score,
MulticlassAUROC, MultilabelAUROC, BinaryAUROC,
MulticlassAveragePrecision, MultilabelAveragePrecision, BinaryAveragePrecision,
)


Expand All @@ -37,6 +37,13 @@ def metric_loader(args: Namespace) -> MetricCollection:
's_auroc': MultilabelAUROC(num_classes=args.num_classes),
's_ap': MultilabelAveragePrecision(num_classes=args.num_classes),
})
elif args.num_classes == 1:
metric = ResCollection({
's_acc': BinaryAccuracy(),
's_f1i': BinaryF1Score(),
's_auroc': BinaryAUROC(),
's_ap': BinaryAveragePrecision(),
})
else:
metric = ResCollection({
's_acc': MulticlassAccuracy(num_classes=args.num_classes),
Expand Down
2 changes: 1 addition & 1 deletion benchmark/trainer/minibatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(self,

metric = metric_loader(args).to(self.device)
self.evaluator = {k: metric.clone(postfix='_'+k) for k in self.splits}
self.criterion = nn.CrossEntropyLoss()
self.criterion = nn.BCELoss() if self.num_classes == 1 else nn.NLLLoss()

self.shuffle = {'train': True, 'val': False, 'test': False}
self.embed = None
Expand Down
5 changes: 4 additions & 1 deletion pyg_spectral/nn/models/base_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,10 @@ def forward(self,
x = self.convolute(x, edge_index, batch=batch, batch_size=batch_size)
if self.out_layers > 0:
x = self.out_mlp(x, batch=batch, batch_size=batch_size)
return x
if self.out_channels == 1:
return torch.sigmoid(x)
else:
return torch.log_softmax(x)


class BaseNNCompose(BaseNN):
Expand Down
20 changes: 16 additions & 4 deletions pyg_spectral/nn/models/precomputed.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,10 @@ def forward(self,
"""
if self.out_layers > 0:
x = self.out_mlp(x, batch=batch, batch_size=batch_size)
return x
if self.out_channels == 1:
return torch.sigmoid(x)
else:
return torch.log_softmax(x)


class PrecomputedVar(DecoupledVar):
Expand Down Expand Up @@ -125,7 +128,10 @@ def forward(self,
out = conv_mat['out']
if self.out_layers > 0:
out = self.out_mlp(out, batch=batch, batch_size=batch_size)
return out
if self.out_channels == 1:
return torch.sigmoid(out)
else:
return torch.log_softmax(out)


# ==========
Expand Down Expand Up @@ -190,7 +196,10 @@ def forward(self,

if self.out_layers > 0:
out = self.out_mlp(out, batch=batch, batch_size=batch_size)
return out
if self.out_channels == 1:
return torch.sigmoid(out)
else:
return torch.log_softmax(out)


class PrecomputedVarCompose(DecoupledVarCompose):
Expand Down Expand Up @@ -271,4 +280,7 @@ def forward(self,

if self.out_layers > 0:
out = self.out_mlp(out, batch=batch, batch_size=batch_size)
return out
if self.out_channels == 1:
return torch.sigmoid(out)
else:
return torch.log_softmax(out)
22 changes: 22 additions & 0 deletions pyg_spectral/profile/efficiency.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def reset(self):
def data(self) -> float:
return self.elapsed_sec

def __repr__(self) -> str:
return f'{self.data:.2f} s'

def __enter__(self):
self.start()
return self
Expand Down Expand Up @@ -153,3 +156,22 @@ def update(self, module: Module):
mem_params = sum([p.nelement()*p.element_size() for p in module.parameters()])
mem_bufs = sum([b.nelement()*b.element_size() for b in module.buffers()])
self.set(mem_params + mem_bufs)


def log_memory(suffix: str = None, row: int = 0):
def decorator(func):
def wrapper(self, *args, **kwargs):
with self.device:
torch.cuda.empty_cache()

res = func(self, *args, **kwargs)
res.concat(
[('mem_ram', MemoryRAM()(unit='G')),
('mem_cuda', MemoryCUDA()(unit='G')),],
row=row, suffix=suffix)

with self.device:
torch.cuda.empty_cache()
return res
return wrapper
return decorator

0 comments on commit 1a3f09c

Please sign in to comment.