From a25b928008720125de17fe1a47c8cf15a2b92550 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9ment=20H=2E=20Benedetti?= Date: Tue, 1 Oct 2024 13:47:25 +0200 Subject: [PATCH] Finished YOLO annotator --- .../_widget_yolo_annotations.py | 179 +++++++++++++----- src/microglia_analyzer/custom-loss.py | 38 ++++ 2 files changed, 174 insertions(+), 43 deletions(-) create mode 100644 src/microglia_analyzer/custom-loss.py diff --git a/src/microglia_analyzer/_widget_yolo_annotations.py b/src/microglia_analyzer/_widget_yolo_annotations.py index 75e35cf..b3f799d 100644 --- a/src/microglia_analyzer/_widget_yolo_annotations.py +++ b/src/microglia_analyzer/_widget_yolo_annotations.py @@ -5,7 +5,7 @@ from napari.utils.notifications import show_info import tifffile -from microglia_analyzer import TIFF_REGEX +import cv2 import numpy as np import os @@ -15,24 +15,30 @@ # Name of the layer containing the current image. _IMAGE_LAYER = "Image" # Colors assigned to each YOLO class. -_COLORS = [ - "#FF0000", - "#00FF00", - "#0000FF", - "#FFFF00", - "#FF00FF", - "#00FFFF", - "#FF8000", - "#8000FF", - "#0080FF", - "#80FF00", - "#FF0080", - "#00FF80", - "#800000", - "#008000", +_COLORS = [ + "#FF4D4D", + "#4DFF4D", + "#4D4DFF", + "#FFD700", + "#FF66FF", + "#66FFFF", + "#FF9900", + "#9933FF", + "#3399FF", + "#99FF33", + "#FF3399", + "#33FF99", + "#B20000", + "#006600", "#800080", "#808000" ] +# Function used to read images. +imread = tifffile.imread +# Indices to have the width and height from an image shape +_WIDTH_HEIGHT = (0, 2) +# Arguments to pass to the imread function. +ARGS = {} # A YOLO bounding-box == a tuple of 5 elements: # - (int) The class to which this box belongs. @@ -69,9 +75,9 @@ def add_media_management_group_ui(self): box.setLayout(layout) # Label + button to select the source directory: - self.select_sources_directory_button = QPushButton("đź“‚ Sources directory") - self.select_sources_directory_button.clicked.connect(self.select_sources_directory) - layout.addWidget(self.select_sources_directory_button) + self.select_root_directory_button = QPushButton("đź“‚ Root directory") + self.select_root_directory_button.clicked.connect(self.select_sources_directory) + layout.addWidget(self.select_root_directory_button) # Label + text box for the inputs sub-folder's name: inputs_name_label = QLabel("Inputs sub-folder:") @@ -113,14 +119,7 @@ def add_classes_management_group_ui(self): layout.addLayout(h_laytout) # Label showing the number of boxes in each class - self.counts_label = QLabel("Counts:") - font = self.counts_label.font() - font.setBold(True) - self.counts_label.setFont(font) - layout.addWidget(self.counts_label) self.count_display_label = QLabel("") - self.counts_label.setMaximumHeight(20) - self.count_display_label.setMaximumHeight(10) layout.addWidget(self.count_display_label) self.layout.addWidget(box) @@ -205,7 +204,7 @@ def bbox2yolo(self, bbox): """ ymax, xmax = self.upper_corner(bbox) ymin, xmin = self.lower_corner(bbox) - height, width = self.viewer.layers[_IMAGE_LAYER].data.shape[:2] + height, width = self.viewer.layers[_IMAGE_LAYER].data.shape[_WIDTH_HEIGHT[0]:_WIDTH_HEIGHT[1]] x = (xmin + xmax) / 2 / width y = (ymin + ymax) / 2 / height w = (xmax - xmin) / width @@ -221,8 +220,15 @@ def layer2yolo(self, layer_name, index): return tuples def write_annotations(self, tuples): + """ + Responsible for writing the annotations in a '.txt' file, and updating the 'classes.txt' file. + + Args: + - tuples (list): A list of tuples, each tuple containing the class index and the YOLO bounding-box. + """ labels_folder = os.path.join(self.root_directory, self.annotations_directory) - current_as_txt = TIFF_REGEX.match(self.image_selector.currentText()).group(1) + ".txt" + name, _ = os.path.splitext(self.image_selector.currentText()) + current_as_txt = name + ".txt" labels_path = os.path.join(labels_folder, current_as_txt) with open(labels_path, "w") as f: for row in tuples: @@ -233,6 +239,10 @@ def write_annotations(self, tuples): show_info("Annotations saved.") def save_state(self): + """ + Saves the current annotations (bounding-boxes) in a '.txt' file. + The class index corresponds to the rank of the layers in the Napari stack. + """ count = 0 lines = [] for l in self.viewer.layers: @@ -259,6 +269,9 @@ def get_new_class_name(self): return full_name def add_yolo_class(self): + """ + Adds a new layer representing a YOLO class. + """ class_name = self.get_new_class_name() if _IMAGE_LAYER not in self.viewer.layers: show_info("No image loaded.") @@ -280,6 +293,14 @@ def add_yolo_class(self): ) def set_root_directory(self, directory): + """ + Sets the root directory (folder in which the 'sources' and 'annotations' folders are) of the application. + Probes the content of the 'sources' folder to find all sub-folders, to propose them in the GUI's dropdown. + No sub-folder is selected by default. + + Args: + - directory (str): The absolute path to the root directory. + """ folders = sorted([f for f in os.listdir(directory) if os.path.isdir(os.path.join(directory, f))]) folders = ["---"] + folders self.inputs_name.clear() @@ -287,6 +308,10 @@ def set_root_directory(self, directory): self.root_directory = directory def set_sources_directory(self): + """ + Whenever the user selects a new source directory, the content of the 'sources' folder is probed. + This function also checks if the 'annotations' folder exists, and creates it if not. + """ source_folder = self.inputs_name.currentText() annotations_folder = source_folder + "-labels" if (source_folder is None) or (source_folder == "---") or (source_folder == ""): @@ -302,17 +327,54 @@ def set_sources_directory(self): self.annotations_name.setText(annotations_folder) self.open_sources_directory() + def update_reader_fx(self): + global imread + global _WIDTH_HEIGHT + global ARGS + ARGS = {} + ext = self.images_list[0] + if (ext == "---") or (ext == ""): + return + ext = ext.split('.')[-1].lower() + if (ext == "tif") or (ext == "tiff"): + imread = tifffile.imread + _WIDTH_HEIGHT = (0, 2) + im = imread(os.path.join(self.root_directory, self.sources_directory, self.images_list[0])) + if len(im.shape) == 3: + _WIDTH_HEIGHT = (1, 3) + else: + imread = cv2.imread + _WIDTH_HEIGHT = (0, 2) + ARGS = {'flags': cv2.IMREAD_GRAYSCALE} + def open_sources_directory(self): + """ + Triggered when the user chooses a new source directory in the GUI's dropdown. + The content of the provided folder is probed to find all TIFF files. + If the folder is empty, a message is displayed, and the images list is set to ['---']. + The first item of the list is selected by default. + It gets opened automatically due to the signal 'currentIndexChanged'. + """ inputs_path = os.path.join(self.root_directory, self.sources_directory) - self.images_list = sorted([f for f in os.listdir(inputs_path) if TIFF_REGEX.match(f) is not None]) + self.images_list = sorted([f for f in os.listdir(inputs_path)]) if len(self.images_list) == 0: # Didn't find any file in the folder. - show_info("Didn't find any TIFF file in the provided folder.") + show_info("Didn't find any image in the provided folder.") self.images_list = ['---'] + else: + self.update_reader_fx() self.image_selector.clear() self.image_selector.addItems(self.images_list) return True def get_classes(self): + """ + Probes the layers stack of Napari to find the classes layers. + These layers are found through the prefix '_CLASS_PREFIX'. + The order matters. + + Returns: + (list): A list of strings containing the name of the classes. + """ classes = [] for l in self.viewer.layers: if l.name.startswith(_CLASS_PREFIX): @@ -320,12 +382,21 @@ def get_classes(self): return classes def clear_classes_layers(self): + """ + Reset the data of each shape layer representing a YOLO class. + Used before loading the annotations of a new image. + """ names = [l.name for l in self.viewer.layers] for n in names: if n.startswith(_CLASS_PREFIX): self.viewer.layers[n].data = [] def restore_classes_layers(self): + """ + Parses the 'classes.txt' file to restore the classes layers. + The file contains the name of the classes, one per line and nothing else. + Creates the associated shape layers with the right colors. + """ classes_path = os.path.join(self.root_directory, "classes.txt") if not os.path.isfile(classes_path): show_info("No classes file found.") @@ -356,7 +427,7 @@ def add_labels(self, data): The class index refers to the index in which shape layers appear in the layers stack of Napari. """ # Boxes are created according to the current image's size. - h, w = self.viewer.layers[_IMAGE_LAYER].data.shape[:2] + h, w = self.viewer.layers[_IMAGE_LAYER].data.shape[_WIDTH_HEIGHT[0]:_WIDTH_HEIGHT[1]] class_layers = [l.name for l in self.viewer.layers if l.name.startswith(_CLASS_PREFIX)] for c, bbox_list in data.items(): rectangles = [] @@ -423,18 +494,20 @@ def open_image(self): Reloads the annotations if some were already made for this image. """ current_image = self.image_selector.currentText() - # Check that the name is valid. if (self.root_directory is None) or (current_image is None) or (current_image == "---") or (current_image == ""): return image_path = os.path.join(self.root_directory, self.sources_directory, current_image) - current_as_txt = TIFF_REGEX.match(current_image).group(1) + ".txt" # Remove extension + adds ".txt" extension. + name, _ = os.path.splitext(current_image) + current_as_txt = name + ".txt" labels_path = os.path.join(self.root_directory, self.annotations_directory, current_as_txt) if not os.path.isfile(image_path): print(f"The image: '{current_image}' doesn't exist.") return - data = tifffile.imread(image_path) + data = imread(image_path, **ARGS) if _IMAGE_LAYER in self.viewer.layers: self.viewer.layers[_IMAGE_LAYER].data = data + self.viewer.layers[_IMAGE_LAYER].contrast_limits = (np.min(data), np.max(data)) + self.viewer.layers[_IMAGE_LAYER].reset_contrast_limits_range() else: self.viewer.add_image(data, name=_IMAGE_LAYER) self.deselect_all() @@ -447,8 +520,10 @@ def open_image(self): def count_boxes(self): """ Counts the number of boxes in each class to make sure annotations are balanced. + No fix is provded, just a display of the current state. + The whole annotations folder is probed to count the boxes. + The current image is not taken into account. """ - classes = self.get_classes() annotations_path = os.path.join(self.root_directory, self.annotations_directory) counts = dict() for f in os.listdir(annotations_path): @@ -462,13 +537,31 @@ def count_boxes(self): c, _, _, _, _ = line.split(" ") c = int(c) counts[c] = counts.get(c, 0) + 1 - - text = "" - for class_idx, class_name in enumerate(classes): - if class_idx not in counts: - text += f'{class_name}: 0
' - else: - text += f'{classes[class_idx]}: {counts[class_idx]}
' + self.update_count_display(counts) + + def update_count_display(self, counts): + """ + Updates the content of the table (in the GUI) displaying how many boxes are in each class. + Beyond 8 classes , the display could have an issue with the height of the QGroupBox. + Args: + - counts (str): is a dictionary where the key is the class index and the value is the number of boxes. + """ + classes = self.get_classes() + total_count = sum(counts.values()) + text = '' + for class_idx, class_name in enumerate(classes): + count = counts.get(class_idx, 0) + text += f''' + + + + + + ''' + text += "
+ {class_name} + + {count} ({round((count/total_count*100) if total_count > 0 else 0, 2)}%) +
" self.count_display_label.setText(text) - self.count_display_label.setMaximumHeight(50 * len(counts)) \ No newline at end of file diff --git a/src/microglia_analyzer/custom-loss.py b/src/microglia_analyzer/custom-loss.py new file mode 100644 index 0000000..9e8d325 --- /dev/null +++ b/src/microglia_analyzer/custom-loss.py @@ -0,0 +1,38 @@ +from skimage.morphology import skeletonize +from skimage.measure import label, regionprops + +def compute_skeleton_penalty(y_pred, min_length=10): + # Binariser la prédiction + y_pred_bin = y_pred > 0.5 + + # Labeliser les composantes connexes + labeled_image = label(y_pred_bin) + + penalty = 0 + for region in regionprops(labeled_image): + # Créer un masque pour chaque région + region_mask = labeled_image == region.label + + # Calculer le squelette de la région + skeleton = skeletonize(region_mask) + + # Calculer la longueur du squelette + skeleton_length = skeleton.sum() + + # Ajouter une pénalité si la longueur du squelette est inférieure au seuil + if skeleton_length < min_length: + penalty += (min_length - skeleton_length) + + return penalty + +def custom_loss(y_true, y_pred): + # Calcul de ta loss actuelle (par exemple, Dice Loss) + dice_loss = dice_loss_function(y_true, y_pred) + + # Calcul des pénalités basées sur la longueur des skeletons + skeleton_penalty = compute_skeleton_penalty(y_pred) + + # Combinaison des pénalités + total_loss = dice_loss + gamma * skeleton_penalty + + return total_loss \ No newline at end of file