Skip to content

Commit

Permalink
[Layer] add tanh-based approximate gelu activation function
Browse files Browse the repository at this point in the history
- add tanh-based approximate gelu(tanh gelu) for vision transformer.
- rename quick gelu to sigmoid gelu(it's a sigmoid-based approximate gelu)

**Self evaluation:**
1. Build test:	 [X]Passed [ ]Failed [ ]Skipped
2. Run test:	 [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Seungbaek Hong <[email protected]>
  • Loading branch information
baek2sm authored and myungjoo committed Jul 11, 2024
1 parent 147dbe1 commit 64bd12e
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 29 deletions.
64 changes: 54 additions & 10 deletions nntrainer/layers/acti_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,13 @@ class ActiFunc {
in_place = false;
this->setActivation<Tensor>(gelu<T>, geluPrime<T>);
break;
case ActivationType::ACT_QUICK_GELU:
case ActivationType::ACT_TANH_GELU:
in_place = false;
this->setActivation<Tensor>(quickGelu<T>, quickGeluPrime<T>);
this->setActivation<Tensor>(tanhGelu<T>, tanhGeluPrime<T>);
break;
case ActivationType::ACT_SIGMOID_GELU:
in_place = false;
this->setActivation<Tensor>(sigmoidGelu<T>, sigmoidGeluPrime<T>);
break;
case ActivationType::ACT_ELU:
this->setActivation<T>(elu<T>, eluPrime<T>);
Expand Down Expand Up @@ -462,30 +466,70 @@ class ActiFunc {
}

/**
* @brief quick gelu activation function (gelu approximation)
* @brief tanh-based gelu approximate function
* @param[in] t_in input tensor
* @param[in] t_out output tensor
*/
template <typename T = float>
static Tensor &quickGelu(Tensor const &t_in, Tensor &t_out) {
static Tensor &tanhGelu(Tensor const &t_in, Tensor &t_out) {
t_in.apply<T>(
[&](T x) { return static_cast<T>(x * (sigmoid<T>(static_cast<T>(1.702 * x)))); }, t_out);
[&](T x) {
return static_cast<T>(
0.5 * x *
(1 + tanhFloat<T>(
static_cast<T>(sqrt(2 / M_PI) * (x + 0.044715 * pow(x, 3))))));
},
t_out);
return t_out;
}

/**
* @brief derivative quick gelu function
* @brief derivative of tanh-based gelu approximate function
* @param[in] t_in input tensor
* @param[in] t_out output tensor
* @param[in] outgoing_derivative outgoing derivative
* @param[in] incoming_derivative incoming derivative
*/
template <typename T = float>
static Tensor &quickGeluPrime(Tensor const &t_in, Tensor const &t_out,
Tensor &outgoing_derivative,
Tensor const &incoming_derivative = Tensor()) {
static Tensor &tanhGeluPrime(Tensor const &t_in, Tensor const &t_out,
Tensor &outgoing_derivative,
Tensor const &incoming_derivative = Tensor()) {
// NYI
ml_logw("tanhGeluPrime which is calculate derivate of tanhGelu function is "
"not yet implemented");
return outgoing_derivative;
}

/**
* @brief sigmoid-based gelu approximate function (quick gelu)
* @param[in] t_in input tensor
* @param[in] t_out output tensor
*/
template <typename T = float>
static Tensor &sigmoidGelu(Tensor const &t_in, Tensor &t_out) {
t_in.apply<T>(
[&](T x) {
return static_cast<T>(x * (sigmoid<T>(static_cast<T>(1.702 * x))));
},
t_out);
return t_out;
}

/**
* @brief derivative of sigmoid-based gelu approximate function
* @param[in] t_in input tensor
* @param[in] t_out output tensor
* @param[in] outgoing_derivative outgoing derivative
* @param[in] incoming_derivative incoming derivative
*/
template <typename T = float>
static Tensor &
sigmoidGeluPrime(Tensor const &t_in, Tensor const &t_out,
Tensor &outgoing_derivative,
Tensor const &incoming_derivative = Tensor()) {
// NYI
ml_logw("quickGeluPrime which is calculate derivate of quickGelu function is not yet implemented");
ml_logw("sigmoidGeluPrime which is calculate derivate of sigmoidGelu "
"function is not yet implemented");
return outgoing_derivative;
}

Expand Down
39 changes: 20 additions & 19 deletions nntrainer/layers/common_properties.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,21 @@ namespace nntrainer {
* accordingly
*/
enum class ActivationType {
ACT_TANH, /**< tanh */
ACT_SIGMOID, /**< sigmoid */
ACT_RELU, /**< ReLU */
ACT_SWISH, /**< Swish */
ACT_GELU, /**< GELU */
ACT_QUICK_GELU, /**< Quick GELU */
ACT_SOFTMAX, /**< softmax */
ACT_SOFTPLUS, /**< softplus */
ACT_LEAKY_RELU, /**< Leaky ReLU */
ACT_ELU, /**< ELU */
ACT_SELU, /**< SELU */
ACT_MISH, /**< Mish */
ACT_NONE, /**< no op */
ACT_UNKNOWN /**< unknown */
ACT_TANH, /**< tanh */
ACT_SIGMOID, /**< sigmoid */
ACT_RELU, /**< ReLU */
ACT_SWISH, /**< Swish */
ACT_GELU, /**< GELU */
ACT_TANH_GELU, /**< tanh GELU */
ACT_SIGMOID_GELU, /**< sigmoid GELU */
ACT_SOFTMAX, /**< softmax */
ACT_SOFTPLUS, /**< softplus */
ACT_LEAKY_RELU, /**< Leaky ReLU */
ACT_ELU, /**< ELU */
ACT_SELU, /**< SELU */
ACT_MISH, /**< Mish */
ACT_NONE, /**< no op */
ACT_UNKNOWN /**< unknown */
};

namespace props {
Expand Down Expand Up @@ -910,12 +911,12 @@ struct ActivationTypeInfo {
static constexpr std::initializer_list<Enum> EnumList = {
Enum::ACT_TANH, Enum::ACT_SIGMOID, Enum::ACT_RELU,
Enum::ACT_SOFTMAX, Enum::ACT_LEAKY_RELU, Enum::ACT_SWISH,
Enum::ACT_GELU, Enum::ACT_QUICK_GELU, Enum::ACT_NONE,
Enum::ACT_UNKNOWN};
Enum::ACT_GELU, Enum::ACT_TANH_GELU, Enum::ACT_SIGMOID_GELU,
Enum::ACT_NONE, Enum::ACT_UNKNOWN};

static constexpr const char *EnumStr[] = {
"tanh", "sigmoid", "relu", "softmax", "leaky_relu",
"swish", "gelu", "quick_gelu", "none", "unknown"};
"tanh", "sigmoid", "relu", "softmax", "leaky_relu", "swish",
"gelu", "tanh_gelu", "sigmoid_gelu", "none", "unknown"};
};

/**
Expand Down Expand Up @@ -1122,7 +1123,7 @@ struct UpsampleModeInfo {
enum class Interpolation { nearest, bilinear };

using Enum = Interpolation;

static constexpr std::initializer_list<Interpolation> EnumList = {
Interpolation::nearest, Interpolation::bilinear};

Expand Down

0 comments on commit 64bd12e

Please sign in to comment.