-
Notifications
You must be signed in to change notification settings - Fork 2
/
yolov5_tflite_video_inference.py
87 lines (63 loc) · 2.91 KB
/
yolov5_tflite_video_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from yolov5_tflite_inference import yolov5_tflite
import argparse
import cv2
from PIL import Image
from utils import letterbox_image, scale_coords
import numpy as np
import time
def detect_video(weights,video_filepath,img_size,conf_thres,iou_thres):
start_time = time.time()
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video = cv2.VideoCapture(video_filepath)
fps = video.get(cv2.CAP_PROP_FPS)
h = int(video.get(3))
w = int(video.get(4))
print(w,h)
#h = 1280
#w = 720
result_video_filepath = video_filepath.split('/')[-1].split('.')[0] + 'yolov5_output.mp4'
out = cv2.VideoWriter(result_video_filepath,fourcc,int(fps),(h,w))
yolov5_tflite_obj = yolov5_tflite(weights,img_size,conf_thres,iou_thres)
size = (img_size,img_size)
no_of_frames = 0
while True:
check, frame = video.read()
if not check:
break
#frame = cv2.resize(frame,(h,w))
no_of_frames += 1
image_resized = letterbox_image(Image.fromarray(frame),size)
image_array = np.asarray(image_resized)
normalized_image_array = image_array.astype(np.float32) / 255.0
result_boxes, result_scores, result_class_names = yolov5_tflite_obj.detect(normalized_image_array)
if len(result_boxes) > 0:
result_boxes = scale_coords(size,np.array(result_boxes),(w,h))
font = cv2.FONT_HERSHEY_SIMPLEX
# org
org = (20, 40)
# fontScale
fontScale = 0.5
# Blue color in BGR
color = (0, 255, 0)
# Line thickness of 1 px
thickness = 1
for i,r in enumerate(result_boxes):
org = (int(r[0]),int(r[1]))
cv2.rectangle(frame, (int(r[0]),int(r[1])), (int(r[2]),int(r[3])), (255,0,0), 1)
cv2.putText(frame, str(int(100*result_scores[i])) + '% ' + str(result_class_names[i]), org, font,
fontScale, color, thickness, cv2.LINE_AA)
out.write(frame)
end_time = time.time()
print('FPS:',1/(end_time-start_time))
start_time = end_time
out.release()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-w','--weights', type=str, default='yolov5s-fp16.tflite', help='model.tflite path(s)')
parser.add_argument('-v','--video_path', type=str, required=True, help='video path')
parser.add_argument('--img_size', type=int, default=416, help='image size') # height, width
parser.add_argument('--conf_thres', type=float, default=0.25, help='object confidence threshold')
parser.add_argument('--iou_thres', type=float, default=0.45, help='IOU threshold for NMS')
opt = parser.parse_args()
print(opt)
detect_video(opt.weights,opt.video_path,opt.img_size,opt.conf_thres,opt.iou_thres)