Segmentation tasks are becoming easier day by day especially due to th rise in number of open source pachages like Segmentation Models PyTorch . Similarly with the advent of mixed precision training, even beginners can train state of the art models on their PCs. In this project I have tried to experiment with fp16 training with the help of PyTorch Lightning framework.
Pytorch lightning is a lightweight wrapper over pytorch and is used by researchers worldwide to speed up their Deep Learning experiments. You can use this project to set up your image segmentation project easily. The project structure is similar to the directory structure used in kaggle competitions so it will be easier to convert this project into a kaggle kernel.
- PyTorch (An open source deep learning platform)
- PyTorch Lightning (A lightweight PyTorch wrapper for ML researchers)
- Segmentation Models PyTorch (Python library with SOTA Networks for Image Segmentation based on PyTorch)
In a nutshell here's how to train your own segmentation model with PyTorch Lightning and Segmentation Models PyTorch , so for example assume you want to implement ResNet-34 to compete in Carvana Image Masking Challenge, so you should do the following:
- Create
/input
folder and place your dataset in it, editconfig.py
such that train and test paths point to their respective folders
INPUT = 'input'
OUTPUT = 'output'
TRAIN_PATH = f'{INPUT}/train_hq'
MASK_PATH = f'{INPUT}/train_masks'
TEST_PATH = f'{INPUT}/test_hq'
- In
config.py
file change the value ofMODEL_NAME
to the name of model that you wish to use, here we have usedsmp_unet_resnet34
.
MODEL_NAME = 'smp_unet_resnet34'
- In
model_dispatcher.py
file, you can build a dictionary namedbuild_model
to define your model using Segmentation Models PyTorch library
import segmentation_models_pytorch as smp
MODELS = {
'smp_unet_resnet34' : smp.Unet('resnet34', encoder_weights='imagenet', classes=config.CLASSES, activation='softmax'),
}
- In
dataset.py
file create aDataset Object
like this
class CarvanaDataset:
def __init__(self, folds):
df = pd.read_csv(config.TRAIN_FOLDS)
df = df[['img', 'kfold']]
df = df[df.kfold.isin(folds)].reset_index(drop=True)
self.image_ids = df.img.values
if len(folds) == 1:
self.aug = A.Compose([
A.Resize(config.CROP_SIZE, config.CROP_SIZE, always_apply=True),
A.Normalize(config.MODEL_MEAN, config.MODEL_STD, always_apply=True)
])
else:
self.aug = A.Compose([
A.Resize(config.CROP_SIZE, config.CROP_SIZE, always_apply=True),
A.ShiftScaleRotate(
shift_limit=0.0625,
scale_limit=0.1,
rotate_limit=15,
p=0.9),
A.Normalize(config.MODEL_MEAN, config.MODEL_STD, always_apply=True)
])
def __len__(self):
return len(self.image_ids)
def __getitem__(self, item):
img_name = self.image_ids[item]
image = np.array(Image.open(f'{config.TRAIN_PATH}/{img_name}.jpg'))
mask = np.array(Image.open(f'{config.MASK_PATH}/{img_name}_mask.gif'))
augmented = self.aug(image=image, mask=mask)
image = augmented['image']
mask = augmented['mask']
image = np.transpose(image, (2, 0, 1)).astype(np.float32)
return {
'image': torch.tensor(image, dtype=torch.float),
'mask': torch.tensor(mask, dtype=torch.float)
}
-
Now, run
folds.py
to create folds for training and validation. You can test your dataset usingtest_dataset.py
. -
Now we can build our
Lightning Module
:-
import pytorch_lightning as pl
import model_dispatcher
import config
from dataset import CarvanaDataset
class CarvanaModel(pl.LightningModule):
def __init__(self, train_folds, val_folds):
super(CarvanaModel, self).__init__()
# import model from model dispatcher
self.model = model_dispatcher.MODELS[config.MODEL_NAME]
self.train_folds = train_folds
self.val_folds = val_folds
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_nb):
# REQUIRED
x = batch['image']
y = batch['mask']
y_hat = self(x)
loss = IoULoss()(y_hat, y)
tensorboard_logs = {'train_loss': loss}
return {'loss': loss, 'log': tensorboard_logs}
def validation_step(self, batch, batch_nb):
# OPTIONAL
x = batch['image']
y = batch['mask']
y_hat = self(x)
return {'val_loss': IoULoss()(y_hat, y)}
def validation_epoch_end(self, outputs):
# OPTIONAL
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
tensorboard_logs = {'val_loss': avg_loss}
return {'val_loss': avg_loss, 'log': tensorboard_logs}
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=config.LR)
def train_dataloader(self):
# REQUIRED
return DataLoader(CarvanaDataset(folds=self.train_folds), shuffle=True, batch_size=config.TRAIN_BATCH_SIZE)
def val_dataloader(self):
# OPTIONAL
return DataLoader(CarvanaDataset(folds=self.val_folds), batch_size=config.VAL_BATCH_SIZE)
- Now we can train our model using the
Lightning_trainer.py
script
carvana_model = CarvanaModel(train_folds, val_folds)
# most basic trainer, uses good defaults (1 gpu)
trainer = pl.Trainer(gpus=1, accumulate_grad_batches=64, amp_level='O1', precision=16, profiler=True, max_epochs=config.EPOCHS)
trainer.fit(carvana_model)
Note: We have used amp to perform Mixed Precision Training using NVIDIA APEX library
├── input
│ ├── test_hq - here's the folder containing test images.
│ ├── train_hq - here's the folder containing train images.
│ └── train_masks - here's the folder containing train masks.
│
│
├── lightning_logs
│ └── version_# - training checkpoints are saved here.
│
│
├── output
│ └── train_folds.py - this file is generated when we run folds.py.
│
│
└── src
├── config.py - this file contains all the hyperparameters for the project.
├── dataset.py - this file contains dataset object.
├── decoders.py - redundant file will be used in future.
├── dice_loss.py - this file defines various losses which can be used to train the model.
├── encoders.py - redundant file will be used in future.
├── folds.py - this file created folds for cross validation.
├── Lightning_module.py - this file contains the Lightinig Module.
├── Lightning_trainer.py - run this file to train the model.
├── Lightning_tester.py - this file is used to evaluate the trained model.
├── model_dispatcher.py - this file contains model definitions.
├── ND_Crossentropy.py - helper function for loss function.
├── test_dataset.py - this file contains the train loops.
└── test_model.py - this file contains the inference process.
- Write Unetify script using encoder.py and decoder.py
- Add Augmentations to the dataset to make the model more robust
- Add Hydra support
Any kind of enhancement or contribution is welcomed.
- Loss functions were taken from this repo SegLoss