diff --git a/nms.py b/nms.py index fd81d56..372fc39 100644 --- a/nms.py +++ b/nms.py @@ -55,8 +55,8 @@ def non_max_suppresion_v8(prediction, conf_thres=0.25, iou_thres=0.45, classes=N labels=(), max_det=300): # TODO: Test this for changed parameter number in yolov8!!! All that could be detected with last commit were persons! - nc = prediction.shape[2] - 5 # number of classes - xc = prediction[..., 4] > conf_thres # candidates + nc = prediction.shape[2] - 4 # number of classes + xc = np.amax(prediction[..., 4:], axis=2, keepdims=True) > conf_thres # Checks assert 0 <= conf_thres <= 1, f'Invalid Confidence threshold {conf_thres}, valid values are between 0.0 and 1.0'