diff --git a/api/ccapi/include/layer.h b/api/ccapi/include/layer.h index 9414f01680..d6f16c485b 100644 --- a/api/ccapi/include/layer.h +++ b/api/ccapi/include/layer.h @@ -366,9 +366,9 @@ Reshape(const std::vector &properties = {}) { /** * @brief Helper function to create addition layer */ -inline std::unique_ptr Addition( - const std::vector &properties = {}, - const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) { +inline std::unique_ptr +Addition(const std::vector &properties = {}, + const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) { return createLayer(LayerType::LAYER_ADDITION, properties, compute_engine); } @@ -376,8 +376,9 @@ inline std::unique_ptr Addition( * @brief Helper function to create concat layer */ inline std::unique_ptr -Concat(const std::vector &properties = {}) { - return createLayer(LayerType::LAYER_CONCAT, properties); +Concat(const std::vector &properties = {}, + const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) { + return createLayer(LayerType::LAYER_CONCAT, properties, compute_engine); } /** diff --git a/nntrainer/cl_context.cpp b/nntrainer/cl_context.cpp index 438031d586..136a74ac8c 100644 --- a/nntrainer/cl_context.cpp +++ b/nntrainer/cl_context.cpp @@ -6,7 +6,7 @@ * @date 23 Feb 2024 * @see https://github.com/nnstreamer/nntrainer * @author Debadri Samaddar - * @author Niket Agarwal + * @author Niket Agarwal * @bug No known bugs except for NYI items * @brief This file contains app context related functions and classes that * manages the global configuration of the current OpenCL environment. It also @@ -15,6 +15,7 @@ #include #include +#include #include #include @@ -36,6 +37,9 @@ static void add_default_object(ClContext &cc) { cc.registerFactory(nntrainer::createLayer, SwiGLULayerCl::type, ml::train::LayerType::LAYER_SWIGLU); + + cc.registerFactory(nntrainer::createLayer, ConcatLayerCl::type, + ml::train::LayerType::LAYER_CONCAT); } static void registerer(ClContext &cc) noexcept { diff --git a/nntrainer/layers/cl_layers/concat_cl.cpp b/nntrainer/layers/cl_layers/concat_cl.cpp new file mode 100644 index 0000000000..50e8c5ac7a --- /dev/null +++ b/nntrainer/layers/cl_layers/concat_cl.cpp @@ -0,0 +1,540 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Niket Agarwal + * + * @file concat_cl.cpp + * @date 2 July 2024 + * @brief Implementation of Concat Layer + * @see https://github.com/nnstreamer/nntrainer + * @author Niket Agarwal + * @bug No known bugs except for NYI items + * + */ + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +std::string concat_cl_kernel_fp16_ = + R"( + #pragma OPENCL EXTENSION cl_khr_fp16 : enable + __kernel void concat_cl_fp16(__global const half* in1, + __global const half* in2, + __global half* out, + const int batch_size, + const int channels, + const int height, + const int width1, + const int width2) { + int global_id = get_global_id(0); + + int total_width = width1 + width2; + + int width = total_width; + + // 4D space coordinates + int w = global_id % total_width; + int h = (global_id / total_width) % height; + int c = (global_id / (total_width * height)) % channels; + int b = global_id / (total_width * height * channels); + + int output_index = ((b * channels + c) * height + h) * total_width + w; + + // Determining if the index is in in1 or in2 + if (w < width1) { + // in1 index calculation + int input1_index = ((b * channels + c) * height + h) * width1 + w; + out[output_index] = in1[input1_index]; + + } else { + // in2 index calculation + int input2_index = ((b * channels + c) * height + h) * width2 + (w - width1); + out[output_index] = in2[input2_index]; + } +})"; + +std::string concat_cl_kernel_ = + R"(__kernel void concat_cl(__global const float* in1, + __global const float* in2, + __global float* out, + const int batch_size, + const int channels, + const int height, + const int width1, + const int width2) { + int global_id = get_global_id(0); + + int total_width = width1 + width2; + + int width = total_width; + + // 4D space coordinates + int w = global_id % total_width; + int h = (global_id / total_width) % height; + int c = (global_id / (total_width * height)) % channels; + int b = global_id / (total_width * height * channels); + + int output_index = ((b * channels + c) * height + h) * total_width + w; + + // Determining if the index is in in1 or in2 + if (w < width1) { + // in1 index calculation + int input1_index = ((b * channels + c) * height + h) * width1 + w; + out[output_index] = in1[input1_index]; + + } else { + // in2 index calculation + int input2_index = ((b * channels + c) * height + h) * width2 + (w - width1); + out[output_index] = in2[input2_index]; + } +})"; + +namespace nntrainer { +ConcatLayerCl::ConcatLayerCl() : Layer(), leading_helper_dim(1) {} + +static constexpr size_t SINGLE_INOUT_IDX = 0; +static constexpr size_t INPUT_IDX_1 = 0; +static constexpr size_t INPUT_IDX_2 = 1; + +void ConcatLayerCl::finalize(InitLayerContext &context) { + auto &concat_dimension_prop = std::get(concat_props); + /** for backward compatibility, default concat dimension will be channel */ + /// @todo this is hacky way to force concat dimension to width if channel + /// dimension is taken, this is because recurrent realizer, return sequence + /// exploits concat layer but have no control over where to stack/axis + unsigned int concat_dimension = + context.getInputDimensions().front().channel() > 1 ? 3 : 1; + if (!concat_dimension_prop.empty()) + concat_dimension = concat_dimension_prop.get(); + + /** + * The concat is only done along the axis dimension. + * For example, consider 2 inputs a, b with dimensions [b,c,h,w] each + * 1. concat_dimension = 1, output_dim = [b,c_a+c_b,h,w] + * 2. concat_dimension = 2, output_dim = [b,c,h_a+h_b,w] + * 3. concat_dimension = 3, output_dim = [b,c,h,w_a+w_b] + */ + auto const &input_dims = context.getInputDimensions(); + const TensorDim &input_dim_0 = input_dims[SINGLE_INOUT_IDX]; + unsigned int concat_dim_val = input_dim_0.getTensorDim(concat_dimension); + + for (unsigned int idx = 1; idx < input_dims.size(); ++idx) { + const TensorDim &dim = input_dims[idx]; + + for (unsigned int i = 0; i < ml::train::TensorDim::getNumDim(); ++i) { + if (i == concat_dimension) + continue; + NNTR_THROW_IF(input_dim_0[i] != dim[i], std::runtime_error) + << "Error: concat layer requires same shape from all input layers " + "along non-concat dimension"; + } + concat_dim_val += dim[concat_dimension]; + } + + TensorDim output_dim = input_dim_0; + output_dim.setTensorDim(concat_dimension, concat_dim_val); + + context.setOutputDimensions({output_dim}); + + /** + * Setup output_reshape_helper to which output will be reshaped in forwarding + * to facilitate easier processing. + * + * The helper shape consolidates all the dimensions before the axis + * together and all the dimensions after the axis to facilitate + * easier splitting of the data. + */ + leading_helper_dim = 1; + output_reshape_helper.channel(1); + output_reshape_helper.height(1); + output_reshape_helper.width(1); + for (unsigned int idx = 1; idx < concat_dimension; ++idx) { + leading_helper_dim *= output_dim.getTensorDim(idx); + } + + output_reshape_helper.height(output_dim.getTensorDim(concat_dimension)); + + for (unsigned int idx = concat_dimension + 1; + idx < ml::train::TensorDim::getNumDim(); ++idx) { + output_reshape_helper.width(output_reshape_helper.width() * + output_dim.getTensorDim(idx)); + } + + /** + * Setup input_reshape_helper to which inputs will be reshaped in forwarding + * to facilitate easier processing. + */ + input_reshape_helper.resize(input_dims.size()); + for (unsigned int idx = 0; idx < input_reshape_helper.size(); idx++) { + input_reshape_helper[idx] = output_reshape_helper; + input_reshape_helper[idx].height( + input_dims[idx].getTensorDim(concat_dimension)); + } + + setBatch(input_dims[SINGLE_INOUT_IDX].batch()); +} + +void ConcatLayerCl::forwarding(RunLayerContext &context, bool training) { + Tensor &out = context.getOutput(SINGLE_INOUT_IDX); + const Tensor &in1 = context.getInput(INPUT_IDX_1); + const Tensor &in2 = context.getInput(INPUT_IDX_2); + ConcatProcess(in1, in2, out, context); +} + +void ConcatLayerCl::incremental_forwarding(RunLayerContext &context, + unsigned int from, unsigned int to, + bool training) { + /** + * @todo create another kernel for incremental_forwarding taking into + * consideration from and to values + */ +} + +opencl::Kernel ConcatLayerCl::kernel_concat; +opencl::Kernel ConcatLayerCl::kernel_concat_fp16; + +void ConcatLayerCl::ConcatProcess(Tensor const &in1, Tensor const &in2, + Tensor &result, RunLayerContext &context) { + + unsigned int input_batch_size, input_height, in1_width, input_channels, + in2_width; + auto dim1 = in1.getDim(); + auto dim2 = in2.getDim(); + input_batch_size = dim1.batch(); + input_height = dim1.height(); + input_channels = dim1.channel(); + in1_width = dim1.width(); + in2_width = dim2.width(); + + if (in1.getDataType() == ml::train::TensorDim::DataType::FP32) { + const float *data1 = in1.getData(); + const float *data2 = in2.getData(); + float *rdata = result.getData(); + concat_cl(data1, data2, rdata, input_batch_size, input_channels, + input_height, in1_width, in2_width, context); + } else if (in1.getDataType() == ml::train::TensorDim::DataType::FP16) { +#ifdef ENABLE_FP16 + const _FP16 *data1 = in1.getData<_FP16>(); + const _FP16 *data2 = in2.getData<_FP16>(); + _FP16 *rdata = result.getData<_FP16>(); + concat_cl_fp16(data1, data2, rdata, input_batch_size, input_channels, + input_height, in1_width, in2_width, context); +#else + throw std::invalid_argument("Error: enable-fp16 is not enabled"); +#endif + } +} + +void ConcatLayerCl::concat_cl(const float *matAdata, const float *vecXdata, + float *vecYdata, unsigned int input_batch_size, + unsigned int input_channels, + unsigned int input_height, unsigned int in1_width, + unsigned int in2_width, + RunLayerContext &context) { + + bool result = false; + + do { + result = + context.clCreateKernel(concat_cl_kernel_, context.LayerKernel::CONCAT, + ConcatLayerCl::kernel_concat); + if (!result) { + break; + } + + int dim = int(input_batch_size * input_channels * input_height * + (in1_width + in2_width)); + + opencl::Buffer inputA(context.context_inst_, + sizeof(float) * input_batch_size * input_channels * + input_height * in1_width, + true, nullptr); + + opencl::Buffer inputX(context.context_inst_, + sizeof(float) * input_batch_size * input_channels * + input_height * in2_width, + true, nullptr); + + opencl::Buffer inOutY(context.context_inst_, + sizeof(float) * input_batch_size * input_channels * + input_height * (in1_width + in2_width), + true, nullptr); + + result = inputA.WriteData(context.command_queue_inst_, matAdata); + if (!result) { + break; + } + + result = inputX.WriteData(context.command_queue_inst_, vecXdata); + if (!result) { + break; + } + + result = inOutY.WriteData(context.command_queue_inst_, vecYdata); + if (!result) { + break; + } + + result = ConcatLayerCl::kernel_concat.SetKernelArguments(0, &inputA, + sizeof(cl_mem)); + if (!result) { + break; + } + + result = ConcatLayerCl::kernel_concat.SetKernelArguments(1, &inputX, + sizeof(cl_mem)); + if (!result) { + break; + } + + result = ConcatLayerCl::kernel_concat.SetKernelArguments(2, &inOutY, + sizeof(cl_mem)); + if (!result) { + break; + } + + result = ConcatLayerCl::kernel_concat.SetKernelArguments( + 3, &input_batch_size, sizeof(int)); + if (!result) { + break; + } + + result = ConcatLayerCl::kernel_concat.SetKernelArguments(4, &input_channels, + sizeof(int)); + if (!result) { + break; + } + + result = ConcatLayerCl::kernel_concat.SetKernelArguments(5, &input_height, + sizeof(int)); + if (!result) { + break; + } + + result = ConcatLayerCl::kernel_concat.SetKernelArguments(6, &in1_width, + sizeof(int)); + if (!result) { + break; + } + + result = ConcatLayerCl::kernel_concat.SetKernelArguments(7, &in2_width, + sizeof(int)); + if (!result) { + break; + } + + const int work_groups_count[3] = {dim, 1, 1}; + const int work_group_size[3] = {32, 32, 1}; // test-value + + result = context.command_queue_inst_.DispatchCommand( + ConcatLayerCl::kernel_concat, work_groups_count, work_group_size); + if (!result) { + break; + } + + result = inOutY.ReadData(context.command_queue_inst_, vecYdata); + if (!result) { + break; + } + + } while (false); +} + +void ConcatLayerCl::concat_cl_fp16( + const __fp16 *matAdata, const __fp16 *vecXdata, __fp16 *vecYdata, + unsigned int input_batch_size, unsigned int input_channels, + unsigned int input_height, unsigned int in1_width, unsigned int in2_width, + RunLayerContext &context) { + + bool result = false; + + do { + result = context.clCreateKernel(concat_cl_kernel_fp16_, + context.LayerKernel::CONCAT_FP16, + ConcatLayerCl::kernel_concat_fp16); + if (!result) { + break; + } + + int dim = int(input_batch_size * input_channels * input_height * + (in1_width + in2_width)); + + opencl::Buffer inputA(context.context_inst_, + sizeof(__fp16) * input_batch_size * input_channels * + input_height * in1_width, + true, nullptr); + + opencl::Buffer inputX(context.context_inst_, + sizeof(__fp16) * input_batch_size * input_channels * + input_height * in2_width, + true, nullptr); + + opencl::Buffer inOutY(context.context_inst_, + sizeof(__fp16) * input_batch_size * input_channels * + input_height * (in1_width + in2_width), + true, nullptr); + + result = inputA.WriteData(context.command_queue_inst_, matAdata); + if (!result) { + break; + } + + result = inputX.WriteData(context.command_queue_inst_, vecXdata); + if (!result) { + break; + } + + result = inOutY.WriteData(context.command_queue_inst_, vecYdata); + if (!result) { + break; + } + + result = ConcatLayerCl::kernel_concat_fp16.SetKernelArguments( + 0, &inputA, sizeof(cl_mem)); + if (!result) { + break; + } + + result = ConcatLayerCl::kernel_concat_fp16.SetKernelArguments( + 1, &inputX, sizeof(cl_mem)); + if (!result) { + break; + } + + result = ConcatLayerCl::kernel_concat_fp16.SetKernelArguments( + 2, &inOutY, sizeof(cl_mem)); + if (!result) { + break; + } + + result = ConcatLayerCl::kernel_concat_fp16.SetKernelArguments( + 3, &input_batch_size, sizeof(int)); + if (!result) { + break; + } + + result = ConcatLayerCl::kernel_concat_fp16.SetKernelArguments( + 4, &input_channels, sizeof(int)); + if (!result) { + break; + } + + result = ConcatLayerCl::kernel_concat_fp16.SetKernelArguments( + 5, &input_height, sizeof(int)); + if (!result) { + break; + } + + result = ConcatLayerCl::kernel_concat_fp16.SetKernelArguments(6, &in1_width, + sizeof(int)); + if (!result) { + break; + } + + result = ConcatLayerCl::kernel_concat_fp16.SetKernelArguments(7, &in2_width, + sizeof(int)); + if (!result) { + break; + } + + const int work_groups_count[3] = {dim, 1, 1}; + const int work_group_size[3] = {32, 32, 1}; // test-value + + result = context.command_queue_inst_.DispatchCommand( + ConcatLayerCl::kernel_concat_fp16, work_groups_count, work_group_size); + if (!result) { + break; + } + + result = inOutY.ReadData(context.command_queue_inst_, vecYdata); + if (!result) { + break; + } + + } while (false); +} + +void ConcatLayerCl::calcDerivative(RunLayerContext &context) { + /** + * @todo avoid copy by creating input here as a shared_tensor of the output + * here and then this layer can be in_place as well + */ + Tensor output = context.getIncomingDerivative(SINGLE_INOUT_IDX); + + output.reshape(output_reshape_helper); + unsigned int output_height_offset = 0; + unsigned int data_copy_size = output_reshape_helper.width(); + TensorDim::TensorType tensor_type = output.getTensorType(); + + for (unsigned int idx = 0; idx < context.getNumInputs(); idx++) { + Tensor &input = context.getOutgoingDerivative(idx); + const TensorDim in_dim = input.getDim(); + auto const &irh = input_reshape_helper[idx]; + input.reshape(irh); + + if (in_dim.getDataType() == TensorDim::DataType::FP32) { + /** loop over the dimensions before the concat dimension */ + for (unsigned int batch = 0; batch < output.batch(); batch++) { + /** loop over the concat dimension itself */ + for (unsigned int count = 0; count < irh.height(); count++) { + const Tensor source_tensor = Tensor::Map( + output.getAddress(batch, 0, output_height_offset + count, 0), + data_copy_size * sizeof(float), + {1, 1, 1, data_copy_size, tensor_type}); + Tensor dest_tensor = + Tensor::Map(input.getAddress(batch, 0, count, 0), + data_copy_size * sizeof(float), + {1, 1, 1, data_copy_size, tensor_type}); + dest_tensor.copy(source_tensor); + } + } + } else if (in_dim.getDataType() == TensorDim::DataType::FP16) { +#ifdef ENABLE_FP16 + /** loop over the dimensions before the concat dimension */ + for (unsigned int batch = 0; batch < output.batch(); batch++) { + /** loop over the concat dimension itself */ + for (unsigned int count = 0; count < irh.height(); count++) { + const Tensor source_tensor = Tensor::Map<_FP16>( + output.getAddress<_FP16>(batch, 0, output_height_offset + count, 0), + data_copy_size * sizeof(_FP16), + {1, 1, 1, data_copy_size, tensor_type}); + Tensor dest_tensor = + Tensor::Map<_FP16>(input.getAddress<_FP16>(batch, 0, count, 0), + data_copy_size * sizeof(_FP16), + {1, 1, 1, data_copy_size, tensor_type}); + dest_tensor.copy(source_tensor); + } + } +#else + throw std::invalid_argument("Error: enable-fp16 is not enabled"); +#endif + } + + input.reshape(in_dim); + output_height_offset += irh.height(); + } +} + +void ConcatLayerCl::setProperty(const std::vector &values) { + auto remain_props = loadProperties(values, concat_props); + NNTR_THROW_IF(!remain_props.empty(), std::invalid_argument) + << "[ConcatLayer] Unknown Layer Properties count " + + std::to_string(values.size()); +} + +void ConcatLayerCl::exportTo(Exporter &exporter, + const ml::train::ExportMethods &method) const { + Layer::exportTo(exporter, method); + exporter.saveResult(concat_props, method, this); +} + +} /* namespace nntrainer */ diff --git a/nntrainer/layers/cl_layers/concat_cl.h b/nntrainer/layers/cl_layers/concat_cl.h new file mode 100644 index 0000000000..5806f654c2 --- /dev/null +++ b/nntrainer/layers/cl_layers/concat_cl.h @@ -0,0 +1,184 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Niket Agarwal + * + * @file concat_cl.h + * @date 2 July 2024 + * @brief Implementation of Concat Layer + * @see https://github.com/nnstreamer/nntrainer + * @author Niket Agarwal + * @bug No known bugs except for NYI items + * + */ + +#ifndef __CONCAT_LAYER_CL_H__ +#define __CONCAT_LAYER_CL_H__ +#ifdef __cplusplus + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nntrainer { + +/** + * @class Concat Layer + * @brief Concat Layer + */ +class ConcatLayerCl : public Layer { +public: + /** + * @brief Constructor of Concat Layer + */ + ConcatLayerCl(); + + /** + * @brief Destructor of Concat Layer + */ + ~ConcatLayerCl() = default; + + /** + * @brief Move constructor of ConcatLayer. + * @param[in] ConcatLayer && + */ + ConcatLayerCl(ConcatLayerCl &&rhs) noexcept = default; + + /** + * @brief Move assignment operator. + * @parma[in] rhs ConcatLayer to be moved. + */ + ConcatLayerCl &operator=(ConcatLayerCl &&rhs) = default; + + /** + * @copydoc Layer::finalize(InitLayerContext &context) + */ + void finalize(InitLayerContext &context) override; + + /** + * @copydoc Layer::forwarding(RunLayerContext &context, bool training) + */ + void forwarding(RunLayerContext &context, bool training) override; + + /** + * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned + * int from, unsigned int to, bool training) + */ + void incremental_forwarding(RunLayerContext &context, unsigned int from, + unsigned int to, bool training) override; + + /** + * @copydoc Layer::calcDerivative(RunLayerContext &context) + */ + void calcDerivative(RunLayerContext &context) override; + + /** + * @copydoc Layer::getType() + */ + const std::string getType() const override { return ConcatLayerCl::type; }; + + /** + * @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods + * method) + */ + void exportTo(Exporter &exporter, + const ml::train::ExportMethods &method) const override; + + /** + * @copydoc Layer::supportBackwarding() + */ + bool supportBackwarding() const override { return false; } + + /** + * @copydoc Layer::setProperty(const PropertyType type, const std::string + * &value) + */ + void setProperty(const std::vector &values) override; + + /** + * @copydoc Layer::setBatch(RunLayerContext &context, unsigned int batch) + */ + void setBatch(RunLayerContext &context, unsigned int batch) override { + setBatch(batch); + } + + inline static const std::string type = "concat"; + + static opencl::Kernel kernel_concat; + static opencl::Kernel kernel_concat_fp16; + + /** + * @brief Process data and dimensions for concat + * @param[in] input1 Tensor + * @param[in] input2 Tensor + * @param[in] result Tensor + * @param[in] RunLayerContext reference + */ + void ConcatProcess(Tensor const &in1, Tensor const &in2, Tensor &result, + RunLayerContext &context); + + /** + * @brief concat computation + * @param[in] matAdata float * for Input Tensor A + * @param[in] vecXdata float * for Input Tensor X + * @param[in] vecYdata float * for Output Tensor Y + * @param[in] input_batch_size represents the number of samples in the input + * tensor + * @param[in] input_channels represents the channels of the input tensor + * @param[in] input_height represents the height of the input tensor + * @param[in] in1_width represents the width of the input tensor A + * @param[in] in2_width represents the width of the input tensor X + * @param[in] context RunLayerContext reference + */ + void concat_cl(const float *matAdata, const float *vecXdata, float *vecYdata, + unsigned int input_batch_size, unsigned int input_channels, + unsigned int input_height, unsigned int in1_width, + unsigned int in2_width, RunLayerContext &context); + + /** + * @brief concat computation + * @param[in] matAdata fp16 * for Input Tensor A + * @param[in] vecXdata fp16 * for Input Tensor X + * @param[in] vecYdata fp16 * for Output Tensor Y + * @param[in] input_batch_size represents the number of samples in the input + * tensor + * @param[in] input_channels represents the channels of the input tensor + * @param[in] input_height represents the height of the input tensor + * @param[in] in1_width represents the width of the input tensor A + * @param[in] in2_width represents the width of the input tensor X + * @param[in] context RunLayerContext reference + */ + void concat_cl_fp16(const __fp16 *matAdata, const __fp16 *vecXdata, + __fp16 *vecYdata, unsigned int input_batch_size, + unsigned int input_channels, unsigned int input_height, + unsigned int in1_width, unsigned int in2_width, + RunLayerContext &context); + +private: + unsigned int leading_helper_dim; /**< batch dimension of helper dimension not + containing the actual batch */ + std::vector + input_reshape_helper; /** helper dimension to reshape inputs */ + TensorDim output_reshape_helper; /** helper dimension to reshape outputs */ + std::tuple concat_props; + + /** + * @brief set batch for the internal variables + * + * @param batch update batch size + */ + void setBatch(unsigned int batch) { + for (auto &irh : input_reshape_helper) + irh.batch(batch * leading_helper_dim); + output_reshape_helper.batch(batch * leading_helper_dim); + } +}; + +} // namespace nntrainer + +#endif /* __cplusplus */ +#endif /* __CONCAT_LAYER_CL_H__ */ diff --git a/nntrainer/layers/cl_layers/meson.build b/nntrainer/layers/cl_layers/meson.build index 68622d1c23..9b3879c5f8 100644 --- a/nntrainer/layers/cl_layers/meson.build +++ b/nntrainer/layers/cl_layers/meson.build @@ -2,6 +2,7 @@ cl_layer_sources = [ 'fc_layer_cl.cpp', 'addition_layer_cl.cpp', 'swiglu_cl.cpp', + 'concat_cl.cpp', ] foreach s : cl_layer_sources diff --git a/nntrainer/layers/layer_context.cpp b/nntrainer/layers/layer_context.cpp index 143b4867b3..6bfe90f763 100644 --- a/nntrainer/layers/layer_context.cpp +++ b/nntrainer/layers/layer_context.cpp @@ -703,6 +703,10 @@ std::string RunLayerContext::getKernelName(LayerKernel layerKernel) { return "sscal_cl"; case LayerKernel::SSCAL_FP16: return "sscal_cl_fp16"; + case LayerKernel::CONCAT: + return "concat_cl"; + case LayerKernel::CONCAT_FP16: + return "concat_cl_fp16"; default: return ""; } diff --git a/nntrainer/layers/layer_context.h b/nntrainer/layers/layer_context.h index 9a621dae1f..a2598d5063 100644 --- a/nntrainer/layers/layer_context.h +++ b/nntrainer/layers/layer_context.h @@ -830,18 +830,20 @@ class RunLayerContext { * getKernelName function. */ enum LayerKernel { - SGEMV = 1 << 0, /**< placeholder for kernel name */ - DOT = 1 << 1, /**< placeholder for kernel name */ - SGEMM = 1 << 2, /**< placeholder for kernel name */ - SGEMV_FP16 = 1 << 3, /**< placeholder for kernel name */ - DOT_FP16 = 1 << 4, /**< placeholder for kernel name */ - SGEMM_FP16 = 1 << 5, /**< placeholder for kernel name */ - ADD = 1 << 6, /**< placeholder for kernel name */ - ADD_FP16 = 1 << 7, /**< placeholder for kernel name */ - SWIGLU = 1 << 8, /**< placeholder for kernel name */ - SWIGLU_FP16 = 1 << 9, /**< placeholder for kernel name */ - SSCAL = 1 << 10, /**< placeholder for kernel name */ - SSCAL_FP16 = 1 << 11, /**< placeholder for kernel name */ + SGEMV = 1 << 0, /**< placeholder for kernel name */ + DOT = 1 << 1, /**< placeholder for kernel name */ + SGEMM = 1 << 2, /**< placeholder for kernel name */ + SGEMV_FP16 = 1 << 3, /**< placeholder for kernel name */ + DOT_FP16 = 1 << 4, /**< placeholder for kernel name */ + SGEMM_FP16 = 1 << 5, /**< placeholder for kernel name */ + ADD = 1 << 6, /**< placeholder for kernel name */ + ADD_FP16 = 1 << 7, /**< placeholder for kernel name */ + SWIGLU = 1 << 8, /**< placeholder for kernel name */ + SWIGLU_FP16 = 1 << 9, /**< placeholder for kernel name */ + SSCAL = 1 << 10, /**< placeholder for kernel name */ + SSCAL_FP16 = 1 << 11, /**< placeholder for kernel name */ + CONCAT = 1 << 12, /**< placeholder for kernel name */ + CONCAT_FP16 = 1 << 13, /**< placeholder for kernel name */ }; /** diff --git a/test/jni/Android.mk b/test/jni/Android.mk index af80f48833..06c8d9aa69 100644 --- a/test/jni/Android.mk +++ b/test/jni/Android.mk @@ -442,6 +442,7 @@ LOCAL_SRC_FILES := \ ../unittest/layers/unittest_layer_node.cpp \ ../unittest/layers/unittest_layers.cpp \ ../unittest/layers/unittest_layers_impl.cpp \ + ../unittest/layers/unittest_layers_concat_cl.cpp \ ../unittest/layers/unittest_layers_swiglu_cl.cpp \ ../unittest/layers/unittest_layers_fully_connected_cl.cpp \ ../unittest/layers/unittest_layers_input.cpp \ diff --git a/test/unittest/layers/unittest_layers_concat_cl.cpp b/test/unittest/layers/unittest_layers_concat_cl.cpp new file mode 100644 index 0000000000..792b905d9a --- /dev/null +++ b/test/unittest/layers/unittest_layers_concat_cl.cpp @@ -0,0 +1,43 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Niket Agarwal + * + * @file unittest_layers_concat_cl.cpp + * @date 2 July 2024 + * @brief Concat Layer Test + * @see https://github.com/nnstreamer/nntrainer + * @author Niket Agarwal + * @bug No known bugs except for NYI items + */ +#include + +#include + +#include +#include + +auto semantic_concat_gpu = LayerSemanticsParamType( + nntrainer::createLayer, + nntrainer::ConcatLayerCl::type, {}, + LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 1); + +GTEST_PARAMETER_TEST(ConcatGPU, LayerSemanticsGpu, + ::testing::Values(semantic_concat_gpu)); + +auto concat_dim3 = LayerGoldenTestParamType( + nntrainer::createLayer, {"axis=3"}, + "2:3:3:2,2:3:3:3", "concat_dim3.nnlayergolden", + LayerGoldenTestParamOptions::SKIP_CALC_DERIV, "nchw", "fp32", "fp32"); + +GTEST_PARAMETER_TEST(ConcatGPU, LayerGoldenTest, + ::testing::Values(concat_dim3)); + +#ifdef ENABLE_FP16 +auto concat_dim3_w16a16 = LayerGoldenTestParamType( + nntrainer::createLayer, {"axis=3"}, + "2:3:3:2,2:3:3:3", "concat_dim3_w16a16.nnlayergolden", + LayerGoldenTestParamOptions::SKIP_CALC_DERIV, "nchw", "fp16", "fp16"); + +GTEST_PARAMETER_TEST(ConcatGPU16, LayerGoldenTest, + ::testing::Values(concat_dim3_w16a16)); +#endif