From de405626dcd9b98e8148435476727a6045d74ff9 Mon Sep 17 00:00:00 2001 From: Huijo Date: Fri, 22 Nov 2024 08:18:37 +0100 Subject: [PATCH] Update has_mask method for mmdet models (handle ConcatDataset) (#1092) Co-authored-by: fatih c. akyon <34196005+fcakyon@users.noreply.github.com> --- sahi/models/mmdet.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/sahi/models/mmdet.py b/sahi/models/mmdet.py index 5bae88b6b..e705f5176 100644 --- a/sahi/models/mmdet.py +++ b/sahi/models/mmdet.py @@ -188,15 +188,31 @@ def num_categories(self): @property def has_mask(self): """ - Returns if model output contains segmentation mask + Returns if model output contains segmentation mask. + Considers both single dataset and ConcatDataset scenarios. """ - # has_mask = self.model.model.with_mask - train_pipeline = self.model.cfg["train_dataloader"]["dataset"]["pipeline"] - has_mask = any( - isinstance(item, dict) and any("mask" in key and value is True for key, value in item.items()) - for item in train_pipeline - ) - return has_mask + + def check_pipeline_for_mask(pipeline): + return any( + isinstance(item, dict) and any("mask" in key and value is True for key, value in item.items()) + for item in pipeline + ) + + # Access the dataset from the configuration + dataset_config = self.model.cfg["train_dataloader"]["dataset"] + + if dataset_config["type"] == "ConcatDataset": + # If using ConcatDataset, check each dataset individually + datasets = dataset_config["datasets"] + for dataset in datasets: + if check_pipeline_for_mask(dataset["pipeline"]): + return True + else: + # Otherwise, assume a single dataset with its own pipeline + if check_pipeline_for_mask(dataset_config["pipeline"]): + return True + + return False @property def category_names(self):