From 402d7583b5e9b6a0fedef2e5d45e9c44747017b8 Mon Sep 17 00:00:00 2001 From: ZhangYiqin <312065559@qq.com> Date: Thu, 8 Aug 2024 18:42:41 +0800 Subject: [PATCH] Update confusion_matrix.py The current implementation do not work correctly when num_classes is larger than 16. Because the gt seg map was forced to uint8 dtype. --- tools/analysis_tools/confusion_matrix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tools/analysis_tools/confusion_matrix.py b/tools/analysis_tools/confusion_matrix.py index 39756cdfdd..a146ee7940 100644 --- a/tools/analysis_tools/confusion_matrix.py +++ b/tools/analysis_tools/confusion_matrix.py @@ -63,7 +63,7 @@ def calculate_confusion_matrix(dataset, results): for idx, per_img_res in enumerate(results): res_segm = per_img_res gt_segm = dataset[idx]['data_samples'] \ - .gt_sem_seg.data.squeeze().numpy().astype(np.uint8) + .gt_sem_seg.data.squeeze().numpy().astype(np.uint32) gt_segm, res_segm = gt_segm.flatten(), res_segm.flatten() if reduce_zero_label: gt_segm = gt_segm - 1