Skip to content

Commit

Permalink
[Layer] add tanh-based approximate gelu activation function
Browse files Browse the repository at this point in the history
- add tanh-based approximate gelu(tanh gelu) for vision transformer.
- rename quick gelu to sigmoid gelu(it's a sigmoid-based approximate gelu)

**Self evaluation:**
1. Build test:	 [X]Passed [ ]Failed [ ]Skipped
2. Run test:	 [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Seungbaek Hong <[email protected]>
  • Loading branch information
baek2sm committed Jul 1, 2024
1 parent 9e917a3 commit c957d13
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 25 deletions.
47 changes: 40 additions & 7 deletions nntrainer/layers/acti_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,13 @@ class ActiFunc {
in_place = false;
this->setActivation<Tensor>(gelu<T>, geluPrime<T>);
break;
case ActivationType::ACT_QUICK_GELU:
case ActivationType::ACT_TANH_GELU:
in_place = false;
this->setActivation<Tensor>(quickGelu<T>, quickGeluPrime<T>);
this->setActivation<Tensor>(tanhGelu<T>, tanhGeluPrime<T>);
break;
case ActivationType::ACT_SIGMOID_GELU:
in_place = false;
this->setActivation<Tensor>(sigmoidGelu<T>, sigmoidGeluPrime<T>);
break;
case ActivationType::ACT_ELU:
this->setActivation<T>(elu<T>, eluPrime<T>);
Expand Down Expand Up @@ -462,30 +466,59 @@ class ActiFunc {
}

/**
* @brief quick gelu activation function (gelu approximation)
* @brief tanh-based gelu approximate function
* @param[in] t_in input tensor
* @param[in] t_out output tensor
*/
template <typename T = float>
static Tensor &tanhGelu(Tensor const &t_in, Tensor &t_out) {
t_in.apply<T>(
[&](T x) { return static_cast<T>(
0.5 * x * (1 + tanhFloat<T>(static_cast<T>(sqrt(2/M_PI) * (x + 0.044715 * pow(x, 3)))))); }, t_out);
return t_out;
}

/**
* @brief derivative of tanh-based gelu approximate function
* @param[in] t_in input tensor
* @param[in] t_out output tensor
* @param[in] outgoing_derivative outgoing derivative
* @param[in] incoming_derivative incoming derivative
*/
template <typename T = float>
static Tensor &tanhGeluPrime(Tensor const &t_in, Tensor const &t_out,
Tensor &outgoing_derivative,
Tensor const &incoming_derivative = Tensor()) {
// NYI
ml_logw("tanhGeluPrime which is calculate derivate of tanhGelu function is not yet implemented");
return outgoing_derivative;
}

/**
* @brief sigmoid-based gelu approximate function (quick gelu)
* @param[in] t_in input tensor
* @param[in] t_out output tensor
*/
template <typename T = float>
static Tensor &quickGelu(Tensor const &t_in, Tensor &t_out) {
static Tensor &sigmoidGelu(Tensor const &t_in, Tensor &t_out) {
t_in.apply<T>(
[&](T x) { return static_cast<T>(x * (sigmoid<T>(static_cast<T>(1.702 * x)))); }, t_out);
return t_out;
}

/**
* @brief derivative quick gelu function
* @brief derivative of sigmoid-based gelu approximate function
* @param[in] t_in input tensor
* @param[in] t_out output tensor
* @param[in] outgoing_derivative outgoing derivative
* @param[in] incoming_derivative incoming derivative
*/
template <typename T = float>
static Tensor &quickGeluPrime(Tensor const &t_in, Tensor const &t_out,
static Tensor &sigmoidGeluPrime(Tensor const &t_in, Tensor const &t_out,
Tensor &outgoing_derivative,
Tensor const &incoming_derivative = Tensor()) {
// NYI
ml_logw("quickGeluPrime which is calculate derivate of quickGelu function is not yet implemented");
ml_logw("sigmoidGeluPrime which is calculate derivate of sigmoidGelu function is not yet implemented");
return outgoing_derivative;
}

Expand Down
37 changes: 19 additions & 18 deletions nntrainer/layers/common_properties.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,21 @@ namespace nntrainer {
* accordingly
*/
enum class ActivationType {
ACT_TANH, /**< tanh */
ACT_SIGMOID, /**< sigmoid */
ACT_RELU, /**< ReLU */
ACT_SWISH, /**< Swish */
ACT_GELU, /**< GELU */
ACT_QUICK_GELU, /**< Quick GELU */
ACT_SOFTMAX, /**< softmax */
ACT_SOFTPLUS, /**< softplus */
ACT_LEAKY_RELU, /**< Leaky ReLU */
ACT_ELU, /**< ELU */
ACT_SELU, /**< SELU */
ACT_MISH, /**< Mish */
ACT_NONE, /**< no op */
ACT_UNKNOWN /**< unknown */
ACT_TANH, /**< tanh */
ACT_SIGMOID, /**< sigmoid */
ACT_RELU, /**< ReLU */
ACT_SWISH, /**< Swish */
ACT_GELU, /**< GELU */
ACT_TANH_GELU, /**< tanh GELU */
ACT_SIGMOID_GELU, /**< sigmoid GELU */
ACT_SOFTMAX, /**< softmax */
ACT_SOFTPLUS, /**< softplus */
ACT_LEAKY_RELU, /**< Leaky ReLU */
ACT_ELU, /**< ELU */
ACT_SELU, /**< SELU */
ACT_MISH, /**< Mish */
ACT_NONE, /**< no op */
ACT_UNKNOWN /**< unknown */
};

namespace props {
Expand Down Expand Up @@ -866,13 +867,13 @@ struct ActivationTypeInfo {
static constexpr std::initializer_list<Enum> EnumList = {
Enum::ACT_TANH, Enum::ACT_SIGMOID, Enum::ACT_RELU,
Enum::ACT_SOFTMAX, Enum::ACT_LEAKY_RELU, Enum::ACT_SWISH,
Enum::ACT_GELU, Enum::ACT_QUICK_GELU, Enum::ACT_NONE,
Enum::ACT_UNKNOWN};
Enum::ACT_GELU, Enum::ACT_TANH_GELU, Enum::ACT_SIGMOID_GELU,
Enum::ACT_NONE, Enum::ACT_UNKNOWN};

static constexpr const char *EnumStr[] = {"tanh", "sigmoid", "relu",
"softmax", "leaky_relu", "swish",
"gelu", "quick_gelu", "none",
"unknown"};
"gelu", "tanh_gelu", "sigmoid_gelu",
"none", "unknown"};
};

/**
Expand Down

0 comments on commit c957d13

Please sign in to comment.