Skip to content

Commit

Permalink
Sort data files in UltrasoundDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
chriscyyeung committed Nov 21, 2024
1 parent 5d7e3fa commit 6de1cf5
Showing 1 changed file with 24 additions and 12 deletions.
36 changes: 24 additions & 12 deletions UltrasoundSegmentation/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@ def __init__(self, root_folder, imgs_dir="images", gts_dir="labels", tfms_dir="t
self.transform = transform

# Find all data segmentation files and matching ultrasound files in input directory
self.images = glob.glob(os.path.join(root_folder, "**", imgs_dir, "**", "*.npy"), recursive=True)
self.segmentations = glob.glob(os.path.join(root_folder, "**", gts_dir, "**", "*.npy"), recursive=True)
self.tfm_matrices = glob.glob(os.path.join(root_folder, "**", tfms_dir, "**", "*.npy"), recursive=True)
self.images = sorted(glob.glob(os.path.join(root_folder, "**", imgs_dir, "**", "*.npy"), recursive=True))
self.segmentations = sorted(glob.glob(os.path.join(root_folder, "**", gts_dir, "**", "*.npy"), recursive=True))
self.tfm_matrices = sorted(glob.glob(os.path.join(root_folder, "**", tfms_dir, "**", "*.npy"), recursive=True))
assert len(self.images) == len(self.segmentations), "Number of images and segmentations must match."

def __len__(self):
"""
Expand Down Expand Up @@ -61,6 +62,9 @@ def __getitem__(self, index):
# If segmentation_data is 2D, add a channel dimension as last dimension
if len(segmentation_data.shape) == 2:
segmentation_data = np.expand_dims(segmentation_data, axis=-1)

if len(transform_data.shape) == 2:
transform_data = np.expand_dims(transform_data, axis=0)

data = {
"image": ultrasound_data,
Expand Down Expand Up @@ -139,19 +143,22 @@ def __getitem__(self, index):
image = np.stack([
np.load(self.data[scan]["image"][index + i])[..., 0]
for i in range(self.window_size)
])
label = np.stack([
np.load(self.data[scan]["label"][index + i])[..., 0]
for i in range(self.window_size)
])
], axis=-1) # shape: (H, W, window_size)

# only take middle frame as label
label = np.load(self.data[scan]["label"][index + self.window_size // 2])
# If segmentation_data is 2D, add a channel dimension as last dimension
if len(label.shape) == 2:
label = np.expand_dims(label, axis=-1)

transform = np.stack([
np.load(self.data[scan]["transform"][index + i])
for i in range(self.window_size)
])
]) # shape: (window_size, 4, 4) - not affected by transforms

# compute relative transform between each matrix with the first frame
# compute relative transform between each matrix with the middle frame
for i in range(self.window_size):
transform[i] = np.linalg.inv(transform[0]) @ transform[i]
transform[i] = np.linalg.inv(transform[self.window_size // 2]) @ transform[i]

data = {
"image": image,
Expand All @@ -177,8 +184,13 @@ def __call__(self, data):


if __name__ == "__main__":
dataset = SlidingWindowTrackedUSDataset("/mnt/e/PerkLab/Data/Spine/SpineTrainingData/04_Slices_train")
# dataset = SlidingWindowTrackedUSDataset("/mnt/c/Users/chris/Data/Spine/2024_SpineSeg/04_Slices_train")
dataset = UltrasoundDataset("/mnt/c/Users/chris/Data/Breast/AIGTData/train")
print(dataset.images[:5])
print(dataset.segmentations[:5])
print(dataset.tfm_matrices[:5])
print(len(dataset))
print(dataset[0]["image"].shape)
print(dataset[0]["label"].shape)
print(dataset[0]["transform"].shape)
print(dataset[0]["transform"])

0 comments on commit 6de1cf5

Please sign in to comment.