Skip to content

Commit

Permalink
classification/regression Heads (#164)
Browse files Browse the repository at this point in the history
* use torch lightning instead of manager

* last changes

* define just num_gpus instead of specify them

* delete managers files and fix lightning log to be for each epoch (not step)

* black fix

* modify make prection file to lightning style

* black changes

* more flake8 corrections

* flake8 corrections

* flake8 corrections

* flake8 corrections

* make model function

* black corrections

* flake8

* added knight test and dataset reformatted

* update

* seperate dataset and dataloader creations

* changed baseline config load for testing

* delete unecessary data

* fix test targets

* black reformat

* reformat flake8

* flake8 fix

* mylint reformat

* black reformat

* tests update

* test fix

* shorter test

* test fix

* test fix

* fix test

* fix test

* fix make targets

* fix make targets

* added 3d-regression head and fixed sampler mistake

* changes requested

* import fix

* flake8 fix

* changed heads to be more general

* fixed imports

* concat right dim

* fix error

* black fix
  • Loading branch information
liamhazan authored Oct 2, 2022
1 parent ddcbf19 commit 060abd2
Show file tree
Hide file tree
Showing 8 changed files with 147 additions and 290 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ ModelMultiHead(
conv_inputs=(('data.input.img', 1),), # input to the backbone model
backbone=BackboneResnet3D(in_channels=1), # PyTorch nn Module
heads=[ # list of heads - gives the option to support multi task / multi head approach
Head3DClassifier(head_name='classification',
Head3D(head_name='classification',
mode="classification",
conv_inputs=[("model.backbone_features", 512)] # Input to the classification head
,),
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pandas as pd
from fuse.dl.models import ModelMultiHead
from fuse.dl.models.backbones.backbone_resnet_3d import BackboneResnet3D
from fuse.dl.models.heads.heads_3D import Head3DClassifier
from fuse.dl.models.heads.heads_3D import Head3D
from fuseimg.datasets.knight import KNIGHT
import torch.nn.functional as F
import torch.nn as nn
Expand Down Expand Up @@ -51,11 +51,12 @@ def make_model(use_data: dict, num_classes: int, imaging_dropout: float, fused_d
conv_inputs=(("data.input.img", 1),),
backbone=backbone,
heads=[
Head3DClassifier(
Head3D(
head_name="head_0",
mode="classification",
conv_inputs=conv_inputs,
dropout_rate=imaging_dropout,
num_classes=num_classes,
num_outputs=num_classes,
append_features=append_features,
append_layers_description=(256, 128),
fused_dropout_rate=fused_dropout,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
from fuse.dl.losses.loss_default import LossDefault
from fuse.dl.models.backbones.backbone_resnet_3d import BackboneResnet3D
from fuse.dl.models import ModelMultiHead
from fuse.dl.models.heads.heads_3D import Head3DClassifier
from fuse.dl.models.heads.heads_3D import Head3D
from fuse.dl.lightning.pl_module import LightningModuleDefault

from fuse.utils.utils_debug import FuseDebug
Expand Down Expand Up @@ -116,19 +116,20 @@
def create_model(imaging_dropout: float, clinical_dropout: float, fused_dropout: float) -> torch.nn.Module:
"""
creates the model
See Head3DClassifier for details about imaging_dropout, clinical_dropout, fused_dropout
See Head3D for details about imaging_dropout, clinical_dropout, fused_dropout
"""
model = ModelMultiHead(
conv_inputs=(("data.input.img", 1),),
backbone=BackboneResnet3D(in_channels=1),
heads=[
Head3DClassifier(
Head3D(
head_name="classification",
mode="classification",
conv_inputs=[("model.backbone_features", 512)],
dropout_rate=imaging_dropout,
append_dropout_rate=clinical_dropout,
fused_dropout_rate=fused_dropout,
num_classes=2,
num_outputs=2,
append_features=[("data.input.clinical", 8)],
append_layers_description=(256, 128),
),
Expand Down
2 changes: 1 addition & 1 deletion fuse/dl/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Implemented backbones include a "vanilla" fully connected network, or Multi Laye
[The KNIGHT challenge example uses a 3D ResNet backbone `BackboneResnet3D`](../../examples/fuse_examples/imaging/classification/knight/baseline/fuse_baseline.py)

Implemented "heads" include a number of parameterized classifier heads, in 1D, 2D and 3D, as well as a dense segmentation head.
[The KNIGHT challenge example uses a 3D classification head `Head3DClassifier`](../../examples/fuse_examples/imaging/classification/knight/baseline/fuse_baseline.py)
[The KNIGHT challenge example uses a 3D classification head `Head3D`](../../examples/fuse_examples/imaging/classification/knight/baseline/fuse_baseline.py)

## templates
This module contains a walkthrough template code rich with comments, to demonstrate training with FuseMedML, with all required building blocks.
4 changes: 2 additions & 2 deletions fuse/dl/models/heads/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .common import ClassifierFCN, ClassifierFCN3D, ClassifierMLP
from .head_1D_classifier import Head1DClassifier
from .heads_3D import Head3DClassifier
from .heads_1D import Head1D
from .heads_3D import Head3D
from .head_dense_segmentation import HeadDenseSegmentation
from .head_global_pooling_classifier import HeadGlobalPoolingClassifier
182 changes: 0 additions & 182 deletions fuse/dl/models/heads/head_1D_classifier.py

This file was deleted.

117 changes: 117 additions & 0 deletions fuse/dl/models/heads/heads_1D.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""
(C) Copyright 2021 IBM Corp.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Created on June 30, 2021
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from typing import Dict, Tuple, Sequence, Optional
from fuse.dl.models.heads.common import ClassifierMLP


class Head1D(nn.Module):
def __init__(
self,
head_name: str = "head_0",
mode: str = None, # "classification" or "regression"
conv_inputs: Sequence[Tuple[str, int]] = None,
num_outputs: int = 2, # num classes in case of classification
append_features: Optional[Sequence[Tuple[str, int]]] = None,
layers_description: Sequence[int] = (256,),
append_layers_description: Sequence[int] = tuple(),
append_dropout_rate: float = 0.0,
dropout_rate: float = 0.1,
) -> None:
"""
head 1d.
Output of a forward pass for classification:
'model.logits.head_name' and 'outputs.head_name', both in shape [batch_size, num_outputs]
Output of a forward pass for regression:
'model.output.head_name' in shape [batch_size, num_outputs]
:param head_name: batch_dict key
:param mode: "classification" or "regression"
:param conv_inputs: List of feature map inputs - tuples of (batch_dict key, channel depth)
If multiple inputs are used, they are concatenated on the channel axis
for example:
conv_inputs=(('model.backbone_features', 193),)
:param num_outputs: Number of output classes (in case of classification) or just num outputs in case of regression
:param append_features: Additional vector (one dimensional) inputs, concatenated just before the classifier module
:param layers_description: Layers description for the classifier module - sequence of hidden layers sizes
:param dropout_rate: Dropout rate for classifier module layers
"""
super().__init__()

self.head_name = head_name
self.mode = mode
assert conv_inputs is not None, "conv_inputs must be provided"
self.conv_inputs = conv_inputs
self.append_features = append_features

self.features_size = sum([conv_input[1] for conv_input in self.conv_inputs])

if append_features is not None:
if len(append_layers_description) == 0:
self.features_size += sum([post_concat_input[1] for post_concat_input in append_features])
self.append_features_module = nn.Identity()
else:
self.features_size += append_layers_description[-1]
self.append_features_module = ClassifierMLP(
in_ch=sum([post_concat_input[1] for post_concat_input in append_features]),
num_classes=None,
layers_description=append_layers_description,
dropout_rate=append_dropout_rate,
)

self.head_module = ClassifierMLP(
in_ch=self.features_size,
num_classes=num_outputs,
layers_description=layers_description,
dropout_rate=dropout_rate,
)

def forward(self, batch_dict: Dict) -> Dict:

conv_input = torch.cat([batch_dict[conv_input[0]] for conv_input in self.conv_inputs], dim=1)
global_features = conv_input

if self.append_features is not None:
features = torch.cat([batch_dict[append_feature[0]] for append_feature in self.append_features])
features = self.append_features_module(features)
features = features.reshape(features.shape + (1, 1, 1))
if self.conv_inputs is not None:
global_features = torch.cat((global_features, features), dim=1)
else:
global_features = features
if self.mode == "regression":
prediction = self.head_module(global_features).squeeze(dim=1)
batch_dict["model.output." + self.head_name] = prediction
else:
logits = self.head_module(global_features) # --> res.shape = [batch_size, 2, 1, 1]
if len(logits.shape) > 2:
logits = logits.squeeze(dim=3) # --> res.shape = [batch_size, 2, 1]
logits = logits.squeeze(dim=2) # --> res.shape = [batch_size, 2]

cls_preds = F.softmax(logits, dim=1)

batch_dict["model.logits." + self.head_name] = logits
batch_dict["model.output." + self.head_name] = cls_preds

return batch_dict
Loading

0 comments on commit 060abd2

Please sign in to comment.