diff --git a/wenet/utils/class_utils.py b/wenet/utils/class_utils.py index 3123dc4f7..314e1b826 100644 --- a/wenet/utils/class_utils.py +++ b/wenet/utils/class_utils.py @@ -3,8 +3,8 @@ # Copyright [2023-11-28] import torch from wenet.paraformer.embedding import ParaformerPositinoalEncoding -from wenet.transformer.positionwise_feed_forward import (GatedVariantsMLP, - MoEFFNLayer) +from wenet.transformer.positionwise_feed_forward import ( + GatedVariantsMLP, MoEFFNLayer, PositionwiseFeedForward) from wenet.transformer.swish import Swish from wenet.transformer.subsampling import ( @@ -73,7 +73,7 @@ } WENET_MLP_CLASSES = { - 'position_wise_feed_forward': PositionalEncoding, + 'position_wise_feed_forward': PositionwiseFeedForward, 'moe': MoEFFNLayer, 'gated': GatedVariantsMLP }