Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Seg example #85

Open
wants to merge 20 commits into
base: fuse2
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/fuse_examples/imaging/segmentation/siim/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SIIM-ACR Pneumothorax Segmentation with Fute
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import pandas as pd
from glob import glob
import random
import pickle
from typing import Sequence, Hashable, Union, Optional, List, Dict
from pathlib import Path


def filter_files(files, include=[], exclude=[]):
for incl in include:
files = [f for f in files if incl in f.name]
for excl in exclude:
files = [f for f in files if excl not in f.name]
return sorted(files)


def ls(x, recursive=False, include=[], exclude=[]):
if not recursive:
out = list(x.iterdir())
else:
out = [o for o in x.glob('**/*')]
out = filter_files(out, include=include, exclude=exclude)
return out


def get_data_sample_ids(
phase: str, # can be ['train', 'validation']
data_folder: Optional[str] = None,
partition_file: Optional[str] = None,
val_split: float = 0.2,
override_partition: bool = True,
data_shuffle: bool = True
):
"""
Create DataSource
:param input_source: path to images
:param partition_file: Optional, name of a pickle file when no validation set is available
If train = True, train/val indices are dumped into the file,
If train = False, train/val indices are loaded
:param train: specifies if we are in training phase
:param val_split: validation proportion in case of splitting
:param override_partition: specifies if the given partition file is filled with new train/val splits
"""

# Extract entities
# ----------------
if partition_file is not None:
if phase == 'train':
if override_partition:

Path.ls = ls
files = Path(data_folder).ls(recursive=True, include=['.dcm'])

sample_descs = [str(fn) for fn in files]

if len(sample_descs) == 0:
raise Exception('Error detecting input source in FuseDataSourceDefault')

if data_shuffle:
# random shuffle the file-list
random.shuffle(sample_descs)

# split to train-validation -
n_train = int(len(sample_descs) * (1-val_split))

train_samples = sample_descs[:n_train]
val_samples = sample_descs[n_train:]
splits = {'train': train_samples, 'val': val_samples}

with open(partition_file, "wb") as pickle_out:
pickle.dump(splits, pickle_out)
sample_descs = train_samples
else:
# read from a previous train/test split to evaluate on the same partition
with open(partition_file, "rb") as splits:
repartition = pickle.load(splits)
sample_descs = repartition['train']
elif phase == 'validation':
with open(partition_file, "rb") as splits:
repartition = pickle.load(splits)
sample_descs = repartition['val']
else:
rle_df = pd.read_csv(data_source)

Path.ls = ls
files = Path(data_folder).ls(recursive=True, include=['.dcm'])

sample_descs = [str(fn) for fn in files]

return sample_descs
121 changes: 121 additions & 0 deletions examples/fuse_examples/imaging/segmentation/siim/image_mask_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@

"""
(C) Copyright 2021 IBM Corp.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

Created on June 30, 2021

"""

import numpy as np
import pandas as pd
from skimage.io import imread
import torch
from pathlib import Path
import PIL
import pydicom

from typing import Optional, Tuple

from fuse.data.ops.op_base import OpBase
from fuse.utils.ndict import NDict
from fuse.data.utils.sample import get_sample_id

# from fuse.data.processor.processor_base import FuseProcessorBase


def rle2mask(rles, width, height):
"""

rle encoding if images
input: rles(list of rle), width and height of image
returns: mask of shape (width,height)
"""

mask= np.zeros(width* height)
for rle in rles:
array = np.asarray([int(x) for x in rle.split()])
starts = array[0::2]
lengths = array[1::2]

current_position = 0
for index, start in enumerate(starts):
current_position += start
mask[current_position:current_position+lengths[index]] = 255
current_position += lengths[index]

return mask.reshape(width, height).T


class OpImageMaskLoader(OpBase):
def __init__(self,
data_csv: str = None,
size: int = 512,
normalization: float = 255.0, **kwargs):
"""
Create Input processor
:param input_data: path to images
:param normalized_target_range: range for image normalization
:param resize_to: Optional, new size of input images, keeping proportions
"""
super().__init__(**kwargs)

if data_csv:
self.df = pd.read_csv(data_csv)
else:
self.df = None

self.size = (size, size)
self.norm = normalization

def __call__(self, sample_dict: NDict, op_id: Optional[str], key_in:str, key_out: str):

desc = get_sample_id(sample_dict)

if self.df is not None: # compute mask
I = self.df.ImageId == Path(desc).stem
enc = self.df.loc[I, ' EncodedPixels']
if sum(I) == 0:
im = np.zeros((1024, 1024)).astype(np.uint8)
elif sum(I) == 1:
enc = enc.values[0]
if enc == '-1':
im = np.zeros((1024, 1024)).astype(np.uint8)
else:
im = rle2mask([enc], 1024, 1024).astype(np.uint8)
else:
im = rle2mask(enc.values, 1024, 1024).astype(np.uint8)

im = np.asarray(PIL.Image.fromarray(im).resize(self.size))
image = im > 0
image = image.astype('float32')

else: # load image
dcm = pydicom.read_file(desc).pixel_array
image = np.asarray(PIL.Image.fromarray(dcm).resize(self.size))

image = image.astype('float32')
image = image / 255.0

# convert image from shape (H x W x C) to shape (C x H x W) with C=3
if len(image.shape) > 2:
image = np.moveaxis(image, -1, 0)
else:
image = np.expand_dims(image, 0)

sample_dict[key_out] = image
return sample_dict

def reverse(self, sample_dict: NDict, key_to_reverse: str, key_to_follow: str, op_id: Optional[str]) -> dict:
return sample_dict
Loading