From 54a1c549175d2d3d00cb358e56518d496d4b7446 Mon Sep 17 00:00:00 2001 From: Sebastian Date: Sun, 24 Mar 2024 01:15:15 +0100 Subject: [PATCH] Added v8 parameter to edgetpumodel.py and nms.py --- edgetpumodel.py | 5 +++-- nms.py | 10 +++++----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/edgetpumodel.py b/edgetpumodel.py index 11e6076..00f30e6 100644 --- a/edgetpumodel.py +++ b/edgetpumodel.py @@ -167,14 +167,15 @@ def forward(self, x:np.ndarray, with_nms=True) -> np.ndarray: # Scale output result = (common.output_tensor(self.interpreter, 0).astype('float32') - self.output_zero) * self.output_scale if self.v8: - result = np.transpose(result,[0,2,1]) + result = np.transpose(result, [0, 2, 1]) # tranpose for yoolov8 models self.inference_time = time.time() - tstart if with_nms: tstart = time.time() - nms_result = non_max_suppression(result, self.conf_thresh, self.iou_thresh, self.filter_classes, self.agnostic_nms, max_det=self.max_det) + nms_result = non_max_suppression(result, self.conf_thresh, self.iou_thresh, self.filter_classes, + self.agnostic_nms, max_det=self.max_det, v8=self.v8) self.nms_time = time.time() - tstart return nms_result diff --git a/nms.py b/nms.py index 2978cef..541a642 100644 --- a/nms.py +++ b/nms.py @@ -53,11 +53,11 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non labels=(), max_det=300, v8=False): # TODO: Correct this for changed parameter number in yolov8!!! if v8: - split_val=-1 + split_val = -1 else: - split_val=0 + split_val = 0 - nc = prediction.shape[2] - 5 + split_val # number of classes + nc = prediction.shape[2] - 5 - split_val # number of classes xc = prediction[..., 4] > conf_thres # candidates # Checks @@ -83,7 +83,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non if labels and len(labels[xi]): l = labels[xi] v = np.zeros((len(l), nc + 5 + split_val)) - v[:, :4] = l[:, 1:5 + split_val] # box + v[:, :4] = l[:, 1:(5 + split_val)] # box v[:, 4 + split_val] = 1.0 # conf v[range(len(l)), l[:, 0].long() + 5 + split_val] = 1.0 # cls x = np.concatenate((x, v), 0) @@ -103,7 +103,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non # Detections matrix nx6 (xyxy, conf, cls) if multi_label: - i, j = (x[:, 5 + split_val:] > conf_thres).nonzero(as_tuple=False).T + i, j = (x[:, (5 + split_val):] > conf_thres).nonzero(as_tuple=False).T x = np.concatenate((box[i], x[i, j + 5 + split_val, None], j[:, None].astype(float)), axis=1) else: # best class only conf = np.amax(x[:, 5 + split_val:], axis=1, keepdims=True)