Skip to content

Commit

Permalink
[Layer] Improve forwarding logic of ConcatLayer
Browse files Browse the repository at this point in the history
This PR updates current ConcatLayer forwarding for faster computation.

**Changes proposed in this PR:**
- Utilize the Tensor::concat() operation to perform forwarding and replace manual mapping and copying.

**Self-evaluation:**
1. Build test: [X]Passed [ ]Failed [ ]Skipped
2. Run test:   [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Donghyeon Jeong <[email protected]>
  • Loading branch information
djeong20 authored and jijoongmoon committed Aug 12, 2024
1 parent 0328879 commit 3b11453
Showing 1 changed file with 29 additions and 49 deletions.
78 changes: 29 additions & 49 deletions nntrainer/layers/concat_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,63 +112,43 @@ void ConcatLayer::forwarding(RunLayerContext &context, bool training) {
* @todo avoid copy by creating input here as a shared_tensor of the output
* here and then this layer can be in_place as well
*/
Tensor &output = context.getOutput(SINGLE_INOUT_IDX);

const TensorDim out_dim = output.getDim();
output.reshape(output_reshape_helper);
unsigned int output_height_offset = 0;
unsigned int data_copy_size = output_reshape_helper.width();
TensorDim::TensorType tensor_type = out_dim.getTensorType();
// Store original input tensor dimensions, then reshape input tensors.
std::vector<Tensor> input_tensors;
std::vector<TensorDim> original_input_dims;

for (unsigned int idx = 0; idx < context.getNumInputs(); idx++) {
Tensor &input = context.getInput(idx);
const TensorDim in_dim = input.getDim();
auto const &irh = input_reshape_helper[idx];
input.reshape(irh);
original_input_dims.push_back(input.getDim());
input.reshape(input_reshape_helper[idx]);
input_tensors.push_back(input);
}

if (in_dim.getDataType() == TensorDim::DataType::FP32) {
/** loop over the dimensions before the concat dimension */
for (unsigned int batch = 0; batch < output.batch(); batch++) {
/** loop over the concat dimension itself */
for (unsigned int count = 0; count < irh.height(); count++) {
Tensor dest_tensor = Tensor::Map<float>(
output.getAddress<float>(batch, 0, output_height_offset + count, 0),
data_copy_size * sizeof(float),
{1, 1, 1, data_copy_size, tensor_type});
const Tensor source_tensor =
Tensor::Map<float>(input.getAddress<float>(batch, 0, count, 0),
data_copy_size * sizeof(float),
{1, 1, 1, data_copy_size, tensor_type});
dest_tensor.copy(source_tensor);
}
}
} else if (in_dim.getDataType() == TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
/** loop over the dimensions before the concat dimension */
for (unsigned int batch = 0; batch < output.batch(); batch++) {
/** loop over the concat dimension itself */
for (unsigned int count = 0; count < irh.height(); count++) {
Tensor dest_tensor = Tensor::Map<_FP16>(
output.getAddress<_FP16>(batch, 0, output_height_offset + count, 0),
data_copy_size * sizeof(_FP16),
{1, 1, 1, data_copy_size, tensor_type});
const Tensor source_tensor =
Tensor::Map<_FP16>(input.getAddress<_FP16>(batch, 0, count, 0),
data_copy_size * sizeof(_FP16),
{1, 1, 1, data_copy_size, tensor_type});
dest_tensor.copy(source_tensor);
}
}
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
// Store the original output tensor dimension, then reshape the output tensor.
Tensor &output = context.getOutput(SINGLE_INOUT_IDX);
const TensorDim original_output_dim = output.getDim();
output.reshape(output_reshape_helper);

// Search for an axis and concatenate tensors.
const TensorDim out_dim = output.getDim();
const TensorDim in_dim = context.getInput(0).getDim();

for (int axis = 0; axis < 4; ++axis) {
if (out_dim[axis] != in_dim[axis]) {
/// @todo Currently a new output tensor is created. This can be optimized.
Tensor result = Tensor::cat(input_tensors, axis);
output.copy(result);
break;
}
}

input.reshape(in_dim);
output_height_offset += irh.height();
// Revert the tensors' dimensions back to their original shape.
for (unsigned int idx = 0; idx < context.getNumInputs(); idx++) {
Tensor &in = context.getInput(idx);
in.reshape(original_input_dims[idx]);
}

output.reshape(out_dim);
output.reshape(original_output_dim);
}

void ConcatLayer::incremental_forwarding(RunLayerContext &context,
Expand Down Expand Up @@ -229,7 +209,7 @@ void ConcatLayer::calcDerivative(RunLayerContext &context) {
unsigned int data_copy_size = output_reshape_helper.width();
TensorDim::TensorType tensor_type = output.getTensorType();

for (unsigned int idx = 0; idx < context.getNumInputs(); idx++) {
for (unsigned int idx = 0; idx < context.getNumInputs(); idx++) {
Tensor &input = context.getOutgoingDerivative(idx);
const TensorDim in_dim = input.getDim();
auto const &irh = input_reshape_helper[idx];
Expand Down

0 comments on commit 3b11453

Please sign in to comment.