diff --git a/nntrainer/npu/qnn/LLaMAPackage/config/LLaMAOpPackageHtp.xml b/nntrainer/npu/qnn/LLaMAPackage/config/LLaMAOpPackageHtp.xml new file mode 100755 index 000000000..259f786ef --- /dev/null +++ b/nntrainer/npu/qnn/LLaMAPackage/config/LLaMAOpPackageHtp.xml @@ -0,0 +1,1792 @@ + + + + + + + LLaMASuperSiLU + + + fused SiLU function + + + + + + + in[0] + + input activation + + true + BACKEND_SPECIFIC + + 4D + NHWC + [N, C, H , W] + + + + + in[1] + + input activation + + true + BACKEND_SPECIFIC + + 4D + NHWC + [N, C, H , W] + + + + + out[0] + + output activation + + true + BACKEND_SPECIFIC + + 4D + [N, C, H , W] + + + + + a_scale + true + QNN_DATATYPE_FLOAT_32 + + SCALAR + + N-1 + + + + b_scale + true + QNN_DATATYPE_FLOAT_32 + + SCALAR + + N-1 + + + + o_scale + true + QNN_DATATYPE_FLOAT_32 + + SCALAR + + N-1 + + + + HTP + + + + SiLU + + + Applies the Sigmoid Linear Unit (SiLU) function, element-wise. The SiLU function is also known as the swish function. + + + + + + + in[0] + + input activation + + true + BACKEND_SPECIFIC + + 4D + NHWC + [N, C, H , W] + + + + + out[0] + + output activation + + true + BACKEND_SPECIFIC + + 4D + [N, C, H , W] + + + + + HTP + + + + LLaMAReLU + + + LLaMA ReLU + + + + + in[0] + + input activation + + true + BACKEND_SPECIFIC + + 4D + NHWC + [N, C, H , W] + + + + + out[0] + + output activation + + true + BACKEND_SPECIFIC + + 4D + [N, C, H , W] + + + + + HTP + + + + LLaMALinear + + + LLaMA Linear + + + + + in[0] + + input activation + + true + BACKEND_SPECIFIC + + 4D + NHWC + [N, C, H , W] + + + + + in[1] + + weights + + true + BACKEND_SPECIFIC + + 4D + NHWC + [N, C, H , W] + + + + + in[2] + + bias + + true + BACKEND_SPECIFIC + + 4D + NHWC + [N, C, H , W] + + + + + out[0] + + output activation + + true + BACKEND_SPECIFIC + + 4D + [N, C, H , W] + + + + + in_scale + true + QNN_DATATYPE_FLOAT_32 + + SCALAR + + N-1 + + + + weight_scale + true + QNN_DATATYPE_FLOAT_32 + + SCALAR + + N-1 + + + + bias_scale + true + QNN_DATATYPE_FLOAT_32 + + SCALAR + + N-1 + + + + output_scale + true + QNN_DATATYPE_FLOAT_32 + + SCALAR + + N-1 + + + + HTP + + + + + Attention + + + Allows the model to jointly attend to information from different representation subspaces as described in the paper: Attention Is All You Need. + + + + + in[0] + + input activation + + true + BACKEND_SPECIFIC + + 4D + NHWC + [BATCH, HEAD, SEQ, EMB] + + + + + in[1] + + attention mask + + true + BACKEND_SPECIFIC + + 4D + NHWC + [BATCH, SEQ] + + + + + in[2] + + Q + + true + BACKEND_SPECIFIC + + 4D + [HEAD, EMB, EMB] + + + + + in[3] + + K + + true + BACKEND_SPECIFIC + + 4D + [HEAD, EMB, EMB] + + + + + in[4] + + V + + true + BACKEND_SPECIFIC + + 4D + [HEAD, EMB, EMB] + + + + + + out[0] + + The output activation + + + + + + true + BACKEND_SPECIFIC + + 4D + NHWC + [BATCH, HEAD, SEQ, EMB] + + + + + HTP + + + + QLayerNorm + + + LayerNorm QFP version + + + + + in[0] + + input activation + + true + BACKEND_SPECIFIC + + 4D + NHWC + [N, C, H , W] + + + + + weights + + LayerNorm weights + + true + BACKEND_SPECIFIC + + 1D + [ EMB] + + + + + bias + + LayerNorm weights + + true + BACKEND_SPECIFIC + + 1D + [ EMB] + + + + + out[0] + + output activation + + true + BACKEND_SPECIFIC + + 4D + [N, C, H , W] + + + + + HTP + + + + RMSNorm + + + LLaMA RMSNorm + + + + + in[0] + + input activation + + true + BACKEND_SPECIFIC + + 4D + NHWC + [N, C, H , W] + + + + + weights + + RMSNorm weights + + true + BACKEND_SPECIFIC + + 1D + [ EMB] + + + + + out[0] + + output activation + + true + BACKEND_SPECIFIC + + 4D + [N, C, H , W] + + + + + HTP + + + + RoPE + + + LLaMA RoPE + + + + + in[0] + + input activation + + true + BACKEND_SPECIFIC + + 4D + NHWC + [N, C, H , W] + + + + + sin + + RoPE sin weights + + true + BACKEND_SPECIFIC + + 2D + [ 16384, hidden state ] + + + + + cos + + RoPE cos weights + + true + BACKEND_SPECIFIC + + 2D + [ 16384, hidden state ] + + + + + h_cnt + + h_cnt + + true + BACKEND_SPECIFIC + + SCALAR + + N-1 + + + + + out[0] + + output activation + + true + BACKEND_SPECIFIC + + 4D + [N, C, H , W] + + + + + pose_type + false + QNN_DATATYPE_UINT_32 + + SCALAR + + N-1 + + + + HTP + + + + IRoPE + + + LLaMA IRoPE + + + + + in[0] + + input activation + + true + BACKEND_SPECIFIC + + 4D + NHWC + [N, C, H , W] + + + + + sin + + RoPE sin weights + + true + BACKEND_SPECIFIC + + 2D + [ 16384, hidden state ] + + + + + cos + + RoPE cos weights + + true + BACKEND_SPECIFIC + + 2D + [ 16384, hidden state ] + + + + + h_cnt + + h_cnt + + true + BACKEND_SPECIFIC + + SCALAR + + N-1 + + + + + out[0] + + output activation + + true + BACKEND_SPECIFIC + + 4D + [N, C, H , W] + + + + + pose_type + false + QNN_DATATYPE_UINT_32 + + SCALAR + + N-1 + + + + HTP + + + + LLaMADequantize + + + LLaMA Dequantize + + + + + in[0] + + input activation + + true + BACKEND_SPECIFIC + + 4D + NHWC + [N, C, H , W] + + + + + out[0] + + output activation + + true + BACKEND_SPECIFIC + + 4D + [N, C, H , W] + + + + + scale + true + QNN_DATATYPE_FLOAT_32 + + SCALAR + + N-1 + + + + HTP + + + + LLaMAQuantize + + + LLaMA Quantize + + + + + in[0] + + input activation + + true + BACKEND_SPECIFIC + + 4D + NHWC + [N, C, H , W] + + + + + out[0] + + output activation + + true + BACKEND_SPECIFIC + + 4D + [N, C, H , W] + + + + + scale + true + QNN_DATATYPE_FLOAT_32 + + SCALAR + + N-1 + + + + HTP + + + + CausalMask + + + LLaMA CausalMask + + + + + in[0] + + input activation + + true + BACKEND_SPECIFIC + + 4D + NHWC + [N, C, H , W] + + + + + out[0] + + output activation + + true + BACKEND_SPECIFIC + + 4D + [N, C, H , W] + + + + HTP + + + + + HeadMatmul + + + LLaMA HeadMatmul + + + + + in[0] + + X activation + + true + BACKEND_SPECIFIC + + 4D + NHWC + [N, C, H , W] + + + + + in[1] + + Y activation + + true + BACKEND_SPECIFIC + + 4D + NHWC + [N, C, H , W] + + + + + out[0] + + output activation + + true + BACKEND_SPECIFIC + + 4D + [N, C, H , W] + + + + + transpose_in0 + false + QNN_DATATYPE_BOOL_8 + + SCALAR + + + + + transpose_in1 + false + QNN_DATATYPE_BOOL_8 + + SCALAR + + + + + + HTP + + + + LLaMAMul + + + LLaMA element-wise mul + + + + + in[0] + + X + + true + BACKEND_SPECIFIC + + 4D + NHWC + [N, C, H , W] + + + + + in[1] + + Y + + true + BACKEND_SPECIFIC + + 4D + NHWC + [N, C, H , W] + + + + + out[0] + + output activation + + true + BACKEND_SPECIFIC + + 4D + [N, C, H , W] + + + + + + HTP + + + + LLaMAAdd + + + LLaMA element-wise add + + + + + in[0] + + X + + true + BACKEND_SPECIFIC + + 4D + NHWC + [N, C, H , W] + + + + + in[1] + + Y + + true + BACKEND_SPECIFIC + + 4D + NHWC + [N, C, H , W] + + + + + out[0] + + output activation + + true + BACKEND_SPECIFIC + + 4D + [N, C, H , W] + + + + + + HTP + + + + KVCache + + + Decoder KVCache + + + + + in[0] + + new KV activation output + + true + BACKEND_SPECIFIC + + 4D + NHWC + [N, C, H , W] + + + + + seq_pos + + current output sequence position + + true + BACKEND_SPECIFIC + + 1D + [1] + + + + + out[0] + + New KVCache + + true + BACKEND_SPECIFIC + + 4D + [N, C, H , W] + + + + + hidden_dim + true + QNN_DATATYPE_UINT_32 + + SCALAR + + + + + + HTP + + + + WNop + + + CPU NPU Sync waiting op + + + + + in[0] + + sync input + + true + BACKEND_SPECIFIC + + 4D + NHWC + [1] + + + + in[1] + + sync input var + + true + BACKEND_SPECIFIC + + 1D + [N, C, H , W] + + + + + out[0] + + sync output + + true + BACKEND_SPECIFIC + + 4D + [N, C, H , W] + + + + + sync_var + + sync singnal variable + + true + BACKEND_SPECIFIC + + 1D + [1] + + + + + sync_type + true + QNN_DATATYPE_UINT_32 + + SCALAR + + + + + + HTP + + + + MergeOutput + + + Merge q k v x into one tensor for mllm + + + + + in[0] + + merge q input + + true + BACKEND_SPECIFIC + + 4D + NHWC + + + + in[1] + + merge k input + + true + BACKEND_SPECIFIC + + 4D + NHWC + + + + in[2] + + merge v input + + true + BACKEND_SPECIFIC + + 4D + NHWC + + + + in[3] + + merge x input + + true + BACKEND_SPECIFIC + + 4D + NHWC + + + + + out[0] + + merge output + + true + BACKEND_SPECIFIC + + 4D + NHWC + + + + + num + true + QNN_DATATYPE_UINT_32 + + SCALAR + + + + + HTP + + + + SplitInput + + + Split q k v into three tensors for mllm + + + + + in[0] + + merge input + + true + BACKEND_SPECIFIC + + 4D + NHWC + + + + + in[1] + + merge sequence + + true + BACKEND_SPECIFIC + + 1D + NHWC + + + + + + out[0] + + q output + + true + BACKEND_SPECIFIC + + 4D + NHWC + + + + + out[1] + + k output + + true + BACKEND_SPECIFIC + + 4D + NHWC + + + + + num + true + QNN_DATATYPE_UINT_32 + + SCALAR + + + + + HTP + + + + + + + LLaMASuperSiLU + SiLU + Attention + RMSNorm + RoPE + IRoPE + LLaMAQuantize + LLaMAMul + LLaMAAdd + LLaMAReLU + CausalMask + HeadMatmul + + + + + SiLU + + + in[0] + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + + out[0] + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + + + LLaMASuperSiLU + + + in[0] + QNN_DATATYPE_SFIXED_POINT_8 + + + in[1] + QNN_DATATYPE_SFIXED_POINT_8 + + + + + + out[0] + QNN_DATATYPE_SFIXED_POINT_8 + + + + + + LLaMAReLU + + + in[0] + QNN_DATATYPE_UFIXED_POINT_8 + + + + + out[0] + QNN_DATATYPE_UFIXED_POINT_8 + + + + + + LLaMALinear + + + in[0] + QNN_DATATYPE_UFIXED_POINT_8 + QNN_DATATYPE_SFIXED_POINT_8 + + + in[1] + QNN_DATATYPE_UFIXED_POINT_8 + QNN_DATATYPE_SFIXED_POINT_8 + + + in[2] + QNN_DATATYPE_UFIXED_POINT_8 + QNN_DATATYPE_SFIXED_POINT_8 + QNN_DATATYPE_FLOAT_32 + + + + out[0] + QNN_DATATYPE_UFIXED_POINT_8 + QNN_DATATYPE_SFIXED_POINT_8 + + + + + + + Attention + + + in[0] + QNN_DATATYPE_FLOAT_16 + + + in[1] + QNN_DATATYPE_UINT_32 + + + in[2] + QNN_DATATYPE_UFIXED_POINT_8 + + + in[3] + QNN_DATATYPE_UFIXED_POINT_8 + + + in[4] + QNN_DATATYPE_UFIXED_POINT_8 + + + + out[0] + QNN_DATATYPE_UFIXED_POINT_8 + QNN_DATATYPE_UFIXED_POINT_16 + + + + + + + QLayerNorm + + + in[0] + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + weights + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + bias + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + out[0] + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + + + + RMSNorm + + + in[0] + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + weights + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + out[0] + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + QNN_DATATYPE_SFIXED_POINT_8 + + + + + + RoPE + + + in[0] + QNN_DATATYPE_SFIXED_POINT_8 + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + sin + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + cos + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + h_cnt + QNN_DATATYPE_UINT_32 + + + + out[0] + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + + + IRoPE + + + in[0] + QNN_DATATYPE_SFIXED_POINT_8 + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + sin + QNN_DATATYPE_SFIXED_POINT_8 + + + + cos + QNN_DATATYPE_SFIXED_POINT_8 + + + + h_cnt + QNN_DATATYPE_UINT_32 + + + + out[0] + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + + + LLaMAQuantize + + + in[0] + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + out[0] + QNN_DATATYPE_SFIXED_POINT_8 + + + + + + LLaMADequantize + + + in[0] + QNN_DATATYPE_SFIXED_POINT_8 + + + out[0] + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + + + CausalMask + + + in[0] + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + out[0] + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + + + HeadMatmul + + + in[0] + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + in[1] + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + out[0] + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + + + LLaMAMul + + + in[0] + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + in[1] + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + out[0] + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + + + + LLaMAAdd + + + in[0] + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + in[1] + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + out[0] + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + + + KVCache + + + in[0] + QNN_DATATYPE_SFIXED_POINT_8 + QNN_DATATYPE_UFIXED_POINT_8 + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + seq_pos + QNN_DATATYPE_UINT_32 + + + + out[0] + QNN_DATATYPE_UFIXED_POINT_8 + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + + + WNop + + + in[0] + QNN_DATATYPE_UFIXED_POINT_8 + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + in[1] + QNN_DATATYPE_UINT_32 + + + + out[0] + QNN_DATATYPE_UFIXED_POINT_8 + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + sync_var + QNN_DATATYPE_UINT_32 + + + + + + + MergeOutput + + + in[0] + QNN_DATATYPE_UFIXED_POINT_8 + QNN_DATATYPE_SFIXED_POINT_8 + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + in[1] + QNN_DATATYPE_UFIXED_POINT_8 + QNN_DATATYPE_SFIXED_POINT_8 + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + in[2] + QNN_DATATYPE_UFIXED_POINT_8 + QNN_DATATYPE_SFIXED_POINT_8 + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + in[3] + QNN_DATATYPE_UFIXED_POINT_8 + QNN_DATATYPE_SFIXED_POINT_8 + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + out[0] + QNN_DATATYPE_UFIXED_POINT_8 + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + + + + SplitInput + + + in[0] + QNN_DATATYPE_UFIXED_POINT_8 + QNN_DATATYPE_SFIXED_POINT_8 + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + in[1] + QNN_DATATYPE_UINT_32 + + + + out[0] + QNN_DATATYPE_UFIXED_POINT_8 + QNN_DATATYPE_SFIXED_POINT_8 + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + out[1] + QNN_DATATYPE_UFIXED_POINT_8 + QNN_DATATYPE_SFIXED_POINT_8 + QNN_DATATYPE_FLOAT_16 + QNN_DATATYPE_FLOAT_32 + + + + + + + diff --git a/nntrainer/npu/qnn/LLaMAPackage/src/ops/Attention.cpp b/nntrainer/npu/qnn/LLaMAPackage/src/ops/Attention.cpp new file mode 100755 index 000000000..934815b34 --- /dev/null +++ b/nntrainer/npu/qnn/LLaMAPackage/src/ops/Attention.cpp @@ -0,0 +1,120 @@ +//============================================================================== +// Auto Generated Code for LLaMAPackage +//============================================================================== + +#include "HTP/core/constraints.h" +#include "HTP/core/op_package_feature_support.h" +#include "HTP/core/op_register_ext.h" +#include "HTP/core/optimize.h" +#include "HTP/core/simple_reg.h" +#include "QnnOpPackage.h" + +BEGIN_PKG_OP_DEFINITION(PKG_Attention); + +// op execute function declarations +template +GraphStatus attentionImpl(TensorType1 &out_0, const TensorType1 &in_0, + const TensorType1 &in_1, const TensorType &in_2, + const TensorType &in_3, const TensorType &in_4); + +// forward declaration of sample cost function +static float attentionCostFunc(const Op *op); + +/* + * method 1 for defining op, using default cost value (i.e. GLACIAL) and default + * flag (Flags::RESOURCE_HVX) syntax: DEF_PACKAGE_OP(F,OP) e.g. + * DEF_PACKAGE_OP((attentionImpl), "Attention") + */ +DEF_PACKAGE_OP((attentionImpl), "Attention") + +/* + * method 2 for defining op with specified cost value (one of GLACIAL, SNAIL, + * FAST, FREE) and provided flags syntax: + * DEF_PACKAGE_OP_AND_COST_AND_FLAGS(F,OP,COST,...) can use zero or more flags, + * FLAG options are IS_CONST, INHIBIT_CONST_PROP, RESOURCE_HVX, RESOURCE_HMX(not + * supported in external op packages) e.g. + * DEF_PACKAGE_OP_AND_COST_AND_FLAGS((attentionImpl), "Attention", SNAIL) + */ + +/* + * method 3 for defining op with cost function pointer and provided flags + * cost function pointer type: typedef float (*cost_function) (const Op * op); + * syntax: DEF_PACKAGE_OP_AND_COST_F_AND_FLAGS(F,OP,COST_F,...) + * e.g. DEF_PACKAGE_OP_AND_COST_F_AND_FLAGS((attentionImpl), "Attention", + * attentionCostFunc, Flags::RESOURCE_HVX) + */ + +/* + * optimization definitions + * need to be global in the package + * one definition per optimization + * syntax: + * DEF_PACKAGE_OPTIMIZATION(PRIORITY,MATCHCODE,CONSTRAINTCODE,REPLACECODE) + * PRIORITY predefined values include EARLY(2000), MIDDLE(3000), LATE(4000) + * HTP core provides some replacement functions for op package to use + * for more information about optimization rules, please refer to HTP core + * documentations + */ + +/* + * op parameter order definitions + * need to be global in the package + * one definition per op, and this is optional + * syntax: + * DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) + * one or more parameters can be specified for each op + * order of parameters listed determines the order of parameters passed into op + * execution functions if an op does not have a parameter order definition, + * parameter order passed into Qnn_addNode will be passed into op execution + * functions if an op has a parameter order definition, any parameter passed + * into Qnn_addNode with unlisted name will be abandoned if two or more op + * packages with the same package name will be registered, they cannot list + * conflicting parameter orders + * PARAM refers to parameter name as a string literal + * MANDATORY refers to whether this parameter is required to be provided at + * Qnn_addNode DEFAULT is used when MANDATORY is false if provided as + * Qnn_Param_t*, DEFAULT will be used for graph construction when this parameter + * is not provided at Qnn_addNode if provided as nullptr, graph construction + * will skip this parameter when this parameter is not provided at Qnn_addNode + */ + +/* execute functions for ops */ + +template +GraphStatus attentionImpl(TensorType1 &out_0, const TensorType1 &in_0, + const TensorType1 &in_1, const TensorType &in_2, + const TensorType &in_3, const TensorType &in_4) + +{ + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + return GraphStatus::Success; +} + +__attribute__((unused)) static float attentionCostFunc(const Op *op) { + /* + * add code here + * */ + + float cost = 0.0; // add cost computation here + return cost; +} + +/* At the bottom of the op file, call END_PKG_OP_DEFINITION(), + where is as BEGIN_PKG_OP_DEFINITION +*/ +END_PKG_OP_DEFINITION(PKG_Attention); diff --git a/nntrainer/npu/qnn/LLaMAPackage/src/ops/IRoPE.cpp b/nntrainer/npu/qnn/LLaMAPackage/src/ops/IRoPE.cpp new file mode 100644 index 000000000..15e8a911c --- /dev/null +++ b/nntrainer/npu/qnn/LLaMAPackage/src/ops/IRoPE.cpp @@ -0,0 +1,219 @@ +//============================================================================== +// Auto Generated Code for LLaMAPackage +//============================================================================== + +#include "HTP/core/constraints.h" +#include "HTP/core/op_package_feature_support.h" +#include "HTP/core/op_register_ext.h" +#include "HTP/core/optimize.h" +#include "HTP/core/simple_reg.h" +#include "QnnOpPackage.h" + +BEGIN_PKG_OP_DEFINITION(PKG_IRoPE); + +// op execute function declarations +template +GraphStatus iropeImpl(TensorType &out_0, const TensorType &in_0, + const TensorType &in_1, const TensorType &cos, + const TensorType1 &h_cnt, const Tensor &pose_type); + +// forward declaration of sample cost function +static float iropeCostFunc(const Op *op); + +/* + * method 1 for defining op, using default cost value (i.e. GLACIAL) and default + * flag (Flags::RESOURCE_HVX) syntax: DEF_PACKAGE_OP(F,OP) e.g. + * DEF_PACKAGE_OP((iropeImpl), "IRoPE") + */ +DEF_PACKAGE_OP((iropeImpl), "IRoPE") + +/* + * method 2 for defining op with specified cost value (one of GLACIAL, SNAIL, + * FAST, FREE) and provided flags syntax: + * DEF_PACKAGE_OP_AND_COST_AND_FLAGS(F,OP,COST,...) can use zero or more flags, + * FLAG options are IS_CONST, INHIBIT_CONST_PROP, RESOURCE_HVX, RESOURCE_HMX(not + * supported in external op packages) e.g. + * DEF_PACKAGE_OP_AND_COST_AND_FLAGS((iropeImpl), "IRoPE", SNAIL) + */ + +/* + * method 3 for defining op with cost function pointer and provided flags + * cost function pointer type: typedef float (*cost_function) (const Op * op); + * syntax: DEF_PACKAGE_OP_AND_COST_F_AND_FLAGS(F,OP,COST_F,...) + * e.g. DEF_PACKAGE_OP_AND_COST_F_AND_FLAGS((iropeImpl), "IRoPE", iropeCostFunc, Flags::RESOURCE_HVX) + */ + +/* + * optimization definitions + * need to be global in the package + * one definition per optimization + * syntax: + * DEF_PACKAGE_OPTIMIZATION(PRIORITY,MATCHCODE,CONSTRAINTCODE,REPLACECODE) + * PRIORITY predefined values include EARLY(2000), MIDDLE(3000), LATE(4000) + * HTP core provides some replacement functions for op package to use + * for more information about optimization rules, please refer to HTP core + * documentations + */ + +/* + * op parameter order definitions + * need to be global in the package + * one definition per op, and this is optional + * syntax: + * DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) + * one or more parameters can be specified for each op + * order of parameters listed determines the order of parameters passed into op + * execution functions if an op does not have a parameter order definition, + * parameter order passed into Qnn_addNode will be passed into op execution + * functions if an op has a parameter order definition, any parameter passed + * into Qnn_addNode with unlisted name will be abandoned if two or more op + * packages with the same package name will be registered, they cannot list + * conflicting parameter orders + * PARAM refers to parameter name as a string literal + * MANDATORY refers to whether this parameter is required to be provided at + * Qnn_addNode DEFAULT is used when MANDATORY is false if provided as + * Qnn_Param_t*, DEFAULT will be used for graph construction when this parameter + * is not provided at Qnn_addNode if provided as nullptr, graph construction + * will skip this parameter when this parameter is not provided at Qnn_addNode + */ +DEF_PACKAGE_PARAM_ORDER("IRoPE", "pose_type", true, nullptr) + +/* execute functions for ops */ + +template +GraphStatus iropeImpl(TensorType &out_0, const TensorType &in_0, + const TensorType &sin, const TensorType &cos, + const TensorType1 &h_cnt, const Tensor &pose_type) + +{ + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + auto pose_type_ = pose_type(0, 0, 0, 0); + auto h_cnt_ = static_cast(h_cnt(0, 0, 0, 0)); + + out_0.set_dims(in_0); + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + + uint32_t half_dimension = d_in / 2; + + auto sin_ptr = (uint8_t *)sin.raw_data_const(); + auto cos_ptr = (uint8_t *)cos.raw_data_const(); + + auto in_ptr = (uint8_t *)in_0.raw_data_const(); + + sin_ptr += half_dimension * h_cnt_; + cos_ptr += half_dimension * h_cnt_; + + // float scale_ = in_0.get_interface_scale() * sin.get_interface_scale() * + // cos.get_interface_scale(); + + if (pose_type_ == 4) { + DType dtype = out_0.get_dtype(); + + if (dtype == DType::Float32) { + + auto out_ptr = (float *)out_0.raw_data(); + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + + int partial_dimension = d_in; + for (Idx d = 0; d < partial_dimension / 2; ++d) { + int in_value = *in_ptr; + int in_value_2 = *(in_ptr + half_dimension); + + int sin_value = *(sin_ptr + d); + int cos_value = *(cos_ptr + d); + float value = (in_value - 128) * (cos_value - 128) * + cos.get_interface_scale() - + (in_value_2 - 128) * (sin_value - 128) * + sin.get_interface_scale(); + float value2 = (in_value - 128) * (sin_value - 128) * + sin.get_interface_scale() + + (in_value_2 - 128) * (cos_value - 128) * + cos.get_interface_scale(); + + *out_ptr = value; + *(out_ptr + half_dimension) = value2; + + out_ptr++; + in_ptr++; + } + + in_ptr += half_dimension; + out_ptr += half_dimension; + } + + sin_ptr += half_dimension; + cos_ptr += half_dimension; + } + } + } else if (dtype == DType::Float16) { + + auto out_ptr = (__fp16 *)out_0.raw_data(); + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + + int partial_dimension = d_in; + for (Idx d = 0; d < partial_dimension / 2; ++d) { + int in_value = *in_ptr; + int in_value_2 = *(in_ptr + half_dimension); + + int sin_value = *(sin_ptr + d); + int cos_value = *(cos_ptr + d); + float value = (in_value - 128) * (cos_value - 128) * + cos.get_interface_scale() - + (in_value_2 - 128) * (sin_value - 128) * + sin.get_interface_scale(); + float value2 = (in_value - 128) * (sin_value - 128) * + sin.get_interface_scale() + + (in_value_2 - 128) * (cos_value - 128) * + cos.get_interface_scale(); + + *out_ptr = static_cast<__fp16>(value); + *(out_ptr + half_dimension) = static_cast<__fp16>(value2); + + out_ptr++; + in_ptr++; + } + + in_ptr += half_dimension; + out_ptr += half_dimension; + } + + sin_ptr += half_dimension; + cos_ptr += half_dimension; + } + } + } + } + + return GraphStatus::Success; +} + +__attribute__((unused)) static float iropeCostFunc(const Op *op) { + /* + * add code here + * */ + + float cost = 0.0; // add cost computation here + return cost; +} + +/* At the bottom of the op file, call END_PKG_OP_DEFINITION(), + where is as BEGIN_PKG_OP_DEFINITION +*/ +END_PKG_OP_DEFINITION(PKG_IRoPE); diff --git a/nntrainer/npu/qnn/LLaMAPackage/src/ops/LLaMADequantize.cpp b/nntrainer/npu/qnn/LLaMAPackage/src/ops/LLaMADequantize.cpp new file mode 100644 index 000000000..ca181fb68 --- /dev/null +++ b/nntrainer/npu/qnn/LLaMAPackage/src/ops/LLaMADequantize.cpp @@ -0,0 +1,398 @@ +//============================================================================== +// Auto Generated Code for LLaMAPackage +//============================================================================== + +#include "HTP/core/constraints.h" +#include "HTP/core/op_package_feature_support.h" +#include "HTP/core/op_register_ext.h" +#include "HTP/core/optimize.h" +#include "HTP/core/simple_reg.h" +#include "QnnOpPackage.h" + +BEGIN_PKG_OP_DEFINITION(PKG_LLaMADequantize); + +// op execute function declarations +template +GraphStatus llamadequantizeImpl(TensorType1 &out_0, const TensorType1 &in_0, + const PlainFloatTensor &scale); + +// forward declaration of sample cost function +static float llamadequantizeCostFunc(const Op *op); + +/* + * method 1 for defining op, using default cost value (i.e. GLACIAL) and default + * flag (Flags::RESOURCE_HVX) syntax: DEF_PACKAGE_OP(F,OP) e.g. + * DEF_PACKAGE_OP((llamadequantizeImpl), + * "LLaMADequantize") + */ +DEF_PACKAGE_OP((llamadequantizeImpl), "LLaMADequantize") + +/* + * method 2 for defining op with specified cost value (one of GLACIAL, SNAIL, + * FAST, FREE) and provided flags syntax: + * DEF_PACKAGE_OP_AND_COST_AND_FLAGS(F,OP,COST,...) can use zero or more flags, + * FLAG options are IS_CONST, INHIBIT_CONST_PROP, RESOURCE_HVX, RESOURCE_HMX(not + * supported in external op packages) e.g. + * DEF_PACKAGE_OP_AND_COST_AND_FLAGS((llamadequantizeImpl), "LLaMADequantize", SNAIL) + */ + +/* + * method 3 for defining op with cost function pointer and provided flags + * cost function pointer type: typedef float (*cost_function) (const Op * op); + * syntax: DEF_PACKAGE_OP_AND_COST_F_AND_FLAGS(F,OP,COST_F,...) + * e.g. + * DEF_PACKAGE_OP_AND_COST_F_AND_FLAGS((llamadequantizeImpl), "LLaMADequantize", + * llamadequantizeCostFunc, Flags::RESOURCE_HVX) + */ + +/* + * optimization definitions + * need to be global in the package + * one definition per optimization + * syntax: + * DEF_PACKAGE_OPTIMIZATION(PRIORITY,MATCHCODE,CONSTRAINTCODE,REPLACECODE) + * PRIORITY predefined values include EARLY(2000), MIDDLE(3000), LATE(4000) + * HTP core provides some replacement functions for op package to use + * for more information about optimization rules, please refer to HTP core + * documentations + */ + +/* + * op parameter order definitions + * need to be global in the package + * one definition per op, and this is optional + * syntax: + * DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) + * one or more parameters can be specified for each op + * order of parameters listed determines the order of parameters passed into op + * execution functions if an op does not have a parameter order definition, + * parameter order passed into Qnn_addNode will be passed into op execution + * functions if an op has a parameter order definition, any parameter passed + * into Qnn_addNode with unlisted name will be abandoned if two or more op + * packages with the same package name will be registered, they cannot list + * conflicting parameter orders + * PARAM refers to parameter name as a string literal + * MANDATORY refers to whether this parameter is required to be provided at + * Qnn_addNode DEFAULT is used when MANDATORY is false if provided as + * Qnn_Param_t*, DEFAULT will be used for graph construction when this parameter + * is not provided at Qnn_addNode if provided as nullptr, graph construction + * will skip this parameter when this parameter is not provided at Qnn_addNode + */ +DEF_PACKAGE_PARAM_ORDER("LLaMADequantize", "scale", true, nullptr) + +#ifndef REFERENCE_OP +/* execute functions for ops */ +#include "hvx_internal.h" +#include "qhmath_hvx.h" +#include +#include + +#define BLOCK_SIZE (8 * 1024 / VLEN) /* vector chunks */ +#define L2FETCH_AHEAD (BLOCK_SIZE) + +static inline int32_t float_to_fp16s(float input) { + union { + int32_t i; + __fp16 f[2]; + } fp32 = {.f = {(__fp16)input, (__fp16)input}}; + return fp32.i; +} + +static HVX_INLINE_ALWAYS uint32_t float_to_bits(float x) { + union { + float f; + uint32_t i; + } fp32 = {.f = x}; + return fp32.i; +} + +/* execute functions for ops */ +int32_t qhmath_hvx_dequantize_ahf(int8_t *restrict input, + int8_t *restrict output, uint32_t size, + float scale) { + if ((input == NULL) || (output == NULL) || (size == 0)) { + return -1; + } + + HVX_Vector *iptr = (HVX_Vector *)input; + HVX_UVector *optr = (HVX_UVector *)output; + + HVX_Vector sline1p, sline1c, sline1; + HVX_Vector scale_vec; + + int32_t block, l2fetch_block; + int32_t leftover = size & 63; + int32_t vectors_in_rounddown = size / 128; // element number! + // int32_t leftover_size = leftover * sizeof(float); + + sline1p = *iptr++; + + uint32_t convert = 0x00800080; + HVX_Vector convert_vector = Q6_V_vsplat_R(convert); + + scale_vec = Q6_V_vsplat_R(float_to_fp16s(scale)); + HVX_Vector zero_v_sf = Q6_V_vzero(); + + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { + block = Q6_R_min_RR(i, BLOCK_SIZE); + l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); + + if (l2fetch_block > 0) { + l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); + } + + for (int32_t j = 0; j < block; ++j) { + sline1c = *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + HVX_VectorPair temp = Q6_Wh_vadd_VubVub(sline1, zero_v_sf); + + temp = Q6_W_vshuff_VVR(Q6_V_hi_W(temp), Q6_V_lo_W(temp), -2); + HVX_Vector sout1 = Q6_Vh_vsub_VhVh(Q6_V_lo_W(temp), convert_vector); + HVX_Vector sout2 = Q6_Vh_vsub_VhVh(Q6_V_hi_W(temp), convert_vector); + + *optr++ = Q6_Vhf_equals_Vqf16( + Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout1), scale_vec)); + *optr++ = Q6_Vhf_equals_Vqf16( + Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout2), scale_vec)); + + sline1p = sline1c; + } + } + + if (vectors_in_rounddown > 0) { + + sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + + HVX_VectorPair temp = Q6_Wh_vadd_VubVub(sline1, zero_v_sf); + + temp = Q6_W_vshuff_VVR(Q6_V_hi_W(temp), Q6_V_lo_W(temp), -2); + HVX_Vector sout1 = Q6_Vh_vsub_VhVh(Q6_V_lo_W(temp), convert_vector); + HVX_Vector sout2 = Q6_Vh_vsub_VhVh(Q6_V_hi_W(temp), convert_vector); + + *optr++ = Q6_Vhf_equals_Vqf16( + Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout1), scale_vec)); + *optr++ = Q6_Vhf_equals_Vqf16( + Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout2), scale_vec)); + } + + return 0; +} + +// Only support 128x dimension +int32_t qhmath_hvx_dequantize_af(int8_t *restrict input, + int8_t *restrict output, uint32_t size, + float scale) { + if ((input == NULL) || (output == NULL) || (size == 0)) { + return -1; + } + + HVX_Vector *iptr = (HVX_Vector *)input; + HVX_UVector *optr = (HVX_UVector *)output; + + HVX_Vector sline1p, sline1c, sline1; + HVX_Vector scale_vec; + HVX_Vector one_vec; + + int32_t block, l2fetch_block; + int32_t leftover = size & 127; + int32_t vectors_in_rounddown = size / 128; + // int32_t leftover_size = leftover * sizeof(float); + + sline1p = *iptr++; + + uint32_t convert = 0x00800080; + HVX_Vector convert_vector = Q6_V_vsplat_R(convert); + + scale_vec = Q6_V_vsplat_R(float_to_bits(scale)); + one_vec = Q6_V_vsplat_R(float_to_fp16s(1.0)); + HVX_Vector zero_v_sf = Q6_V_vzero(); + scale_vec = Q6_Vqf32_vadd_VsfVsf(scale_vec, Q6_V_vzero()); + + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { + block = Q6_R_min_RR(i, BLOCK_SIZE); + l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); + + if (l2fetch_block > 0) { + l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); + } + + for (int32_t j = 0; j < block; ++j) { + sline1c = *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + HVX_VectorPair temp = Q6_Wh_vadd_VubVub(sline1, zero_v_sf); + + temp = Q6_W_vshuff_VVR(Q6_V_hi_W(temp), Q6_V_lo_W(temp), -2); + HVX_Vector sout1 = Q6_Vh_vsub_VhVh(Q6_V_lo_W(temp), convert_vector); + HVX_Vector sout2 = Q6_Vh_vsub_VhVh(Q6_V_hi_W(temp), convert_vector); + + HVX_VectorPair result1 = + Q6_Wqf32_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout1), one_vec); + result1 = Q6_W_vshuff_VVR(Q6_V_hi_W(result1), Q6_V_lo_W(result1), -4); + + HVX_VectorPair result2 = + Q6_Wqf32_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout2), one_vec); + result2 = Q6_W_vshuff_VVR(Q6_V_hi_W(result2), Q6_V_lo_W(result2), -4); + + *optr++ = Q6_Vsf_equals_Vqf32( + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result1), scale_vec)); + *optr++ = Q6_Vsf_equals_Vqf32( + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result1), scale_vec)); + *optr++ = Q6_Vsf_equals_Vqf32( + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result2), scale_vec)); + *optr++ = Q6_Vsf_equals_Vqf32( + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result2), scale_vec)); + + sline1p = sline1c; + } + } + + if (vectors_in_rounddown > 0) { + + sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + + HVX_VectorPair temp = Q6_Wh_vadd_VubVub(sline1, zero_v_sf); + + temp = Q6_W_vshuff_VVR(Q6_V_hi_W(temp), Q6_V_lo_W(temp), -2); + HVX_Vector sout1 = Q6_Vh_vsub_VhVh(Q6_V_lo_W(temp), convert_vector); + HVX_Vector sout2 = Q6_Vh_vsub_VhVh(Q6_V_hi_W(temp), convert_vector); + + HVX_VectorPair result1 = + Q6_Wqf32_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout1), one_vec); + result1 = Q6_W_vshuff_VVR(Q6_V_hi_W(result1), Q6_V_lo_W(result1), -4); + + HVX_VectorPair result2 = + Q6_Wqf32_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout2), one_vec); + result2 = Q6_W_vshuff_VVR(Q6_V_hi_W(result2), Q6_V_lo_W(result2), -4); + + *optr++ = Q6_Vsf_equals_Vqf32( + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result1), scale_vec)); + *optr++ = Q6_Vsf_equals_Vqf32( + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result1), scale_vec)); + *optr++ = Q6_Vsf_equals_Vqf32( + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result2), scale_vec)); + *optr++ = Q6_Vsf_equals_Vqf32( + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result2), scale_vec)); + } + + return 0; +} + +template +GraphStatus llamadequantizeImpl(TensorType1 &out_0, const TensorType1 &in_0, + const PlainFloatTensor &scale) + +{ + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + + // HVX Method -- FP32 Version + out_0.set_dims(in_0); + + // NHWC + auto in_ptr = (int8_t *)in_0.raw_data_const(); + auto out_ptr = (int8_t *)out_0.raw_data(); + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + + float scale_ = scale(0, 0, 0, 0); + + size_t size = b_in * h_in * w_in * d_in; + + if (in_0.get_dtype() == DType::QUInt8 && + out_0.get_dtype() == DType::Float16) { + qhmath_hvx_dequantize_ahf(in_ptr, out_ptr, size, scale_); + } else { + qhmath_hvx_dequantize_af(in_ptr, out_ptr, size, scale_); + } + + return GraphStatus::Success; +} +#else +template +GraphStatus llamadequantizeImpl(TensorType1 &out_0, const TensorType1 &in_0, + const PlainFloatTensor &scale) + +{ + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + + // HVX Method -- FP32 Version + out_0.set_dims(in_0); + + float scale_ = scale(0, 0, 0, 0); + + auto in_ptr = (uint8_t *)in_0.raw_data_const(); + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + + if (out_0.get_dtype() == DType::Float32) { + auto out_ptr = (float *)out_0.raw_data(); + + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + for (Idx d = 0; d < d_in; d++) { + + int32_t inval = static_cast(*in_ptr++); + *out_ptr++ = (inval - 128) * scale_; + } + } + } + } + } else if (out_0.get_dtype() == DType::Float16) { + + auto out_ptr = (__fp16 *)out_0.raw_data(); + + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + for (Idx d = 0; d < d_in; d++) { + + int32_t inval = static_cast(*in_ptr++); + *out_ptr++ = (__fp16)((inval - 128) * scale_); + } + } + } + } + } + + return GraphStatus::Success; +} + +#endif + +__attribute__((unused)) static float llamadequantizeCostFunc(const Op *op) { + /* + * add code here + * */ + + float cost = 0.0; // add cost computation here + return cost; +} + +/* At the bottom of the op file, call END_PKG_OP_DEFINITION(), + where is as BEGIN_PKG_OP_DEFINITION +*/ +END_PKG_OP_DEFINITION(PKG_LLaMADequantize); diff --git a/nntrainer/npu/qnn/LLaMAPackage/src/ops/LLaMAQuantize.cpp b/nntrainer/npu/qnn/LLaMAPackage/src/ops/LLaMAQuantize.cpp new file mode 100644 index 000000000..2d0aa5503 --- /dev/null +++ b/nntrainer/npu/qnn/LLaMAPackage/src/ops/LLaMAQuantize.cpp @@ -0,0 +1,1298 @@ +//============================================================================== +// Auto Generated Code for LLaMAPackage +//============================================================================== + +#include "HTP/core/constraints.h" +#include "HTP/core/op_package_feature_support.h" +#include "HTP/core/op_register_ext.h" +#include "HTP/core/optimize.h" +#include "HTP/core/simple_reg.h" +#include "QnnOpPackage.h" + +BEGIN_PKG_OP_DEFINITION(PKG_LLaMAQuantize); + +// op execute function declarations +template +GraphStatus llamaquantizeImpl(TensorType1 &out_0, const TensorType1 &in_0, + const PlainFloatTensor &scale); + +// forward declaration of sample cost function +static float llamaquantizeCostFunc(const Op *op); + +/* + * method 1 for defining op, using default cost value (i.e. GLACIAL) and default + * flag (Flags::RESOURCE_HVX) syntax: DEF_PACKAGE_OP(F,OP) e.g. + * DEF_PACKAGE_OP((llamaquantizeImpl), "LLaMAQuantize") + */ +DEF_PACKAGE_OP((llamaquantizeImpl), "LLaMAQuantize") + +/* + * method 2 for defining op with specified cost value (one of GLACIAL, SNAIL, + * FAST, FREE) and provided flags syntax: + * DEF_PACKAGE_OP_AND_COST_AND_FLAGS(F,OP,COST,...) can use zero or more flags, + * FLAG options are IS_CONST, INHIBIT_CONST_PROP, RESOURCE_HVX, RESOURCE_HMX(not + * supported in external op packages) e.g. + * DEF_PACKAGE_OP_AND_COST_AND_FLAGS((llamaquantizeImpl), "LLaMAQuantize", SNAIL) + */ + +/* + * method 3 for defining op with cost function pointer and provided flags + * cost function pointer type: typedef float (*cost_function) (const Op * op); + * syntax: DEF_PACKAGE_OP_AND_COST_F_AND_FLAGS(F,OP,COST_F,...) + * e.g. DEF_PACKAGE_OP_AND_COST_F_AND_FLAGS((llamaquantizeImpl), "LLaMAQuantize", llamaquantizeCostFunc, + * Flags::RESOURCE_HVX) + */ + +/* + * optimization definitions + * need to be global in the package + * one definition per optimization + * syntax: + * DEF_PACKAGE_OPTIMIZATION(PRIORITY,MATCHCODE,CONSTRAINTCODE,REPLACECODE) + * PRIORITY predefined values include EARLY(2000), MIDDLE(3000), LATE(4000) + * HTP core provides some replacement functions for op package to use + * for more information about optimization rules, please refer to HTP core + * documentations + */ + +/* + * op parameter order definitions + * need to be global in the package + * one definition per op, and this is optional + * syntax: + * DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) + * one or more parameters can be specified for each op + * order of parameters listed determines the order of parameters passed into op + * execution functions if an op does not have a parameter order definition, + * parameter order passed into Qnn_addNode will be passed into op execution + * functions if an op has a parameter order definition, any parameter passed + * into Qnn_addNode with unlisted name will be abandoned if two or more op + * packages with the same package name will be registered, they cannot list + * conflicting parameter orders + * PARAM refers to parameter name as a string literal + * MANDATORY refers to whether this parameter is required to be provided at + * Qnn_addNode DEFAULT is used when MANDATORY is false if provided as + * Qnn_Param_t*, DEFAULT will be used for graph construction when this parameter + * is not provided at Qnn_addNode if provided as nullptr, graph construction + * will skip this parameter when this parameter is not provided at Qnn_addNode + */ +DEF_PACKAGE_PARAM_ORDER("LLaMAQuantize", "scale", true, nullptr) +#ifndef REFERENCE_OP + +#include "hvx_internal.h" +#include "qhmath_hvx.h" +#include +#include + +#define BLOCK_SIZE (8 * 1024 / VLEN) /* vector chunks */ +#define L2FETCH_AHEAD (BLOCK_SIZE) + +static HVX_INLINE_ALWAYS uint32_t float_to_bits(float x) { + union { + float f; + uint32_t i; + } fp32 = {.f = x}; + return fp32.i; +} + +static inline int32_t float_to_fp16s(float input) { + union { + int32_t i; + __fp16 f[2]; + } fp32 = {.f = {(__fp16)input, (__fp16)input}}; + return fp32.i; +} + +#define FP16_MANTISA 10 +#define FP16_EXPONENT_MASK 0x1f +#define FP16_EXPONENT_BIAS 0xf +#define FP16_MANTISA_MASK 0x000003ff +#define FP16_SIGN 15 +#define FP16_NEG_1 0xbc00 + +/* execute functions for ops */ +int32_t qhmath_hvx_quantize_ahf(__fp16 *restrict input, __fp16 *restrict output, + uint32_t size, float low_level, + float high_level, float scale) { + if ((input == NULL) || (output == NULL) || (size == 0)) { + return -1; + } + + HVX_Vector *iptr = (HVX_Vector *)input; + HVX_UVector *optr = (HVX_UVector *)output; + + HVX_Vector sline1p, sline1c, sline1; + HVX_Vector sline2p, sline2c, sline2; + HVX_Vector sline3p, sline3c, sline3; + HVX_Vector sline4p, sline4c, sline4; + + HVX_Vector sout1, sout2, sout3, sout4; + HVX_Vector low_level_vec, high_level_vec, scale_vec, es_vec; + int32_t block, l2fetch_block; + // int32_t leftover = size & 31; + int32_t vectors_in_rounddown = size / 64; + // int32_t leftover_size = leftover * sizeof(float); + + sline1p = *iptr++; + sline2p = *iptr++; + sline3p = *iptr++; + sline4p = *iptr++; + + HVX_Vector uintconvert = Q6_V_vsplat_R(0x80808080); + + float es = 0.5; + low_level_vec = Q6_V_vsplat_R(float_to_fp16s(low_level)); + high_level_vec = Q6_V_vsplat_R(float_to_fp16s(high_level)); + scale_vec = Q6_V_vsplat_R(float_to_fp16s(scale)); + es_vec = Q6_V_vsplat_R(float_to_fp16s(es)); + + HVX_Vector zero_v_sf = Q6_V_vzero(); + es_vec = Q6_Vqf16_vadd_VhfVhf(es_vec, zero_v_sf); + + HVX_Vector expmask = Q6_Vh_vsplat_R(FP16_EXPONENT_MASK); + HVX_Vector expbias = Q6_Vh_vsplat_R(FP16_EXPONENT_BIAS); + HVX_Vector manmask = Q6_Vh_vsplat_R(FP16_MANTISA_MASK); + HVX_Vector exp23 = Q6_Vh_vsplat_R(23 - 1); + HVX_Vector exp0 = Q6_Vh_vsplat_R(0 - 1); + HVX_Vector negone = Q6_Vh_vsplat_R(FP16_NEG_1); + HVX_Vector zero = Q6_V_vzero(); + + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { + block = Q6_R_min_RR(i, BLOCK_SIZE); + l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); + + if (l2fetch_block > 0) { + l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); + } + + for (int32_t j = 0; j < block; j += 4) { + sline1c = *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + + sout1 = Q6_Vqf16_vmpy_VhfVhf(sline1, scale_vec); + sout1 = Q6_Vqf16_vadd_Vqf16Vqf16(sout1, es_vec); + sout1 = Q6_Vhf_equals_Vqf16(sout1); + sout1 = Q6_Vhf_vmin_VhfVhf(sout1, high_level_vec); + sout1 = Q6_Vhf_vmax_VhfVhf(sout1, low_level_vec); + + { + HVX_Vector exp = Q6_Vh_vasr_VhR(sout1, FP16_MANTISA); + exp = Q6_V_vand_VV(exp, expmask); + exp = Q6_Vh_vsub_VhVh(exp, expbias); + + HVX_Vector man = Q6_Vh_vasr_VhVh(manmask, exp); + HVX_Vector manzero = Q6_V_vand_VV(sout1, man); + + HVX_Vector sign = Q6_Vh_vasr_VhR(sout1, FP16_SIGN); + HVX_Vector issignpos = Q6_Q_vcmp_eq_VhVh(sign, zero); + + HVX_Vector expgte23 = Q6_Q_vcmp_gt_VhVh(exp, exp23); + HVX_Vector expgte0 = Q6_Q_vcmp_gt_VhVh(exp, exp0); + HVX_Vector maneqzero = Q6_Q_vcmp_eq_VhVh(manzero, zero); + + HVX_Vector exppos_signneg = Q6_Vh_vadd_VhVh(sout1, man); + man = Q6_V_vnot_V(man); + HVX_Vector exppos_signpos = Q6_V_vand_VV(sout1, man); + exppos_signneg = Q6_V_vand_VV(exppos_signneg, man); + HVX_Vector shift1 = Q6_Vh_vasl_VhR(sout1, 1); + HVX_Vector iszero = Q6_Q_vcmp_eq_VhVh(shift1, zero); + + // exp >= 0 + HVX_Vector tsout1 = + Q6_V_vmux_QVV(issignpos, exppos_signpos, exppos_signneg); + tsout1 = Q6_V_vmux_QVV(maneqzero, sout1, tsout1); + + // exp < 0 (-1, 1) + HVX_Vector tsout2 = Q6_V_vmux_QVV(iszero, sout1, negone); + tsout2 = Q6_V_vmux_QVV(issignpos, zero, tsout2); + + tsout1 = Q6_V_vmux_QVV(expgte0, tsout1, tsout2); + sout1 = Q6_V_vmux_QVV(expgte23, sout1, tsout1); + } + + sout1 = Q6_Vh_equals_Vhf(sout1); + + sline2c = *iptr++; + sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)input); + + sout2 = Q6_Vqf16_vmpy_VhfVhf(sline2, scale_vec); + sout2 = Q6_Vqf16_vadd_Vqf16Vqf16(sout2, es_vec); + sout2 = Q6_Vhf_equals_Vqf16(sout2); + sout2 = Q6_Vhf_vmin_VhfVhf(sout2, high_level_vec); + sout2 = Q6_Vhf_vmax_VhfVhf(sout2, low_level_vec); + + { + HVX_Vector exp = Q6_Vh_vasr_VhR(sout2, FP16_MANTISA); + exp = Q6_V_vand_VV(exp, expmask); + exp = Q6_Vh_vsub_VhVh(exp, expbias); + + HVX_Vector man = Q6_Vh_vasr_VhVh(manmask, exp); + HVX_Vector manzero = Q6_V_vand_VV(sout2, man); + + HVX_Vector sign = Q6_Vh_vasr_VhR(sout2, FP16_SIGN); + HVX_Vector issignpos = Q6_Q_vcmp_eq_VhVh(sign, zero); + + HVX_Vector expgte23 = Q6_Q_vcmp_gt_VhVh(exp, exp23); + HVX_Vector expgte0 = Q6_Q_vcmp_gt_VhVh(exp, exp0); + HVX_Vector maneqzero = Q6_Q_vcmp_eq_VhVh(manzero, zero); + + HVX_Vector exppos_signneg = Q6_Vh_vadd_VhVh(sout2, man); + man = Q6_V_vnot_V(man); + HVX_Vector exppos_signpos = Q6_V_vand_VV(sout2, man); + exppos_signneg = Q6_V_vand_VV(exppos_signneg, man); + HVX_Vector shift1 = Q6_Vh_vasl_VhR(sout2, 1); + HVX_Vector iszero = Q6_Q_vcmp_eq_VhVh(shift1, zero); + + // exp >= 0 + HVX_Vector tsout1 = + Q6_V_vmux_QVV(issignpos, exppos_signpos, exppos_signneg); + tsout1 = Q6_V_vmux_QVV(maneqzero, sout2, tsout1); + + // exp < 0 (-1, 1) + HVX_Vector tsout2 = Q6_V_vmux_QVV(iszero, sout2, negone); + tsout2 = Q6_V_vmux_QVV(issignpos, zero, tsout2); + + tsout1 = Q6_V_vmux_QVV(expgte0, tsout1, tsout2); + sout2 = Q6_V_vmux_QVV(expgte23, sout2, tsout1); + } + + sout2 = Q6_Vh_equals_Vhf(sout2); + + sline3c = *iptr++; + sline3 = Q6_V_valign_VVR(sline3c, sline3p, (size_t)input); + + sout3 = Q6_Vqf16_vmpy_VhfVhf(sline3, scale_vec); + sout3 = Q6_Vqf16_vadd_Vqf16Vqf16(sout3, es_vec); + sout3 = Q6_Vhf_equals_Vqf16(sout3); + sout3 = Q6_Vhf_vmin_VhfVhf(sout3, high_level_vec); + sout3 = Q6_Vhf_vmax_VhfVhf(sout3, low_level_vec); + + { + HVX_Vector exp = Q6_Vh_vasr_VhR(sout3, FP16_MANTISA); + exp = Q6_V_vand_VV(exp, expmask); + exp = Q6_Vh_vsub_VhVh(exp, expbias); + + HVX_Vector man = Q6_Vh_vasr_VhVh(manmask, exp); + HVX_Vector manzero = Q6_V_vand_VV(sout3, man); + + HVX_Vector sign = Q6_Vh_vasr_VhR(sout3, FP16_SIGN); + HVX_Vector issignpos = Q6_Q_vcmp_eq_VhVh(sign, zero); + + HVX_Vector expgte23 = Q6_Q_vcmp_gt_VhVh(exp, exp23); + HVX_Vector expgte0 = Q6_Q_vcmp_gt_VhVh(exp, exp0); + HVX_Vector maneqzero = Q6_Q_vcmp_eq_VhVh(manzero, zero); + + HVX_Vector exppos_signneg = Q6_Vh_vadd_VhVh(sout3, man); + man = Q6_V_vnot_V(man); + HVX_Vector exppos_signpos = Q6_V_vand_VV(sout3, man); + exppos_signneg = Q6_V_vand_VV(exppos_signneg, man); + HVX_Vector shift1 = Q6_Vh_vasl_VhR(sout3, 1); + HVX_Vector iszero = Q6_Q_vcmp_eq_VhVh(shift1, zero); + + // exp >= 0 + HVX_Vector tsout1 = + Q6_V_vmux_QVV(issignpos, exppos_signpos, exppos_signneg); + tsout1 = Q6_V_vmux_QVV(maneqzero, sout3, tsout1); + + // exp < 0 (-1, 1) + HVX_Vector tsout2 = Q6_V_vmux_QVV(iszero, sout3, negone); + tsout2 = Q6_V_vmux_QVV(issignpos, zero, tsout2); + + tsout1 = Q6_V_vmux_QVV(expgte0, tsout1, tsout2); + sout3 = Q6_V_vmux_QVV(expgte23, sout3, tsout1); + } + + sout3 = Q6_Vh_equals_Vhf(sout3); + + sline4c = *iptr++; + sline4 = Q6_V_valign_VVR(sline4c, sline4p, (size_t)input); + + sout4 = Q6_Vqf16_vmpy_VhfVhf(sline4, scale_vec); + sout4 = Q6_Vqf16_vadd_Vqf16Vqf16(sout4, es_vec); + sout4 = Q6_Vhf_equals_Vqf16(sout4); + sout4 = Q6_Vhf_vmin_VhfVhf(sout4, high_level_vec); + sout4 = Q6_Vhf_vmax_VhfVhf(sout4, low_level_vec); + + { + HVX_Vector exp = Q6_Vh_vasr_VhR(sout4, FP16_MANTISA); + exp = Q6_V_vand_VV(exp, expmask); + exp = Q6_Vh_vsub_VhVh(exp, expbias); + + HVX_Vector man = Q6_Vh_vasr_VhVh(manmask, exp); + HVX_Vector manzero = Q6_V_vand_VV(sout4, man); + + HVX_Vector sign = Q6_Vh_vasr_VhR(sout4, FP16_SIGN); + HVX_Vector issignpos = Q6_Q_vcmp_eq_VhVh(sign, zero); + + HVX_Vector expgte23 = Q6_Q_vcmp_gt_VhVh(exp, exp23); + HVX_Vector expgte0 = Q6_Q_vcmp_gt_VhVh(exp, exp0); + HVX_Vector maneqzero = Q6_Q_vcmp_eq_VhVh(manzero, zero); + + HVX_Vector exppos_signneg = Q6_Vh_vadd_VhVh(sout4, man); + man = Q6_V_vnot_V(man); + HVX_Vector exppos_signpos = Q6_V_vand_VV(sout4, man); + exppos_signneg = Q6_V_vand_VV(exppos_signneg, man); + HVX_Vector shift1 = Q6_Vh_vasl_VhR(sout4, 1); + HVX_Vector iszero = Q6_Q_vcmp_eq_VhVh(shift1, zero); + + // exp >= 0 + HVX_Vector tsout1 = + Q6_V_vmux_QVV(issignpos, exppos_signpos, exppos_signneg); + tsout1 = Q6_V_vmux_QVV(maneqzero, sout4, tsout1); + + // exp < 0 (-1, 1) + HVX_Vector tsout2 = Q6_V_vmux_QVV(iszero, sout4, negone); + tsout2 = Q6_V_vmux_QVV(issignpos, zero, tsout2); + + tsout1 = Q6_V_vmux_QVV(expgte0, tsout1, tsout2); + sout4 = Q6_V_vmux_QVV(expgte23, sout4, tsout1); + } + + sout4 = Q6_Vh_equals_Vhf(sout4); + + HVX_Vector reql_h = Q6_Vb_vpack_VhVh_sat(sout2, sout1); + *optr++ = Q6_Vb_vadd_VbVb(reql_h, uintconvert); + + HVX_Vector reqh_h = Q6_Vb_vpack_VhVh_sat(sout4, sout3); + *optr++ = Q6_Vb_vadd_VbVb(reqh_h, uintconvert); + + sline1p = sline1c; + sline2p = sline2c; + sline3p = sline3c; + sline4p = sline4c; + } + } + + return 0; +} + +int32_t qhmath_hvx_quantize_ahf_int8(__fp16 *restrict input, + __fp16 *restrict output, uint32_t size, + float low_level, float high_level, + float scale) { + if ((input == NULL) || (output == NULL) || (size == 0)) { + return -1; + } + + HVX_Vector *iptr = (HVX_Vector *)input; + HVX_UVector *optr = (HVX_UVector *)output; + + HVX_Vector sline1p, sline1c, sline1; + HVX_Vector sline2p, sline2c, sline2; + HVX_Vector sline3p, sline3c, sline3; + HVX_Vector sline4p, sline4c, sline4; + + HVX_Vector sout1, sout2, sout3, sout4; + HVX_Vector low_level_vec, high_level_vec, scale_vec, es_vec; + int32_t block, l2fetch_block; + // int32_t leftover = size & 31; + int32_t vectors_in_rounddown = size / 64; + // int32_t leftover_size = leftover * sizeof(float); + + sline1p = *iptr++; + sline2p = *iptr++; + sline3p = *iptr++; + sline4p = *iptr++; + + float es = 0.5; + low_level_vec = Q6_V_vsplat_R(float_to_fp16s(low_level)); + high_level_vec = Q6_V_vsplat_R(float_to_fp16s(high_level)); + scale_vec = Q6_V_vsplat_R(float_to_fp16s(scale)); + es_vec = Q6_V_vsplat_R(float_to_fp16s(es)); + + HVX_Vector zero_v_sf = Q6_V_vzero(); + es_vec = Q6_Vqf16_vadd_VhfVhf(es_vec, zero_v_sf); + + HVX_Vector expmask = Q6_Vh_vsplat_R(FP16_EXPONENT_MASK); + HVX_Vector expbias = Q6_Vh_vsplat_R(FP16_EXPONENT_BIAS); + HVX_Vector manmask = Q6_Vh_vsplat_R(FP16_MANTISA_MASK); + HVX_Vector exp23 = Q6_Vh_vsplat_R(23 - 1); + HVX_Vector exp0 = Q6_Vh_vsplat_R(0 - 1); + HVX_Vector negone = Q6_Vh_vsplat_R(FP16_NEG_1); + HVX_Vector zero = Q6_V_vzero(); + + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { + block = Q6_R_min_RR(i, BLOCK_SIZE); + l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); + + if (l2fetch_block > 0) { + l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); + } + + for (int32_t j = 0; j < block; j += 4) { + sline1c = *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + + sout1 = Q6_Vqf16_vmpy_VhfVhf(sline1, scale_vec); + sout1 = Q6_Vqf16_vadd_Vqf16Vqf16(sout1, es_vec); + sout1 = Q6_Vhf_equals_Vqf16(sout1); + sout1 = Q6_Vhf_vmin_VhfVhf(sout1, high_level_vec); + sout1 = Q6_Vhf_vmax_VhfVhf(sout1, low_level_vec); + + { + HVX_Vector exp = Q6_Vh_vasr_VhR(sout1, FP16_MANTISA); + exp = Q6_V_vand_VV(exp, expmask); + exp = Q6_Vh_vsub_VhVh(exp, expbias); + + HVX_Vector man = Q6_Vh_vasr_VhVh(manmask, exp); + HVX_Vector manzero = Q6_V_vand_VV(sout1, man); + + HVX_Vector sign = Q6_Vh_vasr_VhR(sout1, FP16_SIGN); + HVX_Vector issignpos = Q6_Q_vcmp_eq_VhVh(sign, zero); + + HVX_Vector expgte23 = Q6_Q_vcmp_gt_VhVh(exp, exp23); + HVX_Vector expgte0 = Q6_Q_vcmp_gt_VhVh(exp, exp0); + HVX_Vector maneqzero = Q6_Q_vcmp_eq_VhVh(manzero, zero); + + HVX_Vector exppos_signneg = Q6_Vh_vadd_VhVh(sout1, man); + man = Q6_V_vnot_V(man); + HVX_Vector exppos_signpos = Q6_V_vand_VV(sout1, man); + exppos_signneg = Q6_V_vand_VV(exppos_signneg, man); + HVX_Vector shift1 = Q6_Vh_vasl_VhR(sout1, 1); + HVX_Vector iszero = Q6_Q_vcmp_eq_VhVh(shift1, zero); + + // exp >= 0 + HVX_Vector tsout1 = + Q6_V_vmux_QVV(issignpos, exppos_signpos, exppos_signneg); + tsout1 = Q6_V_vmux_QVV(maneqzero, sout1, tsout1); + + // exp < 0 (-1, 1) + HVX_Vector tsout2 = Q6_V_vmux_QVV(iszero, sout1, negone); + tsout2 = Q6_V_vmux_QVV(issignpos, zero, tsout2); + + tsout1 = Q6_V_vmux_QVV(expgte0, tsout1, tsout2); + sout1 = Q6_V_vmux_QVV(expgte23, sout1, tsout1); + } + + sout1 = Q6_Vh_equals_Vhf(sout1); + + sline2c = *iptr++; + sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)input); + + sout2 = Q6_Vqf16_vmpy_VhfVhf(sline2, scale_vec); + sout2 = Q6_Vqf16_vadd_Vqf16Vqf16(sout2, es_vec); + sout2 = Q6_Vhf_equals_Vqf16(sout2); + sout2 = Q6_Vhf_vmin_VhfVhf(sout2, high_level_vec); + sout2 = Q6_Vhf_vmax_VhfVhf(sout2, low_level_vec); + + { + HVX_Vector exp = Q6_Vh_vasr_VhR(sout2, FP16_MANTISA); + exp = Q6_V_vand_VV(exp, expmask); + exp = Q6_Vh_vsub_VhVh(exp, expbias); + + HVX_Vector man = Q6_Vh_vasr_VhVh(manmask, exp); + HVX_Vector manzero = Q6_V_vand_VV(sout2, man); + + HVX_Vector sign = Q6_Vh_vasr_VhR(sout2, FP16_SIGN); + HVX_Vector issignpos = Q6_Q_vcmp_eq_VhVh(sign, zero); + + HVX_Vector expgte23 = Q6_Q_vcmp_gt_VhVh(exp, exp23); + HVX_Vector expgte0 = Q6_Q_vcmp_gt_VhVh(exp, exp0); + HVX_Vector maneqzero = Q6_Q_vcmp_eq_VhVh(manzero, zero); + + HVX_Vector exppos_signneg = Q6_Vh_vadd_VhVh(sout2, man); + man = Q6_V_vnot_V(man); + HVX_Vector exppos_signpos = Q6_V_vand_VV(sout2, man); + exppos_signneg = Q6_V_vand_VV(exppos_signneg, man); + HVX_Vector shift1 = Q6_Vh_vasl_VhR(sout2, 1); + HVX_Vector iszero = Q6_Q_vcmp_eq_VhVh(shift1, zero); + + // exp >= 0 + HVX_Vector tsout1 = + Q6_V_vmux_QVV(issignpos, exppos_signpos, exppos_signneg); + tsout1 = Q6_V_vmux_QVV(maneqzero, sout2, tsout1); + + // exp < 0 (-1, 1) + HVX_Vector tsout2 = Q6_V_vmux_QVV(iszero, sout2, negone); + tsout2 = Q6_V_vmux_QVV(issignpos, zero, tsout2); + + tsout1 = Q6_V_vmux_QVV(expgte0, tsout1, tsout2); + sout2 = Q6_V_vmux_QVV(expgte23, sout2, tsout1); + } + + sout2 = Q6_Vh_equals_Vhf(sout2); + + sline3c = *iptr++; + sline3 = Q6_V_valign_VVR(sline3c, sline3p, (size_t)input); + + sout3 = Q6_Vqf16_vmpy_VhfVhf(sline3, scale_vec); + sout3 = Q6_Vqf16_vadd_Vqf16Vqf16(sout3, es_vec); + sout3 = Q6_Vhf_equals_Vqf16(sout3); + sout3 = Q6_Vhf_vmin_VhfVhf(sout3, high_level_vec); + sout3 = Q6_Vhf_vmax_VhfVhf(sout3, low_level_vec); + + { + HVX_Vector exp = Q6_Vh_vasr_VhR(sout3, FP16_MANTISA); + exp = Q6_V_vand_VV(exp, expmask); + exp = Q6_Vh_vsub_VhVh(exp, expbias); + + HVX_Vector man = Q6_Vh_vasr_VhVh(manmask, exp); + HVX_Vector manzero = Q6_V_vand_VV(sout3, man); + + HVX_Vector sign = Q6_Vh_vasr_VhR(sout3, FP16_SIGN); + HVX_Vector issignpos = Q6_Q_vcmp_eq_VhVh(sign, zero); + + HVX_Vector expgte23 = Q6_Q_vcmp_gt_VhVh(exp, exp23); + HVX_Vector expgte0 = Q6_Q_vcmp_gt_VhVh(exp, exp0); + HVX_Vector maneqzero = Q6_Q_vcmp_eq_VhVh(manzero, zero); + + HVX_Vector exppos_signneg = Q6_Vh_vadd_VhVh(sout3, man); + man = Q6_V_vnot_V(man); + HVX_Vector exppos_signpos = Q6_V_vand_VV(sout3, man); + exppos_signneg = Q6_V_vand_VV(exppos_signneg, man); + HVX_Vector shift1 = Q6_Vh_vasl_VhR(sout3, 1); + HVX_Vector iszero = Q6_Q_vcmp_eq_VhVh(shift1, zero); + + // exp >= 0 + HVX_Vector tsout1 = + Q6_V_vmux_QVV(issignpos, exppos_signpos, exppos_signneg); + tsout1 = Q6_V_vmux_QVV(maneqzero, sout3, tsout1); + + // exp < 0 (-1, 1) + HVX_Vector tsout2 = Q6_V_vmux_QVV(iszero, sout3, negone); + tsout2 = Q6_V_vmux_QVV(issignpos, zero, tsout2); + + tsout1 = Q6_V_vmux_QVV(expgte0, tsout1, tsout2); + sout3 = Q6_V_vmux_QVV(expgte23, sout3, tsout1); + } + + sout3 = Q6_Vh_equals_Vhf(sout3); + + sline4c = *iptr++; + sline4 = Q6_V_valign_VVR(sline4c, sline4p, (size_t)input); + + sout4 = Q6_Vqf32_vmpy_VsfVsf(sline4, scale_vec); + sout4 = Q6_Vqf16_vadd_Vqf16Vqf16(sout4, es_vec); + sout4 = Q6_Vhf_equals_Vqf16(sout4); + sout4 = Q6_Vhf_vmin_VhfVhf(sout4, high_level_vec); + sout4 = Q6_Vhf_vmax_VhfVhf(sout4, low_level_vec); + + { + HVX_Vector exp = Q6_Vh_vasr_VhR(sout4, FP16_MANTISA); + exp = Q6_V_vand_VV(exp, expmask); + exp = Q6_Vh_vsub_VhVh(exp, expbias); + + HVX_Vector man = Q6_Vh_vasr_VhVh(manmask, exp); + HVX_Vector manzero = Q6_V_vand_VV(sout4, man); + + HVX_Vector sign = Q6_Vh_vasr_VhR(sout4, FP16_SIGN); + HVX_Vector issignpos = Q6_Q_vcmp_eq_VhVh(sign, zero); + + HVX_Vector expgte23 = Q6_Q_vcmp_gt_VhVh(exp, exp23); + HVX_Vector expgte0 = Q6_Q_vcmp_gt_VhVh(exp, exp0); + HVX_Vector maneqzero = Q6_Q_vcmp_eq_VhVh(manzero, zero); + + HVX_Vector exppos_signneg = Q6_Vh_vadd_VhVh(sout4, man); + man = Q6_V_vnot_V(man); + HVX_Vector exppos_signpos = Q6_V_vand_VV(sout4, man); + exppos_signneg = Q6_V_vand_VV(exppos_signneg, man); + HVX_Vector shift1 = Q6_Vh_vasl_VhR(sout4, 1); + HVX_Vector iszero = Q6_Q_vcmp_eq_VhVh(shift1, zero); + + // exp >= 0 + HVX_Vector tsout1 = + Q6_V_vmux_QVV(issignpos, exppos_signpos, exppos_signneg); + tsout1 = Q6_V_vmux_QVV(maneqzero, sout4, tsout1); + + // exp < 0 (-1, 1) + HVX_Vector tsout2 = Q6_V_vmux_QVV(iszero, sout4, negone); + tsout2 = Q6_V_vmux_QVV(issignpos, zero, tsout2); + + tsout1 = Q6_V_vmux_QVV(expgte0, tsout1, tsout2); + sout4 = Q6_V_vmux_QVV(expgte23, sout4, tsout1); + } + + sout4 = Q6_Vh_equals_Vhf(sout4); + + HVX_Vector reql_h = Q6_Vb_vpack_VhVh_sat(sout2, sout1); + *optr++ = reql_h; + + HVX_Vector reqh_h = Q6_Vb_vpack_VhVh_sat(sout4, sout3); + *optr++ = reqh_h; + + sline1p = sline1c; + sline2p = sline2c; + sline3p = sline3c; + sline4p = sline4c; + } + } + + return 0; +} + +#define FLOAT_MANTISA 23 +#define FLOAT_EXPONENT_MASK 0xff +#define FLOAT_EXPONENT_BIAS 0x7f +#define FLOAT_MANTISA_MASK 0x007fffff +#define FLOAT_SIGN 31 +#define FLOAT_NEG_1 0xBF800000 +#define ROUND_2_SCALE 22 +#define ROUND_SCALSE ((1 << ROUND_2_SCALE) * 1.0f) + +int32_t qhmath_hvx_quantize_af(float *restrict input, float *restrict output, + uint32_t size, float low_level, float high_level, + float scale) { + if ((input == NULL) || (output == NULL) || (size == 0)) { + return -1; + } + + HVX_Vector *iptr = (HVX_Vector *)input; + HVX_UVector *optr = (HVX_UVector *)output; + + HVX_Vector sline1p, sline1c, sline1; + HVX_Vector sline2p, sline2c, sline2; + HVX_Vector sline3p, sline3c, sline3; + HVX_Vector sline4p, sline4c, sline4; + + HVX_Vector sout1, sout2, sout3, sout4; + HVX_Vector low_level_vec, high_level_vec, scale_vec, es_vec, round_scale_vec; + int32_t block, l2fetch_block; + // int32_t leftover = size & 31; + int32_t vectors_in_rounddown = size / 32; + // int32_t leftover_size = leftover * sizeof(float); + + sline1p = *iptr++; + sline2p = *iptr++; + sline3p = *iptr++; + sline4p = *iptr++; + + float es = 0.5f; + low_level_vec = Q6_V_vsplat_R(float_to_bits(low_level)); + high_level_vec = Q6_V_vsplat_R(float_to_bits(high_level)); + scale_vec = Q6_V_vsplat_R(float_to_bits(scale)); + es_vec = Q6_V_vsplat_R(float_to_bits(es)); + round_scale_vec = Q6_V_vsplat_R(float_to_bits(ROUND_SCALSE)); + + HVX_Vector zero_v_sf = Q6_V_vzero(); + es_vec = Q6_Vqf32_vadd_VsfVsf(es_vec, zero_v_sf); + + HVX_Vector uintconvert = Q6_V_vsplat_R(0x80808080); + + // HVX_Vector expmask = Q6_V_vsplat_R(FLOAT_EXPONENT_MASK); + // HVX_Vector expbias = Q6_V_vsplat_R(FLOAT_EXPONENT_BIAS); + // HVX_Vector manmask = Q6_V_vsplat_R(FLOAT_MANTISA_MASK); + // HVX_Vector exp23 = Q6_V_vsplat_R(23 - 1); + // HVX_Vector exp0 = Q6_V_vsplat_R(0 - 1); + // HVX_Vector negone = Q6_V_vsplat_R(FLOAT_NEG_1); + // HVX_Vector zero = Q6_V_vzero(); + + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { + block = Q6_R_min_RR(i, BLOCK_SIZE); + l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); + + if (l2fetch_block > 0) { + l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); + } + + for (int32_t j = 0; j < block; j += 4) { + sline1c = *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + + sout1 = Q6_Vqf32_vmpy_VsfVsf(sline1, scale_vec); + sout1 = Q6_Vqf32_vadd_Vqf32Vqf32(sout1, es_vec); + sout1 = Q6_Vsf_equals_Vqf32(sout1); + sout1 = Q6_Vsf_vmin_VsfVsf(sout1, high_level_vec); + sout1 = Q6_Vsf_vmax_VsfVsf(sout1, low_level_vec); + sout1 = Q6_Vqf32_vmpy_VsfVsf(sout1, round_scale_vec); + sout1 = Q6_Vsf_equals_Vqf32(sout1); + + // { + // HVX_Vector exp = Q6_Vw_vasr_VwR(sout1, FLOAT_MANTISA); + // exp = Q6_V_vand_VV(exp, expmask); + // exp = Q6_Vw_vsub_VwVw(exp, expbias); + + // HVX_Vector man = Q6_Vw_vasr_VwVw(manmask, exp); + // HVX_Vector manzero = Q6_V_vand_VV(sout1, man); + + // HVX_Vector sign = Q6_Vw_vasr_VwR(sout1, FLOAT_SIGN); + // HVX_Vector issignpos = Q6_Q_vcmp_eq_VwVw(sign, zero); + + // HVX_Vector expgte23 = Q6_Q_vcmp_gt_VwVw(exp, exp23); + // HVX_Vector expgte0 = Q6_Q_vcmp_gt_VwVw(exp, exp0); + // HVX_Vector maneqzero = Q6_Q_vcmp_eq_VwVw(manzero, zero); + + // HVX_Vector exppos_signneg = Q6_Vw_vadd_VwVw(sout1, man); + // man = Q6_V_vnot_V(man); + // HVX_Vector exppos_signpos = Q6_V_vand_VV(sout1, man); + // exppos_signneg = Q6_V_vand_VV(exppos_signneg, man); + // HVX_Vector shift1 = Q6_Vw_vasl_VwR(sout1, 1); + // HVX_Vector iszero = Q6_Q_vcmp_eq_VwVw(shift1, zero); + + // // exp >= 0 + // HVX_Vector tsout1 = Q6_V_vmux_QVV(issignpos, exppos_signpos, + // exppos_signneg); tsout1 = Q6_V_vmux_QVV(maneqzero, sout1, tsout1); + + // // exp < 0 (-1, 1) + // HVX_Vector tsout2 = Q6_V_vmux_QVV(iszero, sout1, negone); + // tsout2 = Q6_V_vmux_QVV(issignpos, zero, tsout2); + + // tsout1 = Q6_V_vmux_QVV(expgte0, tsout1, tsout2); + // sout1 = Q6_V_vmux_QVV(expgte23, sout1, tsout1); + // } + + sout1 = Q6_Vw_equals_Vsf(sout1); + sout1 = Q6_Vw_vasr_VwR(sout1, ROUND_2_SCALE); + // sout1 = qhmath_hvx_vw_convert_vqf32_rmode(Q6_Vqf32_vadd_VsfVsf(sout1, + // Q6_V_vzero()), 0); + + sline2c = *iptr++; + sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)input); + + sout2 = Q6_Vqf32_vmpy_VsfVsf(sline2, scale_vec); + sout2 = Q6_Vqf32_vadd_Vqf32Vqf32(sout2, es_vec); + sout2 = Q6_Vsf_equals_Vqf32(sout2); + sout2 = Q6_Vsf_vmin_VsfVsf(sout2, high_level_vec); + sout2 = Q6_Vsf_vmax_VsfVsf(sout2, low_level_vec); + sout2 = Q6_Vqf32_vmpy_VsfVsf(sout2, round_scale_vec); + sout2 = Q6_Vsf_equals_Vqf32(sout2); + + // { + // HVX_Vector exp = Q6_Vw_vasr_VwR(sout2, FLOAT_MANTISA); + // exp = Q6_V_vand_VV(exp, expmask); + // exp = Q6_Vw_vsub_VwVw(exp, expbias); + + // HVX_Vector man = Q6_Vw_vasr_VwVw(manmask, exp); + // HVX_Vector manzero = Q6_V_vand_VV(sout2, man); + + // HVX_Vector sign = Q6_Vw_vasr_VwR(sout2, FLOAT_SIGN); + // HVX_Vector issignpos = Q6_Q_vcmp_eq_VwVw(sign, zero); + + // HVX_Vector expgte23 = Q6_Q_vcmp_gt_VwVw(exp, exp23); + // HVX_Vector expgte0 = Q6_Q_vcmp_gt_VwVw(exp, exp0); + // HVX_Vector maneqzero = Q6_Q_vcmp_eq_VwVw(manzero, zero); + + // HVX_Vector exppos_signneg = Q6_Vw_vadd_VwVw(sout2, man); + // man = Q6_V_vnot_V(man); + // HVX_Vector exppos_signpos = Q6_V_vand_VV(sout2, man); + // exppos_signneg = Q6_V_vand_VV(exppos_signneg, man); + // HVX_Vector shift1 = Q6_Vw_vasl_VwR(sout2, 1); + // HVX_Vector iszero = Q6_Q_vcmp_eq_VwVw(shift1, zero); + + // // exp >= 0 + // HVX_Vector tsout1 = Q6_V_vmux_QVV(issignpos, exppos_signpos, + // exppos_signneg); tsout1 = Q6_V_vmux_QVV(maneqzero, sout2, tsout1); + + // // exp < 0 (-1, 1) + // HVX_Vector tsout2 = Q6_V_vmux_QVV(iszero, sout2, negone); + // tsout2 = Q6_V_vmux_QVV(issignpos, zero, tsout2); + + // tsout1 = Q6_V_vmux_QVV(expgte0, tsout1, tsout2); + // sout2 = Q6_V_vmux_QVV(expgte23, sout2, tsout1); + // } + + sout2 = Q6_Vw_equals_Vsf(sout2); + sout2 = Q6_Vw_vasr_VwR(sout2, ROUND_2_SCALE); + // sout2 = qhmath_hvx_vw_convert_vqf32_rmode(Q6_Vqf32_vadd_VsfVsf(sout2, + // Q6_V_vzero()), 0); + + sline3c = *iptr++; + sline3 = Q6_V_valign_VVR(sline3c, sline3p, (size_t)input); + + sout3 = Q6_Vqf32_vmpy_VsfVsf(sline3, scale_vec); + sout3 = Q6_Vqf32_vadd_Vqf32Vqf32(sout3, es_vec); + sout3 = Q6_Vsf_equals_Vqf32(sout3); + sout3 = Q6_Vsf_vmin_VsfVsf(sout3, high_level_vec); + sout3 = Q6_Vsf_vmax_VsfVsf(sout3, low_level_vec); + sout3 = Q6_Vqf32_vmpy_VsfVsf(sout3, round_scale_vec); + sout3 = Q6_Vsf_equals_Vqf32(sout3); + + // { + // HVX_Vector exp = Q6_Vw_vasr_VwR(sout3, FLOAT_MANTISA); + // exp = Q6_V_vand_VV(exp, expmask); + // exp = Q6_Vw_vsub_VwVw(exp, expbias); + + // HVX_Vector man = Q6_Vw_vasr_VwVw(manmask, exp); + // HVX_Vector manzero = Q6_V_vand_VV(sout3, man); + + // HVX_Vector sign = Q6_Vw_vasr_VwR(sout3, FLOAT_SIGN); + // HVX_Vector issignpos = Q6_Q_vcmp_eq_VwVw(sign, zero); + + // HVX_Vector expgte23 = Q6_Q_vcmp_gt_VwVw(exp, exp23); + // HVX_Vector expgte0 = Q6_Q_vcmp_gt_VwVw(exp, exp0); + // HVX_Vector maneqzero = Q6_Q_vcmp_eq_VwVw(manzero, zero); + + // HVX_Vector exppos_signneg = Q6_Vw_vadd_VwVw(sout3, man); + // man = Q6_V_vnot_V(man); + // HVX_Vector exppos_signpos = Q6_V_vand_VV(sout3, man); + // exppos_signneg = Q6_V_vand_VV(exppos_signneg, man); + // HVX_Vector shift1 = Q6_Vw_vasl_VwR(sout3, 1); + // HVX_Vector iszero = Q6_Q_vcmp_eq_VwVw(shift1, zero); + + // // exp >= 0 + // HVX_Vector tsout1 = Q6_V_vmux_QVV(issignpos, exppos_signpos, + // exppos_signneg); tsout1 = Q6_V_vmux_QVV(maneqzero, sout3, tsout1); + + // // exp < 0 (-1, 1) + // HVX_Vector tsout2 = Q6_V_vmux_QVV(iszero, sout3, negone); + // tsout2 = Q6_V_vmux_QVV(issignpos, zero, tsout2); + + // tsout1 = Q6_V_vmux_QVV(expgte0, tsout1, tsout2); + // sout3 = Q6_V_vmux_QVV(expgte23, sout3, tsout1); + // } + + sout3 = Q6_Vw_equals_Vsf(sout3); + sout3 = Q6_Vw_vasr_VwR(sout3, ROUND_2_SCALE); + // sout3 = qhmath_hvx_vw_convert_vqf32_rmode(Q6_Vqf32_vadd_VsfVsf(sout3, + // Q6_V_vzero()), 0); + + sline4c = *iptr++; + sline4 = Q6_V_valign_VVR(sline4c, sline4p, (size_t)input); + + sout4 = Q6_Vqf32_vmpy_VsfVsf(sline4, scale_vec); + sout4 = Q6_Vqf32_vadd_Vqf32Vqf32(sout4, es_vec); + sout4 = Q6_Vsf_equals_Vqf32(sout4); + sout4 = Q6_Vsf_vmin_VsfVsf(sout4, high_level_vec); + sout4 = Q6_Vsf_vmax_VsfVsf(sout4, low_level_vec); + sout4 = Q6_Vqf32_vmpy_VsfVsf(sout4, round_scale_vec); + sout4 = Q6_Vsf_equals_Vqf32(sout4); + + // { + // HVX_Vector exp = Q6_Vw_vasr_VwR(sout4, FLOAT_MANTISA); + // exp = Q6_V_vand_VV(exp, expmask); + // exp = Q6_Vw_vsub_VwVw(exp, expbias); + + // HVX_Vector man = Q6_Vw_vasr_VwVw(manmask, exp); + // HVX_Vector manzero = Q6_V_vand_VV(sout4, man); + + // HVX_Vector sign = Q6_Vw_vasr_VwR(sout4, FLOAT_SIGN); + // HVX_Vector issignpos = Q6_Q_vcmp_eq_VwVw(sign, zero); + + // HVX_Vector expgte23 = Q6_Q_vcmp_gt_VwVw(exp, exp23); + // HVX_Vector expgte0 = Q6_Q_vcmp_gt_VwVw(exp, exp0); + // HVX_Vector maneqzero = Q6_Q_vcmp_eq_VwVw(manzero, zero); + + // HVX_Vector exppos_signneg = Q6_Vw_vadd_VwVw(sout4, man); + // man = Q6_V_vnot_V(man); + // HVX_Vector exppos_signpos = Q6_V_vand_VV(sout4, man); + // exppos_signneg = Q6_V_vand_VV(exppos_signneg, man); + // HVX_Vector shift1 = Q6_Vw_vasl_VwR(sout4, 1); + // HVX_Vector iszero = Q6_Q_vcmp_eq_VwVw(shift1, zero); + + // // exp >= 0 + // HVX_Vector tsout1 = Q6_V_vmux_QVV(issignpos, exppos_signpos, + // exppos_signneg); tsout1 = Q6_V_vmux_QVV(maneqzero, sout4, tsout1); + + // // exp < 0 (-1, 1) + // HVX_Vector tsout2 = Q6_V_vmux_QVV(iszero, sout4, negone); + // tsout2 = Q6_V_vmux_QVV(issignpos, zero, tsout2); + + // tsout1 = Q6_V_vmux_QVV(expgte0, tsout1, tsout2); + // sout4 = Q6_V_vmux_QVV(expgte23, sout4, tsout1); + // } + + sout4 = Q6_Vw_equals_Vsf(sout4); + sout4 = Q6_Vw_vasr_VwR(sout4, ROUND_2_SCALE); + // sout4 = qhmath_hvx_vw_convert_vqf32_rmode(Q6_Vqf32_vadd_VsfVsf(sout4, + // Q6_V_vzero()), 0); + + HVX_Vector reql_h = Q6_Vh_vpack_VwVw_sat(sout2, sout1); + HVX_Vector reqh_h = Q6_Vh_vpack_VwVw_sat(sout4, sout3); + HVX_Vector req_b = Q6_Vb_vpack_VhVh_sat(reqh_h, reql_h); + + *optr++ = Q6_Vb_vadd_VbVb(req_b, uintconvert); + + sline1p = sline1c; + sline2p = sline2c; + sline3p = sline3c; + sline4p = sline4c; + } + } + + return 0; +} + +int32_t qhmath_hvx_quantize_af_out_int8(float *restrict input, + float *restrict output, uint32_t size, + float low_level, float high_level, + float scale) { + if ((input == NULL) || (output == NULL) || (size == 0)) { + return -1; + } + + HVX_Vector *iptr = (HVX_Vector *)input; + HVX_UVector *optr = (HVX_UVector *)output; + + HVX_Vector sline1p, sline1c, sline1; + HVX_Vector sline2p, sline2c, sline2; + HVX_Vector sline3p, sline3c, sline3; + HVX_Vector sline4p, sline4c, sline4; + + HVX_Vector sout1, sout2, sout3, sout4; + HVX_Vector low_level_vec, high_level_vec, scale_vec, es_vec; + int32_t block, l2fetch_block; + // int32_t leftover = size & 31; + int32_t vectors_in_rounddown = size / 32; + // int32_t leftover_size = leftover * sizeof(float); + + sline1p = *iptr++; + sline2p = *iptr++; + sline3p = *iptr++; + sline4p = *iptr++; + + float es = 0.5f; + low_level_vec = Q6_V_vsplat_R(float_to_bits(low_level)); + high_level_vec = Q6_V_vsplat_R(float_to_bits(high_level)); + scale_vec = Q6_V_vsplat_R(float_to_bits(scale)); + es_vec = Q6_V_vsplat_R(float_to_bits(es)); + + HVX_Vector zero_v_sf = Q6_V_vzero(); + es_vec = Q6_Vqf32_vadd_VsfVsf(es_vec, zero_v_sf); + + HVX_Vector expmask = Q6_V_vsplat_R(FLOAT_EXPONENT_MASK); + HVX_Vector expbias = Q6_V_vsplat_R(FLOAT_EXPONENT_BIAS); + HVX_Vector manmask = Q6_V_vsplat_R(FLOAT_MANTISA_MASK); + HVX_Vector exp23 = Q6_V_vsplat_R(23 - 1); + HVX_Vector exp0 = Q6_V_vsplat_R(0 - 1); + HVX_Vector negone = Q6_V_vsplat_R(FLOAT_NEG_1); + HVX_Vector zero = Q6_V_vzero(); + + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { + block = Q6_R_min_RR(i, BLOCK_SIZE); + l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); + + if (l2fetch_block > 0) { + l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); + } + + for (int32_t j = 0; j < block; j += 4) { + sline1c = *iptr++; + sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t)input); + + sout1 = Q6_Vqf32_vmpy_VsfVsf(sline1, scale_vec); + sout1 = Q6_Vqf32_vadd_Vqf32Vqf32(sout1, es_vec); + sout1 = Q6_Vsf_equals_Vqf32(sout1); + sout1 = Q6_Vsf_vmin_VsfVsf(sout1, high_level_vec); + sout1 = Q6_Vsf_vmax_VsfVsf(sout1, low_level_vec); + + { + HVX_Vector exp = Q6_Vw_vasr_VwR(sout1, FLOAT_MANTISA); + exp = Q6_V_vand_VV(exp, expmask); + exp = Q6_Vw_vsub_VwVw(exp, expbias); + + HVX_Vector man = Q6_Vw_vasr_VwVw(manmask, exp); + HVX_Vector manzero = Q6_V_vand_VV(sout1, man); + + HVX_Vector sign = Q6_Vw_vasr_VwR(sout1, FLOAT_SIGN); + HVX_Vector issignpos = Q6_Q_vcmp_eq_VwVw(sign, zero); + + HVX_Vector expgte23 = Q6_Q_vcmp_gt_VwVw(exp, exp23); + HVX_Vector expgte0 = Q6_Q_vcmp_gt_VwVw(exp, exp0); + HVX_Vector maneqzero = Q6_Q_vcmp_eq_VwVw(manzero, zero); + + HVX_Vector exppos_signneg = Q6_Vw_vadd_VwVw(sout1, man); + man = Q6_V_vnot_V(man); + HVX_Vector exppos_signpos = Q6_V_vand_VV(sout1, man); + exppos_signneg = Q6_V_vand_VV(exppos_signneg, man); + HVX_Vector shift1 = Q6_Vw_vasl_VwR(sout1, 1); + HVX_Vector iszero = Q6_Q_vcmp_eq_VwVw(shift1, zero); + + // exp >= 0 + HVX_Vector tsout1 = + Q6_V_vmux_QVV(issignpos, exppos_signpos, exppos_signneg); + tsout1 = Q6_V_vmux_QVV(maneqzero, sout1, tsout1); + + // exp < 0 (-1, 1) + HVX_Vector tsout2 = Q6_V_vmux_QVV(iszero, sout1, negone); + tsout2 = Q6_V_vmux_QVV(issignpos, zero, tsout2); + + tsout1 = Q6_V_vmux_QVV(expgte0, tsout1, tsout2); + sout1 = Q6_V_vmux_QVV(expgte23, sout1, tsout1); + } + + sout1 = Q6_Vw_equals_Vsf(sout1); + // sout1 = qhmath_hvx_vw_convert_vqf32_rmode(Q6_Vqf32_vadd_VsfVsf(sout1, + // Q6_V_vzero()), 0); + + sline2c = *iptr++; + sline2 = Q6_V_valign_VVR(sline2c, sline2p, (size_t)input); + + sout2 = Q6_Vqf32_vmpy_VsfVsf(sline2, scale_vec); + sout2 = Q6_Vqf32_vadd_Vqf32Vqf32(sout2, es_vec); + sout2 = Q6_Vsf_equals_Vqf32(sout2); + sout2 = Q6_Vsf_vmin_VsfVsf(sout2, high_level_vec); + sout2 = Q6_Vsf_vmax_VsfVsf(sout2, low_level_vec); + + { + HVX_Vector exp = Q6_Vw_vasr_VwR(sout2, FLOAT_MANTISA); + exp = Q6_V_vand_VV(exp, expmask); + exp = Q6_Vw_vsub_VwVw(exp, expbias); + + HVX_Vector man = Q6_Vw_vasr_VwVw(manmask, exp); + HVX_Vector manzero = Q6_V_vand_VV(sout2, man); + + HVX_Vector sign = Q6_Vw_vasr_VwR(sout2, FLOAT_SIGN); + HVX_Vector issignpos = Q6_Q_vcmp_eq_VwVw(sign, zero); + + HVX_Vector expgte23 = Q6_Q_vcmp_gt_VwVw(exp, exp23); + HVX_Vector expgte0 = Q6_Q_vcmp_gt_VwVw(exp, exp0); + HVX_Vector maneqzero = Q6_Q_vcmp_eq_VwVw(manzero, zero); + + HVX_Vector exppos_signneg = Q6_Vw_vadd_VwVw(sout2, man); + man = Q6_V_vnot_V(man); + HVX_Vector exppos_signpos = Q6_V_vand_VV(sout2, man); + exppos_signneg = Q6_V_vand_VV(exppos_signneg, man); + HVX_Vector shift1 = Q6_Vw_vasl_VwR(sout2, 1); + HVX_Vector iszero = Q6_Q_vcmp_eq_VwVw(shift1, zero); + + // exp >= 0 + HVX_Vector tsout1 = + Q6_V_vmux_QVV(issignpos, exppos_signpos, exppos_signneg); + tsout1 = Q6_V_vmux_QVV(maneqzero, sout2, tsout1); + + // exp < 0 (-1, 1) + HVX_Vector tsout2 = Q6_V_vmux_QVV(iszero, sout2, negone); + tsout2 = Q6_V_vmux_QVV(issignpos, zero, tsout2); + + tsout1 = Q6_V_vmux_QVV(expgte0, tsout1, tsout2); + sout2 = Q6_V_vmux_QVV(expgte23, sout2, tsout1); + } + + sout2 = Q6_Vw_equals_Vsf(sout2); + // sout2 = qhmath_hvx_vw_convert_vqf32_rmode(Q6_Vqf32_vadd_VsfVsf(sout2, + // Q6_V_vzero()), 0); + + sline3c = *iptr++; + sline3 = Q6_V_valign_VVR(sline3c, sline3p, (size_t)input); + + sout3 = Q6_Vqf32_vmpy_VsfVsf(sline3, scale_vec); + sout3 = Q6_Vqf32_vadd_Vqf32Vqf32(sout3, es_vec); + sout3 = Q6_Vsf_equals_Vqf32(sout3); + sout3 = Q6_Vsf_vmin_VsfVsf(sout3, high_level_vec); + sout3 = Q6_Vsf_vmax_VsfVsf(sout3, low_level_vec); + + { + HVX_Vector exp = Q6_Vw_vasr_VwR(sout3, FLOAT_MANTISA); + exp = Q6_V_vand_VV(exp, expmask); + exp = Q6_Vw_vsub_VwVw(exp, expbias); + + HVX_Vector man = Q6_Vw_vasr_VwVw(manmask, exp); + HVX_Vector manzero = Q6_V_vand_VV(sout3, man); + + HVX_Vector sign = Q6_Vw_vasr_VwR(sout3, FLOAT_SIGN); + HVX_Vector issignpos = Q6_Q_vcmp_eq_VwVw(sign, zero); + + HVX_Vector expgte23 = Q6_Q_vcmp_gt_VwVw(exp, exp23); + HVX_Vector expgte0 = Q6_Q_vcmp_gt_VwVw(exp, exp0); + HVX_Vector maneqzero = Q6_Q_vcmp_eq_VwVw(manzero, zero); + + HVX_Vector exppos_signneg = Q6_Vw_vadd_VwVw(sout3, man); + man = Q6_V_vnot_V(man); + HVX_Vector exppos_signpos = Q6_V_vand_VV(sout3, man); + exppos_signneg = Q6_V_vand_VV(exppos_signneg, man); + HVX_Vector shift1 = Q6_Vw_vasl_VwR(sout3, 1); + HVX_Vector iszero = Q6_Q_vcmp_eq_VwVw(shift1, zero); + + // exp >= 0 + HVX_Vector tsout1 = + Q6_V_vmux_QVV(issignpos, exppos_signpos, exppos_signneg); + tsout1 = Q6_V_vmux_QVV(maneqzero, sout3, tsout1); + + // exp < 0 (-1, 1) + HVX_Vector tsout2 = Q6_V_vmux_QVV(iszero, sout3, negone); + tsout2 = Q6_V_vmux_QVV(issignpos, zero, tsout2); + + tsout1 = Q6_V_vmux_QVV(expgte0, tsout1, tsout2); + sout3 = Q6_V_vmux_QVV(expgte23, sout3, tsout1); + } + + sout3 = Q6_Vw_equals_Vsf(sout3); + // sout3 = qhmath_hvx_vw_convert_vqf32_rmode(Q6_Vqf32_vadd_VsfVsf(sout3, + // Q6_V_vzero()), 0); + + sline4c = *iptr++; + sline4 = Q6_V_valign_VVR(sline4c, sline4p, (size_t)input); + + sout4 = Q6_Vqf32_vmpy_VsfVsf(sline4, scale_vec); + sout4 = Q6_Vqf32_vadd_Vqf32Vqf32(sout4, es_vec); + sout4 = Q6_Vsf_equals_Vqf32(sout4); + sout4 = Q6_Vsf_vmin_VsfVsf(sout4, high_level_vec); + sout4 = Q6_Vsf_vmax_VsfVsf(sout4, low_level_vec); + + { + HVX_Vector exp = Q6_Vw_vasr_VwR(sout4, FLOAT_MANTISA); + exp = Q6_V_vand_VV(exp, expmask); + exp = Q6_Vw_vsub_VwVw(exp, expbias); + + HVX_Vector man = Q6_Vw_vasr_VwVw(manmask, exp); + HVX_Vector manzero = Q6_V_vand_VV(sout4, man); + + HVX_Vector sign = Q6_Vw_vasr_VwR(sout4, FLOAT_SIGN); + HVX_Vector issignpos = Q6_Q_vcmp_eq_VwVw(sign, zero); + + HVX_Vector expgte23 = Q6_Q_vcmp_gt_VwVw(exp, exp23); + HVX_Vector expgte0 = Q6_Q_vcmp_gt_VwVw(exp, exp0); + HVX_Vector maneqzero = Q6_Q_vcmp_eq_VwVw(manzero, zero); + + HVX_Vector exppos_signneg = Q6_Vw_vadd_VwVw(sout4, man); + man = Q6_V_vnot_V(man); + HVX_Vector exppos_signpos = Q6_V_vand_VV(sout4, man); + exppos_signneg = Q6_V_vand_VV(exppos_signneg, man); + HVX_Vector shift1 = Q6_Vw_vasl_VwR(sout4, 1); + HVX_Vector iszero = Q6_Q_vcmp_eq_VwVw(shift1, zero); + + // exp >= 0 + HVX_Vector tsout1 = + Q6_V_vmux_QVV(issignpos, exppos_signpos, exppos_signneg); + tsout1 = Q6_V_vmux_QVV(maneqzero, sout4, tsout1); + + // exp < 0 (-1, 1) + HVX_Vector tsout2 = Q6_V_vmux_QVV(iszero, sout4, negone); + tsout2 = Q6_V_vmux_QVV(issignpos, zero, tsout2); + + tsout1 = Q6_V_vmux_QVV(expgte0, tsout1, tsout2); + sout4 = Q6_V_vmux_QVV(expgte23, sout4, tsout1); + } + + sout4 = Q6_Vw_equals_Vsf(sout4); + // sout4 = qhmath_hvx_vw_convert_vqf32_rmode(Q6_Vqf32_vadd_VsfVsf(sout4, + // Q6_V_vzero()), 0); + + HVX_Vector reql_h = Q6_Vh_vpack_VwVw_sat(sout2, sout1); + HVX_Vector reqh_h = Q6_Vh_vpack_VwVw_sat(sout4, sout3); + HVX_Vector req_b = Q6_Vb_vpack_VhVh_sat(reqh_h, reql_h); + + *optr++ = req_b; + + sline1p = sline1c; + sline2p = sline2c; + sline3p = sline3c; + sline4p = sline4c; + } + } + + return 0; +} + +template +GraphStatus llamaquantizeImpl(TensorType1 &out_0, const TensorType1 &in_0, + const PlainFloatTensor &scale) + +{ + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + + // HVX Method -- FP32 Version + out_0.set_dims(in_0); + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + + float scale_ = scale(0, 0, 0, 0); + + scale_ = 1.0f / scale_; + + size_t size = b_in * h_in * w_in * d_in; + DType dtype = in_0.get_dtype(); + + if (dtype == DType::Float16 && out_0.get_dtype() == DType::QUInt8) { + // NHWC + auto in_ptr = (__fp16 *)in_0.raw_data_const(); + auto out_ptr = (__fp16 *)out_0.raw_data(); + + qhmath_hvx_quantize_ahf(in_ptr, out_ptr, size, -128.0f, 127.0f, scale_); + } + if (dtype == DType::Float32 && out_0.get_dtype() == DType::QUInt8) { + + // NHWC + auto in_ptr = (float *)in_0.raw_data_const(); + auto out_ptr = (float *)out_0.raw_data(); + qhmath_hvx_quantize_af(in_ptr, out_ptr, size, -128.0f, 127.0f, scale_); + } + + if (dtype == DType::Float16 && out_0.get_dtype() == DType::QInt8) { + // NHWC + auto in_ptr = (__fp16 *)in_0.raw_data_const(); + auto out_ptr = (__fp16 *)out_0.raw_data(); + + qhmath_hvx_quantize_ahf_int8(in_ptr, out_ptr, size, -128.0f, 127.0f, + scale_); + } + + if (dtype == DType::Float32 && out_0.get_dtype() == DType::QInt8) { + + // NHWC + auto in_ptr = (float *)in_0.raw_data_const(); + auto out_ptr = (float *)out_0.raw_data(); + qhmath_hvx_quantize_af_out_int8(in_ptr, out_ptr, size, -128.0f, 127.0f, + scale_); + } + + // auto out_ptr = (int8_t*)out_0.raw_data(); + + // out_ptr[0] = (int)dtype; + // out_ptr[1] = (int)out_0.get_dtype(); + + return GraphStatus::Success; +} + +#else + +extern float Round(float num); + +template +GraphStatus llamaquantizeImpl(TensorType1 &out_0, const TensorType1 &in_0, + const PlainFloatTensor &scale) + +{ + out_0.set_dims(in_0); + + float scale_ = scale(0, 0, 0, 0); + + auto out_ptr = (int8_t *)out_0.raw_data(); + + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + + for (Idx w = 0; w < w_in; w++) { + + for (Idx d = 0; d < d_in; d++) { + + float inval = in_0(b, h, w, d); + + // float result = Round(inval / scale_); + + long v = lroundf(inval / scale_); + + if (v > 127) + v = 127; + + if (v < -128) + v = -128; + + if (out_0.get_dtype() == DType::QUInt8) + v += 128; + + *out_ptr++ = static_cast(v); + } + } + } + } + + return GraphStatus::Success; +} +#endif + +__attribute__((unused)) static float llamaquantizeCostFunc(const Op *op) { + /* + * add code here + * */ + + float cost = 0.0; // add cost computation here + return cost; +} + +/* At the bottom of the op file, call END_PKG_OP_DEFINITION(), + where is as BEGIN_PKG_OP_DEFINITION +*/ +END_PKG_OP_DEFINITION(PKG_LLaMAQuantize); diff --git a/nntrainer/npu/qnn/LLaMAPackage/src/ops/MergeOutput.cpp b/nntrainer/npu/qnn/LLaMAPackage/src/ops/MergeOutput.cpp new file mode 100644 index 000000000..214c463ed --- /dev/null +++ b/nntrainer/npu/qnn/LLaMAPackage/src/ops/MergeOutput.cpp @@ -0,0 +1,162 @@ +//============================================================================== +// Auto Generated Code for LLaMAPackage +//============================================================================== + +#include "HTP/core/constraints.h" +#include "HTP/core/op_package_feature_support.h" +#include "HTP/core/op_register_ext.h" +#include "HTP/core/optimize.h" +#include "HTP/core/simple_reg.h" +#include "QnnOpPackage.h" + +BEGIN_PKG_OP_DEFINITION(PKG_MergeOutput); + +// op execute function declarations +template +GraphStatus mergeoutputImpl(TensorType &out_0, const TensorType &in_0, + const TensorType &in_1, const TensorType &in_2, + const TensorType &in_3, const Tensor &num); + +// forward declaration of sample cost function +static float mergeoutputCostFunc(const Op *op); + +/* + * method 1 for defining op, using default cost value (i.e. GLACIAL) and default + * flag (Flags::RESOURCE_HVX) syntax: DEF_PACKAGE_OP(F,OP) e.g. + * DEF_PACKAGE_OP((mergeoutputImpl), "MergeOutput") + */ +DEF_PACKAGE_OP((mergeoutputImpl), "MergeOutput") + +/* + * method 2 for defining op with specified cost value (one of GLACIAL, SNAIL, + * FAST, FREE) and provided flags syntax: + * DEF_PACKAGE_OP_AND_COST_AND_FLAGS(F,OP,COST,...) can use zero or more flags, + * FLAG options are IS_CONST, INHIBIT_CONST_PROP, RESOURCE_HVX, RESOURCE_HMX(not + * supported in external op packages) e.g. + * DEF_PACKAGE_OP_AND_COST_AND_FLAGS((mergeoutputImpl), + * "MergeOutput", SNAIL) + */ + +/* + * method 3 for defining op with cost function pointer and provided flags + * cost function pointer type: typedef float (*cost_function) (const Op * op); + * syntax: DEF_PACKAGE_OP_AND_COST_F_AND_FLAGS(F,OP,COST_F,...) + * e.g. DEF_PACKAGE_OP_AND_COST_F_AND_FLAGS((mergeoutputImpl), + * "MergeOutput", mergeoutputCostFunc, Flags::RESOURCE_HVX) + */ + +/* + * optimization definitions + * need to be global in the package + * one definition per optimization + * syntax: + * DEF_PACKAGE_OPTIMIZATION(PRIORITY,MATCHCODE,CONSTRAINTCODE,REPLACECODE) + * PRIORITY predefined values include EARLY(2000), MIDDLE(3000), LATE(4000) + * HTP core provides some replacement functions for op package to use + * for more information about optimization rules, please refer to HTP core + * documentations + */ + +/* + * op parameter order definitions + * need to be global in the package + * one definition per op, and this is optional + * syntax: + * DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) + * one or more parameters can be specified for each op + * order of parameters listed determines the order of parameters passed into op + * execution functions if an op does not have a parameter order definition, + * parameter order passed into Qnn_addNode will be passed into op execution + * functions if an op has a parameter order definition, any parameter passed + * into Qnn_addNode with unlisted name will be abandoned if two or more op + * packages with the same package name will be registered, they cannot list + * conflicting parameter orders + * PARAM refers to parameter name as a string literal + * MANDATORY refers to whether this parameter is required to be provided at + * Qnn_addNode DEFAULT is used when MANDATORY is false if provided as + * Qnn_Param_t*, DEFAULT will be used for graph construction when this parameter + * is not provided at Qnn_addNode if provided as nullptr, graph construction + * will skip this parameter when this parameter is not provided at Qnn_addNode + */ + +/* execute functions for ops */ + +template +GraphStatus mergeoutputImpl(TensorType &out_0, const TensorType &in_0, + const TensorType &in_1, const TensorType &in_2, + const TensorType &in_3, const Tensor &num) + +{ + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + + auto [b_in_0, h_in_0, w_in_0, d_in_0] = in_0.dims(); + auto [b_in_1, h_in_1, w_in_1, d_in_1] = in_1.dims(); + auto [b_in_2, h_in_2, w_in_2, d_in_2] = in_2.dims(); + auto [b_in_3, h_in_3, w_in_3, d_in_3] = in_3.dims(); + + const size_t dims[] = {b_in_0, h_in_0 + h_in_1 + h_in_2 + h_in_3 * 4, w_in_0, + d_in_0}; + + out_0.set_dims(dims); + + DType dtype = in_0.get_dtype(); + uint32_t bitwidth = 4; + + if (dtype == DType::QUInt8 || dtype == DType::QInt8) { + + bitwidth = 1; + + } else if (dtype == DType::Float16) { + + bitwidth = 2; + } else if (dtype == DType::Float32) { + + bitwidth = 4; + } + + const uint8_t *in_ptr_0 = (uint8_t *)in_0.raw_data_const(); + const uint8_t *in_ptr_1 = (uint8_t *)in_1.raw_data_const(); + const uint8_t *in_ptr_2 = (uint8_t *)in_2.raw_data_const(); + // const uint8_t *in_ptr_3 = (uint8_t*)in_3.raw_data_const(); + + uint8_t *out_ptr = (uint8_t *)out_0.raw_data(); + + memcpy(out_ptr, in_ptr_0, b_in_0 * h_in_0 * w_in_0 * d_in_0 * bitwidth); + out_ptr += b_in_0 * h_in_0 * w_in_0 * d_in_0 * bitwidth; + + memcpy(out_ptr, in_ptr_1, b_in_1 * h_in_1 * w_in_1 * d_in_1 * bitwidth); + out_ptr += b_in_1 * h_in_1 * w_in_1 * d_in_1 * bitwidth; + + memcpy(out_ptr, in_ptr_2, b_in_2 * h_in_2 * w_in_2 * d_in_2 * bitwidth); + out_ptr += b_in_2 * h_in_2 * w_in_2 * d_in_2 * bitwidth; + + // memcpy(out_ptr, in_ptr_3, b_in_3 * h_in_3 * w_in_3 * d_in_3 * bitwidth * + // 4); + + return GraphStatus::Success; +} + +__attribute__((unused)) static float mergeoutputCostFunc(const Op *op) { + /* + * add code here + * */ + + float cost = 0.0; // add cost computation here + return cost; +} + +/* At the bottom of the op file, call END_PKG_OP_DEFINITION(), + where is as BEGIN_PKG_OP_DEFINITION +*/ +END_PKG_OP_DEFINITION(PKG_MergeOutput); diff --git a/nntrainer/npu/qnn/LLaMAPackage/src/ops/RoPE.cpp b/nntrainer/npu/qnn/LLaMAPackage/src/ops/RoPE.cpp new file mode 100644 index 000000000..ade2aa0ad --- /dev/null +++ b/nntrainer/npu/qnn/LLaMAPackage/src/ops/RoPE.cpp @@ -0,0 +1,1058 @@ +//============================================================================== +// Auto Generated Code for LLaMAPackage +//============================================================================== + +#include "HTP/core/constraints.h" +#include "HTP/core/op_package_feature_support.h" +#include "HTP/core/op_register_ext.h" +#include "HTP/core/optimize.h" +#include "HTP/core/simple_reg.h" +#include "QnnOpPackage.h" + +BEGIN_PKG_OP_DEFINITION(PKG_RoPE); + +// op execute function declarations +template +GraphStatus ropeImpl(TensorType &out_0, const TensorType &in_0, + const TensorType &sin, const TensorType &cos, + const TensorType1 &h_cnt, const Tensor &pose_type); + +// forward declaration of sample cost function +static float ropeCostFunc(const Op *op); + +/* + * method 1 for defining op, using default cost value (i.e. GLACIAL) and default + * flag (Flags::RESOURCE_HVX) syntax: DEF_PACKAGE_OP(F,OP) e.g. + * DEF_PACKAGE_OP((ropeImpl), "RoPE") + */ +DEF_PACKAGE_OP((ropeImpl), "RoPE") + +/* + * method 2 for defining op with specified cost value (one of GLACIAL, SNAIL, + * FAST, FREE) and provided flags syntax: + * DEF_PACKAGE_OP_AND_COST_AND_FLAGS(F,OP,COST,...) can use zero or more flags, + * FLAG options are IS_CONST, INHIBIT_CONST_PROP, RESOURCE_HVX, RESOURCE_HMX(not + * supported in external op packages) e.g. + * DEF_PACKAGE_OP_AND_COST_AND_FLAGS((ropeImpl), "RoPE", + * SNAIL) + */ + +/* + * method 3 for defining op with cost function pointer and provided flags + * cost function pointer type: typedef float (*cost_function) (const Op * op); + * syntax: DEF_PACKAGE_OP_AND_COST_F_AND_FLAGS(F,OP,COST_F,...) + * e.g. DEF_PACKAGE_OP_AND_COST_F_AND_FLAGS((ropeImpl), + * "RoPE", ropeCostFunc, Flags::RESOURCE_HVX) + */ + +/* + * optimization definitions + * need to be global in the package + * one definition per optimization + * syntax: + * DEF_PACKAGE_OPTIMIZATION(PRIORITY,MATCHCODE,CONSTRAINTCODE,REPLACECODE) + * PRIORITY predefined values include EARLY(2000), MIDDLE(3000), LATE(4000) + * HTP core provides some replacement functions for op package to use + * for more information about optimization rules, please refer to HTP core + * documentations + */ + +/* + * op parameter order definitions + * need to be global in the package + * one definition per op, and this is optional + * syntax: + * DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) + * one or more parameters can be specified for each op + * order of parameters listed determines the order of parameters passed into op + * execution functions if an op does not have a parameter order definition, + * parameter order passed into Qnn_addNode will be passed into op execution + * functions if an op has a parameter order definition, any parameter passed + * into Qnn_addNode with unlisted name will be abandoned if two or more op + * packages with the same package name will be registered, they cannot list + * conflicting parameter orders + * PARAM refers to parameter name as a string literal + * MANDATORY refers to whether this parameter is required to be provided at + * Qnn_addNode DEFAULT is used when MANDATORY is false if provided as + * Qnn_Param_t*, DEFAULT will be used for graph construction when this parameter + * is not provided at Qnn_addNode if provided as nullptr, graph construction + * will skip this parameter when this parameter is not provided at Qnn_addNode + */ +DEF_PACKAGE_PARAM_ORDER("RoPE", "pose_type", true, nullptr) + +/* execute functions for ops */ + +#ifndef REFERENCE_OP + +#include "hvx_internal.h" +#include "qhmath_hvx.h" +#include +#include + +#define BLOCK_SIZE (8 * 1024 / VLEN) /* vector chunks */ +#define L2FETCH_AHEAD (BLOCK_SIZE) +#define ONE 0x3F800000 +#define M_ONE 0xAF800000 + +int32_t hvx_rope_af(float *restrict input, float *restrict sin, + float *restrict cos, float *restrict output, uint32_t size, + uint32_t partial_dimension) { + if ((input == NULL) || (output == NULL) || (size == 0)) { + return -1; + } + + HVX_Vector *iptr = (HVX_Vector *)input; + HVX_Vector *iptr_half = (HVX_Vector *)(input + partial_dimension / 2); + HVX_Vector *iptr2 = (HVX_Vector *)sin; + HVX_Vector *iptr3 = (HVX_Vector *)cos; + HVX_UVector *optr = (HVX_UVector *)output; + HVX_UVector *optr_half = (HVX_UVector *)(output + partial_dimension / 2); + ; + HVX_Vector sline1; + HVX_Vector sline1_half; + HVX_Vector sinline1p, sinline1c, sinline1; + HVX_Vector cosline1p, cosline1c, cosline1; + + int32_t l2fetch_block; + int32_t leftover = size & 31; + int32_t vectors_in_rounddown = size / 32; + int32_t leftover_size = leftover * sizeof(float); + + sinline1p = *iptr2++; + cosline1p = *iptr3++; + + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { + l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); + + if (l2fetch_block > 0) { + l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); + l2fetch(iptr2 + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); + l2fetch(iptr3 + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); + } + + for (int32_t d = 0; d < partial_dimension / 2; d += 32) { + cosline1c = *iptr3++; + cosline1 = Q6_V_valign_VVR(cosline1c, cosline1p, (size_t)cos); + cosline1p = cosline1c; + + sinline1c = *iptr2++; + sinline1 = Q6_V_valign_VVR(sinline1c, sinline1p, (size_t)sin); + sinline1p = sinline1c; + + HVX_Vector *jiptr = iptr + d / 32; + HVX_Vector *jiptr_half = iptr_half + d / 32; + HVX_Vector *joptr = optr + d / 32; + HVX_Vector *joptr_half = optr_half + d / 32; + + for (int32_t j = 0; j < size / partial_dimension; j++) { + sline1 = *jiptr; + sline1_half = *jiptr_half; + + // auto value = in_value * cos_value - in_value_2 * sin_value; + { + HVX_Vector cos_middle_value_qf32 = + Q6_Vqf32_vmpy_VsfVsf(sline1, cosline1); + HVX_Vector sin_middle_value_qf32 = + Q6_Vqf32_vmpy_VsfVsf(sline1_half, sinline1); + *joptr = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_Vqf32Vqf32( + cos_middle_value_qf32, sin_middle_value_qf32)); + } + + // auto value2 = in_value * sin_value + in_value_2 * cos_value; + { + HVX_Vector cos_middle_value_qf32 = + Q6_Vqf32_vmpy_VsfVsf(sline1_half, cosline1); + HVX_Vector sin_middle_value_qf32 = + Q6_Vqf32_vmpy_VsfVsf(sline1, sinline1); + *joptr_half = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vqf32( + cos_middle_value_qf32, sin_middle_value_qf32)); + } + + jiptr += partial_dimension / 32; + jiptr_half += partial_dimension / 32; + joptr += partial_dimension / 32; + joptr_half += partial_dimension / 32; + } + } + } + + // if (vectors_in_rounddown > 0) { + + // sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; + // sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t) input); + // sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum, Q6_Vqf32_vmpy_VsfVsf(sline1, + // sline1)); + + // } + + if (leftover_size > 0) + return -1; + + return 0; +} + +static inline int32_t float_to_fp16s(float input) { + union { + int32_t i; + __fp16 f[2]; + } fp32 = {.f = {(__fp16)input, (__fp16)input}}; + return fp32.i; +} + +int32_t hvx_rope_uint8_af(uint8_t *restrict input, float *restrict sin, + float *restrict cos, float *restrict output, + uint32_t size, uint32_t partial_dimension) { + if ((input == NULL) || (output == NULL) || (size == 0)) { + return -1; + } + + HVX_Vector *iptr = (HVX_Vector *)input; + HVX_Vector *iptr2 = (HVX_Vector *)sin; + HVX_Vector *iptr3 = (HVX_Vector *)cos; + HVX_UVector *optr = (HVX_UVector *)output; + + int32_t l2fetch_block; + int32_t leftover = size & 127; + int32_t vectors_in_rounddown = size / 128; + int32_t leftover_size = leftover * sizeof(float); + + HVX_Vector zero_v_sf = Q6_V_vzero(); + uint32_t convert = 0x00800080; + HVX_Vector convert_vector = Q6_V_vsplat_R(convert); + HVX_Vector one_vec = Q6_V_vsplat_R(float_to_fp16s(1.0)); + + // + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { + l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); + + if (l2fetch_block > 0) { + l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); + } + + // + HVX_Vector sinline1_low = *iptr2; + HVX_Vector cosline1_low = *iptr3; + sinline1_low = Q6_Vqf32_vadd_VsfVsf(sinline1_low, Q6_V_vzero()); + cosline1_low = Q6_Vqf32_vadd_VsfVsf(cosline1_low, Q6_V_vzero()); + + HVX_Vector sinline1_high = *(iptr2 + 1); + HVX_Vector cosline1_high = *(iptr3 + 1); + sinline1_high = Q6_Vqf32_vadd_VsfVsf(sinline1_high, Q6_V_vzero()); + cosline1_high = Q6_Vqf32_vadd_VsfVsf(cosline1_high, Q6_V_vzero()); + + for (int32_t j = 0; j < size / partial_dimension; j++) { + + HVX_Vector sline1 = *iptr++; + + HVX_VectorPair temp = Q6_Wh_vadd_VubVub(sline1, zero_v_sf); + + temp = Q6_W_vshuff_VVR(Q6_V_hi_W(temp), Q6_V_lo_W(temp), -2); + HVX_Vector sout1 = Q6_Vh_vsub_VhVh(Q6_V_lo_W(temp), convert_vector); + HVX_Vector sout2 = Q6_Vh_vsub_VhVh(Q6_V_hi_W(temp), convert_vector); + + HVX_VectorPair result1 = + Q6_Wqf32_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout1), one_vec); + result1 = Q6_W_vshuff_VVR(Q6_V_hi_W(result1), Q6_V_lo_W(result1), -4); + + HVX_VectorPair result2 = + Q6_Wqf32_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout2), one_vec); + result2 = Q6_W_vshuff_VVR(Q6_V_hi_W(result2), Q6_V_lo_W(result2), -4); + + // auto value = in_value * cos_value - in_value_2 * sin_value; + { + HVX_Vector cos_middle_value_qf32 = + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result1), cosline1_low); + HVX_Vector sin_middle_value_qf32 = + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result2), sinline1_low); + *optr = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_Vqf32Vqf32( + cos_middle_value_qf32, sin_middle_value_qf32)); + } + + // auto value2 = in_value * sin_value + in_value_2 * cos_value; + { + HVX_Vector cos_middle_value_qf32 = + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result2), cosline1_low); + HVX_Vector sin_middle_value_qf32 = + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result1), sinline1_low); + *(optr + 2) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vqf32( + cos_middle_value_qf32, sin_middle_value_qf32)); + } + + // auto value = in_value * cos_value - in_value_2 * sin_value; + { + HVX_Vector cos_middle_value_qf32 = + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result1), cosline1_high); + HVX_Vector sin_middle_value_qf32 = + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result2), sinline1_high); + *(optr + 1) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vsub_Vqf32Vqf32( + cos_middle_value_qf32, sin_middle_value_qf32)); + } + + // auto value2 = in_value * sin_value + in_value_2 * cos_value; + { + HVX_Vector cos_middle_value_qf32 = + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result2), cosline1_high); + HVX_Vector sin_middle_value_qf32 = + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result1), sinline1_high); + *(optr + 3) = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vqf32( + cos_middle_value_qf32, sin_middle_value_qf32)); + } + + optr += 4; + } + } + + // if (vectors_in_rounddown > 0) { + + // sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; + // sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t) input); + // sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum, Q6_Vqf32_vmpy_VsfVsf(sline1, + // sline1)); + + // } + + if (leftover_size > 0) + return -1; + + return 0; +} + +int32_t hvx_rope_uint8_ahf(uint8_t *restrict input, float *restrict sin, + float *restrict cos, __fp16 *restrict output, + uint32_t size, uint32_t partial_dimension, + float scale) { + if ((input == NULL) || (output == NULL) || (size == 0)) { + return -1; + } + + HVX_Vector *iptr = (HVX_Vector *)input; + HVX_Vector *iptr2 = (HVX_Vector *)sin; + HVX_Vector *iptr3 = (HVX_Vector *)cos; + HVX_UVector *optr = (HVX_UVector *)output; + + int32_t l2fetch_block; + int32_t leftover = size & 127; + int32_t vectors_in_rounddown = size / 128; + int32_t leftover_size = leftover * sizeof(float); + + HVX_Vector zero_v_sf = Q6_V_vzero(); + uint32_t convert = 0x00800080; + HVX_Vector convert_vector = Q6_V_vsplat_R(convert); + + HVX_Vector scale_vec = Q6_V_vsplat_R(float_to_fp16s(scale)); + + // + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { + l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); + + if (l2fetch_block > 0) { + l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); + } + + // + HVX_Vector sinline1_low = *iptr2; + HVX_Vector cosline1_low = *iptr3; + sinline1_low = Q6_Vqf32_vadd_VsfVsf(sinline1_low, Q6_V_vzero()); + cosline1_low = Q6_Vqf32_vadd_VsfVsf(cosline1_low, Q6_V_vzero()); + + HVX_Vector sinline1_high = *(iptr2 + 1); + HVX_Vector cosline1_high = *(iptr3 + 1); + sinline1_high = Q6_Vqf32_vadd_VsfVsf(sinline1_high, Q6_V_vzero()); + cosline1_high = Q6_Vqf32_vadd_VsfVsf(cosline1_high, Q6_V_vzero()); + + for (int32_t j = 0; j < size / partial_dimension; j++) { + + HVX_Vector sline1 = *iptr++; + + HVX_VectorPair temp = Q6_Wh_vadd_VubVub(sline1, zero_v_sf); + + temp = Q6_W_vshuff_VVR(Q6_V_hi_W(temp), Q6_V_lo_W(temp), -2); + HVX_Vector sout1 = Q6_Vh_vsub_VhVh(Q6_V_lo_W(temp), convert_vector); + HVX_Vector sout2 = Q6_Vh_vsub_VhVh(Q6_V_hi_W(temp), convert_vector); + + HVX_VectorPair result1 = + Q6_Wqf32_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout1), scale_vec); + result1 = Q6_W_vshuff_VVR(Q6_V_hi_W(result1), Q6_V_lo_W(result1), -4); + + HVX_VectorPair result2 = + Q6_Wqf32_vmpy_VhfVhf(Q6_Vhf_equals_Vh(sout2), scale_vec); + result2 = Q6_W_vshuff_VVR(Q6_V_hi_W(result2), Q6_V_lo_W(result2), -4); + + { + HVX_Vector first; + HVX_Vector second; + // auto value = in_value * cos_value - in_value_2 * sin_value; + { + HVX_Vector cos_middle_value_qf32 = + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result1), cosline1_low); + HVX_Vector sin_middle_value_qf32 = + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result2), sinline1_low); + first = Q6_Vqf32_vsub_Vqf32Vqf32(cos_middle_value_qf32, + sin_middle_value_qf32); + } + + // auto value = in_value * cos_value - in_value_2 * sin_value; + { + HVX_Vector cos_middle_value_qf32 = + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result1), cosline1_high); + HVX_Vector sin_middle_value_qf32 = + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result2), sinline1_high); + second = Q6_Vqf32_vsub_Vqf32Vqf32(cos_middle_value_qf32, + sin_middle_value_qf32); + } + + HVX_Vector r = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(second, first)); + r = Q6_Vh_vdeal_Vh(r); + *optr = r; + } + + { + HVX_Vector first; + HVX_Vector second; + // auto value2 = in_value * sin_value + in_value_2 * cos_value; + { + HVX_Vector cos_middle_value_qf32 = + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result2), cosline1_low); + HVX_Vector sin_middle_value_qf32 = + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(result1), sinline1_low); + first = Q6_Vqf32_vadd_Vqf32Vqf32(cos_middle_value_qf32, + sin_middle_value_qf32); + } + + // auto value2 = in_value * sin_value + in_value_2 * cos_value; + { + HVX_Vector cos_middle_value_qf32 = + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result2), cosline1_high); + HVX_Vector sin_middle_value_qf32 = + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(result1), sinline1_high); + second = Q6_Vqf32_vadd_Vqf32Vqf32(cos_middle_value_qf32, + sin_middle_value_qf32); + } + HVX_Vector r = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(second, first)); + r = Q6_Vh_vdeal_Vh(r); + *(optr + 1) = r; + } + + optr += 2; + } + } + + // if (vectors_in_rounddown > 0) { + + // sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; + // sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t) input); + // sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum, Q6_Vqf32_vmpy_VsfVsf(sline1, + // sline1)); + + // } + + if (leftover_size > 0) + return -1; + + return 0; +} + +int32_t hvx_rope_ahf(__fp16 *restrict input, float *restrict sin, + float *restrict cos, __fp16 *restrict output, + uint32_t size, uint32_t partial_dimension) { + if ((input == NULL) || (output == NULL) || (size == 0)) { + return -1; + } + + HVX_Vector *iptr = (HVX_Vector *)input; + HVX_Vector *iptr_half = (HVX_Vector *)(input + partial_dimension / 2); + HVX_Vector *iptr2 = (HVX_Vector *)sin; + HVX_Vector *iptr3 = (HVX_Vector *)cos; + HVX_UVector *optr = (HVX_UVector *)output; + HVX_UVector *optr_half = (HVX_UVector *)(output + partial_dimension / 2); + ; + HVX_Vector sline1; + HVX_Vector sline1_half; + + int32_t l2fetch_block; + int32_t leftover = size & 63; + int32_t vectors_in_rounddown = size / 64; + int32_t leftover_size = leftover * sizeof(float); + + HVX_Vector one_vsf = Q6_V_vsplat_R(ONE); + HVX_Vector m_one_vqf32 = Q6_Vqf32_vsub_VsfVsf(Q6_V_vzero(), one_vsf); + + HVX_Vector one_vhf = Q6_V_vsplat_R(float_to_fp16s(1.0)); + // HVX_Vector m_one_vqf16 = Q6_Vqf32_vsub_VsfVsf(Q6_V_vzero(), one_vhf); + + for (int32_t i = vectors_in_rounddown - 1; i > 0; i -= BLOCK_SIZE) { + l2fetch_block = Q6_R_min_RR(i - L2FETCH_AHEAD, BLOCK_SIZE); + + if (l2fetch_block > 0) { + l2fetch(iptr + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); + l2fetch(iptr2 + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); + l2fetch(iptr3 + L2FETCH_AHEAD, VLEN, VLEN, l2fetch_block, 0); + } + + for (int32_t d = 0; d < partial_dimension / 2; d += 64) { + + HVX_Vector sinline1_low = *iptr2++; + HVX_Vector cosline1_low = *iptr3++; + + HVX_Vector sinline1_high = *iptr2++; + HVX_Vector cosline1_high = *iptr3++; + + HVX_Vector *jiptr = iptr + d / 64; + HVX_Vector *jiptr_half = iptr_half + d / 64; + HVX_Vector *joptr = optr + d / 64; + HVX_Vector *joptr_half = optr_half + d / 64; + + for (int32_t j = 0; j < size / partial_dimension; j++) { + sline1 = *jiptr; + sline1_half = *jiptr_half; + + HVX_VectorPair sline1_half_pair = + Q6_Wqf32_vmpy_VhfVhf(sline1_half, one_vhf); + HVX_VectorPair sline1_pair = Q6_Wqf32_vmpy_VhfVhf(sline1, one_vhf); + + sline1_half_pair = Q6_W_vshuff_VVR(Q6_V_hi_W(sline1_half_pair), + Q6_V_lo_W(sline1_half_pair), -4); + sline1_pair = + Q6_W_vshuff_VVR(Q6_V_hi_W(sline1_pair), Q6_V_lo_W(sline1_pair), -4); + + HVX_Vector m_sline1_half_low = + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_lo_W(sline1_half_pair), m_one_vqf32); + HVX_Vector m_sline1_half_hi = + Q6_Vqf32_vmpy_Vqf32Vqf32(Q6_V_hi_W(sline1_half_pair), m_one_vqf32); + + // auto value = in_value * cos_value - in_value_2 * sin_value; + HVX_Vector middle_value_low; + { + HVX_Vector cosline1_vqf32_low = + Q6_Vqf32_vadd_VsfVsf(cosline1_low, Q6_V_vzero()); + HVX_Vector cos_middle_value_qf32_low = Q6_Vqf32_vmpy_Vqf32Vqf32( + Q6_V_lo_W(sline1_pair), cosline1_vqf32_low); + + HVX_Vector sinline1_vqf32_low = + Q6_Vqf32_vadd_VsfVsf(sinline1_low, Q6_V_vzero()); + + HVX_Vector sin_middle_value_qf32_low = + Q6_Vqf32_vmpy_Vqf32Vqf32(m_sline1_half_low, sinline1_vqf32_low); + middle_value_low = Q6_Vqf32_vadd_Vqf32Vqf32( + cos_middle_value_qf32_low, sin_middle_value_qf32_low); + } + + // auto value2 = in_value * sin_value + in_value_2 * cos_value; + + HVX_Vector middle_value_half_low; + { + HVX_Vector cosline1_vqf32_low = + Q6_Vqf32_vadd_VsfVsf(cosline1_low, Q6_V_vzero()); + HVX_Vector cos_middle_value_qf32_low = Q6_Vqf32_vmpy_Vqf32Vqf32( + Q6_V_lo_W(sline1_half_pair), cosline1_vqf32_low); + + HVX_Vector sinline1_vqf32_low = + Q6_Vqf32_vadd_VsfVsf(sinline1_low, Q6_V_vzero()); + HVX_Vector sin_middle_value_qf32_low = Q6_Vqf32_vmpy_Vqf32Vqf32( + Q6_V_lo_W(sline1_pair), sinline1_vqf32_low); + + middle_value_half_low = Q6_Vqf32_vadd_Vqf32Vqf32( + cos_middle_value_qf32_low, sin_middle_value_qf32_low); + } + + // second qf16 vector + HVX_Vector middle_value_high; + { + HVX_Vector cosline1_vqf32_high = + Q6_Vqf32_vadd_VsfVsf(cosline1_high, Q6_V_vzero()); + HVX_Vector cos_middle_value_qf32_high = Q6_Vqf32_vmpy_Vqf32Vqf32( + Q6_V_hi_W(sline1_pair), cosline1_vqf32_high); + + HVX_Vector sinline1_vqf32_high = + Q6_Vqf32_vadd_VsfVsf(sinline1_high, Q6_V_vzero()); + + HVX_Vector sin_middle_value_qf32_high = + Q6_Vqf32_vmpy_Vqf32Vqf32(m_sline1_half_hi, sinline1_vqf32_high); + middle_value_high = Q6_Vqf32_vadd_Vqf32Vqf32( + cos_middle_value_qf32_high, sin_middle_value_qf32_high); + } + + // auto value2 = in_value * sin_value + in_value_2 * cos_value; + + HVX_Vector middle_value_half_high; + { + HVX_Vector cosline1_vqf32_high = + Q6_Vqf32_vadd_VsfVsf(cosline1_high, Q6_V_vzero()); + HVX_Vector cos_middle_value_qf32_high = Q6_Vqf32_vmpy_Vqf32Vqf32( + Q6_V_hi_W(sline1_half_pair), cosline1_vqf32_high); + + HVX_Vector sinline1_vqf32_high = + Q6_Vqf32_vadd_VsfVsf(sinline1_high, Q6_V_vzero()); + HVX_Vector sin_middle_value_qf32_high = Q6_Vqf32_vmpy_Vqf32Vqf32( + Q6_V_hi_W(sline1_pair), sinline1_vqf32_high); + + middle_value_half_high = Q6_Vqf32_vadd_Vqf32Vqf32( + cos_middle_value_qf32_high, sin_middle_value_qf32_high); + } + + HVX_Vector sline = Q6_Vhf_equals_Wqf32( + Q6_W_vcombine_VV(middle_value_high, middle_value_low)); + sline = Q6_Vh_vdeal_Vh(sline); + + HVX_Vector sline_half = Q6_Vhf_equals_Wqf32( + Q6_W_vcombine_VV(middle_value_half_high, middle_value_half_low)); + sline_half = Q6_Vh_vdeal_Vh(sline_half); + + *joptr = sline; + *joptr_half = sline_half; + + jiptr += partial_dimension / 64; + jiptr_half += partial_dimension / 64; + joptr += partial_dimension / 64; + joptr_half += partial_dimension / 64; + } + } + } + + // if (vectors_in_rounddown > 0) { + + // sline1c = is_aligned(iptr, VLEN) && leftover == 0 ? sline1p : *iptr++; + // sline1 = Q6_V_valign_VVR(sline1c, sline1p, (size_t) input); + // sum = Q6_Vqf32_vadd_Vqf32Vqf32(sum, Q6_Vqf32_vmpy_VsfVsf(sline1, + // sline1)); + + // } + + if (leftover_size > 0) + return -1; + + return 0; +} + +template +GraphStatus ropeImpl(TensorType &out_0, const TensorType &in_0, + const TensorType &sin, const TensorType &cos, + const TensorType1 &h_cnt, const Tensor &pose_type) { + + out_0.set_dims(in_0); + + auto pose_type_ = pose_type(0, 0, 0, 0); + auto h_cnt_ = static_cast(h_cnt(0, 0, 0, 0)); + + if (pose_type_ == 4) { + + DType dtype = out_0.get_dtype(); + + if (in_0.get_dtype() == DType::Float32 && dtype == DType::Float32) { + auto in_ptr = (float *)in_0.raw_data_const(); + auto sin_ptr = (float *)sin.raw_data_const(); + auto cos_ptr = (float *)cos.raw_data_const(); + auto out_ptr = (float *)out_0.raw_data(); + + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + + uint32_t half_dimension = d_in / 2; + sin_ptr += half_dimension * h_cnt_; + cos_ptr += half_dimension * h_cnt_; + + int partial_dimension = d_in; + + // NSHD + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + + // for (Idx w = 0; w < w_in; w++) { + hvx_rope_af(in_ptr, sin_ptr, cos_ptr, out_ptr, w_in * d_in, + partial_dimension); + + in_ptr += w_in * d_in; + out_ptr += w_in * d_in; + // } + + sin_ptr += half_dimension; + cos_ptr += half_dimension; + } + } + } else if (in_0.get_dtype() == DType::Float16 && dtype == DType::Float16) { + + auto in_ptr = (__fp16 *)in_0.raw_data_const(); + auto sin_ptr = (float *)sin.raw_data_const(); + auto cos_ptr = (float *)cos.raw_data_const(); + auto out_ptr = (__fp16 *)out_0.raw_data(); + + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + + uint32_t half_dimension = d_in / 2; + sin_ptr += half_dimension * h_cnt_; + cos_ptr += half_dimension * h_cnt_; + + int partial_dimension = d_in; + + // NSHD + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + + // for (Idx w = 0; w < w_in; w++) { + hvx_rope_ahf(in_ptr, sin_ptr, cos_ptr, out_ptr, w_in * d_in, + partial_dimension); + + in_ptr += w_in * d_in; + out_ptr += w_in * d_in; + // } + + sin_ptr += half_dimension; + cos_ptr += half_dimension; + } + } + } else if (in_0.get_dtype() == DType::QUInt8 && dtype == DType::Float32) { + auto in_ptr = (uint8_t *)in_0.raw_data_const(); + auto sin_ptr = (float *)sin.raw_data_const(); + auto cos_ptr = (float *)cos.raw_data_const(); + auto out_ptr = (float *)out_0.raw_data(); + + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + + uint32_t half_dimension = d_in / 2; + sin_ptr += half_dimension * h_cnt_; + cos_ptr += half_dimension * h_cnt_; + + int partial_dimension = d_in; + + // NSHD + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + + // for (Idx w = 0; w < w_in; w++) { + hvx_rope_uint8_af(in_ptr, sin_ptr, cos_ptr, out_ptr, w_in * d_in, + partial_dimension); + + in_ptr += w_in * d_in; + out_ptr += w_in * d_in; + // } + + sin_ptr += half_dimension; + cos_ptr += half_dimension; + } + } + } else if (in_0.get_dtype() == DType::QUInt8 && dtype == DType::Float16) { + + auto in_ptr = (uint8_t *)in_0.raw_data_const(); + auto sin_ptr = (float *)sin.raw_data_const(); + auto cos_ptr = (float *)cos.raw_data_const(); + auto out_ptr = (__fp16 *)out_0.raw_data(); + + float scale_ = in_0.get_interface_scale(); + + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + + uint32_t half_dimension = d_in / 2; + sin_ptr += half_dimension * h_cnt_; + cos_ptr += half_dimension * h_cnt_; + + int partial_dimension = d_in; + + // NSHD + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + + // for (Idx w = 0; w < w_in; w++) { + hvx_rope_uint8_ahf(in_ptr, sin_ptr, cos_ptr, out_ptr, w_in * d_in, + partial_dimension, scale_); + + in_ptr += w_in * d_in; + out_ptr += w_in * d_in; + // } + + sin_ptr += half_dimension; + cos_ptr += half_dimension; + } + } + } + + } else { + + // only support pose_type == 2 (LLaMA) now + return GraphStatus::ErrorFatal; + } + + return GraphStatus::Success; +} + +#else + +template +GraphStatus ropeImpl(TensorType &out_0, const TensorType &in_0, + const TensorType &sin, const TensorType &cos, + const TensorType1 &h_cnt, const Tensor &pose_type) + +{ + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + + debuglog("RoPE execute... dims=(%zdx%zdx%zdx%zd)", in_0.dim(0), in_0.dim(1), + in_0.dim(2), in_0.dim(3)); + debuglog("RoPE execute... dims=(%zdx%zdx%zdx%zd)", sin.dim(0), sin.dim(1), + sin.dim(2), sin.dim(3)); + debuglog("RoPE execute... dims=(%zdx%zdx%zdx%zd)", cos.dim(0), cos.dim(1), + cos.dim(2), cos.dim(3)); + + // BSHD => NHWC + + // Todo: We need consider to store the sequence position if we have KV Cache + + auto pose_type_ = pose_type(0, 0, 0, 0); + auto h_cnt_ = static_cast(h_cnt(0, 0, 0, 0)); + + out_0.set_dims(in_0); + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + if (pose_type_ == 4) { + DType dtype = out_0.get_dtype(); + + if (dtype == DType::Float32) { + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + + int s = h; // BSHD order + int partial_dimension = d_in; + int half = (int)(partial_dimension / 2); + for (Idx d = 0; d < partial_dimension / 2; ++d) { + float in_value = in_0(b, h, w, d); + float in_value_2 = in_0(b, h, w, d + half); + float sin_value = sin(0, 0, s + h_cnt_, d); + float cos_value = cos(0, 0, s + h_cnt_, d); + auto value = in_value * cos_value - in_value_2 * sin_value; + auto value2 = in_value * sin_value + in_value_2 * cos_value; + out_0(b, h, w, d) = value; + out_0(b, h, w, d + half) = value2; + } + } + } + } + } else if (dtype == DType::Float16) { + + auto in_ptr = (__fp16 *)in_0.raw_data_const(); + // auto sin_ptr = (__fp16*)sin.raw_data_const(); + // auto cos_ptr = (__fp16*)cos.raw_data_const(); + auto out_ptr = (__fp16 *)out_0.raw_data(); + + for (Idx b = 0; b < b_in; b++) { + for (Idx h = 0; h < h_in; h++) { + for (Idx w = 0; w < w_in; w++) { + + int s = h; // BSHD order + int partial_dimension = d_in; + int half = (int)(partial_dimension / 2); + for (Idx d = 0; d < partial_dimension / 2; ++d) { + __fp16 in_value = *in_ptr; + __fp16 in_value_2 = *(in_ptr + half); + float sin_value = sin(0, 0, s + h_cnt_, d); + float cos_value = cos(0, 0, s + h_cnt_, d); + auto value = in_value * cos_value - in_value_2 * sin_value; + auto value2 = in_value * sin_value + in_value_2 * cos_value; + *out_ptr = static_cast<__fp16>(value); + *(out_ptr + half) = static_cast<__fp16>(value2); + + out_ptr++; + in_ptr++; + } + + out_ptr += half; + in_ptr += half; + } + } + } + } + } + + // for (Idx b = 0; b < b_in; b++) { + // for (Idx h = 0; h < h_in; h++) { + // for (Idx w = 0; w < w_in; w++) { + // // RoPE + // for (Idx d = 0; d < d_in; d++) { + + // int s = h; // BSHD order + // if (pose_type_ == 1) { + // float in_value = in_0(b, h, w, d); + // float in_value_2; + // if (d < d_in / 2) { // 偶數 0,2,4 + // in_value_2 = -in_0(b, h, w, d + d_in / 2); + // } else { + // in_value_2 = in_0(b, h, w, d - d_in / 2); + // } + // float sin_value = sin(0, 0, s +h_cnt_, d); + // float cos_value = cos(0, 0, s +h_cnt_, d); + // auto value = in_value * cos_value + in_value_2 * sin_value; + // out_0(b, h, w, d) = value; + // } + // else if (pose_type_ == 2) { + // float in_value = in_0(b, h, w, d); + // debuglog("rope execute... in_value=(%f)", in_value); + // float in_value_2; + // if (d % 2 == 0) { // 偶數 0,2,4 + // in_value_2 = -in_0(b, h, w, d + 1); + // } else { + // in_value_2 = in_0(b, h, w, d - 1); + // } + // debuglog("rope execute... in_value_2=(%f)", in_value_2); + // float sin_value = sin(0, 0, s +h_cnt_, d); + // float cos_value = cos(0, 0, s +h_cnt_, d); + // auto value = in_value * cos_value + in_value_2 * sin_value; + + // debuglog("rope execute... sin_value=(%f)", sin_value); + // debuglog("rope execute... cos_value=(%f)", cos_value); + + // debuglog("rope execute... value=(%f)", value); + // out_0(b, h, w, d) = value; + // } else if (pose_type_ == 4) { + // } else { + // float in_value = in_0(b, h, w, d); + // float in_value_2; + // float sin_value = sin(0, 0, s +h_cnt_, d); + // float cos_value = cos(0, 0, s +h_cnt_, d); + // if (d < d_in / 4) { + // in_value_2 = -in_0(b, h, w, d + d_in / 4); + // auto value = in_value * cos_value + in_value_2 * sin_value; + + // out_0(b ,h , w, d) = value; + // } else if(d < d_in / 2){ + // in_value_2 = in_0(b, h, w, d - d_in / 4); + // auto value = in_value * cos_value + in_value_2 * sin_value; + + // out_0(b ,h , w, d) = value; + // }else { + + // out_0(b ,h , w, d) = in_value; + // } + // } + + // } + // } + // } + // } + + // auto &input = inputs[0]; + // auto &output = outputs[0]; + // for (int n = 0; n < input->batch(); ++n) { + // for (int h = 0; h < input->head(); ++h) { + // for (int s = 0; s < input->sequence(); ++s) {//sequance + // #pragma omp parallel for num_threads(4) + // for (int d = 0; d < input->dimension(); ++d) { + // if (pose_type_== 1) { + // float in_value = input->dataAt(n, h, s, d); + // float in_value_2; + // if (d < input->dimension() / 2) { // 偶數 0,2,4 + // in_value_2 = -input->dataAt(n, h, s, d + + // input->dimension() / 2); + // } else { + // in_value_2 = input->dataAt(n, h, s, d - + // input->dimension() / 2); + // } + // float sin_value = sin_.dataAt(0, 0, s +h_cnt_, + // d); float cos_value = cos_.dataAt(0, 0, s + // +h_cnt_, d); auto value = in_value * cos_value + + // in_value_2 * sin_value; if(output->dtypeAt(n,h,s, d) + // == MLLM_TYPE_F32) { + // output->setDataAt(n, h, s, d, value); + // } + // else if(output->dtypeAt(n,h,s, d) == MLLM_TYPE_F16) { + // output->setDataAt(n, h, s, d, + // MLLM_FP32_TO_FP16(value)); + // } + // } + // else if (pose_type_== 2) { + // float in_value = input->dataAt(n, h, s, d); + // float in_value_2; + // if (d % 2 == 0) { // 偶數 0,2,4 + // in_value_2 = -input->dataAt(n, h, s, d + + // 1); + // } else { + // in_value_2 = input->dataAt(n, h, s, d - + // 1); + // } + // float sin_value = sin_.dataAt(0, 0, s +h_cnt_, + // d); float cos_value = cos_.dataAt(0, 0, s + // +h_cnt_, d); auto value = in_value * cos_value + + // in_value_2 * sin_value; if(output->dtypeAt(n,h,s, d) + // == MLLM_TYPE_F32) { + // output->setDataAt(n, h, s, d, value); + // } + // else if(output->dtypeAt(n,h,s, d) == MLLM_TYPE_F16) { + // output->setDataAt(n, h, s, d, + // MLLM_FP32_TO_FP16(value)); + // } + // }else{ + // float in_value = input->dataAt(n, h, s, d); + // float in_value_2; + // float sin_value = sin_.dataAt(0, 0, s +h_cnt_, + // d); float cos_value = cos_.dataAt(0, 0, s + // +h_cnt_, d); if (d < input->dimension() / 4) { + // in_value_2 = - input->dataAt(n, h, s, d + + // input->dimension() / 4); auto value = in_value * + // cos_value + in_value_2 * sin_value; + // if(output->dtypeAt(n,h,s, d) == MLLM_TYPE_F32) { + // output->setDataAt(n, h, s, d, value); + // } + // else if(output->dtypeAt(n,h,s, d) == + // MLLM_TYPE_F16) { + // output->setDataAt(n, h, s, d, + // MLLM_FP32_TO_FP16(value)); + // } + // } else if(d < input->dimension() / 2){ + // in_value_2 = input->dataAt(n, h, s, d - + // input->dimension() / 4); auto value = in_value * + // cos_value + in_value_2 * sin_value; + // if(output->dtypeAt(n,h,s, d) == MLLM_TYPE_F32) { + // output->setDataAt(n, h, s, d, value); + // } + // else if(output->dtypeAt(n,h,s, d) == + // MLLM_TYPE_F16) { + // output->setDataAt(n, h, s, d, + // MLLM_FP32_TO_FP16(value)); + // } + // }else { + // if(output->dtypeAt(n,h,s, d) == MLLM_TYPE_F32) { + // output->setDataAt(n, h, s, d, + // in_value); + // } + // else if(output->dtypeAt(n,h,s, d) == + // MLLM_TYPE_F16) { + // output->setDataAt(n, h, s, d, + // MLLM_FP32_TO_FP16(in_value)); + // } + // } + // } + // } + // } + // } + // } + + // Todo store history position + // h_cnt_ += input->sequence(); + // if(h_cnt_ >pos_max_){ + // h_cnt_ = 0; + // } + + return GraphStatus::Success; +} + +#endif + +__attribute__((unused)) static float ropeCostFunc(const Op *op) { + /* + * add code here + * */ + + float cost = 0.0; // add cost computation here + return cost; +} + +/* At the bottom of the op file, call END_PKG_OP_DEFINITION(), + where is as BEGIN_PKG_OP_DEFINITION +*/ +END_PKG_OP_DEFINITION(PKG_RoPE); diff --git a/nntrainer/npu/qnn/LLaMAPackage/src/ops/WNop.cpp b/nntrainer/npu/qnn/LLaMAPackage/src/ops/WNop.cpp new file mode 100644 index 000000000..0b27016e9 --- /dev/null +++ b/nntrainer/npu/qnn/LLaMAPackage/src/ops/WNop.cpp @@ -0,0 +1,170 @@ +//============================================================================== +// Auto Generated Code for LLaMAPackage +//============================================================================== + +#include "HTP/core/constraints.h" +#include "HTP/core/op_package_feature_support.h" +#include "HTP/core/op_register_ext.h" +#include "HTP/core/optimize.h" +#include "HTP/core/simple_reg.h" +#include "QnnOpPackage.h" + +#include +#include + +BEGIN_PKG_OP_DEFINITION(PKG_WNop); + +// op execute function declarations +template +GraphStatus wnopImpl(TensorType &out_0, TensorType1 &sync_var, + const TensorType &in_0, const TensorType &in_1, + const Tensor &sync_type); + +// forward declaration of sample cost function +static float wnopCostFunc(const Op *op); + +/* + * method 1 for defining op, using default cost value (i.e. GLACIAL) and default + * flag (Flags::RESOURCE_HVX) syntax: DEF_PACKAGE_OP(F,OP) e.g. + * DEF_PACKAGE_OP((wnopImpl), "WNop") + */ +DEF_PACKAGE_OP((wnopImpl), "WNop") + +/* + * method 2 for defining op with specified cost value (one of GLACIAL, SNAIL, + * FAST, FREE) and provided flags syntax: + * DEF_PACKAGE_OP_AND_COST_AND_FLAGS(F,OP,COST,...) can use zero or more flags, + * FLAG options are IS_CONST, INHIBIT_CONST_PROP, RESOURCE_HVX, RESOURCE_HMX(not + * supported in external op packages) e.g. + * DEF_PACKAGE_OP_AND_COST_AND_FLAGS((wnopImpl), "WNop", SNAIL) + */ + +/* + * method 3 for defining op with cost function pointer and provided flags + * cost function pointer type: typedef float (*cost_function) (const Op * op); + * syntax: DEF_PACKAGE_OP_AND_COST_F_AND_FLAGS(F,OP,COST_F,...) + * e.g. DEF_PACKAGE_OP_AND_COST_F_AND_FLAGS((wnopImpl), "WNop", wnopCostFunc, Flags::RESOURCE_HVX) + */ + +/* + * optimization definitions + * need to be global in the package + * one definition per optimization + * syntax: + * DEF_PACKAGE_OPTIMIZATION(PRIORITY,MATCHCODE,CONSTRAINTCODE,REPLACECODE) + * PRIORITY predefined values include EARLY(2000), MIDDLE(3000), LATE(4000) + * HTP core provides some replacement functions for op package to use + * for more information about optimization rules, please refer to HTP core + * documentations + */ + +/* + * op parameter order definitions + * need to be global in the package + * one definition per op, and this is optional + * syntax: + * DEF_PACKAGE_PARAM_ORDER(OP,PARAM1,MANDATORY1,DEFAULT1,PARAM2,MANDATORY2,DEFAULT2...) + * one or more parameters can be specified for each op + * order of parameters listed determines the order of parameters passed into op + * execution functions if an op does not have a parameter order definition, + * parameter order passed into Qnn_addNode will be passed into op execution + * functions if an op has a parameter order definition, any parameter passed + * into Qnn_addNode with unlisted name will be abandoned if two or more op + * packages with the same package name will be registered, they cannot list + * conflicting parameter orders + * PARAM refers to parameter name as a string literal + * MANDATORY refers to whether this parameter is required to be provided at + * Qnn_addNode DEFAULT is used when MANDATORY is false if provided as + * Qnn_Param_t*, DEFAULT will be used for graph construction when this parameter + * is not provided at Qnn_addNode if provided as nullptr, graph construction + * will skip this parameter when this parameter is not provided at Qnn_addNode + */ +DEF_PACKAGE_PARAM_ORDER("WNop", "sync_type", true, nullptr) + +/* execute functions for ops */ + +template +GraphStatus wnopImpl(TensorType &out_0, TensorType1 &sync_var, + const TensorType &in_0, const TensorType &in_1, + const Tensor &sync_type) + +{ + /* + * add code here + * */ + /* + * To have good performance and stability, it is required to avoid heap memory + * allocation in this function. The heap memory allocation includes but not + * limited to calling malloc, operator new, constructing STL container objects + * like std::vector with default allocator, and adding items like calling + * std::vector::push_back to STL container objects with default allocator. + * + * Please check in SDK documentation for more information. + */ + + out_0.set_dims(in_0); + + auto sync_type_ = sync_type(0, 0, 0, 0); + + // sync_type == 0 sending signal to CPU + // sync_type == 1 waiting signal from CPU + + DType dtype = in_0.get_dtype(); + uint32_t bitwidth = 4; + + if (dtype == DType::QUInt8) { + + bitwidth = 1; + + } else if (dtype == DType::Float16) { + + bitwidth = 2; + } else if (dtype == DType::Float32) { + + bitwidth = 4; + } + + if (sync_type_ == 0) { + + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + + auto in_ptr = (void *)in_0.raw_data_const(); + auto out_ptr = (void *)out_0.raw_data(); + + memcpy(out_ptr, in_ptr, b_in * h_in * w_in * d_in * bitwidth); + + sync_var(0, 0, 0, 0) = 1; + + } else if (sync_type_ == 1) { + + while (in_1(0, 0, 0, 0) == 0) { + + Q6_V_vzero(); + } + + auto [b_in, h_in, w_in, d_in] = in_0.dims(); + + auto in_ptr = (void *)in_0.raw_data_const(); + auto out_ptr = (void *)out_0.raw_data(); + + memcpy(out_ptr, in_ptr, b_in * h_in * w_in * d_in * bitwidth); + } + + return GraphStatus::Success; +} + +__attribute__((unused)) static float wnopCostFunc(const Op *op) { + /* + * add code here + * */ + + float cost = 0.0; // add cost computation here + return cost; +} + +/* At the bottom of the op file, call END_PKG_OP_DEFINITION(), + where is as BEGIN_PKG_OP_DEFINITION +*/ +END_PKG_OP_DEFINITION(PKG_WNop);