-
Notifications
You must be signed in to change notification settings - Fork 299
/
zipformer.py
1909 lines (1634 loc) · 73.2 KB
/
zipformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
#!/usr/bin/env python3
# Copyright 2022 Xiaomi Corp. (authors: Daniel Povey)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import itertools
import logging
import math
import random
import warnings
from typing import List, Optional, Tuple, Union
import torch
from encoder_interface import EncoderInterface
from scaling import (
ScaledLinear, # not as in other dirs.. just scales down initial parameter values.
)
from scaling import (
ActivationBalancer,
BasicNorm,
DoubleSwish,
Identity,
MaxEig,
ScaledConv1d,
Whiten,
_diag,
penalize_abs_values_gt,
random_clamp,
softmax,
)
from torch import Tensor, nn
from icefall.dist import get_rank
from icefall.utils import is_jit_tracing, make_pad_mask
class Zipformer(EncoderInterface):
"""
Args:
num_features (int): Number of input features
d_model: (int,int): embedding dimension of 2 encoder stacks
attention_dim: (int,int): attention dimension of 2 encoder stacks
nhead (int, int): number of heads
dim_feedforward (int, int): feedforward dimension in 2 encoder stacks
num_encoder_layers (int): number of encoder layers
dropout (float): dropout rate
cnn_module_kernel (int): Kernel size of convolution module
vgg_frontend (bool): whether to use vgg frontend.
warmup_batches (float): number of batches to warm up over
"""
def __init__(
self,
num_features: int,
output_downsampling_factor: int = 2,
encoder_dims: Tuple[int] = (384, 384),
attention_dim: Tuple[int] = (256, 256),
encoder_unmasked_dims: Tuple[int] = (256, 256),
zipformer_downsampling_factors: Tuple[int] = (2, 4),
nhead: Tuple[int] = (8, 8),
feedforward_dim: Tuple[int] = (1536, 2048),
num_encoder_layers: Tuple[int] = (12, 12),
dropout: float = 0.1,
cnn_module_kernels: Tuple[int] = (31, 31),
pos_dim: int = 4,
warmup_batches: float = 4000.0,
) -> None:
super(Zipformer, self).__init__()
self.num_features = num_features
assert 0 < encoder_dims[0] <= encoder_dims[1]
self.encoder_dims = encoder_dims
self.encoder_unmasked_dims = encoder_unmasked_dims
self.zipformer_downsampling_factors = zipformer_downsampling_factors
self.output_downsampling_factor = output_downsampling_factor
# will be written to, see set_batch_count()
self.batch_count = 0
self.warmup_end = warmup_batches
for u, d in zip(encoder_unmasked_dims, encoder_dims):
assert u <= d, (u, d)
# self.encoder_embed converts the input of shape (N, T, num_features)
# to the shape (N, (T - 7)//2, encoder_dims).
# That is, it does two things simultaneously:
# (1) subsampling: T -> (T - 7)//2
# (2) embedding: num_features -> encoder_dims
self.encoder_embed = Conv2dSubsampling(
num_features, encoder_dims[0], dropout=dropout
)
# each one will be ZipformerEncoder or DownsampledZipformerEncoder
encoders = []
num_encoders = len(encoder_dims)
for i in range(num_encoders):
encoder_layer = ZipformerEncoderLayer(
encoder_dims[i],
attention_dim[i],
nhead[i],
feedforward_dim[i],
dropout,
cnn_module_kernels[i],
pos_dim,
)
# For the segment of the warmup period, we let the Conv2dSubsampling
# layer learn something. Then we start to warm up the other encoders.
encoder = ZipformerEncoder(
encoder_layer,
num_encoder_layers[i],
dropout,
warmup_begin=warmup_batches * (i + 1) / (num_encoders + 1),
warmup_end=warmup_batches * (i + 2) / (num_encoders + 1),
)
if zipformer_downsampling_factors[i] != 1:
encoder = DownsampledZipformerEncoder(
encoder,
input_dim=encoder_dims[i - 1] if i > 0 else encoder_dims[0],
output_dim=encoder_dims[i],
downsample=zipformer_downsampling_factors[i],
)
encoders.append(encoder)
self.encoders = nn.ModuleList(encoders)
# initializes self.skip_layers and self.skip_modules
self._init_skip_modules()
self.downsample_output = AttentionDownsample(
encoder_dims[-1], encoder_dims[-1], downsample=output_downsampling_factor
)
def _get_layer_skip_dropout_prob(self):
if not self.training:
return 0.0
batch_count = self.batch_count
min_dropout_prob = 0.025
if batch_count > self.warmup_end:
return min_dropout_prob
else:
return 0.5 - (batch_count / self.warmup_end) * (0.5 - min_dropout_prob)
def _init_skip_modules(self):
"""
If self.zipformer_downampling_factors = (1, 2, 4, 8, 4, 2), then at the input of layer
indexed 4 (in zero indexing), with has subsapling_factor=4, we combine the output of
layers 2 and 3; and at the input of layer indexed 5, which which has subsampling_factor=2,
we combine the outputs of layers 1 and 5.
"""
skip_layers = []
skip_modules = []
z = self.zipformer_downsampling_factors
for i in range(len(z)):
if i <= 1 or z[i - 1] <= z[i]:
skip_layers.append(None)
skip_modules.append(SimpleCombinerIdentity())
else:
# TEMP
for j in range(i - 2, -1, -1):
if z[j] <= z[i] or j == 0:
# TEMP logging statement.
logging.info(
f"At encoder stack {i}, which has downsampling_factor={z[i]}, we will "
f"combine the outputs of layers {j} and {i-1}, with downsampling_factors={z[j]} and {z[i-1]}."
)
skip_layers.append(j)
skip_modules.append(
SimpleCombiner(
self.encoder_dims[j],
self.encoder_dims[i - 1],
min_weight=(0.0, 0.25),
)
)
break
self.skip_layers = skip_layers
self.skip_modules = nn.ModuleList(skip_modules)
def get_feature_masks(self, x: torch.Tensor) -> List[float]:
# Note: The actual return type is Union[List[float], List[Tensor]],
# but to make torch.jit.script() work, we use List[float]
"""
In eval mode, returns [1.0] * num_encoders; in training mode, returns a number of
randomized feature masks, one per encoder.
On e.g. 15% of frames, these masks will zero out all encoder dims larger than
some supplied number, e.g. >256, so in effect on those frames we are using
a smaller encoder dim.
We generate the random masks at this level because we want the 2 masks to 'agree'
all the way up the encoder stack. This will mean that the 1st mask will have
mask values repeated self.zipformer_downsampling_factors times.
Args:
x: the embeddings (needed for the shape and dtype and device), of shape
(num_frames, batch_size, encoder_dims0)
"""
num_encoders = len(self.encoder_dims)
if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing():
return [1.0] * num_encoders
(num_frames0, batch_size, _encoder_dims0) = x.shape
assert self.encoder_dims[0] == _encoder_dims0, (
self.encoder_dims,
_encoder_dims0,
)
max_downsampling_factor = max(self.zipformer_downsampling_factors)
num_frames_max = num_frames0 + max_downsampling_factor - 1
feature_mask_dropout_prob = 0.15
# frame_mask_max shape: (num_frames_max, batch_size, 1)
frame_mask_max = (
torch.rand(num_frames_max, batch_size, 1, device=x.device)
> feature_mask_dropout_prob
).to(x.dtype)
feature_masks = []
for i in range(num_encoders):
ds = self.zipformer_downsampling_factors[i]
upsample_factor = max_downsampling_factor // ds
frame_mask = (
frame_mask_max.unsqueeze(1)
.expand(num_frames_max, upsample_factor, batch_size, 1)
.reshape(num_frames_max * upsample_factor, batch_size, 1)
)
num_frames = (num_frames0 + ds - 1) // ds
frame_mask = frame_mask[:num_frames]
feature_mask = torch.ones(
num_frames,
batch_size,
self.encoder_dims[i],
dtype=x.dtype,
device=x.device,
)
u = self.encoder_unmasked_dims[i]
feature_mask[:, :, u:] *= frame_mask
feature_masks.append(feature_mask)
return feature_masks
def forward(
self,
x: torch.Tensor,
x_lens: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
x:
The input tensor. Its shape is (batch_size, seq_len, feature_dim).
x_lens:
A tensor of shape (batch_size,) containing the number of frames in
`x` before padding.
Returns:
Return a tuple containing 2 tensors:
- embeddings: its shape is (batch_size, output_seq_len, encoder_dims[-1])
- lengths, a tensor of shape (batch_size,) containing the number
of frames in `embeddings` before padding.
"""
x = self.encoder_embed(x)
x = x.permute(1, 0, 2) # (N, T, C) -> (T, N, C)
lengths = (x_lens - 7) >> 1
assert x.size(0) == lengths.max().item(), (x.shape, lengths, lengths.max())
mask = make_pad_mask(lengths, x.size(0))
outputs = []
feature_masks = self.get_feature_masks(x)
for i, (module, skip_module) in enumerate(
zip(self.encoders, self.skip_modules)
):
ds = self.zipformer_downsampling_factors[i]
k = self.skip_layers[i]
if isinstance(k, int):
layer_skip_dropout_prob = self._get_layer_skip_dropout_prob()
if torch.jit.is_scripting() or torch.jit.is_tracing():
x = skip_module(outputs[k], x)
elif (not self.training) or random.random() > layer_skip_dropout_prob:
x = skip_module(outputs[k], x)
x = module(
x,
feature_mask=feature_masks[i],
src_key_padding_mask=None if mask is None else mask[..., ::ds],
)
outputs.append(x)
x = self.downsample_output(x)
# class Downsample has this rounding behavior..
assert self.output_downsampling_factor == 2, self.output_downsampling_factor
lengths = (lengths + 1) >> 1
x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x, lengths
class ZipformerEncoderLayer(nn.Module):
"""
ZipformerEncoderLayer is made up of self-attn, feedforward and convolution networks.
Args:
d_model: the number of expected features in the input (required).
nhead: the number of heads in the multiheadattention models (required).
feedforward_dim: the dimension of the feedforward network model (default=2048).
dropout: the dropout value (default=0.1).
cnn_module_kernel (int): Kernel size of convolution module.
Examples::
>>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8)
>>> src = torch.rand(10, 32, 512)
>>> pos_emb = torch.rand(32, 19, 512)
>>> out = encoder_layer(src, pos_emb)
"""
def __init__(
self,
d_model: int,
attention_dim: int,
nhead: int,
feedforward_dim: int = 2048,
dropout: float = 0.1,
cnn_module_kernel: int = 31,
pos_dim: int = 4,
) -> None:
super(ZipformerEncoderLayer, self).__init__()
self.d_model = d_model
# will be written to, see set_batch_count()
self.batch_count = 0
self.self_attn = RelPositionMultiheadAttention(
d_model,
attention_dim,
nhead,
pos_dim,
dropout=0.0,
)
self.pooling = PoolingModule(d_model)
self.feed_forward1 = FeedforwardModule(d_model, feedforward_dim, dropout)
self.feed_forward2 = FeedforwardModule(d_model, feedforward_dim, dropout)
self.feed_forward3 = FeedforwardModule(d_model, feedforward_dim, dropout)
self.conv_module1 = ConvolutionModule(d_model, cnn_module_kernel)
self.conv_module2 = ConvolutionModule(d_model, cnn_module_kernel)
self.norm_final = BasicNorm(d_model)
self.bypass_scale = nn.Parameter(torch.tensor(0.5))
# try to ensure the output is close to zero-mean (or at least, zero-median).
self.balancer = ActivationBalancer(
d_model,
channel_dim=-1,
min_positive=0.45,
max_positive=0.55,
max_abs=6.0,
)
self.whiten = Whiten(
num_groups=1, whitening_limit=5.0, prob=(0.025, 0.25), grad_scale=0.01
)
def get_bypass_scale(self):
if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing():
return self.bypass_scale
if random.random() < 0.1:
# ensure we get grads if self.bypass_scale becomes out of range
return self.bypass_scale
# hardcode warmup period for bypass scale
warmup_period = 20000.0
initial_clamp_min = 0.75
final_clamp_min = 0.25
if self.batch_count > warmup_period:
clamp_min = final_clamp_min
else:
clamp_min = initial_clamp_min - (self.batch_count / warmup_period) * (
initial_clamp_min - final_clamp_min
)
return self.bypass_scale.clamp(min=clamp_min, max=1.0)
def get_dynamic_dropout_rate(self):
# return dropout rate for the dynamic modules (self_attn, pooling, convolution); this
# starts at 0.2 and rapidly decreases to 0. Its purpose is to keep the training stable
# at the beginning, by making the network focus on the feedforward modules.
if torch.jit.is_scripting() or not self.training or torch.jit.is_tracing():
return 0.0
warmup_period = 2000.0
initial_dropout_rate = 0.2
final_dropout_rate = 0.0
if self.batch_count > warmup_period:
return final_dropout_rate
else:
return initial_dropout_rate - (
initial_dropout_rate * final_dropout_rate
) * (self.batch_count / warmup_period)
def forward(
self,
src: Tensor,
pos_emb: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
"""
Pass the input through the encoder layer.
Args:
src: the sequence to the encoder layer (required).
pos_emb: Positional embedding tensor (required).
src_mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
batch_split: if not None, this layer will only be applied to
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E)
src_mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, N is the batch size, E is the feature number
"""
src_orig = src
# macaron style feed forward module
src = src + self.feed_forward1(src)
# dropout rate for submodules that interact with time.
dynamic_dropout = self.get_dynamic_dropout_rate()
# pooling module
if torch.jit.is_scripting() or torch.jit.is_tracing():
src = src + self.pooling(src, key_padding_mask=src_key_padding_mask)
elif random.random() >= dynamic_dropout:
src = src + self.pooling(src, key_padding_mask=src_key_padding_mask)
if torch.jit.is_scripting() or torch.jit.is_tracing():
src_att, attn_weights = self.self_attn(
src,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)
src = src + src_att
src = src + self.conv_module1(
src, src_key_padding_mask=src_key_padding_mask
)
src = src + self.feed_forward2(src)
src = src + self.self_attn.forward2(src, attn_weights)
src = src + self.conv_module2(
src, src_key_padding_mask=src_key_padding_mask
)
else:
use_self_attn = random.random() >= dynamic_dropout
if use_self_attn:
src_att, attn_weights = self.self_attn(
src,
pos_emb=pos_emb,
attn_mask=src_mask,
key_padding_mask=src_key_padding_mask,
)
src = src + src_att
if random.random() >= dynamic_dropout:
src = src + self.conv_module1(
src, src_key_padding_mask=src_key_padding_mask
)
src = src + self.feed_forward2(src)
if use_self_attn:
src = src + self.self_attn.forward2(src, attn_weights)
if random.random() >= dynamic_dropout:
src = src + self.conv_module2(
src, src_key_padding_mask=src_key_padding_mask
)
src = src + self.feed_forward3(src)
src = self.norm_final(self.balancer(src))
delta = src - src_orig
src = src_orig + delta * self.get_bypass_scale()
return self.whiten(src)
class ZipformerEncoder(nn.Module):
r"""ZipformerEncoder is a stack of N encoder layers
Args:
encoder_layer: an instance of the ZipformerEncoderLayer() class (required).
num_layers: the number of sub-encoder-layers in the encoder (required).
Examples::
>>> encoder_layer = ZipformerEncoderLayer(d_model=512, nhead=8)
>>> zipformer_encoder = ZipformerEncoder(encoder_layer, num_layers=6)
>>> src = torch.rand(10, 32, 512)
>>> out = zipformer_encoder(src)
"""
def __init__(
self,
encoder_layer: nn.Module,
num_layers: int,
dropout: float,
warmup_begin: float,
warmup_end: float,
) -> None:
super().__init__()
# will be written to, see set_batch_count() Note: in inference time this
# may be zero but should be treated as large, we can check if
# self.training is true.
self.batch_count = 0
self.warmup_begin = warmup_begin
self.warmup_end = warmup_end
# module_seed is for when we need a random number that is unique to the module but
# shared across jobs. It's used to randomly select how many layers to drop,
# so that we can keep this consistent across worker tasks (for efficiency).
self.module_seed = torch.randint(0, 1000, ()).item()
self.encoder_pos = RelPositionalEncoding(encoder_layer.d_model, dropout)
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
)
self.num_layers = num_layers
assert 0 <= warmup_begin <= warmup_end, (warmup_begin, warmup_end)
delta = (1.0 / num_layers) * (warmup_end - warmup_begin)
cur_begin = warmup_begin
for i in range(num_layers):
self.layers[i].warmup_begin = cur_begin
cur_begin += delta
self.layers[i].warmup_end = cur_begin
def get_layers_to_drop(self, rnd_seed: int):
ans = set()
if not self.training:
return ans
batch_count = self.batch_count
num_layers = len(self.layers)
def get_layerdrop_prob(layer: int) -> float:
layer_warmup_begin = self.layers[layer].warmup_begin
layer_warmup_end = self.layers[layer].warmup_end
initial_layerdrop_prob = 0.5
final_layerdrop_prob = 0.05
if batch_count == 0:
# As a special case, if batch_count == 0, return 0 (drop no
# layers). This is rather ugly, I'm afraid; it is intended to
# enable our scan_pessimistic_batches_for_oom() code to work correctly
# so if we are going to get OOM it will happen early.
# also search for 'batch_count' with quotes in this file to see
# how we initialize the warmup count to a random number between
# 0 and 10.
return 0.0
elif batch_count < layer_warmup_begin:
return initial_layerdrop_prob
elif batch_count > layer_warmup_end:
return final_layerdrop_prob
else:
# linearly interpolate
t = (batch_count - layer_warmup_begin) / layer_warmup_end
assert 0.0 <= t < 1.001, t
return initial_layerdrop_prob + t * (
final_layerdrop_prob - initial_layerdrop_prob
)
shared_rng = random.Random(batch_count + self.module_seed)
independent_rng = random.Random(rnd_seed)
layerdrop_probs = [get_layerdrop_prob(i) for i in range(num_layers)]
tot = sum(layerdrop_probs)
# Instead of drawing the samples independently, we first randomly decide
# how many layers to drop out, using the same random number generator between
# jobs so that all jobs drop out the same number (this is for speed).
# Then we use an approximate approach to drop out the individual layers
# with their specified probs while reaching this exact target.
num_to_drop = int(tot) + int(shared_rng.random() < (tot - int(tot)))
layers = list(range(num_layers))
independent_rng.shuffle(layers)
# go through the shuffled layers until we get the required number of samples.
if num_to_drop > 0:
for layer in itertools.cycle(layers):
if independent_rng.random() < layerdrop_probs[layer]:
ans.add(layer)
if len(ans) == num_to_drop:
break
if shared_rng.random() < 0.005 or __name__ == "__main__":
logging.info(
f"warmup_begin={self.warmup_begin:.1f}, warmup_end={self.warmup_end:.1f}, "
f"batch_count={batch_count:.1f}, num_to_drop={num_to_drop}, layers_to_drop={ans}"
)
return ans
def forward(
self,
src: Tensor,
# Note: The type of feature_mask should be Union[float, Tensor],
# but to make torch.jit.script() work, we use `float` here
feature_mask: float = 1.0,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
r"""Pass the input through the encoder layers in turn.
Args:
src: the sequence to the encoder (required).
feature_mask: something that broadcasts with src, that we'll multiply `src`
by at every layer.
mask: the mask for the src sequence (optional).
src_key_padding_mask: the mask for the src keys per batch (optional).
Shape:
src: (S, N, E).
pos_emb: (N, 2*S-1, E)
mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
Returns: (x, x_no_combine), both of shape (S, N, E)
"""
pos_emb = self.encoder_pos(src)
output = src
if torch.jit.is_scripting() or torch.jit.is_tracing():
layers_to_drop = []
else:
rnd_seed = src.numel() + random.randint(0, 1000)
layers_to_drop = self.get_layers_to_drop(rnd_seed)
output = output * feature_mask
for i, mod in enumerate(self.layers):
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
if i in layers_to_drop:
continue
output = mod(
output,
pos_emb,
src_mask=mask,
src_key_padding_mask=src_key_padding_mask,
)
output = output * feature_mask
return output
class DownsampledZipformerEncoder(nn.Module):
r"""
DownsampledZipformerEncoder is a zipformer encoder evaluated at a reduced frame rate,
after convolutional downsampling, and then upsampled again at the output, and combined
with the origin input, so that the output has the same shape as the input.
"""
def __init__(
self, encoder: nn.Module, input_dim: int, output_dim: int, downsample: int
):
super(DownsampledZipformerEncoder, self).__init__()
self.downsample_factor = downsample
self.downsample = AttentionDownsample(input_dim, output_dim, downsample)
self.encoder = encoder
self.upsample = SimpleUpsample(output_dim, downsample)
self.out_combiner = SimpleCombiner(
input_dim, output_dim, min_weight=(0.0, 0.25)
)
def forward(
self,
src: Tensor,
# Note: the type of feature_mask should be Unino[float, Tensor],
# but to make torch.jit.script() happ, we use float here
feature_mask: float = 1.0,
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
r"""Downsample, go through encoder, upsample.
Args:
src: the sequence to the encoder (required).
feature_mask: something that broadcasts with src, that we'll multiply `src`
by at every layer. feature_mask is expected to be already downsampled by
self.downsample_factor.
mask: the mask for the src sequence (optional). CAUTION: we need to downsample
this, if we are to support it. Won't work correctly yet.
src_key_padding_mask: the mask for the src keys per batch (optional). Should
be downsampled already.
Shape:
src: (S, N, E).
mask: (S, S).
src_key_padding_mask: (N, S).
S is the source sequence length, T is the target sequence length, N is the batch size, E is the feature number
Returns: output of shape (S, N, F) where F is the number of output features
(output_dim to constructor)
"""
src_orig = src
src = self.downsample(src)
ds = self.downsample_factor
if mask is not None:
mask = mask[::ds, ::ds]
src = self.encoder(
src,
feature_mask=feature_mask,
mask=mask,
src_key_padding_mask=src_key_padding_mask,
)
src = self.upsample(src)
# remove any extra frames that are not a multiple of downsample_factor
src = src[: src_orig.shape[0]]
return self.out_combiner(src_orig, src)
class AttentionDownsample(torch.nn.Module):
"""
Does downsampling with attention, by weighted sum, and a projection..
"""
def __init__(self, in_channels: int, out_channels: int, downsample: int):
"""
Require out_channels > in_channels.
"""
super(AttentionDownsample, self).__init__()
self.query = nn.Parameter(torch.randn(in_channels) * (in_channels**-0.5))
# fill in the extra dimensions with a projection of the input
if out_channels > in_channels:
self.extra_proj = nn.Linear(
in_channels * downsample, out_channels - in_channels, bias=False
)
else:
self.extra_proj = None
self.downsample = downsample
def forward(self, src: Tensor) -> Tensor:
"""
x: (seq_len, batch_size, in_channels)
Returns a tensor of shape
( (seq_len+downsample-1)//downsample, batch_size, out_channels)
"""
(seq_len, batch_size, in_channels) = src.shape
ds = self.downsample
d_seq_len = (seq_len + ds - 1) // ds
# Pad to an exact multiple of self.downsample, could be 0 for onnx-export-compatibility
# right-pad src, repeating the last element.
pad = d_seq_len * ds - seq_len
src_extra = src[src.shape[0] - 1 :].expand(pad, src.shape[1], src.shape[2])
src = torch.cat((src, src_extra), dim=0)
assert src.shape[0] == d_seq_len * ds, (src.shape[0], d_seq_len, ds)
src = src.reshape(d_seq_len, ds, batch_size, in_channels)
scores = (src * self.query).sum(dim=-1, keepdim=True)
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
scores = penalize_abs_values_gt(scores, limit=10.0, penalty=1.0e-04)
weights = scores.softmax(dim=1)
# ans1 is the first `in_channels` channels of the output
ans = (src * weights).sum(dim=1)
src = src.permute(0, 2, 1, 3).reshape(d_seq_len, batch_size, ds * in_channels)
if self.extra_proj is not None:
ans2 = self.extra_proj(src)
ans = torch.cat((ans, ans2), dim=2)
return ans
class SimpleUpsample(torch.nn.Module):
"""
A very simple form of upsampling that mostly just repeats the input, but
also adds a position-specific bias.
"""
def __init__(self, num_channels: int, upsample: int):
super(SimpleUpsample, self).__init__()
self.bias = nn.Parameter(torch.randn(upsample, num_channels) * 0.01)
def forward(self, src: Tensor) -> Tensor:
"""
x: (seq_len, batch_size, num_channels)
Returns a tensor of shape
( (seq_len*upsample), batch_size, num_channels)
"""
upsample = self.bias.shape[0]
(seq_len, batch_size, num_channels) = src.shape
src = src.unsqueeze(1).expand(seq_len, upsample, batch_size, num_channels)
src = src + self.bias.unsqueeze(1)
src = src.reshape(seq_len * upsample, batch_size, num_channels)
return src
class SimpleCombinerIdentity(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
def forward(self, src1: Tensor, src2: Tensor) -> Tensor:
return src1
class SimpleCombiner(torch.nn.Module):
"""
A very simple way of combining 2 vectors of 2 different dims, via a
learned weighted combination in the shared part of the dim.
Args:
dim1: the dimension of the first input, e.g. 256
dim2: the dimension of the second input, e.g. 384.
The output will have the same dimension as dim2.
"""
def __init__(self, dim1: int, dim2: int, min_weight: Tuple[float] = (0.0, 0.0)):
super(SimpleCombiner, self).__init__()
assert dim2 >= dim1, (dim2, dim1)
self.weight1 = nn.Parameter(torch.zeros(()))
self.min_weight = min_weight
def forward(self, src1: Tensor, src2: Tensor) -> Tensor:
"""
src1: (*, dim1)
src2: (*, dim2)
Returns: a tensor of shape (*, dim2)
"""
assert src1.shape[:-1] == src2.shape[:-1], (src1.shape, src2.shape)
weight1 = self.weight1
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
if (
self.training
and random.random() < 0.25
and self.min_weight != (0.0, 0.0)
):
weight1 = weight1.clamp(
min=self.min_weight[0], max=1.0 - self.min_weight[1]
)
src1 = src1 * weight1
src2 = src2 * (1.0 - weight1)
src1_dim = src1.shape[-1]
src2_dim = src2.shape[-1]
if src1_dim != src2_dim:
if src1_dim < src2_dim:
src1 = torch.nn.functional.pad(src1, (0, src2_dim - src1_dim))
else:
src1 = src1[:src2_dim]
return src1 + src2
class RelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding module.
See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py
Args:
d_model: Embedding dimension.
dropout_rate: Dropout rate.
max_len: Maximum input length.
"""
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None:
"""Construct a PositionalEncoding object."""
super(RelPositionalEncoding, self).__init__()
if is_jit_tracing():
# 10k frames correspond to ~100k ms, e.g., 100 seconds, i.e.,
# It assumes that the maximum input won't have more than
# 10k frames.
#
# TODO(fangjun): Use torch.jit.script() for this module
max_len = 10000
self.d_model = d_model
self.dropout = torch.nn.Dropout(dropout_rate)
self.pe = None
self.extend_pe(torch.tensor(0.0).expand(max_len))
def extend_pe(self, x: Tensor) -> None:
"""Reset the positional encodings."""
if self.pe is not None:
# self.pe contains both positive and negative parts
# the length of self.pe is 2 * input_len - 1
if self.pe.size(1) >= x.size(0) * 2 - 1:
# Note: TorchScript doesn't implement operator== for torch.Device
if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device):
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
return
# Suppose `i` means to the position of query vecotr and `j` means the
# position of key vector. We use position relative positions when keys
# are to the left (i>j) and negative relative positions otherwise (i<j).
pe_positive = torch.zeros(x.size(0), self.d_model)
pe_negative = torch.zeros(x.size(0), self.d_model)
position = torch.arange(0, x.size(0), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.d_model, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.d_model)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
pe_positive[:, 1::2] = torch.cos(position * div_term)
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
# Reserve the order of positive indices and concat both positive and
# negative indices. This is used to support the shifting trick
# as in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context"
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
pe_negative = pe_negative[1:].unsqueeze(0)
pe = torch.cat([pe_positive, pe_negative], dim=1)
self.pe = pe.to(device=x.device, dtype=x.dtype)
def forward(self, x: torch.Tensor) -> Tensor:
"""Add positional encoding.
Args:
x (torch.Tensor): Input tensor (time, batch, `*`).
Returns:
torch.Tensor: Encoded tensor (batch, time, `*`).
torch.Tensor: Encoded tensor (batch, 2*time-1, `*`).
"""
self.extend_pe(x)
pos_emb = self.pe[
:,
self.pe.size(1) // 2
- x.size(0)
+ 1 : self.pe.size(1) // 2 # noqa E203
+ x.size(0),
]
return self.dropout(pos_emb)
class RelPositionMultiheadAttention(nn.Module):
r"""Multi-Head Attention layer with relative position encoding
This is a quite heavily modified from: "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context",
we have to write up the differences.
Args:
embed_dim: total dimension of the model.
attention_dim: dimension in the attention module, may be less or more than embed_dim
but must be a multiple of num_heads.
num_heads: parallel attention heads.
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
Examples::
>>> rel_pos_multihead_attn = RelPositionMultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value, pos_emb)
"""
def __init__(
self,
embed_dim: int,
attention_dim: int,
num_heads: int,
pos_dim: int,
dropout: float = 0.0,