Skip to content

Commit

Permalink
[Refactor] Align test accuracy for AE (#2737)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Louis authored Oct 8, 2023
1 parent e8ac800 commit ccb4d8d
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
input_size=(512, 512),
heatmap_size=(128, 128),
sigma=2,
decode_topk=30,
decode_center_shift=0.5,
decode_keypoint_order=[
0, 1, 2, 3, 4, 5, 6, 11, 12, 7, 8, 9, 10, 13, 14, 15, 16
],
Expand Down Expand Up @@ -97,7 +99,7 @@
test_cfg=dict(
multiscale_test=False,
flip_test=True,
shift_heatmap=True,
shift_heatmap=False,
restore_heatmap_size=True,
align_corners=False))

Expand All @@ -113,9 +115,14 @@
dict(
type='BottomupResize',
input_size=codec['input_size'],
size_factor=32,
size_factor=64,
resize_mode='expand'),
dict(type='PackPoseInputs')
dict(
type='PackPoseInputs',
meta_keys=('id', 'img_id', 'img_path', 'crowd_index', 'ori_shape',
'img_shape', 'input_size', 'input_center', 'input_scale',
'flip', 'flip_direction', 'flip_indices', 'raw_ann_info',
'skeleton_links'))
]

# data loaders
Expand Down Expand Up @@ -154,6 +161,6 @@
type='CocoMetric',
ann_file=data_root + 'annotations/person_keypoints_val2017.json',
nms_mode='none',
score_mode='keypoint',
score_mode='bbox',
)
test_evaluator = val_evaluator
57 changes: 57 additions & 0 deletions configs/body_2d_keypoint/associative_embedding/coco/hrnet_coco.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
<!-- [ALGORITHM] -->

<details>
<summary align="right"><a href="https://arxiv.org/abs/1611.05424">Associative Embedding (NIPS'2017)</a></summary>

```bibtex
@inproceedings{newell2017associative,
title={Associative embedding: End-to-end learning for joint detection and grouping},
author={Newell, Alejandro and Huang, Zhiao and Deng, Jia},
booktitle={Advances in neural information processing systems},
pages={2277--2287},
year={2017}
}
```

</details>

<!-- [ALGORITHM] -->

<details>
<summary align="right"><a href="http://openaccess.thecvf.com/content_CVPR_2019/html/Sun_Deep_High-Resolution_Representation_Learning_for_Human_Pose_Estimation_CVPR_2019_paper.html">HRNet (CVPR'2019)</a></summary>

```bibtex
@inproceedings{sun2019deep,
title={Deep high-resolution representation learning for human pose estimation},
author={Sun, Ke and Xiao, Bin and Liu, Dong and Wang, Jingdong},
booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
pages={5693--5703},
year={2019}
}
```

</details>

<!-- [DATASET] -->

<details>
<summary align="right"><a href="https://link.springer.com/chapter/10.1007/978-3-319-10602-1_48">COCO (ECCV'2014)</a></summary>

```bibtex
@inproceedings{lin2014microsoft,
title={Microsoft coco: Common objects in context},
author={Lin, Tsung-Yi and Maire, Michael and Belongie, Serge and Hays, James and Perona, Pietro and Ramanan, Deva and Doll{\'a}r, Piotr and Zitnick, C Lawrence},
booktitle={European conference on computer vision},
pages={740--755},
year={2014},
organization={Springer}
}
```

</details>

Results on COCO val2017 without multi-scale test

| Arch | Input Size | AP | AP<sup>50</sup> | AP<sup>75</sup> | AR | AR<sup>50</sup> | ckpt | log |
| :-------------------------------------------- | :--------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :-------------------------------------------: | :-------------------------------------------: |
| [HRNet-w32](/configs/body_2d_keypoint/associative_embedding/coco/ae_hrnet-w32_8xb24-300e_coco-512x512.py) | 512x512 | 0.656 | 0.864 | 0.719 | 0.711 | 0.893 | [ckpt](https://download.openmmlab.com/mmpose/bottom_up/hrnet_w32_coco_512x512-bcb8c247_20200816.pth) | [log](https://download.openmmlab.com/mmpose/bottom_up/hrnet_w32_coco_512x512_20200816.log.json) |
25 changes: 25 additions & 0 deletions configs/body_2d_keypoint/associative_embedding/coco/hrnet_coco.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
Collections:
- Name: AE
Paper:
Title: "Associative embedding: End-to-end learning for joint detection and grouping"
URL: https://arxiv.org/abs/1611.05424
README: https://github.com/open-mmlab/mmpose/blob/main/docs/src/papers/algorithms/associative_embedding.md
Models:
- Config: configs/body_2d_keypoint/associative_embedding/coco/ae_hrnet-w32_8xb24-300e_coco-512x512.py
In Collection: AE
Metadata:
Architecture:
- AE
- HRNet
Training Data: COCO
Name: ae_hrnet-w32_8xb24-300e_coco-512x512
Results:
- Dataset: COCO
Metrics:
AP: 0.656
[email protected]: 0.864
[email protected]: 0.719
AR: 0.711
[email protected]: 0.893
Task: Body 2D Keypoint
Weights: https://download.openmmlab.com/mmpose/bottom_up/hrnet_w32_coco_512x512-bcb8c247_20200816.pth
178 changes: 94 additions & 84 deletions mmpose/codecs/associative_embedding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from collections import namedtuple
from itertools import product
from typing import Any, List, Optional, Tuple

Expand All @@ -16,6 +15,21 @@
refine_keypoints_dark_udp)


def _py_max_match(scores):
"""Apply munkres algorithm to get the best match.
Args:
scores(np.ndarray): cost matrix.
Returns:
np.ndarray: best match.
"""
m = Munkres()
tmp = m.compute(scores)
tmp = np.array(tmp).astype(int)
return tmp


def _group_keypoints_by_tags(vals: np.ndarray,
tags: np.ndarray,
locs: np.ndarray,
Expand Down Expand Up @@ -54,89 +68,78 @@ def _group_keypoints_by_tags(vals: np.ndarray,
np.ndarray: grouped keypoints in shape (G, K, D+1), where the last
dimenssion is the concatenated keypoint coordinates and scores.
"""

tag_k, loc_k, val_k = tags, locs, vals
K, M, D = locs.shape
assert vals.shape == tags.shape[:2] == (K, M)
assert len(keypoint_order) == K

# Build Munkres instance
munkres = Munkres()

# Build a group pool, each group contains the keypoints of an instance
groups = []
default_ = np.zeros((K, 3 + tag_k.shape[2]), dtype=np.float32)

Group = namedtuple('Group', field_names=['kpts', 'scores', 'tag_list'])
joint_dict = {}
tag_dict = {}
for i in range(K):
idx = keypoint_order[i]

def _init_group():
"""Initialize a group, which is composed of the keypoints, keypoint
scores and the tag of each keypoint."""
_group = Group(
kpts=np.zeros((K, D), dtype=np.float32),
scores=np.zeros(K, dtype=np.float32),
tag_list=[])
return _group
tags = tag_k[idx]
joints = np.concatenate((loc_k[idx], val_k[idx, :, None], tags), 1)
mask = joints[:, 2] > val_thr
tags = tags[mask] # shape: [M, L]
joints = joints[mask] # shape: [M, 3 + L], 3: x, y, val

for i in keypoint_order:
# Get all valid candidate of the i-th keypoints
valid = vals[i] > val_thr
if not valid.any():
if joints.shape[0] == 0:
continue

tags_i = tags[i, valid] # (M', L)
vals_i = vals[i, valid] # (M',)
locs_i = locs[i, valid] # (M', D)

if len(groups) == 0: # Initialize the group pool
for tag, val, loc in zip(tags_i, vals_i, locs_i):
group = _init_group()
group.kpts[i] = loc
group.scores[i] = val
group.tag_list.append(tag)

groups.append(group)

else: # Match keypoints to existing groups
groups = groups[:max_groups]
group_tags = [np.mean(g.tag_list, axis=0) for g in groups]

# Calculate distance matrix between group tags and tag candidates
# of the i-th keypoint
# Shape: (M', 1, L) , (1, G, L) -> (M', G, L)
diff = tags_i[:, None] - np.array(group_tags)[None]
dists = np.linalg.norm(diff, ord=2, axis=2)
num_kpts, num_groups = dists.shape[:2]

# Experimental cost function for keypoint-group matching
costs = np.round(dists) * 100 - vals_i[..., None]
if num_kpts > num_groups:
padding = np.full((num_kpts, num_kpts - num_groups),
1e10,
dtype=np.float32)
costs = np.concatenate((costs, padding), axis=1)

# Match keypoints and groups by Munkres algorithm
matches = munkres.compute(costs)
for kpt_idx, group_idx in matches:
if group_idx < num_groups and dists[kpt_idx,
group_idx] < tag_thr:
# Add the keypoint to the matched group
group = groups[group_idx]
if i == 0 or len(joint_dict) == 0:
for tag, joint in zip(tags, joints):
key = tag[0]
joint_dict.setdefault(key, np.copy(default_))[idx] = joint
tag_dict[key] = [tag]
else:
# shape: [M]
grouped_keys = list(joint_dict.keys())
# shape: [M, L]
grouped_tags = [np.mean(tag_dict[i], axis=0) for i in grouped_keys]

# shape: [M, M, L]
diff = joints[:, None, 3:] - np.array(grouped_tags)[None, :, :]
# shape: [M, M]
diff_normed = np.linalg.norm(diff, ord=2, axis=2)
diff_saved = np.copy(diff_normed)
diff_normed = np.round(diff_normed) * 100 - joints[:, 2:3]

num_added = diff.shape[0]
num_grouped = diff.shape[1]

if num_added > num_grouped:
diff_normed = np.concatenate(
(diff_normed,
np.zeros((num_added, num_added - num_grouped),
dtype=np.float32) + 1e10),
axis=1)

pairs = _py_max_match(diff_normed)
for row, col in pairs:
if (row < num_added and col < num_grouped
and diff_saved[row][col] < tag_thr):
key = grouped_keys[col]
joint_dict[key][idx] = joints[row]
tag_dict[key].append(tags[row])
else:
# Initialize a new group with unmatched keypoint
group = _init_group()
groups.append(group)

group.kpts[i] = locs_i[kpt_idx]
group.scores[i] = vals_i[kpt_idx]
group.tag_list.append(tags_i[kpt_idx])

groups = groups[:max_groups]
if groups:
grouped_keypoints = np.stack(
[np.r_['1', g.kpts, g.scores[:, None]] for g in groups])
else:
grouped_keypoints = np.empty((0, K, D + 1))
key = tags[row][0]
joint_dict.setdefault(key, np.copy(default_))[idx] = \
joints[row]
tag_dict[key] = [tags[row]]

return grouped_keypoints
joint_dict_keys = list(joint_dict.keys())[:max_groups]

if joint_dict_keys:
results = np.array([joint_dict[i]
for i in joint_dict_keys]).astype(np.float32)
results = results[..., :D + 1]
else:
results = np.empty((0, K, D + 1), dtype=np.float32)
return results


@KEYPOINT_CODECS.register_module()
Expand Down Expand Up @@ -210,7 +213,8 @@ def __init__(
decode_gaussian_kernel: int = 3,
decode_keypoint_thr: float = 0.1,
decode_tag_thr: float = 1.0,
decode_topk: int = 20,
decode_topk: int = 30,
decode_center_shift=0.0,
decode_max_instances: Optional[int] = None,
) -> None:
super().__init__()
Expand All @@ -222,8 +226,9 @@ def __init__(
self.decode_keypoint_thr = decode_keypoint_thr
self.decode_tag_thr = decode_tag_thr
self.decode_topk = decode_topk
self.decode_center_shift = decode_center_shift
self.decode_max_instances = decode_max_instances
self.dedecode_keypoint_order = decode_keypoint_order.copy()
self.decode_keypoint_order = decode_keypoint_order.copy()

if self.use_udp:
self.scale_factor = ((np.array(input_size) - 1) /
Expand Down Expand Up @@ -376,7 +381,7 @@ def _group_func(inputs: Tuple):
vals,
tags,
locs,
keypoint_order=self.dedecode_keypoint_order,
keypoint_order=self.decode_keypoint_order,
val_thr=self.decode_keypoint_thr,
tag_thr=self.decode_tag_thr,
max_groups=self.decode_max_instances)
Expand Down Expand Up @@ -463,13 +468,13 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
f'tagging map ({batch_tags.shape})')

# Heatmap NMS
batch_heatmaps = batch_heatmap_nms(batch_heatmaps,
self.decode_nms_kernel)
batch_heatmaps_peak = batch_heatmap_nms(batch_heatmaps,
self.decode_nms_kernel)

# Get top-k in each heatmap and and convert to numpy
batch_topk_vals, batch_topk_tags, batch_topk_locs = to_numpy(
self._get_batch_topk(
batch_heatmaps, batch_tags, k=self.decode_topk))
batch_heatmaps_peak, batch_tags, k=self.decode_topk))

# Group keypoint candidates into groups (instances)
batch_groups = self._group_keypoints(batch_topk_vals, batch_topk_tags,
Expand All @@ -482,16 +487,14 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
# Refine the keypoint prediction
batch_keypoints = []
batch_keypoint_scores = []
batch_instance_scores = []
for i, (groups, heatmaps, tags) in enumerate(
zip(batch_groups, batch_heatmaps_np, batch_tags_np)):

keypoints, scores = groups[..., :-1], groups[..., -1]
instance_scores = scores.mean(axis=-1)

if keypoints.size > 0:
# identify missing keypoints
keypoints, scores = self._fill_missing_keypoints(
keypoints, scores, heatmaps, tags)

# refine keypoint coordinates according to heatmap distribution
if self.use_udp:
keypoints = refine_keypoints_dark_udp(
Expand All @@ -500,13 +503,20 @@ def batch_decode(self, batch_heatmaps: Tensor, batch_tags: Tensor
blur_kernel_size=self.decode_gaussian_kernel)
else:
keypoints = refine_keypoints(keypoints, heatmaps)
keypoints += self.decode_center_shift * \
(scores > 0).astype(keypoints.dtype)[..., None]

# identify missing keypoints
keypoints, scores = self._fill_missing_keypoints(
keypoints, scores, heatmaps, tags)

batch_keypoints.append(keypoints)
batch_keypoint_scores.append(scores)
batch_instance_scores.append(instance_scores)

# restore keypoint scale
batch_keypoints = [
kpts * self.scale_factor for kpts in batch_keypoints
]

return batch_keypoints, batch_keypoint_scores
return batch_keypoints, batch_keypoint_scores, batch_instance_scores
Loading

0 comments on commit ccb4d8d

Please sign in to comment.