Skip to content

Commit

Permalink
Finished YOLO annotator
Browse files Browse the repository at this point in the history
  • Loading branch information
c-h-benedetti committed Oct 1, 2024
1 parent 02bdfe6 commit a25b928
Show file tree
Hide file tree
Showing 2 changed files with 174 additions and 43 deletions.
179 changes: 136 additions & 43 deletions src/microglia_analyzer/_widget_yolo_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from napari.utils.notifications import show_info

import tifffile
from microglia_analyzer import TIFF_REGEX
import cv2

import numpy as np
import os
Expand All @@ -15,24 +15,30 @@
# Name of the layer containing the current image.
_IMAGE_LAYER = "Image"
# Colors assigned to each YOLO class.
_COLORS = [
"#FF0000",
"#00FF00",
"#0000FF",
"#FFFF00",
"#FF00FF",
"#00FFFF",
"#FF8000",
"#8000FF",
"#0080FF",
"#80FF00",
"#FF0080",
"#00FF80",
"#800000",
"#008000",
_COLORS = [
"#FF4D4D",
"#4DFF4D",
"#4D4DFF",
"#FFD700",
"#FF66FF",
"#66FFFF",
"#FF9900",
"#9933FF",
"#3399FF",
"#99FF33",
"#FF3399",
"#33FF99",
"#B20000",
"#006600",
"#800080",
"#808000"
]
# Function used to read images.
imread = tifffile.imread
# Indices to have the width and height from an image shape
_WIDTH_HEIGHT = (0, 2)
# Arguments to pass to the imread function.
ARGS = {}

# A YOLO bounding-box == a tuple of 5 elements:
# - (int) The class to which this box belongs.
Expand Down Expand Up @@ -69,9 +75,9 @@ def add_media_management_group_ui(self):
box.setLayout(layout)

# Label + button to select the source directory:
self.select_sources_directory_button = QPushButton("📂 Sources directory")
self.select_sources_directory_button.clicked.connect(self.select_sources_directory)
layout.addWidget(self.select_sources_directory_button)
self.select_root_directory_button = QPushButton("📂 Root directory")
self.select_root_directory_button.clicked.connect(self.select_sources_directory)
layout.addWidget(self.select_root_directory_button)

# Label + text box for the inputs sub-folder's name:
inputs_name_label = QLabel("Inputs sub-folder:")
Expand Down Expand Up @@ -113,14 +119,7 @@ def add_classes_management_group_ui(self):
layout.addLayout(h_laytout)

# Label showing the number of boxes in each class
self.counts_label = QLabel("Counts:")
font = self.counts_label.font()
font.setBold(True)
self.counts_label.setFont(font)
layout.addWidget(self.counts_label)
self.count_display_label = QLabel("")
self.counts_label.setMaximumHeight(20)
self.count_display_label.setMaximumHeight(10)
layout.addWidget(self.count_display_label)

self.layout.addWidget(box)
Expand Down Expand Up @@ -205,7 +204,7 @@ def bbox2yolo(self, bbox):
"""
ymax, xmax = self.upper_corner(bbox)
ymin, xmin = self.lower_corner(bbox)
height, width = self.viewer.layers[_IMAGE_LAYER].data.shape[:2]
height, width = self.viewer.layers[_IMAGE_LAYER].data.shape[_WIDTH_HEIGHT[0]:_WIDTH_HEIGHT[1]]
x = (xmin + xmax) / 2 / width
y = (ymin + ymax) / 2 / height
w = (xmax - xmin) / width
Expand All @@ -221,8 +220,15 @@ def layer2yolo(self, layer_name, index):
return tuples

def write_annotations(self, tuples):
"""
Responsible for writing the annotations in a '.txt' file, and updating the 'classes.txt' file.
Args:
- tuples (list): A list of tuples, each tuple containing the class index and the YOLO bounding-box.
"""
labels_folder = os.path.join(self.root_directory, self.annotations_directory)
current_as_txt = TIFF_REGEX.match(self.image_selector.currentText()).group(1) + ".txt"
name, _ = os.path.splitext(self.image_selector.currentText())
current_as_txt = name + ".txt"
labels_path = os.path.join(labels_folder, current_as_txt)
with open(labels_path, "w") as f:
for row in tuples:
Expand All @@ -233,6 +239,10 @@ def write_annotations(self, tuples):
show_info("Annotations saved.")

def save_state(self):
"""
Saves the current annotations (bounding-boxes) in a '.txt' file.
The class index corresponds to the rank of the layers in the Napari stack.
"""
count = 0
lines = []
for l in self.viewer.layers:
Expand All @@ -259,6 +269,9 @@ def get_new_class_name(self):
return full_name

def add_yolo_class(self):
"""
Adds a new layer representing a YOLO class.
"""
class_name = self.get_new_class_name()
if _IMAGE_LAYER not in self.viewer.layers:
show_info("No image loaded.")
Expand All @@ -280,13 +293,25 @@ def add_yolo_class(self):
)

def set_root_directory(self, directory):
"""
Sets the root directory (folder in which the 'sources' and 'annotations' folders are) of the application.
Probes the content of the 'sources' folder to find all sub-folders, to propose them in the GUI's dropdown.
No sub-folder is selected by default.
Args:
- directory (str): The absolute path to the root directory.
"""
folders = sorted([f for f in os.listdir(directory) if os.path.isdir(os.path.join(directory, f))])
folders = ["---"] + folders
self.inputs_name.clear()
self.inputs_name.addItems(folders)
self.root_directory = directory

def set_sources_directory(self):
"""
Whenever the user selects a new source directory, the content of the 'sources' folder is probed.
This function also checks if the 'annotations' folder exists, and creates it if not.
"""
source_folder = self.inputs_name.currentText()
annotations_folder = source_folder + "-labels"
if (source_folder is None) or (source_folder == "---") or (source_folder == ""):
Expand All @@ -302,30 +327,76 @@ def set_sources_directory(self):
self.annotations_name.setText(annotations_folder)
self.open_sources_directory()

def update_reader_fx(self):
global imread
global _WIDTH_HEIGHT
global ARGS
ARGS = {}
ext = self.images_list[0]
if (ext == "---") or (ext == ""):
return
ext = ext.split('.')[-1].lower()
if (ext == "tif") or (ext == "tiff"):
imread = tifffile.imread
_WIDTH_HEIGHT = (0, 2)
im = imread(os.path.join(self.root_directory, self.sources_directory, self.images_list[0]))
if len(im.shape) == 3:
_WIDTH_HEIGHT = (1, 3)
else:
imread = cv2.imread
_WIDTH_HEIGHT = (0, 2)
ARGS = {'flags': cv2.IMREAD_GRAYSCALE}

def open_sources_directory(self):
"""
Triggered when the user chooses a new source directory in the GUI's dropdown.
The content of the provided folder is probed to find all TIFF files.
If the folder is empty, a message is displayed, and the images list is set to ['---'].
The first item of the list is selected by default.
It gets opened automatically due to the signal 'currentIndexChanged'.
"""
inputs_path = os.path.join(self.root_directory, self.sources_directory)
self.images_list = sorted([f for f in os.listdir(inputs_path) if TIFF_REGEX.match(f) is not None])
self.images_list = sorted([f for f in os.listdir(inputs_path)])
if len(self.images_list) == 0: # Didn't find any file in the folder.
show_info("Didn't find any TIFF file in the provided folder.")
show_info("Didn't find any image in the provided folder.")
self.images_list = ['---']
else:
self.update_reader_fx()
self.image_selector.clear()
self.image_selector.addItems(self.images_list)
return True

def get_classes(self):
"""
Probes the layers stack of Napari to find the classes layers.
These layers are found through the prefix '_CLASS_PREFIX'.
The order matters.
Returns:
(list): A list of strings containing the name of the classes.
"""
classes = []
for l in self.viewer.layers:
if l.name.startswith(_CLASS_PREFIX):
classes.append(l.name[len(_CLASS_PREFIX):])
return classes

def clear_classes_layers(self):
"""
Reset the data of each shape layer representing a YOLO class.
Used before loading the annotations of a new image.
"""
names = [l.name for l in self.viewer.layers]
for n in names:
if n.startswith(_CLASS_PREFIX):
self.viewer.layers[n].data = []

def restore_classes_layers(self):
"""
Parses the 'classes.txt' file to restore the classes layers.
The file contains the name of the classes, one per line and nothing else.
Creates the associated shape layers with the right colors.
"""
classes_path = os.path.join(self.root_directory, "classes.txt")
if not os.path.isfile(classes_path):
show_info("No classes file found.")
Expand Down Expand Up @@ -356,7 +427,7 @@ def add_labels(self, data):
The class index refers to the index in which shape layers appear in the layers stack of Napari.
"""
# Boxes are created according to the current image's size.
h, w = self.viewer.layers[_IMAGE_LAYER].data.shape[:2]
h, w = self.viewer.layers[_IMAGE_LAYER].data.shape[_WIDTH_HEIGHT[0]:_WIDTH_HEIGHT[1]]
class_layers = [l.name for l in self.viewer.layers if l.name.startswith(_CLASS_PREFIX)]
for c, bbox_list in data.items():
rectangles = []
Expand Down Expand Up @@ -423,18 +494,20 @@ def open_image(self):
Reloads the annotations if some were already made for this image.
"""
current_image = self.image_selector.currentText()
# Check that the name is valid.
if (self.root_directory is None) or (current_image is None) or (current_image == "---") or (current_image == ""):
return
image_path = os.path.join(self.root_directory, self.sources_directory, current_image)
current_as_txt = TIFF_REGEX.match(current_image).group(1) + ".txt" # Remove extension + adds ".txt" extension.
name, _ = os.path.splitext(current_image)
current_as_txt = name + ".txt"
labels_path = os.path.join(self.root_directory, self.annotations_directory, current_as_txt)
if not os.path.isfile(image_path):
print(f"The image: '{current_image}' doesn't exist.")
return
data = tifffile.imread(image_path)
data = imread(image_path, **ARGS)
if _IMAGE_LAYER in self.viewer.layers:
self.viewer.layers[_IMAGE_LAYER].data = data
self.viewer.layers[_IMAGE_LAYER].contrast_limits = (np.min(data), np.max(data))
self.viewer.layers[_IMAGE_LAYER].reset_contrast_limits_range()
else:
self.viewer.add_image(data, name=_IMAGE_LAYER)
self.deselect_all()
Expand All @@ -447,8 +520,10 @@ def open_image(self):
def count_boxes(self):
"""
Counts the number of boxes in each class to make sure annotations are balanced.
No fix is provded, just a display of the current state.
The whole annotations folder is probed to count the boxes.
The current image is not taken into account.
"""
classes = self.get_classes()
annotations_path = os.path.join(self.root_directory, self.annotations_directory)
counts = dict()
for f in os.listdir(annotations_path):
Expand All @@ -462,13 +537,31 @@ def count_boxes(self):
c, _, _, _, _ = line.split(" ")
c = int(c)
counts[c] = counts.get(c, 0) + 1

text = ""
for class_idx, class_name in enumerate(classes):
if class_idx not in counts:
text += f'<span style="color: {_COLORS[class_idx % len(_COLORS)]}"><b>{class_name}</b></span>: 0<br>'
else:
text += f'<span style="color: {_COLORS[class_idx % len(_COLORS)]}"><b>{classes[class_idx]}</b></span>: {counts[class_idx]}<br>'
self.update_count_display(counts)

def update_count_display(self, counts):
"""
Updates the content of the table (in the GUI) displaying how many boxes are in each class.
Beyond 8 classes , the display could have an issue with the height of the QGroupBox.
Args:
- counts (str): is a dictionary where the key is the class index and the value is the number of boxes.
"""
classes = self.get_classes()
total_count = sum(counts.values())
text = '<table>'
for class_idx, class_name in enumerate(classes):
count = counts.get(class_idx, 0)
text += f'''
<tr>
<td style="background-color: {_COLORS[class_idx % len(_COLORS)]}; color: white; padding: 6px;">
<b>{class_name}</b>
</td>
<td style="padding: 6px;">
{count} ({round((count/total_count*100) if total_count > 0 else 0, 2)}%)
</td>
</tr>
'''
text += "</table>"
self.count_display_label.setText(text)
self.count_display_label.setMaximumHeight(50 * len(counts))
38 changes: 38 additions & 0 deletions src/microglia_analyzer/custom-loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from skimage.morphology import skeletonize
from skimage.measure import label, regionprops

def compute_skeleton_penalty(y_pred, min_length=10):
# Binariser la prédiction
y_pred_bin = y_pred > 0.5

# Labeliser les composantes connexes
labeled_image = label(y_pred_bin)

penalty = 0
for region in regionprops(labeled_image):
# Créer un masque pour chaque région
region_mask = labeled_image == region.label

# Calculer le squelette de la région
skeleton = skeletonize(region_mask)

# Calculer la longueur du squelette
skeleton_length = skeleton.sum()

# Ajouter une pénalité si la longueur du squelette est inférieure au seuil
if skeleton_length < min_length:
penalty += (min_length - skeleton_length)

return penalty

def custom_loss(y_true, y_pred):
# Calcul de ta loss actuelle (par exemple, Dice Loss)
dice_loss = dice_loss_function(y_true, y_pred)

# Calcul des pénalités basées sur la longueur des skeletons
skeleton_penalty = compute_skeleton_penalty(y_pred)

# Combinaison des pénalités
total_loss = dice_loss + gamma * skeleton_penalty

return total_loss

0 comments on commit a25b928

Please sign in to comment.