Skip to content

Commit

Permalink
Added v8 parameter to edgetpumodel.py and nms.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastianfis committed Mar 24, 2024
1 parent 64062a6 commit 54a1c54
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
5 changes: 3 additions & 2 deletions edgetpumodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,14 +167,15 @@ def forward(self, x:np.ndarray, with_nms=True) -> np.ndarray:
# Scale output
result = (common.output_tensor(self.interpreter, 0).astype('float32') - self.output_zero) * self.output_scale
if self.v8:
result = np.transpose(result,[0,2,1])
result = np.transpose(result, [0, 2, 1]) # tranpose for yoolov8 models

self.inference_time = time.time() - tstart

if with_nms:

tstart = time.time()
nms_result = non_max_suppression(result, self.conf_thresh, self.iou_thresh, self.filter_classes, self.agnostic_nms, max_det=self.max_det)
nms_result = non_max_suppression(result, self.conf_thresh, self.iou_thresh, self.filter_classes,
self.agnostic_nms, max_det=self.max_det, v8=self.v8)
self.nms_time = time.time() - tstart

return nms_result
Expand Down
10 changes: 5 additions & 5 deletions nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,11 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
labels=(), max_det=300, v8=False):
# TODO: Correct this for changed parameter number in yolov8!!!
if v8:
split_val=-1
split_val = -1
else:
split_val=0
split_val = 0

nc = prediction.shape[2] - 5 + split_val # number of classes
nc = prediction.shape[2] - 5 - split_val # number of classes
xc = prediction[..., 4] > conf_thres # candidates

# Checks
Expand All @@ -83,7 +83,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
if labels and len(labels[xi]):
l = labels[xi]
v = np.zeros((len(l), nc + 5 + split_val))
v[:, :4] = l[:, 1:5 + split_val] # box
v[:, :4] = l[:, 1:(5 + split_val)] # box
v[:, 4 + split_val] = 1.0 # conf
v[range(len(l)), l[:, 0].long() + 5 + split_val] = 1.0 # cls
x = np.concatenate((x, v), 0)
Expand All @@ -103,7 +103,7 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non

# Detections matrix nx6 (xyxy, conf, cls)
if multi_label:
i, j = (x[:, 5 + split_val:] > conf_thres).nonzero(as_tuple=False).T
i, j = (x[:, (5 + split_val):] > conf_thres).nonzero(as_tuple=False).T
x = np.concatenate((box[i], x[i, j + 5 + split_val, None], j[:, None].astype(float)), axis=1)
else: # best class only
conf = np.amax(x[:, 5 + split_val:], axis=1, keepdims=True)
Expand Down

0 comments on commit 54a1c54

Please sign in to comment.