-
Notifications
You must be signed in to change notification settings - Fork 0
/
videopose_PSTMO.py
199 lines (148 loc) · 7.38 KB
/
videopose_PSTMO.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import os
import time
from common.arguments import parse_args
from common.camera import *
from common.generators import *
from common.loss import *
from common.model import *
from common.utils import Timer, evaluate, add_path
from common.inference_3d import *
from model.block.refine import refine
from model.stmo import Model
import pdb
# from joints_detectors.openpose.main import generate_kpts as open_pose
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
metadata = {'layout_name': 'coco', 'num_joints': 17, 'keypoints_symmetry': [[1, 3, 5, 7, 9, 11, 13, 15], [2, 4, 6, 8, 10, 12, 14, 16]]}
add_path()
# record time
def ckpt_time(ckpt=None):
if not ckpt:
return time.time()
else:
return time.time() - float(ckpt), time.time()
time0 = ckpt_time()
def get_detector_2d(detector_name):
def get_alpha_pose():
from joints_detectors.Alphapose.gene_npz import generate_kpts as alpha_pose
return alpha_pose
def get_hr_pose():
from joints_detectors.hrnet.pose_estimation.video import generate_kpts as hr_pose
return hr_pose
def get_mediapipe_pose():
from joints_detectors.mediapipe.pose import generate_kpts as mediapipe_pose
return mediapipe_pose
detector_map = {
'alpha_pose': get_alpha_pose,
'hr_pose': get_hr_pose,
# 'open_pose': open_pose
'mediapipe_pose': get_mediapipe_pose,
}
assert detector_name in detector_map, f'2D detector: {detector_name} not implemented yet!'
return detector_map[detector_name]()
class Skeleton:
def parents(self):
return np.array([-1, 0, 1, 2, 0, 4, 5, 0, 7, 8, 9, 8, 11, 12, 8, 14, 15])
def joints_right(self):
return [1, 2, 3, 14, 15, 16]
def main(args):
detector_2d = get_detector_2d(args.detector_2d)
assert detector_2d, 'detector_2d should be in ({alpha, hr, open}_pose)'
# 2D kpts loads or generate
#args.input_npz = './outputs/alpha_pose_skiing_cut/skiing_cut.npz'
if not args.input_npz:
video_name = args.viz_video
keypoints = detector_2d(video_name) ### detect 2d keypoints, around 40it/s, [frame,17,2]
else:
npz = np.load(args.input_npz)
keypoints = npz['kpts'] # (N, 17, 2)
keypoints_symmetry = metadata['keypoints_symmetry']
kps_left, kps_right = list(keypoints_symmetry[0]), list(keypoints_symmetry[1])
joints_left, joints_right = list([4, 5, 6, 11, 12, 13]), list([1, 2, 3, 14, 15, 16])
# normlization keypoints Suppose using the camera parameter
keypoints = normalize_screen_coordinates(keypoints[..., :2], w=1000, h=1002)
# model_pos = TemporalModel(17, 2, 17, filter_widths=[3, 3, 3, 3, 3], causal=args.causal, dropout=args.dropout, channels=args.channels,
# dense=args.dense)
model = {}
model['trans'] = Model(args).cuda()
# model['trans'] = Model(args)
# if torch.cuda.is_available():
# model_pos = model_pos.cuda()
ckpt, time1 = ckpt_time(time0)
print('-------------- load data spends {:.2f} seconds'.format(ckpt))
# load trained model
# chk_filename = os.path.join(args.checkpoint, args.resume if args.resume else args.evaluate)
# print('Loading checkpoint', chk_filename)
# checkpoint = torch.load(chk_filename, map_location=lambda storage, loc: storage) # 把loc映射到storage
# model_pos.load_state_dict(checkpoint['model_pos'])
model_dict = model['trans'].state_dict()
no_refine_path = "checkpoint/PSTMOS_no_refine_48_5137_in_the_wild.pth"
pre_dict = torch.load(no_refine_path,map_location=torch.device('cpu'))
for key, value in pre_dict.items():
name = key[7:]
model_dict[name] = pre_dict[key]
model['trans'].load_state_dict(model_dict)
ckpt, time2 = ckpt_time(time1)
print('-------------- load 3D model spends {:.2f} seconds'.format(ckpt))
# Receptive field: 243 frames for args.arc [3, 3, 3, 3, 3]
receptive_field = args.frames
pad = (receptive_field - 1) // 2 # Padding on each side
causal_shift = 0
print('Rendering...')
input_keypoints = keypoints.copy()
print(input_keypoints.shape)
# gen = UnchunkedGenerator(None, None, [input_keypoints],
# pad=pad, causal_shift=causal_shift, augment=args.test_time_augmentation,
# kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, joints_right=joints_right)
# test_data = Fusion(opt=args, train=False, dataset=dataset, root_path=root_path, MAE=opt.MAE)
# test_dataloader = torch.utils.data.DataLoader(test_data, batch_size=1,
# shuffle=False, num_workers=0, pin_memory=True)
#prediction = evaluate(gen, model_pos, return_predictions=True)
gen = Evaluate_Generator(128, None, None, [input_keypoints], args.stride,
pad=pad, causal_shift=causal_shift, augment=args.test_time_augmentation, shuffle=False,
kps_left=kps_left, kps_right=kps_right, joints_left=joints_left, joints_right=joints_right)
prediction = val(args, gen, model) # [frame,17,3]
# save 3D joint points
np.save(f'outputs/test_3d_{args.video_name}_output.npy', prediction, allow_pickle=True)
rot = np.array([0.14070565, -0.15007018, -0.7552408, 0.62232804], dtype=np.float32)
prediction = camera_to_world(prediction, R=rot, t=0)
# We don't have the trajectory, but at least we can rebase the height
prediction[:, :, 2] -= np.min(prediction[:, :, 2])
np.save(f'outputs/test_3d_output_{args.video_name}_postprocess.npy', prediction, allow_pickle=True)
anim_output = {'Ours': prediction}
input_keypoints = image_coordinates(input_keypoints[..., :2], w=1000, h=1002)
# pdb.set_trace()
ckpt, time3 = ckpt_time(time2)
print('-------------- generate reconstruction 3D data spends {:.2f} seconds'.format(ckpt))
if not args.viz_output:
args.viz_output = 'outputs/alpha_result.mp4'
from common.visualization import render_animation
render_animation(input_keypoints, anim_output,
Skeleton(), 25, args.viz_bitrate, np.array(70., dtype=np.float32), args.viz_output,
limit=args.viz_limit, downsample=args.viz_downsample, size=args.viz_size,
input_video_path=args.viz_video, viewport=(1000, 1002),
input_video_skip=args.viz_skip)
ckpt, time4 = ckpt_time(time3)
print('total spend {:2f} second'.format(ckpt))
def inference_video(video_path, detector_2d):
"""
Do image -> 2d points -> 3d points to video.
:param detector_2d: used 2d joints detector. Can be {alpha_pose, hr_pose}
:param video_path: relative to outputs
:return: None
"""
args = parse_args()
args.detector_2d = detector_2d
dir_name = os.path.dirname(video_path)
basename = os.path.basename(video_path)
args.video_name = basename[:basename.rfind('.')]
args.viz_video = video_path
# args.viz_export = f'{dir_name}/{args.detector_2d}_{video_name}_data.npy'
args.viz_output = f'./outputs/{args.detector_2d}_{args.video_name}_video.mp4'
# args.viz_limit = 20
#args.input_npz = 'outputs/alpha_pose_test/test.npz'
args.evaluate = 'pretrained_h36m_detectron_coco.bin'
with Timer(video_path):
main(args)
if __name__ == '__main__':
inference_video('./input/H017_GF_01_20210922_151118.mp4', 'mediapipe_pose')