diff --git a/nms.py b/nms.py index e134ca3..6fa2dac 100644 --- a/nms.py +++ b/nms.py @@ -101,14 +101,14 @@ def non_max_suppresion_v8(prediction, conf_thres=0.25, iou_thres=0.45, classes=N nc = prediction.shape[2] - 4 # number of classes # xc = prediction[..., 4] > conf_thres # candidates - xc = [] - for xi, x in enumerate(prediction): # image index, image inference - conf = np.amax(x[:, 4:], axis=1, keepdims=True) - if np.any(conf > conf_thres): - xc.append(True) - else: - xc.append(False) - # xc = np.amax(prediction[..., 4:], axis=2, keepdims=True) > conf_thres + # xc = [] + # for xi, x in enumerate(prediction): # image index, image inference + # conf = np.amax(x[:, 4:], axis=1, keepdims=True) + # if np.any(conf > conf_thres): + # xc.append(True) + # else: + # xc.append(False) + xc = np.amax(prediction[..., 4:], axis=2) > conf_thres # Checks assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'