diff --git a/benchmark/trainer/fullbatch.py b/benchmark/trainer/fullbatch.py index c506355..58d341b 100755 --- a/benchmark/trainer/fullbatch.py +++ b/benchmark/trainer/fullbatch.py @@ -54,7 +54,7 @@ def __init__(self, super(TrnFullbatch, self).__init__(model, data, args, **kwargs) metric = metric_loader(args).to(self.device) self.evaluator = {k: metric.clone(postfix='_'+k) for k in self.splits} - self.criterion = nn.CrossEntropyLoss() + self.criterion = nn.BCELoss() if self.num_classes == 1 else nn.NLLLoss() self.mask: dict = None self.flag_test_deg = args.test_deg if hasattr(args, 'test_deg') else False diff --git a/benchmark/trainer/load_metric.py b/benchmark/trainer/load_metric.py index f7ed56a..1593a81 100755 --- a/benchmark/trainer/load_metric.py +++ b/benchmark/trainer/load_metric.py @@ -7,10 +7,10 @@ from argparse import Namespace from torchmetrics import MetricCollection from torchmetrics.classification import ( - MulticlassAccuracy, MultilabelAccuracy, - MulticlassF1Score, MultilabelF1Score, - MulticlassAUROC, MultilabelAUROC, - MulticlassAveragePrecision, MultilabelAveragePrecision, + MulticlassAccuracy, MultilabelAccuracy, BinaryAccuracy, + MulticlassF1Score, MultilabelF1Score, BinaryF1Score, + MulticlassAUROC, MultilabelAUROC, BinaryAUROC, + MulticlassAveragePrecision, MultilabelAveragePrecision, BinaryAveragePrecision, ) @@ -37,6 +37,13 @@ def metric_loader(args: Namespace) -> MetricCollection: 's_auroc': MultilabelAUROC(num_classes=args.num_classes), 's_ap': MultilabelAveragePrecision(num_classes=args.num_classes), }) + elif args.num_classes == 1: + metric = ResCollection({ + 's_acc': BinaryAccuracy(), + 's_f1i': BinaryF1Score(), + 's_auroc': BinaryAUROC(), + 's_ap': BinaryAveragePrecision(), + }) else: metric = ResCollection({ 's_acc': MulticlassAccuracy(num_classes=args.num_classes), diff --git a/benchmark/trainer/minibatch.py b/benchmark/trainer/minibatch.py index a5853b4..c63f3e0 100755 --- a/benchmark/trainer/minibatch.py +++ b/benchmark/trainer/minibatch.py @@ -63,7 +63,7 @@ def __init__(self, metric = metric_loader(args).to(self.device) self.evaluator = {k: metric.clone(postfix='_'+k) for k in self.splits} - self.criterion = nn.CrossEntropyLoss() + self.criterion = nn.BCELoss() if self.num_classes == 1 else nn.NLLLoss() self.shuffle = {'train': True, 'val': False, 'test': False} self.embed = None diff --git a/pyg_spectral/nn/models/base_nn.py b/pyg_spectral/nn/models/base_nn.py index 99d989c..3080308 100644 --- a/pyg_spectral/nn/models/base_nn.py +++ b/pyg_spectral/nn/models/base_nn.py @@ -214,7 +214,10 @@ def forward(self, x = self.convolute(x, edge_index, batch=batch, batch_size=batch_size) if self.out_layers > 0: x = self.out_mlp(x, batch=batch, batch_size=batch_size) - return x + if self.out_channels == 1: + return torch.sigmoid(x) + else: + return torch.log_softmax(x) class BaseNNCompose(BaseNN): diff --git a/pyg_spectral/nn/models/precomputed.py b/pyg_spectral/nn/models/precomputed.py index c3af4a9..1a9875a 100644 --- a/pyg_spectral/nn/models/precomputed.py +++ b/pyg_spectral/nn/models/precomputed.py @@ -57,7 +57,10 @@ def forward(self, """ if self.out_layers > 0: x = self.out_mlp(x, batch=batch, batch_size=batch_size) - return x + if self.out_channels == 1: + return torch.sigmoid(x) + else: + return torch.log_softmax(x) class PrecomputedVar(DecoupledVar): @@ -125,7 +128,10 @@ def forward(self, out = conv_mat['out'] if self.out_layers > 0: out = self.out_mlp(out, batch=batch, batch_size=batch_size) - return out + if self.out_channels == 1: + return torch.sigmoid(out) + else: + return torch.log_softmax(out) # ========== @@ -190,7 +196,10 @@ def forward(self, if self.out_layers > 0: out = self.out_mlp(out, batch=batch, batch_size=batch_size) - return out + if self.out_channels == 1: + return torch.sigmoid(out) + else: + return torch.log_softmax(out) class PrecomputedVarCompose(DecoupledVarCompose): @@ -271,4 +280,7 @@ def forward(self, if self.out_layers > 0: out = self.out_mlp(out, batch=batch, batch_size=batch_size) - return out + if self.out_channels == 1: + return torch.sigmoid(out) + else: + return torch.log_softmax(out) diff --git a/pyg_spectral/profile/efficiency.py b/pyg_spectral/profile/efficiency.py index 93054f1..9b13cc8 100755 --- a/pyg_spectral/profile/efficiency.py +++ b/pyg_spectral/profile/efficiency.py @@ -34,6 +34,9 @@ def reset(self): def data(self) -> float: return self.elapsed_sec + def __repr__(self) -> str: + return f'{self.data:.2f} s' + def __enter__(self): self.start() return self @@ -153,3 +156,22 @@ def update(self, module: Module): mem_params = sum([p.nelement()*p.element_size() for p in module.parameters()]) mem_bufs = sum([b.nelement()*b.element_size() for b in module.buffers()]) self.set(mem_params + mem_bufs) + + +def log_memory(suffix: str = None, row: int = 0): + def decorator(func): + def wrapper(self, *args, **kwargs): + with self.device: + torch.cuda.empty_cache() + + res = func(self, *args, **kwargs) + res.concat( + [('mem_ram', MemoryRAM()(unit='G')), + ('mem_cuda', MemoryCUDA()(unit='G')),], + row=row, suffix=suffix) + + with self.device: + torch.cuda.empty_cache() + return res + return wrapper + return decorator