diff --git a/nntrainer/layers/concat_layer.cpp b/nntrainer/layers/concat_layer.cpp index 8a28fb3e80..5536c4a82d 100644 --- a/nntrainer/layers/concat_layer.cpp +++ b/nntrainer/layers/concat_layer.cpp @@ -112,63 +112,43 @@ void ConcatLayer::forwarding(RunLayerContext &context, bool training) { * @todo avoid copy by creating input here as a shared_tensor of the output * here and then this layer can be in_place as well */ - Tensor &output = context.getOutput(SINGLE_INOUT_IDX); - const TensorDim out_dim = output.getDim(); - output.reshape(output_reshape_helper); - unsigned int output_height_offset = 0; - unsigned int data_copy_size = output_reshape_helper.width(); - TensorDim::TensorType tensor_type = out_dim.getTensorType(); + // Store original input tensor dimensions, then reshape input tensors. + std::vector input_tensors; + std::vector original_input_dims; for (unsigned int idx = 0; idx < context.getNumInputs(); idx++) { Tensor &input = context.getInput(idx); - const TensorDim in_dim = input.getDim(); - auto const &irh = input_reshape_helper[idx]; - input.reshape(irh); + original_input_dims.push_back(input.getDim()); + input.reshape(input_reshape_helper[idx]); + input_tensors.push_back(input); + } - if (in_dim.getDataType() == TensorDim::DataType::FP32) { - /** loop over the dimensions before the concat dimension */ - for (unsigned int batch = 0; batch < output.batch(); batch++) { - /** loop over the concat dimension itself */ - for (unsigned int count = 0; count < irh.height(); count++) { - Tensor dest_tensor = Tensor::Map( - output.getAddress(batch, 0, output_height_offset + count, 0), - data_copy_size * sizeof(float), - {1, 1, 1, data_copy_size, tensor_type}); - const Tensor source_tensor = - Tensor::Map(input.getAddress(batch, 0, count, 0), - data_copy_size * sizeof(float), - {1, 1, 1, data_copy_size, tensor_type}); - dest_tensor.copy(source_tensor); - } - } - } else if (in_dim.getDataType() == TensorDim::DataType::FP16) { -#ifdef ENABLE_FP16 - /** loop over the dimensions before the concat dimension */ - for (unsigned int batch = 0; batch < output.batch(); batch++) { - /** loop over the concat dimension itself */ - for (unsigned int count = 0; count < irh.height(); count++) { - Tensor dest_tensor = Tensor::Map<_FP16>( - output.getAddress<_FP16>(batch, 0, output_height_offset + count, 0), - data_copy_size * sizeof(_FP16), - {1, 1, 1, data_copy_size, tensor_type}); - const Tensor source_tensor = - Tensor::Map<_FP16>(input.getAddress<_FP16>(batch, 0, count, 0), - data_copy_size * sizeof(_FP16), - {1, 1, 1, data_copy_size, tensor_type}); - dest_tensor.copy(source_tensor); - } - } -#else - throw std::invalid_argument("Error: enable-fp16 is not enabled"); -#endif + // Store the original output tensor dimension, then reshape the output tensor. + Tensor &output = context.getOutput(SINGLE_INOUT_IDX); + const TensorDim original_output_dim = output.getDim(); + output.reshape(output_reshape_helper); + + // Search for an axis and concatenate tensors. + const TensorDim out_dim = output.getDim(); + const TensorDim in_dim = context.getInput(0).getDim(); + + for (int axis = 0; axis < 4; ++axis) { + if (out_dim[axis] != in_dim[axis]) { + /// @todo Currently a new output tensor is created. This can be optimized. + Tensor result = Tensor::cat(input_tensors, axis); + output.copy(result); + break; } + } - input.reshape(in_dim); - output_height_offset += irh.height(); + // Revert the tensors' dimensions back to their original shape. + for (unsigned int idx = 0; idx < context.getNumInputs(); idx++) { + Tensor &in = context.getInput(idx); + in.reshape(original_input_dims[idx]); } - output.reshape(out_dim); + output.reshape(original_output_dim); } void ConcatLayer::incremental_forwarding(RunLayerContext &context, @@ -229,7 +209,7 @@ void ConcatLayer::calcDerivative(RunLayerContext &context) { unsigned int data_copy_size = output_reshape_helper.width(); TensorDim::TensorType tensor_type = output.getTensorType(); - for (unsigned int idx = 0; idx < context.getNumInputs(); idx++) { + for (unsigned int idx = 0; idx < context.getNumInputs(); idx++) { Tensor &input = context.getOutgoingDerivative(idx); const TensorDim in_dim = input.getDim(); auto const &irh = input_reshape_helper[idx];