-
Notifications
You must be signed in to change notification settings - Fork 254
/
PoseEstimateLoader.py
39 lines (31 loc) · 1.33 KB
/
PoseEstimateLoader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
import cv2
import torch
from SPPE.src.main_fast_inference import InferenNet_fast, InferenNet_fastRes50
from SPPE.src.utils.img import crop_dets
from pPose_nms import pose_nms
from SPPE.src.utils.eval import getPrediction
class SPPE_FastPose(object):
def __init__(self,
backbone,
input_height=320,
input_width=256,
device='cuda'):
assert backbone in ['resnet50', 'resnet101'], '{} backbone is not support yet!'.format(backbone)
self.inp_h = input_height
self.inp_w = input_width
self.device = device
if backbone == 'resnet101':
self.model = InferenNet_fast().to(device)
else:
self.model = InferenNet_fastRes50().to(device)
self.model.eval()
def predict(self, image, bboxs, bboxs_scores):
inps, pt1, pt2 = crop_dets(image, bboxs, self.inp_h, self.inp_w)
pose_hm = self.model(inps.to(self.device)).cpu().data
# Cut eyes and ears.
pose_hm = torch.cat([pose_hm[:, :1, ...], pose_hm[:, 5:, ...]], dim=1)
xy_hm, xy_img, scores = getPrediction(pose_hm, pt1, pt2, self.inp_h, self.inp_w,
pose_hm.shape[-2], pose_hm.shape[-1])
result = pose_nms(bboxs, bboxs_scores, xy_img, scores)
return result