From 6de1cf513c9ea8bc63fd3cba482faed542f8b143 Mon Sep 17 00:00:00 2001 From: Chris Yeung Date: Thu, 21 Nov 2024 16:57:58 -0500 Subject: [PATCH] Sort data files in UltrasoundDataset --- UltrasoundSegmentation/datasets.py | 36 ++++++++++++++++++++---------- 1 file changed, 24 insertions(+), 12 deletions(-) diff --git a/UltrasoundSegmentation/datasets.py b/UltrasoundSegmentation/datasets.py index 8dc58bf..3efa1d9 100644 --- a/UltrasoundSegmentation/datasets.py +++ b/UltrasoundSegmentation/datasets.py @@ -15,9 +15,10 @@ def __init__(self, root_folder, imgs_dir="images", gts_dir="labels", tfms_dir="t self.transform = transform # Find all data segmentation files and matching ultrasound files in input directory - self.images = glob.glob(os.path.join(root_folder, "**", imgs_dir, "**", "*.npy"), recursive=True) - self.segmentations = glob.glob(os.path.join(root_folder, "**", gts_dir, "**", "*.npy"), recursive=True) - self.tfm_matrices = glob.glob(os.path.join(root_folder, "**", tfms_dir, "**", "*.npy"), recursive=True) + self.images = sorted(glob.glob(os.path.join(root_folder, "**", imgs_dir, "**", "*.npy"), recursive=True)) + self.segmentations = sorted(glob.glob(os.path.join(root_folder, "**", gts_dir, "**", "*.npy"), recursive=True)) + self.tfm_matrices = sorted(glob.glob(os.path.join(root_folder, "**", tfms_dir, "**", "*.npy"), recursive=True)) + assert len(self.images) == len(self.segmentations), "Number of images and segmentations must match." def __len__(self): """ @@ -61,6 +62,9 @@ def __getitem__(self, index): # If segmentation_data is 2D, add a channel dimension as last dimension if len(segmentation_data.shape) == 2: segmentation_data = np.expand_dims(segmentation_data, axis=-1) + + if len(transform_data.shape) == 2: + transform_data = np.expand_dims(transform_data, axis=0) data = { "image": ultrasound_data, @@ -139,19 +143,22 @@ def __getitem__(self, index): image = np.stack([ np.load(self.data[scan]["image"][index + i])[..., 0] for i in range(self.window_size) - ]) - label = np.stack([ - np.load(self.data[scan]["label"][index + i])[..., 0] - for i in range(self.window_size) - ]) + ], axis=-1) # shape: (H, W, window_size) + + # only take middle frame as label + label = np.load(self.data[scan]["label"][index + self.window_size // 2]) + # If segmentation_data is 2D, add a channel dimension as last dimension + if len(label.shape) == 2: + label = np.expand_dims(label, axis=-1) + transform = np.stack([ np.load(self.data[scan]["transform"][index + i]) for i in range(self.window_size) - ]) + ]) # shape: (window_size, 4, 4) - not affected by transforms - # compute relative transform between each matrix with the first frame + # compute relative transform between each matrix with the middle frame for i in range(self.window_size): - transform[i] = np.linalg.inv(transform[0]) @ transform[i] + transform[i] = np.linalg.inv(transform[self.window_size // 2]) @ transform[i] data = { "image": image, @@ -177,8 +184,13 @@ def __call__(self, data): if __name__ == "__main__": - dataset = SlidingWindowTrackedUSDataset("/mnt/e/PerkLab/Data/Spine/SpineTrainingData/04_Slices_train") + # dataset = SlidingWindowTrackedUSDataset("/mnt/c/Users/chris/Data/Spine/2024_SpineSeg/04_Slices_train") + dataset = UltrasoundDataset("/mnt/c/Users/chris/Data/Breast/AIGTData/train") + print(dataset.images[:5]) + print(dataset.segmentations[:5]) + print(dataset.tfm_matrices[:5]) print(len(dataset)) print(dataset[0]["image"].shape) print(dataset[0]["label"].shape) print(dataset[0]["transform"].shape) + print(dataset[0]["transform"])