Skip to content

Commit

Permalink
Update nms.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sebastianfis authored Mar 23, 2024
1 parent b59317a commit 64062a6
Showing 1 changed file with 20 additions and 12 deletions.
32 changes: 20 additions & 12 deletions nms.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,14 @@ def nms(dets, scores, thresh):


def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False,
labels=(), max_det=300):

nc = prediction.shape[2] - 5 # number of classes
labels=(), max_det=300, v8=False):
# TODO: Correct this for changed parameter number in yolov8!!!
if v8:
split_val=-1
else:
split_val=0

nc = prediction.shape[2] - 5 + split_val # number of classes
xc = prediction[..., 4] > conf_thres # candidates

# Checks
Expand All @@ -77,29 +82,32 @@ def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=Non
# Cat apriori labels if autolabelling
if labels and len(labels[xi]):
l = labels[xi]
v = np.zeros((len(l), nc + 5))
v[:, :4] = l[:, 1:5] # box
v[:, 4] = 1.0 # conf
v[range(len(l)), l[:, 0].long() + 5] = 1.0 # cls
v = np.zeros((len(l), nc + 5 + split_val))
v[:, :4] = l[:, 1:5 + split_val] # box
v[:, 4 + split_val] = 1.0 # conf
v[range(len(l)), l[:, 0].long() + 5 + split_val] = 1.0 # cls
x = np.concatenate((x, v), 0)

# If none remain process next image
if not x.shape[0]:
continue

# Compute conf
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf
if v8:
x[:, 4:] *= x[:, 4]
else:
x[:, 5:] *= x[:, 4:5] # conf = obj_conf * cls_conf

# Box (center x, center y, width, height) to (x1, y1, x2, y2)
box = xywh2xyxy(x[:, :4])

# Detections matrix nx6 (xyxy, conf, cls)
if multi_label:
i, j = (x[:, 5:] > conf_thres).nonzero(as_tuple=False).T
x = np.concatenate((box[i], x[i, j + 5, None], j[:, None].astype(float)), axis=1)
i, j = (x[:, 5 + split_val:] > conf_thres).nonzero(as_tuple=False).T
x = np.concatenate((box[i], x[i, j + 5 + split_val, None], j[:, None].astype(float)), axis=1)
else: # best class only
conf = np.amax(x[:, 5:], axis=1, keepdims=True)
j = np.argmax(x[:, 5:], axis=1).reshape(conf.shape)
conf = np.amax(x[:, 5 + split_val:], axis=1, keepdims=True)
j = np.argmax(x[:, 5 + split_val:], axis=1).reshape(conf.shape)
x = np.concatenate((box, conf, j.astype(float)), axis=1)[conf.flatten() > conf_thres]

# Filter by class
Expand Down

0 comments on commit 64062a6

Please sign in to comment.