From 3440e411e3c361ecdb72392187a80befc310197d Mon Sep 17 00:00:00 2001 From: Huijo Date: Fri, 12 Jul 2024 09:41:56 +0200 Subject: [PATCH 1/3] Handle mmdet's {'with_mask': False} case --- sahi/models/mmdet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sahi/models/mmdet.py b/sahi/models/mmdet.py index b4b363d3c..74142cded 100644 --- a/sahi/models/mmdet.py +++ b/sahi/models/mmdet.py @@ -192,7 +192,7 @@ def has_mask(self): """ # has_mask = self.model.model.with_mask train_pipeline = self.model.cfg["train_dataloader"]["dataset"]["pipeline"] - has_mask = any(isinstance(item, dict) and any("mask" in key for key in item.keys()) for item in train_pipeline) + has_mask = any(isinstance(item, dict) and any('mask' in key and value is True for key, value in item.items()) for item in train_pipeline) return has_mask @property From 3a1e9539aa7e9234bb00fe411789b929a0f54f20 Mon Sep 17 00:00:00 2001 From: Huijo Date: Fri, 2 Aug 2024 14:21:37 +0200 Subject: [PATCH 2/3] follow the format --- sahi/models/mmdet.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sahi/models/mmdet.py b/sahi/models/mmdet.py index 74142cded..5bae88b6b 100644 --- a/sahi/models/mmdet.py +++ b/sahi/models/mmdet.py @@ -192,7 +192,10 @@ def has_mask(self): """ # has_mask = self.model.model.with_mask train_pipeline = self.model.cfg["train_dataloader"]["dataset"]["pipeline"] - has_mask = any(isinstance(item, dict) and any('mask' in key and value is True for key, value in item.items()) for item in train_pipeline) + has_mask = any( + isinstance(item, dict) and any("mask" in key and value is True for key, value in item.items()) + for item in train_pipeline + ) return has_mask @property From c955616e0f0e3dcef9c41940f5250c7261a16076 Mon Sep 17 00:00:00 2001 From: huijo Date: Thu, 31 Oct 2024 21:26:25 +0100 Subject: [PATCH 3/3] handle ConcatDataset in has_mask --- sahi/models/mmdet.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/sahi/models/mmdet.py b/sahi/models/mmdet.py index 5bae88b6b..e705f5176 100644 --- a/sahi/models/mmdet.py +++ b/sahi/models/mmdet.py @@ -188,15 +188,31 @@ def num_categories(self): @property def has_mask(self): """ - Returns if model output contains segmentation mask + Returns if model output contains segmentation mask. + Considers both single dataset and ConcatDataset scenarios. """ - # has_mask = self.model.model.with_mask - train_pipeline = self.model.cfg["train_dataloader"]["dataset"]["pipeline"] - has_mask = any( - isinstance(item, dict) and any("mask" in key and value is True for key, value in item.items()) - for item in train_pipeline - ) - return has_mask + + def check_pipeline_for_mask(pipeline): + return any( + isinstance(item, dict) and any("mask" in key and value is True for key, value in item.items()) + for item in pipeline + ) + + # Access the dataset from the configuration + dataset_config = self.model.cfg["train_dataloader"]["dataset"] + + if dataset_config["type"] == "ConcatDataset": + # If using ConcatDataset, check each dataset individually + datasets = dataset_config["datasets"] + for dataset in datasets: + if check_pipeline_for_mask(dataset["pipeline"]): + return True + else: + # Otherwise, assume a single dataset with its own pipeline + if check_pipeline_for_mask(dataset_config["pipeline"]): + return True + + return False @property def category_names(self):