Skip to content

Latest commit

 

History

History
103 lines (80 loc) · 3.14 KB

README.md

File metadata and controls

103 lines (80 loc) · 3.14 KB

IMPaSh

Implementation of paper [arXiv]:

"IMPaSh: A Novel Domain-shift Resistant Representation for Colorectal Cancer Tissue Classification" by Trinh Thi Le Vuong, Quoc Dang Vu, Mostafa Jahanifar, Simon Graham, Jin Tae Kwak, and Nasir Rajpoot. ECCV Workshop 2022.


Model Weights

IMPaSh's encoder and classifier weights:

Snippet of PatchShuffling Module

import numpy as np
from random import shuffle
from PIL import Image

class PatchShuffling(object):
    """
     PatchShuffling Module
    """
    def __init__(self, n_grid=3, img_size=255, crop_size=64):
        self.n_grid = n_grid
        self.img_size = img_size
        self.crop_size = crop_size
        self.grid_size = int(img_size / self.n_grid)
        self.side = self.grid_size - self.crop_size

        yy, xx = np.meshgrid(np.arange(n_grid), np.arange(n_grid))
        self.yy = np.reshape(yy * self.grid_size, (n_grid * n_grid,))
        self.xx = np.reshape(xx * self.grid_size, (n_grid * n_grid,))

        self.re_yy = np.reshape(yy * self.crop_size, (n_grid * n_grid,))
        self.re_xx = np.reshape(xx * self.crop_size, (n_grid * n_grid,))

    def __call__(self, img):
        r_x = np.random.randint(0, self.side + 1, self.n_grid * self.n_grid)
        r_y = np.random.randint(0, self.side + 1, self.n_grid * self.n_grid)
        img = np.asarray(img, np.uint8)
        crops = []
        for i in range(self.n_grid * self.n_grid):
            crops.append(img[self.xx[i] + r_x[i]: self.xx[i] + r_x[i] + self.crop_size,
                         self.yy[i] + r_y[i]: self.yy[i] + r_y[i] + self.crop_size, :])
        shuffle(crops)

        shuffling_img = np.zeros([self.crop_size*self.n_grid, self.crop_size*self.n_grid, 3], dtype='uint8')
        for i in range(self.n_grid * self.n_grid):
            shuffling_img[self.re_xx[i]: self.re_xx[i] + self.crop_size, self.re_yy[i]: self.re_yy[i] + self.crop_size] \
                = crops[i]

        return Image.fromarray(shuffling_img)

Train the self-supervised encoder and Options

python main_contrast.py \
  --method IMPaShMoCo \
  --jigsaw_stitch\
  --cosine \
  --dataset_name  k19 \
  --multiprocessing-distributed --world-size 1 --rank 0 \
  --dist-url 'tcp://127.0.0.1:23458'

Train the classifier and Options

python main_linear.py \
 --method PatchSMoco \
 --ckpt ./save/k19_IMPaSH/ckpt_epoch_200.pth\
 --aug_linear RA \
 --dataset_name  k19 \
 --keephead head \
 --multiprocessing-distributed --world-size 1 --rank 0 \
 --dist-url 'tcp://127.0.0.1:23458'

Inference on target dataset

python main_infer.py \
 --method IMPaSH \
 --ckpt ./save/k19_IMPaSH/ckpt_epoch_200.pth\
 --ckpt_class ./save/k19_IMPaSH_linear_head_True/ckpt_epoch_40.pth\
 --dataset_name  k16 \
 --keephead head \
 --multiprocessing-distributed --world-size 1 --rank 0 \
 --dist-url 'tcp://127.0.0.1:23452'