Skip to content

Commit

Permalink
EfficientNMS: Dynamic Input Shape
Browse files Browse the repository at this point in the history
Signed-off-by: Rajeev Rao <[email protected]>
  • Loading branch information
wraveane authored and rajeevsrao committed Sep 22, 2021
1 parent bdcfa83 commit 5325570
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 29 deletions.
40 changes: 27 additions & 13 deletions plugin/efficientNMSPlugin/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
- [Description](#description)
- [Structure](#structure)
* [Inputs](#inputs)
* [Dynamic Shape Support](#dynamic-shape-support)
* [Box Coding Type](#box-coding-type)
* [Outputs](#outputs)
* [Parameters](#parameters)
Expand Down Expand Up @@ -32,19 +33,35 @@ The plugin has two modes of operation, depending on the given input data. The pl
Most object detection networks work by generating raw predictions from a "localization head" which adjust the coordinates of standard non-learned anchor coordinates to produce a tighter fitting bounding box. This process is called "box decoding", and it usually involves a large number of element-wise operations to transform the anchors to final box coordinates. As this can involve exponential operations on a large number of anchors, it can be computationally expensive, so this plugin gives the option of fusing the box decoder within the NMS operation which can be done in a far more efficient manner, resulting in lower latency for the network.

#### Boxes Input
The boxes input has shape `[batch_size, number_boxes, 4]` or `[batch_size, number_boxes, number_classes, 4]`, where the former is in case a single box prediction is produced for all classes such as in EfficientDet or SSD, and the latter is when separate box predictions are generated for each individual class, such as in FasterRCNN. The final dimension represents the four coordinates that define the bounding box prediction.
> **Input Shape:** `[batch_size, number_boxes, 4]` or `[batch_size, number_boxes, number_classes, 4]`
>
> **Data Type:** `float32` or `float16`
For *Standard NMS* mode, this tensor should contain the final box coordinates for each predicted detection. For *Fused Box Decoder* mode, this tensor should have the raw localization predictions.
The boxes input can have 3 dimensions in case a single box prediction is produced for all classes (such as in EfficientDet or SSD), or 4 dimensions when separate box predictions are generated for each class (such as in FasterRCNN), in which case `number_classes` >= 1 and must match the number of classes in the scores input. The final dimension represents the four coordinates that define the bounding box prediction.

For *Standard NMS* mode, this tensor should contain the final box coordinates for each predicted detection. For *Fused Box Decoder* mode, this tensor should have the raw localization predictions. In either case, this data is given as `4` coordinates which makes up the final shape dimension.

#### Scores Input
The scores input has shape `[batch_size, number_boxes, number_classes]`, such that for each anchor box, there are `num_classes` elements with the predicted scores for each candidate class.
> **Input Shape:** `[batch_size, number_boxes, number_classes]`
>
> **Data Type:** `float32` or `float16`
The scores input has `number_classes` elements with the predicted scores for each candidate class for each of the `number_boxes` anchor boxes.

Usually, the score values will have passed through a sigmoid activation function before reaching the NMS operation. However, as an optimization, the pre-sigmoid raw scores can also be provided to the NMS plugin to reduce overall network latency. If raw scores are given, enable the `score_activation` parameter so they are processed accordingly.

#### Anchors Input (Optional)
Only used in *Fused Box Decoder* mode. It is much more efficient to perform the box decoding steps within this plugin. In this case, the boxes input will be treated as the raw box corrections, and this third input should contain the default anchor/prior box coordinates.
> **Input Shape:** `[1, number_boxes, 4]` or `[batch_size, number_boxes, 4]`
>
> **Data Type:** `float32` or `float16`
Only used in *Fused Box Decoder* mode. It is much more efficient to perform the box decoding within this plugin. In this case, the boxes input will be treated as the raw localization head box corrections, and this third input should contain the default anchor/prior box coordinates.

When used, the anchors input has shape `[1, number_anchors, 4]` or `[batch_size, number_anchors, 4]`, where the former is in case anchors are the same for all images in the batch, and the latter is in case they change for each image -- such as in the box refinement NMS of FasterRCNN's second stage.
When used, the input must have 3 dimensions, where the first one may be either `1` in case anchors are constant for all images in a batch, or `batch_size` in case each image has different anchors -- such as in the box refinement NMS of FasterRCNN's second stage.

### Dynamic Shape Support

Most input shape dimensions, namely `batch_size`, `number_boxes`, and `number_classes`, for all inputs can be defined dynamically at runtime if the TensorRT engine is built with dynamic input shapes. However, once defined, these dimensions must match across all tensors that use them (e.g. the same `number_boxes` dimension must be given for both boxes and scores, etc.)

### Box Coding Type
Different object detection networks represent their box coordinate system differently. The two types supported by this plugin are:
Expand All @@ -58,22 +75,19 @@ In *Fused Box Decoder* mode, the boxes and anchor tensors should both use the sa

### Outputs

The following five output are generated:
The following four output tensors are generated:

- **num_detections:**
This is a `[batch_size, 1]` integer tensor. The last dimension is a scalar indicating the number of valid detections per batch item. It can be less than `keepTopK`. Only the top `num_detections[i]` entries in `nms_boxes[i]`, `nms_scores[i]` and `nms_classes[i]` are valid.
This is a `[batch_size, 1]` tensor of data type `int32`. The last dimension is a scalar indicating the number of valid detections per batch image. It can be less than `max_output_boxes`. Only the top `num_detections[i]` entries in `nms_boxes[i]`, `nms_scores[i]` and `nms_classes[i]` are valid.

- **detection_boxes:**
This is a `[batch_size, max_output_boxes, 4]` floating point tensor containing the coordinates of non-max suppressed boxes. The output coordinates will always be in BoxCorner format, regardless of the input code type.
This is a `[batch_size, max_output_boxes, 4]` tensor of data type `float32` or `float16`, containing the coordinates of non-max suppressed boxes. The output coordinates will always be in BoxCorner format, regardless of the input code type.

- **detection_scores:**
This is a `[batch_size, max_output_boxes]` floating point tensor containing the scores for the boxes.
This is a `[batch_size, max_output_boxes]` tensor of data type `float32` or `float16`, containing the scores for the boxes.

- **detection_classes:**
This is a `[batch_size, max_output_boxes]` integer tensor containing the classes for the boxes.

- **detection_indices:**
This is a `[batch_size * max_output_boxes, 3]` integer tensor that contains the selected box indices for each box kept by NMS. The purpose of this output is to mimic the result of the [NonMaxSuppression](https://github.com/onnx/onnx/blob/master/docs/Operators.md#NonMaxSuppression) ONNX op.
This is a `[batch_size, max_output_boxes]` tensor of data type `int32`, containing the classes for the boxes.

### Parameters

Expand Down
25 changes: 13 additions & 12 deletions plugin/efficientNMSPlugin/efficientNMSInference.cu
Original file line number Diff line number Diff line change
Expand Up @@ -547,46 +547,47 @@ cudaError_t EfficientNMSFilterLauncher(EfficientNMSParameters& param, const T* s
}

template <typename T>
size_t EfficientNMSSortWorkspaceSize(EfficientNMSParameters param)
size_t EfficientNMSSortWorkspaceSize(int batchSize, int numScoreElements)
{
size_t sortedWorkspaceSize = 0;
cub::DoubleBuffer<T> keysDB(nullptr, nullptr);
cub::DoubleBuffer<int> valuesDB(nullptr, nullptr);
cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr, sortedWorkspaceSize, keysDB, valuesDB,
param.numScoreElements, param.batchSize, (const int*) nullptr, (const int*) nullptr);
numScoreElements, batchSize, (const int*) nullptr, (const int*) nullptr);
return sortedWorkspaceSize;
}

size_t EfficientNMSWorkspaceSize(EfficientNMSParameters param)
size_t EfficientNMSWorkspaceSize(int batchSize, int numScoreElements, int numClasses, DataType datatype)
{
size_t total = 0, size = 0, align = 256;
size_t total = 0;
const size_t align = 256;
// Counters
// 3 for Filtering
// 1 for Output Indexing
// C for Max per Class Limiting
size = (3 + 1 + param.numClasses) * param.batchSize * sizeof(int);
size_t size = (3 + 1 + numClasses) * batchSize * sizeof(int);
total += size + (size % align ? align - (size % align) : 0);
// Int Buffers
for (int i = 0; i < 4; i++)
{
size = param.batchSize * param.numScoreElements * sizeof(int);
size = batchSize * numScoreElements * sizeof(int);
total += size + (size % align ? align - (size % align) : 0);
}
// Float Buffers
for (int i = 0; i < 2; i++)
{
size = param.batchSize * param.numScoreElements * dataTypeSize(param.datatype);
size = batchSize * numScoreElements * dataTypeSize(datatype);
total += size + (size % align ? align - (size % align) : 0);
}
// Sort Workspace
if (param.datatype == DataType::kHALF)
if (datatype == DataType::kHALF)
{
size = EfficientNMSSortWorkspaceSize<__half>(param);
size = EfficientNMSSortWorkspaceSize<__half>(batchSize, numScoreElements);
total += size + (size % align ? align - (size % align) : 0);
}
else if (param.datatype == DataType::kFLOAT)
else if (datatype == DataType::kFLOAT)
{
size = EfficientNMSSortWorkspaceSize<float>(param);
size = EfficientNMSSortWorkspaceSize<float>(batchSize, numScoreElements);
total += size + (size % align ? align - (size % align) : 0);
}

Expand Down Expand Up @@ -652,7 +653,7 @@ pluginStatus_t EfficientNMSDispatch(EfficientNMSParameters param, const void* bo
T* topScoresData = EfficientNMSWorkspace<T>(workspace, workspaceOffset, param.batchSize * param.numScoreElements);
T* sortedScoresData
= EfficientNMSWorkspace<T>(workspace, workspaceOffset, param.batchSize * param.numScoreElements);
size_t sortedWorkspaceSize = EfficientNMSSortWorkspaceSize<T>(param);
size_t sortedWorkspaceSize = EfficientNMSSortWorkspaceSize<T>(param.batchSize, param.numScoreElements);
char* sortedWorkspaceData = EfficientNMSWorkspace<char>(workspace, workspaceOffset, sortedWorkspaceSize);
cub::DoubleBuffer<T> scoresDB(topScoresData, sortedScoresData);
cub::DoubleBuffer<int> indexDB(topIndexData, sortedIndexData);
Expand Down
2 changes: 1 addition & 1 deletion plugin/efficientNMSPlugin/efficientNMSInference.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

#include "efficientNMSParameters.h"

size_t EfficientNMSWorkspaceSize(EfficientNMSParameters param);
size_t EfficientNMSWorkspaceSize(int batchSize, int numScoreElements, int numClasses, nvinfer1::DataType datatype);

pluginStatus_t EfficientNMSInference(EfficientNMSParameters param, const void* boxesInput, const void* scoresInput,
const void* anchorsInput, void* numDetectionsOutput, void* nmsBoxesOutput, void* nmsScoresOutput,
Expand Down
7 changes: 4 additions & 3 deletions plugin/efficientNMSPlugin/efficientNMSPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,9 +330,10 @@ void EfficientNMSPlugin::configurePlugin(
size_t EfficientNMSPlugin::getWorkspaceSize(
const PluginTensorDesc* inputs, int nbInputs, const PluginTensorDesc* outputs, int nbOutputs) const noexcept
{
EfficientNMSParameters p = mParam;
p.batchSize = inputs[0].dims.d[0];
return EfficientNMSWorkspaceSize(p);
int batchSize = inputs[1].dims.d[0];
int numScoreElements = inputs[1].dims.d[1] * inputs[1].dims.d[2];
int numClasses = inputs[1].dims.d[2];
return EfficientNMSWorkspaceSize(batchSize, numScoreElements, numClasses, mParam.datatype);
}

int EfficientNMSPlugin::enqueue(const PluginTensorDesc* inputDesc, const PluginTensorDesc* outputDesc,
Expand Down

0 comments on commit 5325570

Please sign in to comment.