From ccb4d8dfe27544fd91d3ad7c27aafd25ac731425 Mon Sep 17 00:00:00 2001 From: Peng Lu Date: Sun, 8 Oct 2023 16:00:05 +0800 Subject: [PATCH] [Refactor] Align test accuracy for AE (#2737) --- .../ae_hrnet-w32_8xb24-300e_coco-512x512.py | 15 +- .../associative_embedding/coco/hrnet_coco.md | 57 ++++++ .../associative_embedding/coco/hrnet_coco.yml | 25 +++ mmpose/codecs/associative_embedding.py | 178 +++++++++--------- mmpose/models/heads/heatmap_heads/ae_head.py | 58 +++++- .../test_codecs/test_associative_embedding.py | 12 +- 6 files changed, 248 insertions(+), 97 deletions(-) create mode 100644 configs/body_2d_keypoint/associative_embedding/coco/hrnet_coco.md create mode 100644 configs/body_2d_keypoint/associative_embedding/coco/hrnet_coco.yml diff --git a/configs/body_2d_keypoint/associative_embedding/coco/ae_hrnet-w32_8xb24-300e_coco-512x512.py b/configs/body_2d_keypoint/associative_embedding/coco/ae_hrnet-w32_8xb24-300e_coco-512x512.py index 5adc1aac1a..a4804cbe37 100644 --- a/configs/body_2d_keypoint/associative_embedding/coco/ae_hrnet-w32_8xb24-300e_coco-512x512.py +++ b/configs/body_2d_keypoint/associative_embedding/coco/ae_hrnet-w32_8xb24-300e_coco-512x512.py @@ -36,6 +36,8 @@ input_size=(512, 512), heatmap_size=(128, 128), sigma=2, + decode_topk=30, + decode_center_shift=0.5, decode_keypoint_order=[ 0, 1, 2, 3, 4, 5, 6, 11, 12, 7, 8, 9, 10, 13, 14, 15, 16 ], @@ -97,7 +99,7 @@ test_cfg=dict( multiscale_test=False, flip_test=True, - shift_heatmap=True, + shift_heatmap=False, restore_heatmap_size=True, align_corners=False)) @@ -113,9 +115,14 @@ dict( type='BottomupResize', input_size=codec['input_size'], - size_factor=32, + size_factor=64, resize_mode='expand'), - dict(type='PackPoseInputs') + dict( + type='PackPoseInputs', + meta_keys=('id', 'img_id', 'img_path', 'crowd_index', 'ori_shape', + 'img_shape', 'input_size', 'input_center', 'input_scale', + 'flip', 'flip_direction', 'flip_indices', 'raw_ann_info', + 'skeleton_links')) ] # data loaders @@ -154,6 +161,6 @@ type='CocoMetric', ann_file=data_root + 'annotations/person_keypoints_val2017.json', nms_mode='none', - score_mode='keypoint', + score_mode='bbox', ) test_evaluator = val_evaluator diff --git a/configs/body_2d_keypoint/associative_embedding/coco/hrnet_coco.md b/configs/body_2d_keypoint/associative_embedding/coco/hrnet_coco.md new file mode 100644 index 0000000000..caae01d60d --- /dev/null +++ b/configs/body_2d_keypoint/associative_embedding/coco/hrnet_coco.md @@ -0,0 +1,57 @@ + + +
+Associative Embedding (NIPS'2017) + +```bibtex +@inproceedings{newell2017associative, + title={Associative embedding: End-to-end learning for joint detection and grouping}, + author={Newell, Alejandro and Huang, Zhiao and Deng, Jia}, + booktitle={Advances in neural information processing systems}, + pages={2277--2287}, + year={2017} +} +``` + +
+ + + +
+HRNet (CVPR'2019) + +```bibtex +@inproceedings{sun2019deep, + title={Deep high-resolution representation learning for human pose estimation}, + author={Sun, Ke and Xiao, Bin and Liu, Dong and Wang, Jingdong}, + booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition}, + pages={5693--5703}, + year={2019} +} +``` + +
+ + + +
+COCO (ECCV'2014) + +```bibtex +@inproceedings{lin2014microsoft, + title={Microsoft coco: Common objects in context}, + author={Lin, Tsung-Yi and Maire, Michael and Belongie, Serge and Hays, James and Perona, Pietro and Ramanan, Deva and Doll{\'a}r, Piotr and Zitnick, C Lawrence}, + booktitle={European conference on computer vision}, + pages={740--755}, + year={2014}, + organization={Springer} +} +``` + +
+ +Results on COCO val2017 without multi-scale test + +| Arch | Input Size | AP | AP50 | AP75 | AR | AR50 | ckpt | log | +| :-------------------------------------------- | :--------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :-------------------------------------------: | :-------------------------------------------: | +| [HRNet-w32](/configs/body_2d_keypoint/associative_embedding/coco/ae_hrnet-w32_8xb24-300e_coco-512x512.py) | 512x512 | 0.656 | 0.864 | 0.719 | 0.711 | 0.893 | [ckpt](https://download.openmmlab.com/mmpose/bottom_up/hrnet_w32_coco_512x512-bcb8c247_20200816.pth) | [log](https://download.openmmlab.com/mmpose/bottom_up/hrnet_w32_coco_512x512_20200816.log.json) | diff --git a/configs/body_2d_keypoint/associative_embedding/coco/hrnet_coco.yml b/configs/body_2d_keypoint/associative_embedding/coco/hrnet_coco.yml new file mode 100644 index 0000000000..5fcd749f0f --- /dev/null +++ b/configs/body_2d_keypoint/associative_embedding/coco/hrnet_coco.yml @@ -0,0 +1,25 @@ +Collections: +- Name: AE + Paper: + Title: "Associative embedding: End-to-end learning for joint detection and grouping" + URL: https://arxiv.org/abs/1611.05424 + README: https://github.com/open-mmlab/mmpose/blob/main/docs/src/papers/algorithms/associative_embedding.md +Models: +- Config: configs/body_2d_keypoint/associative_embedding/coco/ae_hrnet-w32_8xb24-300e_coco-512x512.py + In Collection: AE + Metadata: + Architecture: + - AE + - HRNet + Training Data: COCO + Name: ae_hrnet-w32_8xb24-300e_coco-512x512 + Results: + - Dataset: COCO + Metrics: + AP: 0.656 + AP@0.5: 0.864 + AP@0.75: 0.719 + AR: 0.711 + AR@0.5: 0.893 + Task: Body 2D Keypoint + Weights: https://download.openmmlab.com/mmpose/bottom_up/hrnet_w32_coco_512x512-bcb8c247_20200816.pth diff --git a/mmpose/codecs/associative_embedding.py b/mmpose/codecs/associative_embedding.py index 7e080f1657..def9bfd89e 100644 --- a/mmpose/codecs/associative_embedding.py +++ b/mmpose/codecs/associative_embedding.py @@ -1,5 +1,4 @@ # Copyright (c) OpenMMLab. All rights reserved. -from collections import namedtuple from itertools import product from typing import Any, List, Optional, Tuple @@ -16,6 +15,21 @@ refine_keypoints_dark_udp) +def _py_max_match(scores): + """Apply munkres algorithm to get the best match. + + Args: + scores(np.ndarray): cost matrix. + + Returns: + np.ndarray: best match. + """ + m = Munkres() + tmp = m.compute(scores) + tmp = np.array(tmp).astype(int) + return tmp + + def _group_keypoints_by_tags(vals: np.ndarray, tags: np.ndarray, locs: np.ndarray, @@ -54,89 +68,78 @@ def _group_keypoints_by_tags(vals: np.ndarray, np.ndarray: grouped keypoints in shape (G, K, D+1), where the last dimenssion is the concatenated keypoint coordinates and scores. """ + + tag_k, loc_k, val_k = tags, locs, vals K, M, D = locs.shape assert vals.shape == tags.shape[:2] == (K, M) assert len(keypoint_order) == K - # Build Munkres instance - munkres = Munkres() - - # Build a group pool, each group contains the keypoints of an instance - groups = [] + default_ = np.zeros((K, 3 + tag_k.shape[2]), dtype=np.float32) - Group = namedtuple('Group', field_names=['kpts', 'scores', 'tag_list']) + joint_dict = {} + tag_dict = {} + for i in range(K): + idx = keypoint_order[i] - def _init_group(): - """Initialize a group, which is composed of the keypoints, keypoint - scores and the tag of each keypoint.""" - _group = Group( - kpts=np.zeros((K, D), dtype=np.float32), - scores=np.zeros(K, dtype=np.float32), - tag_list=[]) - return _group + tags = tag_k[idx] + joints = np.concatenate((loc_k[idx], val_k[idx, :, None], tags), 1) + mask = joints[:, 2] > val_thr + tags = tags[mask] # shape: [M, L] + joints = joints[mask] # shape: [M, 3 + L], 3: x, y, val - for i in keypoint_order: - # Get all valid candidate of the i-th keypoints - valid = vals[i] > val_thr - if not valid.any(): + if joints.shape[0] == 0: continue - tags_i = tags[i, valid] # (M', L) - vals_i = vals[i, valid] # (M',) - locs_i = locs[i, valid] # (M', D) - - if len(groups) == 0: # Initialize the group pool - for tag, val, loc in zip(tags_i, vals_i, locs_i): - group = _init_group() - group.kpts[i] = loc - group.scores[i] = val - group.tag_list.append(tag) - - groups.append(group) - - else: # Match keypoints to existing groups - groups = groups[:max_groups] - group_tags = [np.mean(g.tag_list, axis=0) for g in groups] - - # Calculate distance matrix between group tags and tag candidates - # of the i-th keypoint - # Shape: (M', 1, L) , (1, G, L) -> (M', G, L) - diff = tags_i[:, None] - np.array(group_tags)[None] - dists = np.linalg.norm(diff, ord=2, axis=2) - num_kpts, num_groups = dists.shape[:2] - - # Experimental cost function for keypoint-group matching - costs = np.round(dists) * 100 - vals_i[..., None] - if num_kpts > num_groups: - padding = np.full((num_kpts, num_kpts - num_groups), - 1e10, - dtype=np.float32) - costs = np.concatenate((costs, padding), axis=1) - - # Match keypoints and groups by Munkres algorithm - matches = munkres.compute(costs) - for kpt_idx, group_idx in matches: - if group_idx < num_groups and dists[kpt_idx, - group_idx] < tag_thr: - # Add the keypoint to the matched group - group = groups[group_idx] + if i == 0 or len(joint_dict) == 0: + for tag, joint in zip(tags, joints): + key = tag[0] + joint_dict.setdefault(key, np.copy(default_))[idx] = joint + tag_dict[key] = [tag] + else: + # shape: [M] + grouped_keys = list(joint_dict.keys()) + # shape: [M, L] + grouped_tags = [np.mean(tag_dict[i], axis=0) for i in grouped_keys] + + # shape: [M, M, L] + diff = joints[:, None, 3:] - np.array(grouped_tags)[None, :, :] + # shape: [M, M] + diff_normed = np.linalg.norm(diff, ord=2, axis=2) + diff_saved = np.copy(diff_normed) + diff_normed = np.round(diff_normed) * 100 - joints[:, 2:3] + + num_added = diff.shape[0] + num_grouped = diff.shape[1] + + if num_added > num_grouped: + diff_normed = np.concatenate( + (diff_normed, + np.zeros((num_added, num_added - num_grouped), + dtype=np.float32) + 1e10), + axis=1) + + pairs = _py_max_match(diff_normed) + for row, col in pairs: + if (row < num_added and col < num_grouped + and diff_saved[row][col] < tag_thr): + key = grouped_keys[col] + joint_dict[key][idx] = joints[row] + tag_dict[key].append(tags[row]) else: - # Initialize a new group with unmatched keypoint - group = _init_group() - groups.append(group) - - group.kpts[i] = locs_i[kpt_idx] - group.scores[i] = vals_i[kpt_idx] - group.tag_list.append(tags_i[kpt_idx]) - - groups = groups[:max_groups] - if groups: - grouped_keypoints = np.stack( - [np.r_['1', g.kpts, g.scores[:, None]] for g in groups]) - else: - grouped_keypoints = np.empty((0, K, D + 1)) + key = tags[row][0] + joint_dict.setdefault(key, np.copy(default_))[idx] = \ + joints[row] + tag_dict[key] = [tags[row]] - return grouped_keypoints + joint_dict_keys = list(joint_dict.keys())[:max_groups] + + if joint_dict_keys: + results = np.array([joint_dict[i] + for i in joint_dict_keys]).astype(np.float32) + results = results[..., :D + 1] + else: + results = np.empty((0, K, D + 1), dtype=np.float32) + return results @KEYPOINT_CODECS.register_module() @@ -210,7 +213,8 @@ def __init__( decode_gaussian_kernel: int = 3, decode_keypoint_thr: float = 0.1, decode_tag_thr: float = 1.0, - decode_topk: int = 20, + decode_topk: int = 30, + decode_center_shift=0.0, decode_max_instances: Optional[int] = None, ) -> None: super().__init__() @@ -222,8 +226,9 @@ def __init__( self.decode_keypoint_thr = decode_keypoint_thr self.decode_tag_thr = decode_tag_thr self.decode_topk = decode_topk + self.decode_center_shift = decode_center_shift self.decode_max_instances = decode_max_instances - self.dedecode_keypoint_order = decode_keypoint_order.copy() + self.decode_keypoint_order = decode_keypoint_order.copy() if self.use_udp: self.scale_factor = ((np.array(input_size) - 1) / @@ -376,7 +381,7 @@ def _group_func(inputs: Tuple): vals, tags, locs, - keypoint_order=self.dedecode_keypoint_order, + keypoint_order=self.decode_keypoint_order, val_thr=self.decode_keypoint_thr, tag_thr=self.decode_tag_thr, max_groups=self.decode_max_instances) @@ -463,13 +468,13 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor f'tagging map ({batch_tags.shape})') # Heatmap NMS - batch_heatmaps = batch_heatmap_nms(batch_heatmaps, - self.decode_nms_kernel) + batch_heatmaps_peak = batch_heatmap_nms(batch_heatmaps, + self.decode_nms_kernel) # Get top-k in each heatmap and and convert to numpy batch_topk_vals, batch_topk_tags, batch_topk_locs = to_numpy( self._get_batch_topk( - batch_heatmaps, batch_tags, k=self.decode_topk)) + batch_heatmaps_peak, batch_tags, k=self.decode_topk)) # Group keypoint candidates into groups (instances) batch_groups = self._group_keypoints(batch_topk_vals, batch_topk_tags, @@ -482,16 +487,14 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor # Refine the keypoint prediction batch_keypoints = [] batch_keypoint_scores = [] + batch_instance_scores = [] for i, (groups, heatmaps, tags) in enumerate( zip(batch_groups, batch_heatmaps_np, batch_tags_np)): keypoints, scores = groups[..., :-1], groups[..., -1] + instance_scores = scores.mean(axis=-1) if keypoints.size > 0: - # identify missing keypoints - keypoints, scores = self._fill_missing_keypoints( - keypoints, scores, heatmaps, tags) - # refine keypoint coordinates according to heatmap distribution if self.use_udp: keypoints = refine_keypoints_dark_udp( @@ -500,13 +503,20 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor blur_kernel_size=self.decode_gaussian_kernel) else: keypoints = refine_keypoints(keypoints, heatmaps) + keypoints += self.decode_center_shift * \ + (scores > 0).astype(keypoints.dtype)[..., None] + + # identify missing keypoints + keypoints, scores = self._fill_missing_keypoints( + keypoints, scores, heatmaps, tags) batch_keypoints.append(keypoints) batch_keypoint_scores.append(scores) + batch_instance_scores.append(instance_scores) # restore keypoint scale batch_keypoints = [ kpts * self.scale_factor for kpts in batch_keypoints ] - return batch_keypoints, batch_keypoint_scores + return batch_keypoints, batch_keypoint_scores, batch_instance_scores diff --git a/mmpose/models/heads/heatmap_heads/ae_head.py b/mmpose/models/heads/heatmap_heads/ae_head.py index bd12d57a33..c9559eebc2 100644 --- a/mmpose/models/heads/heatmap_heads/ae_head.py +++ b/mmpose/models/heads/heatmap_heads/ae_head.py @@ -2,14 +2,15 @@ from typing import List, Optional, Sequence, Tuple, Union import torch -from mmengine.structures import PixelData +from mmengine.structures import InstanceData, PixelData from mmengine.utils import is_list_of from torch import Tensor from mmpose.models.utils.tta import aggregate_heatmaps, flip_heatmaps from mmpose.registry import MODELS -from mmpose.utils.typing import (ConfigType, Features, OptConfigType, - OptSampleList, Predictions) +from mmpose.utils.tensor_utils import to_numpy +from mmpose.utils.typing import (ConfigType, Features, InstanceList, + OptConfigType, OptSampleList, Predictions) from .heatmap_head import HeatmapHead OptIntSeq = Optional[Sequence[int]] @@ -226,6 +227,57 @@ def _flip_tags(self, return tags + def decode(self, batch_outputs: Union[Tensor, + Tuple[Tensor]]) -> InstanceList: + """Decode keypoints from outputs. + + Args: + batch_outputs (Tensor | Tuple[Tensor]): The network outputs of + a data batch + + Returns: + List[InstanceData]: A list of InstanceData, each contains the + decoded pose information of the instances of one data sample. + """ + + def _pack_and_call(args, func): + if not isinstance(args, tuple): + args = (args, ) + return func(*args) + + if self.decoder is None: + raise RuntimeError( + f'The decoder has not been set in {self.__class__.__name__}. ' + 'Please set the decoder configs in the init parameters to ' + 'enable head methods `head.predict()` and `head.decode()`') + + if self.decoder.support_batch_decoding: + batch_keypoints, batch_scores, batch_instance_scores = \ + _pack_and_call(batch_outputs, self.decoder.batch_decode) + + else: + batch_output_np = to_numpy(batch_outputs, unzip=True) + batch_keypoints = [] + batch_scores = [] + batch_instance_scores = [] + for outputs in batch_output_np: + keypoints, scores, instance_scores = _pack_and_call( + outputs, self.decoder.decode) + batch_keypoints.append(keypoints) + batch_scores.append(scores) + batch_instance_scores.append(instance_scores) + + preds = [ + InstanceData( + bbox_scores=instance_scores, + keypoints=keypoints, + keypoint_scores=scores) + for keypoints, scores, instance_scores in zip( + batch_keypoints, batch_scores, batch_instance_scores) + ] + + return preds + def forward(self, feats: Tuple[Tensor]) -> Tuple[Tensor, Tensor]: """Forward the network. The input is multi scale feature maps and the output is the heatmaps and tags. diff --git a/tests/test_codecs/test_associative_embedding.py b/tests/test_codecs/test_associative_embedding.py index 983fc93fb1..eae65dbedc 100644 --- a/tests/test_codecs/test_associative_embedding.py +++ b/tests/test_codecs/test_associative_embedding.py @@ -146,8 +146,8 @@ def test_decode(self): batch_heatmaps = torch.from_numpy(heatmaps[None]) batch_tags = torch.from_numpy(tags[None]) - batch_keypoints, batch_keypoint_scores = codec.batch_decode( - batch_heatmaps, batch_tags) + batch_keypoints, batch_keypoint_scores, batch_instance_scores = \ + codec.batch_decode(batch_heatmaps, batch_tags) self.assertIsInstance(batch_keypoints, list) self.assertIsInstance(batch_keypoint_scores, list) @@ -184,8 +184,8 @@ def test_decode(self): batch_heatmaps = torch.from_numpy(heatmaps[None]) batch_tags = torch.from_numpy(tags[None]) - batch_keypoints, batch_keypoint_scores = codec.batch_decode( - batch_heatmaps, batch_tags) + batch_keypoints, batch_keypoint_scores, batch_instance_scores = \ + codec.batch_decode(batch_heatmaps, batch_tags) self.assertIsInstance(batch_keypoints, list) self.assertIsInstance(batch_keypoint_scores, list) @@ -222,8 +222,8 @@ def test_decode(self): batch_heatmaps = torch.from_numpy(heatmaps[None]) batch_tags = torch.from_numpy(tags[None]) - batch_keypoints, batch_keypoint_scores = codec.batch_decode( - batch_heatmaps, batch_tags) + batch_keypoints, batch_keypoint_scores, batch_instance_scores = \ + codec.batch_decode(batch_heatmaps, batch_tags) self.assertIsInstance(batch_keypoints, list) self.assertIsInstance(batch_keypoint_scores, list)