From ed2d27f64070dc0eed07a8c519edb83fbda80809 Mon Sep 17 00:00:00 2001 From: Niket Agarwal Date: Thu, 6 Jun 2024 16:36:23 +0530 Subject: [PATCH] [GPU/OpenCL] Initial version of SwiGLU Layer with OpenCL ops Added naive version of OpenCL implementation for SwiGLU Layer. Incorporated kernel for ops used. Added unit test for SwiGLU_layer_cl. Signed-off-by: Niket Agarwal --- api/ccapi/include/layer.h | 17 +- api/nntrainer-api-common.h | 1 + nntrainer/cl_context.cpp | 5 + nntrainer/layers/cl_layers/meson.build | 1 + nntrainer/layers/cl_layers/swiglu_cl.cpp | 272 ++++++++++++++++++ nntrainer/layers/cl_layers/swiglu_cl.h | 137 +++++++++ nntrainer/layers/layer_context.cpp | 5 + nntrainer/layers/layer_context.h | 7 +- test/input_gen/gen_layer_tests.py | 8 + test/jni/Android.mk | 3 +- .../layers/unittest_layers_swiglu_cl.cpp | 49 ++++ 11 files changed, 499 insertions(+), 6 deletions(-) create mode 100644 nntrainer/layers/cl_layers/swiglu_cl.cpp create mode 100644 nntrainer/layers/cl_layers/swiglu_cl.h create mode 100644 test/unittest/layers/unittest_layers_swiglu_cl.cpp diff --git a/api/ccapi/include/layer.h b/api/ccapi/include/layer.h index 7fcf1b06d6..81afe86ee2 100644 --- a/api/ccapi/include/layer.h +++ b/api/ccapi/include/layer.h @@ -7,6 +7,7 @@ * @see https://github.com/nnstreamer/nntrainer * @author Parichay Kapoor * @author Debadri Samaddar + * @author Niket Agarwal * @bug No known bugs except for NYI items * @brief This is layers interface for c++ API * @@ -34,9 +35,10 @@ namespace train { * @brief Enumeration of layer type */ enum LayerType { - LAYER_IN = ML_TRAIN_LAYER_TYPE_INPUT, /**< Input Layer type */ - LAYER_FC = ML_TRAIN_LAYER_TYPE_FC, /**< Fully Connected Layer type */ - LAYER_BN = ML_TRAIN_LAYER_TYPE_BN, /**< Batch Normalization Layer type */ + LAYER_IN = ML_TRAIN_LAYER_TYPE_INPUT, /**< Input Layer type */ + LAYER_FC = ML_TRAIN_LAYER_TYPE_FC, /**< Fully Connected Layer type */ + LAYER_SWIGLU = ML_TRAIN_LAYER_TYPE_SWIGLU, /**< Swiglu Layer type */ + LAYER_BN = ML_TRAIN_LAYER_TYPE_BN, /**< Batch Normalization Layer type */ LAYER_CONV2D = ML_TRAIN_LAYER_TYPE_CONV2D, /**< Convolution 2D Layer type */ LAYER_POOLING2D = ML_TRAIN_LAYER_TYPE_POOLING2D, /**< Pooling 2D Layer type */ LAYER_FLATTEN = ML_TRAIN_LAYER_TYPE_FLATTEN, /**< Flatten Layer type */ @@ -295,6 +297,15 @@ inline std::unique_ptr FullyConnected( return createLayer(LayerType::LAYER_FC, properties, compute_engine); } +/** + * @brief Helper function to create Swiglu layer + */ +inline std::unique_ptr +Swiglu(const std::vector &properties = {}, + const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) { + return createLayer(LayerType::LAYER_SWIGLU, properties, compute_engine); +} + /** * @brief Helper function to create batch normalization layer */ diff --git a/api/nntrainer-api-common.h b/api/nntrainer-api-common.h index b37a3a750d..4c762150cc 100644 --- a/api/nntrainer-api-common.h +++ b/api/nntrainer-api-common.h @@ -63,6 +63,7 @@ typedef enum { ML_TRAIN_LAYER_TYPE_POSITIONAL_ENCODING = 28, /**< Positional Encoding Layer type (Since 7.0) */ ML_TRAIN_LAYER_TYPE_IDENTITY = 29, /**< Identity Layer type (Since 8.0) */ + ML_TRAIN_LAYER_TYPE_SWIGLU = 30, /**< Swiglu Layer type */ ML_TRAIN_LAYER_TYPE_PREPROCESS_FLIP = 300, /**< Preprocess flip Layer (Since 6.5) */ ML_TRAIN_LAYER_TYPE_PREPROCESS_TRANSLATE = diff --git a/nntrainer/cl_context.cpp b/nntrainer/cl_context.cpp index b92a14ca0d..438031d586 100644 --- a/nntrainer/cl_context.cpp +++ b/nntrainer/cl_context.cpp @@ -6,6 +6,7 @@ * @date 23 Feb 2024 * @see https://github.com/nnstreamer/nntrainer * @author Debadri Samaddar + * @author Niket Agarwal * @bug No known bugs except for NYI items * @brief This file contains app context related functions and classes that * manages the global configuration of the current OpenCL environment. It also @@ -15,6 +16,7 @@ #include #include #include +#include namespace nntrainer { @@ -31,6 +33,9 @@ static void add_default_object(ClContext &cc) { cc.registerFactory(nntrainer::createLayer, AdditionLayerCL::type, ml::train::LayerType::LAYER_ADDITION); + + cc.registerFactory(nntrainer::createLayer, SwiGLULayerCl::type, + ml::train::LayerType::LAYER_SWIGLU); } static void registerer(ClContext &cc) noexcept { diff --git a/nntrainer/layers/cl_layers/meson.build b/nntrainer/layers/cl_layers/meson.build index f28b56cd55..68622d1c23 100644 --- a/nntrainer/layers/cl_layers/meson.build +++ b/nntrainer/layers/cl_layers/meson.build @@ -1,6 +1,7 @@ cl_layer_sources = [ 'fc_layer_cl.cpp', 'addition_layer_cl.cpp', + 'swiglu_cl.cpp', ] foreach s : cl_layer_sources diff --git a/nntrainer/layers/cl_layers/swiglu_cl.cpp b/nntrainer/layers/cl_layers/swiglu_cl.cpp new file mode 100644 index 0000000000..ed4e65bb5e --- /dev/null +++ b/nntrainer/layers/cl_layers/swiglu_cl.cpp @@ -0,0 +1,272 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * + * @file swiglu_cl.cpp + * @date 6th June 2024 + * @brief Implementation of SwiGLU activation function + * @see https://github.com/nnstreamer/nntrainer + * @author Niket Agarwal + * @bug No known bugs except for NYI items + * + */ + +#include "swiglu_cl.h" +#include + +std::string swiglu_cl_kernel_fp16_ = + R"( + #pragma OPENCL EXTENSION cl_khr_fp16 : enable + __kernel void swiglu_cl_fp16(__global const half *in1, __global const half *in2, __global half *out) { + int i = get_global_id(0); + half swish = in1[i] * exp(in1[i]) / (1 + exp(in1[i])); + out[i] = swish * in2[i]; +})"; + +std::string swiglu_cl_kernel_ = + R"(__kernel void swiglu_cl(__global const float *in1, __global const float *in2, __global float *out) { + int i = get_global_id(0); + float swish = in1[i] * exp(in1[i]) / (1 + exp(in1[i])); + out[i] = swish * in2[i]; +})"; + +namespace nntrainer { + +static constexpr size_t OUT_IDX = 0; +static constexpr size_t INPUT_IDX_1 = 0; +static constexpr size_t INPUT_IDX_2 = 1; + +void SwiGLULayerCl::finalize(nntrainer::InitLayerContext &context) { + context.setOutputDimensions({context.getInputDimensions()[0]}); +} + +void SwiGLULayerCl::forwarding(RunLayerContext &context, bool training) { + Tensor &in1 = context.getInput(INPUT_IDX_1); + Tensor &in2 = context.getInput(INPUT_IDX_2); + Tensor &out = context.getOutput(OUT_IDX); + swigluProcess(in1, in2, out, context); +} + +void SwiGLULayerCl::incremental_forwarding(RunLayerContext &context, + unsigned int from, unsigned int to, + bool training) { + Tensor &in1 = context.getInput(INPUT_IDX_1); + Tensor &in2 = context.getInput(INPUT_IDX_2); + Tensor &out = context.getOutput(OUT_IDX); + + if (from) { + NNTR_THROW_IF(to - from != 1, std::invalid_argument) + << "incremental step size is not 1"; + from = 0; + to = 1; + } + + swigluProcess(in1, in2, out, context); +} + +opencl::Kernel SwiGLULayerCl::kernel_swiglu; +opencl::Kernel SwiGLULayerCl::kernel_swiglu_fp16; + +void SwiGLULayerCl::swigluProcess(Tensor const &in1, Tensor const &in2, + Tensor &result, RunLayerContext &context) { + + unsigned int dim1, dim2; + dim1 = in1.batch() * in1.channel() * in1.height(); + dim2 = in1.width(); + + if (in1.getDataType() == ml::train::TensorDim::DataType::FP32) { + const float *data1 = in1.getData(); + const float *data2 = in2.getData(); + float *rdata = result.getData(); + swiglu_cl(data1, data2, rdata, dim1, dim2, context); + } else if (in1.getDataType() == ml::train::TensorDim::DataType::FP16) { +#ifdef ENABLE_FP16 + const _FP16 *data1 = in1.getData<_FP16>(); + const _FP16 *data2 = in2.getData<_FP16>(); + _FP16 *rdata = result.getData<_FP16>(); + swiglu_cl_fp16(data1, data2, rdata, dim1, dim2, context); +#else + throw std::invalid_argument("Error: enable-fp16 is not enabled"); +#endif + } +} + +void SwiGLULayerCl::swiglu_cl(const float *matAdata, const float *vecXdata, + float *vecYdata, unsigned int dim1, + unsigned int dim2, RunLayerContext &context) { + + bool result = false; + + do { + result = + context.clCreateKernel(swiglu_cl_kernel_, context.LayerKernel::SWIGLU, + SwiGLULayerCl::kernel_swiglu); + if (!result) { + break; + } + + int dim = int(dim1 * dim2); + opencl::Buffer inputA(context.context_inst_, sizeof(float) * dim1 * dim2, true, + nullptr); + + opencl::Buffer inputX(context.context_inst_, sizeof(float) * dim1 * dim2, true, + nullptr); + + opencl::Buffer inOutY(context.context_inst_, sizeof(float) * dim1 * dim2, true, + nullptr); + + result = inputA.WriteData(context.command_queue_inst_, matAdata); + if (!result) { + break; + } + + result = inputX.WriteData(context.command_queue_inst_, vecXdata); + if (!result) { + break; + } + + result = inOutY.WriteData(context.command_queue_inst_, vecYdata); + if (!result) { + break; + } + + result = SwiGLULayerCl::kernel_swiglu.SetKernelArguments(0, &inputA, + sizeof(cl_mem)); + if (!result) { + break; + } + + result = SwiGLULayerCl::kernel_swiglu.SetKernelArguments(1, &inputX, + sizeof(cl_mem)); + if (!result) { + break; + } + + result = SwiGLULayerCl::kernel_swiglu.SetKernelArguments(2, &inOutY, + sizeof(cl_mem)); + if (!result) { + break; + } + + const int work_groups_count[3] = {dim, 1, 1}; + const int work_group_size[3] = {32, 32, 1}; // test-value + + result = context.command_queue_inst_.DispatchCommand( + SwiGLULayerCl::kernel_swiglu, work_groups_count, work_group_size); + if (!result) { + break; + } + + result = inOutY.ReadData(context.command_queue_inst_, vecYdata); + if (!result) { + break; + } + + } while (false); +} + +void SwiGLULayerCl::swiglu_cl_fp16(const __fp16 *matAdata, + const __fp16 *vecXdata, __fp16 *vecYdata, + unsigned int dim1, unsigned int dim2, + RunLayerContext &context) { + + bool result = false; + + do { + result = context.clCreateKernel(swiglu_cl_kernel_fp16_, + context.LayerKernel::SWIGLU_FP16, + SwiGLULayerCl::kernel_swiglu_fp16); + if (!result) { + break; + } + + int dim = int(dim1 * dim2); + opencl::Buffer inputA(context.context_inst_, sizeof(__fp16) * dim1 * dim2, true, + nullptr); + + opencl::Buffer inputX(context.context_inst_, sizeof(__fp16) * dim1 * dim2, true, + nullptr); + + opencl::Buffer inOutY(context.context_inst_, sizeof(__fp16) * dim1 * dim2, true, + nullptr); + + result = inputA.WriteData(context.command_queue_inst_, matAdata); + if (!result) { + break; + } + + result = inputX.WriteData(context.command_queue_inst_, vecXdata); + if (!result) { + break; + } + + result = inOutY.WriteData(context.command_queue_inst_, vecYdata); + if (!result) { + break; + } + + result = SwiGLULayerCl::kernel_swiglu_fp16.SetKernelArguments( + 0, &inputA, sizeof(cl_mem)); + if (!result) { + break; + } + + result = SwiGLULayerCl::kernel_swiglu_fp16.SetKernelArguments( + 1, &inputX, sizeof(cl_mem)); + if (!result) { + break; + } + + result = SwiGLULayerCl::kernel_swiglu_fp16.SetKernelArguments( + 2, &inOutY, sizeof(cl_mem)); + if (!result) { + break; + } + + const int work_groups_count[3] = {dim, 1, 1}; + const int work_group_size[3] = {32, 32, 1}; // test-value + + result = context.command_queue_inst_.DispatchCommand( + SwiGLULayerCl::kernel_swiglu_fp16, work_groups_count, work_group_size); + if (!result) { + break; + } + + result = inOutY.ReadData(context.command_queue_inst_, vecYdata); + if (!result) { + break; + } + + } while (false); +} + +void SwiGLULayerCl::calcDerivative(nntrainer::RunLayerContext &context) { + std::throw_with_nested(std::runtime_error("Training is not supported yet.")); +} + +void SwiGLULayerCl::setProperty(const std::vector &values) { + auto remain_props = loadProperties(values, swiglu_props); + if (!remain_props.empty()) { + std::string msg = "[SwigluLayerCl] Unknown Layer Properties count " + + std::to_string(values.size()); + throw exception::not_supported(msg); + } +} + +#ifdef PLUGGABLE + +Layer *create_swiglu_layer_cl() { + auto layer = new SwiGLULayerCl(); + return layer; +} + +void destroy_swiglu_layer_cl(Layer *layer) { + delete layer; +} + +extern "C" { +LayerPluggable ml_train_layer_pluggable{create_swiglu_layer_cl, + destroy_swiglu_layer_cl}; +} + +#endif +} // namespace nntrainer diff --git a/nntrainer/layers/cl_layers/swiglu_cl.h b/nntrainer/layers/cl_layers/swiglu_cl.h new file mode 100644 index 0000000000..3001c527ff --- /dev/null +++ b/nntrainer/layers/cl_layers/swiglu_cl.h @@ -0,0 +1,137 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Niket Agarwal + * + * @file swiglu_cl.h + * @date 6th June 2024 + * @brief Implementation of SwiGLU activation function + * @see https://github.com/nnstreamer/nntrainer + * @author Niket Agarwal + * @bug No known bugs except for NYI items + * + */ + +#ifndef __SWIGLU_LAYER_CL_H__ +#define __SWIGLU_LAYER_CL_H__ + +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace nntrainer { + +/** + * @brief A SwiGLU layer + * + */ +class SwiGLULayerCl final : public Layer { +public: + /** + * @brief Construct a new SwiGLU layer object + * + */ + SwiGLULayerCl() : Layer(), swiglu_props(props::Print()) {} + + /** + * @brief Destroy the SwiGLU layer object + * + */ + ~SwiGLULayerCl() {} + + /** + * @copydoc Layer::finalize(InitLayerContext &context) + */ + void finalize(InitLayerContext &context) override; + + /** + * @copydoc Layer::forwarding(RunLayerContext &context, bool training) + */ + void forwarding(RunLayerContext &context, bool training) override; + + /** + * @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned + * int from, unsigned int to, bool training) + */ + void incremental_forwarding(RunLayerContext &context, unsigned int from, + unsigned int to, bool training) override; + + /** + * @copydoc Layer::calcDerivative(RunLayerContext &context) + */ + void calcDerivative(RunLayerContext &context) override; + + /** + * @copydoc bool supportBackwarding() const + */ + bool supportBackwarding() const override { return true; }; + + /** + * @copydoc Layer::exportTo(Exporter &exporter, ExportMethods method) + */ + void exportTo(Exporter &exporter, + const ml::train::ExportMethods &method) const override {}; + + /** + * @copydoc Layer::getType() + */ + const std::string getType() const override { return SwiGLULayerCl::type; }; + + /** + * @copydoc Layer::setProperty(const std::vector &values) + */ + void setProperty(const std::vector &values) override; + + inline static const std::string type = "swiglu"; + + static opencl::Kernel kernel_swiglu; + static opencl::Kernel kernel_swiglu_fp16; + + std::tuple swiglu_props; /**< swiglu layer properties : unit - + number of output neurons */ + + /** + * @brief Process data and dimensions for swiglu operation + * @param[in] input1 Tensor + * @param[in] input2 Tensor + * @param[in] result Tensor + * @param[in] RunLayerContext reference + */ + void swigluProcess(Tensor const &in1, Tensor const &in2, Tensor &result, + RunLayerContext &context); + + /** + * @brief swiglu computation + * @param[in] matAdata float * for Input Vector A + * @param[in] vecXdata float * for Input Vector X + * @param[in] vecYdata float * for Output Vector Y + * @param[in] dim1 number of elements in input vector A + * @param[in] dim1 number of elements in input vector X + * @param[in] context RunLayerContext reference + */ + void swiglu_cl(const float *matAdata, const float *vecXdata, float *vecYdata, + unsigned int dim1, unsigned int dim2, + RunLayerContext &context); + + /** + * @brief fp16 swiglu computation + * @param[in] matAdata fp16 * for Input Vector A + * @param[in] vecXdata fp16 * for Input Vector X + * @param[in] vecYdata fp16 * for Output Vector Y + * @param[in] dim1 number of elements in input vector A + * @param[in] dim1 number of elements in input vector X + * @param[in] context RunLayerContext reference + */ + void swiglu_cl_fp16(const __fp16 *matAdata, const __fp16 *vecXdata, + __fp16 *vecYdata, unsigned int dim1, unsigned int dim2, + RunLayerContext &context); +}; + +} // namespace nntrainer + +#endif /* __SWIGLU_LAYER_CL_H__ */ diff --git a/nntrainer/layers/layer_context.cpp b/nntrainer/layers/layer_context.cpp index 798f8a1a5e..b959a0af20 100644 --- a/nntrainer/layers/layer_context.cpp +++ b/nntrainer/layers/layer_context.cpp @@ -7,6 +7,7 @@ * @see https://github.com/nnstreamer/nntrainer * @author Parichay Kapoor * @author Debadri Samaddar + * @author Niket Agarwal * @bug No known bugs except for NYI items * @brief This is the layer context for each layer */ @@ -694,6 +695,10 @@ std::string RunLayerContext::getKernelName(LayerKernel layerKernel) { return "addition_cl"; case LayerKernel::ADD_FP16: return "addition_cl_fp16"; + case LayerKernel::SWIGLU: + return "swiglu_cl"; + case LayerKernel::SWIGLU_FP16: + return "swiglu_cl_fp16"; default: return ""; } diff --git a/nntrainer/layers/layer_context.h b/nntrainer/layers/layer_context.h index 842789f6eb..fc0ee91f49 100644 --- a/nntrainer/layers/layer_context.h +++ b/nntrainer/layers/layer_context.h @@ -7,6 +7,7 @@ * @see https://github.com/nnstreamer/nntrainer * @author Parichay Kapoor * @author Debadri Samaddar + * @author Niket Agarwal * @bug No known bugs except for NYI items * @brief This is the layer context for each layer */ @@ -835,8 +836,10 @@ class RunLayerContext { SGEMV_FP16 = 1 << 3, /**< placeholder for kernel name */ DOT_FP16 = 1 << 4, /**< placeholder for kernel name */ SGEMM_FP16 = 1 << 5, /**< placeholder for kernel name */ - ADD = 1 << 6, /**< placeholder for kernel name */ - ADD_FP16 = 1 << 7 /**< placeholder for kernel name */ + ADD = 1 << 6, /**< placeholder for kernel name */ + ADD_FP16 = 1 << 7, /**< placeholder for kernel name */ + SWIGLU = 1 << 8, /**< placeholder for kernel name */ + SWIGLU_FP16 = 1 << 9 /**< placeholder for kernel name */ }; /** diff --git a/test/input_gen/gen_layer_tests.py b/test/input_gen/gen_layer_tests.py index cf8e713983..99017d071f 100644 --- a/test/input_gen/gen_layer_tests.py +++ b/test/input_gen/gen_layer_tests.py @@ -18,6 +18,7 @@ @author Jihoon Lee @author Sungsik Kong @author Debadri Samaddar +@author Niket Agarwal """ import warnings @@ -889,3 +890,10 @@ def swiglu(inputs): added = K.layers.Add() record_single(added, [(3, 4, 3, 4), (3, 4, 3, 4)], "added_w32a32_2") + + record_single_fp16( + swiglu_layer, + [(2, 3, 3, 3), (2, 3, 3, 3)], + "swiglu", + input_type="float", + ) diff --git a/test/jni/Android.mk b/test/jni/Android.mk index 963beb3b01..2e947e5289 100644 --- a/test/jni/Android.mk +++ b/test/jni/Android.mk @@ -441,9 +441,10 @@ LOCAL_SRC_FILES := \ ../unittest/layers/unittest_layer_node.cpp \ ../unittest/layers/unittest_layers.cpp \ ../unittest/layers/unittest_layers_impl.cpp \ + ../unittest/layers/unittest_layers_swiglu_cl.cpp \ + ../unittest/layers/unittest_layers_fully_connected_cl.cpp \ ../unittest/layers/unittest_layers_input.cpp \ ../unittest/layers/unittest_layers_loss.cpp \ - ../unittest/layers/unittest_layers_fully_connected_cl.cpp \ ../unittest/layers/unittest_layers_fully_connected.cpp \ ../unittest/layers/unittest_layers_batch_normalization.cpp \ ../unittest/layers/unittest_layers_layer_normalization.cpp \ diff --git a/test/unittest/layers/unittest_layers_swiglu_cl.cpp b/test/unittest/layers/unittest_layers_swiglu_cl.cpp new file mode 100644 index 0000000000..7e0e0998e4 --- /dev/null +++ b/test/unittest/layers/unittest_layers_swiglu_cl.cpp @@ -0,0 +1,49 @@ +// SPDX-License-Identifier: Apache-2.0 +/** + * Copyright (C) 2024 Niket Agarwal + * + * @file unittest_layers_swiglu_cl.cpp + * @date 6th June 2024 + * @brief Swiglu Layer Test + * @see https://github.com/nnstreamer/nntrainer + * @author Niket Agarwal + * @bug No known bugs except for NYI items + */ +#include + +#include + +#include +#include + +auto semantic_swiglu_gpu = LayerSemanticsParamType( + nntrainer::createLayer, + nntrainer::SwiGLULayerCl::type, {}, + LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 1); + +GTEST_PARAMETER_TEST(SwigluGPU, LayerSemanticsGpu, + ::testing::Values(semantic_swiglu_gpu)); + +auto swiglu_basic_plain = + LayerGoldenTestParamType(nntrainer::createLayer, {}, + "2:3:3:3,2:3:3:3", "swiglu.nnlayergolden", + LayerGoldenTestParamOptions::SKIP_CALC_DERIV | + LayerGoldenTestParamOptions::SKIP_CALC_GRAD | + LayerGoldenTestParamOptions::USE_INC_FORWARD, + "nchw", "fp32", "fp32"); + +GTEST_PARAMETER_TEST(SwigluGPU, LayerGoldenTest, + ::testing::Values(swiglu_basic_plain)); + +#ifdef ENABLE_FP16 +auto swiglu_basic_plain_w16a16 = + LayerGoldenTestParamType(nntrainer::createLayer, {}, + "2:3:3:3,2:3:3:3", "swiglufp16.nnlayergolden", + LayerGoldenTestParamOptions::SKIP_CALC_DERIV | + LayerGoldenTestParamOptions::SKIP_CALC_GRAD | + LayerGoldenTestParamOptions::USE_INC_FORWARD, + "nchw", "fp16", "fp16"); + +GTEST_PARAMETER_TEST(SwigluGPU16, LayerGoldenTest, + ::testing::Values(swiglu_basic_plain_w16a16)); +#endif