From f599bf89546a1bf75b0302fcb94f8442a5c7a9bb Mon Sep 17 00:00:00 2001 From: Mddct Date: Wed, 7 Aug 2024 12:07:12 +0800 Subject: [PATCH] fix norm --- wenet/transformer/attention.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/wenet/transformer/attention.py b/wenet/transformer/attention.py index 73f1cb7a7..9fa855066 100644 --- a/wenet/transformer/attention.py +++ b/wenet/transformer/attention.py @@ -668,6 +668,9 @@ def forward( q = WENET_APPLY_ROTARY_EMB[self.style](q, pos_emb) k = WENET_APPLY_ROTARY_EMB[self.style](k, pos_emb) + if self.qk_norm: + q = self.q_norm(q) + k = self.k_norm(k) k, v, new_cache = self._update_kv_and_cache(k, v, cache,