-
Notifications
You must be signed in to change notification settings - Fork 1
/
image_encoder3D_deform.py
744 lines (640 loc) · 28.6 KB
/
image_encoder3D_deform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple, Type
import einops
from timm.models.layers import to_2tuple, trunc_normal_
class MLPBlock(nn.Module):
def __init__(
self,
embedding_dim: int,
mlp_dim: int,
act: Type[nn.Module] = nn.GELU,
) -> None:
super().__init__()
self.lin1 = nn.Linear(embedding_dim, mlp_dim)
self.lin2 = nn.Linear(mlp_dim, embedding_dim)
self.act = act()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lin2(self.act(self.lin1(x)))
class LayerNorm3d(nn.Module):
# 为channels_first的Tensor量身打造
def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
super().__init__()
# self.weight: (num_channels, )
self.weight = nn.Parameter(torch.ones(num_channels))
self.bias = nn.Parameter(torch.zeros(num_channels))
self.eps = eps
def forward(self, x: torch.Tensor) -> torch.Tensor:
u = x.mean(1, keepdim=True)
# Tensor.pow(2): 逐元素平方操作
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
# [:, None, None, None]*x: 扩展维度, 维度大小和x的后3个维度相同, 第0个维度大小与原始维度相同
x = self.weight[:, None, None, None] * x + self.bias[:, None, None, None]
return x
class LayerNormProxy3D(nn.Module):
def __init__(self, dim):
super().__init__()
self.norm = nn.LayerNorm(dim)
def forward(self, x):
x = einops.rearrange(x, 'b c d h w -> b d h w c')
x = self.norm(x)
return einops.rearrange(x, 'b d h w c -> b c d h w')
# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa
class ImageEncoderViT3D(nn.Module):
def __init__(
self,
img_size: int = 128,
patch_size: int = 16,
in_chans: int = 1,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
out_chans: int = 384,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_abs_pos: bool = True,
use_rel_pos: bool = True,
rel_pos_zero_init: bool = True,
window_size: int = 14,
global_attn_indexes: Tuple[int, ...] = [2, 5, 8, 11],
d_attn_indexes: Tuple[int, ...] = [8, 11], # , [0, 1, 3, 4], [6, 7, 9, 10]
) -> None:
"""
Args:
img_size (int): Input image size.
patch_size (int): Patch size.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
depth (int): Depth of ViT.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_abs_pos (bool): If True, use absolute positional embeddings.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks.
global_attn_indexes (list): Indexes for blocks using global attention.
"""
super().__init__()
self.img_size = img_size
self.patch_embed = PatchEmbed3D(
kernel_size=(patch_size, patch_size, patch_size),
stride=(patch_size, patch_size, patch_size),
in_chans=in_chans,
embed_dim=embed_dim,
)
# typing.Optional[]: nn.Parameter类型的实例或者None
self.pos_embed: Optional[nn.Parameter] = None
if use_abs_pos:
# Initialize absolute positional embedding with pretrain image size.
self.pos_embed = nn.Parameter(
torch.zeros(1, img_size // patch_size, img_size // patch_size, img_size // patch_size, embed_dim)
)
self.blocks = nn.ModuleList()
for i in range(depth):
block = Block3D(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
norm_layer=norm_layer,
act_layer=act_layer,
use_rel_pos=use_rel_pos,
rel_pos_zero_init=rel_pos_zero_init,
window_size=window_size if (i not in global_attn_indexes) and (i not in d_attn_indexes) else 0,
input_size=(img_size // patch_size, img_size // patch_size, img_size // patch_size),
d_attn = True if i in d_attn_indexes else False
)
self.blocks.append(block)
self.neck = nn.Sequential(
nn.Conv3d(
embed_dim,
out_chans,
kernel_size=1,
bias=False,
),
# nn.LayerNorm(out_chans),
LayerNorm3d(out_chans),
nn.Conv3d(
out_chans,
out_chans,
kernel_size=3,
padding=1,
bias=False,
),
LayerNorm3d(out_chans),
# nn.LayerNorm(out_chans),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# input_size = [1,1,256,256,256]
# import IPython; IPython.embed()
x = self.patch_embed(x)
# x = [1,16,16,16,768]
# import pdb; pdb.set_trace()
if self.pos_embed is not None:
x = x + self.pos_embed
for blk in self.blocks:
x = blk(x)
# x = [1,16,16,16,768]
x = self.neck(x.permute(0, 4, 1, 2, 3))
# output_size = [1,256,16,16,16]
return x
class Block3D(nn.Module):
"""Transformer blocks with support of window attention and residual propagation blocks"""
def __init__(
self,
dim: int = 768,
num_heads: int = 12,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
norm_layer: Type[nn.Module] = nn.LayerNorm,
act_layer: Type[nn.Module] = nn.GELU,
use_rel_pos: bool = True,
rel_pos_zero_init: bool = True,
window_size: int = 14,
input_size: Optional[Tuple[int, int, int]] = None,
d_attn: bool = False
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads in each ViT block.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
norm_layer (nn.Module): Normalization layer.
act_layer (nn.Module): Activation layer.
use_rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
window_size (int): Window size for window attention blocks. If it equals 0, then
use global attention.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.norm1 = norm_layer(dim)
self.d_attn = d_attn
if d_attn == True:
self.deform_attn = DAttention3D(
q_size=input_size,
n_heads=num_heads,
n_head_channels=dim // num_heads,
n_groups=6,
attn_drop=0.2,
proj_drop=0.2,
stride=2,
offset_range_factor=1,
ksize=2,
use_pe=True,
no_off=False
)
else:
# self.attn = Attention(
# dim,
# num_heads=num_heads,
# qkv_bias=qkv_bias,
# use_rel_pos=use_rel_pos,
# rel_pos_zero_init=rel_pos_zero_init,
# input_size=input_size if window_size == 0 else (window_size, window_size, window_size),
# )
self.attn = LoRA_Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
use_rel_pos=use_rel_pos,
input_size=input_size if window_size == 0 else (window_size, window_size, window_size),
)
self.norm2 = norm_layer(dim)
self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer)
self.window_size = window_size
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
D, H, W = x.shape[1], x.shape[2], x.shape[3]
x, pad_dhw = window_partition3D(x, self.window_size)
if self.d_attn == True:
x = x.permute(0, 4, 1, 2, 3).contiguous()
if self.d_attn == True:
x = self.deform_attn(x)
else:
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition3D(x, self.window_size, pad_dhw, (D, H, W))
if self.d_attn == True:
x = x.permute(0, 2, 3, 4, 1).contiguous()
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
class DAttention3D(nn.Module):
def __init__(
self, q_size, n_heads, n_head_channels, n_groups,
attn_drop, proj_drop, stride,
offset_range_factor, ksize, use_pe=True,
no_off=False
):
super().__init__()
self.n_head_channels = n_head_channels
self.scale = self.n_head_channels ** -0.5
self.n_heads = n_heads # self.n_heads: 头数
self.q_d, self.q_h, self.q_w = q_size
self.nc = n_head_channels * n_heads
self.n_groups = n_groups # 组数
self.n_group_channels = self.nc // self.n_groups # 每个组的channels
self.n_group_heads = self.n_heads // self.n_groups
self.use_pe = use_pe # 是否使用relative position bias
self.no_off = no_off # self.np_off: 是否不应用offset
self.offset_range_factor = offset_range_factor
self.ksize = ksize
# self.softmax = nn.Softmax(dim=2, inplace=False)
self.stride = stride # self.stride: offset network的stride
kk = self.ksize # kk: kernel size
pad_size = kk // 2 if kk != stride else 0
self.conv_offset = nn.Sequential(
# 分组卷积: in_channels和out_channels要能被groups整除
nn.Conv3d(self.n_group_channels, self.n_group_channels, kk, stride, pad_size, groups=self.n_group_channels),
LayerNormProxy3D(self.n_group_channels),
nn.GELU(),
nn.Conv3d(self.n_group_channels, 3, 1, 1, 0, bias=False)
)
if self.no_off:
for m in self.conv_offset.parameters():
m.requires_grad_(False)
# self.proj_q: x生成q的卷积
self.proj_q = nn.Conv3d(
self.nc, self.nc,
kernel_size=1, stride=1, padding=0
)
self.proj_k = nn.Conv2d(
self.nc, self.nc,
kernel_size=1, stride=1, padding=0
)
self.proj_v = nn.Conv2d(
self.nc, self.nc,
kernel_size=1, stride=1, padding=0
)
self.proj_out = nn.Conv3d(
self.nc, self.nc,
kernel_size=1, stride=1, padding=0
)
self.proj_drop = nn.Dropout(proj_drop, inplace=True)
self.attn_drop = nn.Dropout(attn_drop, inplace=True)
if self.use_pe and not self.no_off:
self.rpe_table = nn.Parameter(
torch.zeros(self.n_heads, self.q_h * 2 - 1, self.q_w * 2 - 1)
)
trunc_normal_(self.rpe_table, std=0.01)
else:
self.rpe_table = None
@torch.no_grad()
def _get_ref_points(self, D_key, H_key, W_key, B, dtype, device):
ref_y, ref_x, ref_z = torch.meshgrid(
torch.linspace(0.5, H_key - 0.5, H_key, dtype=dtype, device=device),
torch.linspace(0.5, W_key - 0.5, W_key, dtype=dtype, device=device),
torch.linspace(0.5, D_key - 0.5, D_key, dtype=dtype, device=device),
indexing='ij'
)
ref = torch.stack((ref_y, ref_x, ref_z), -1)
ref[..., 2].div_(D_key - 1.0).mul_(2.0).sub_(1.0)
ref[..., 1].div_(W_key - 1.0).mul_(2.0).sub_(1.0)
ref[..., 0].div_(H_key - 1.0).mul_(2.0).sub_(1.0)
ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1, -1)
return ref
@torch.no_grad()
def _get_q_grid(self, D, H, W, B, dtype, device):
ref_y, ref_x, ref_z = torch.meshgrid(
torch.arange(0, H, dtype=dtype, device=device),
torch.arange(0, W, dtype=dtype, device=device),
torch.arange(0, D, dtype=dtype, device=device),
indexing='ij'
)
ref = torch.stack((ref_y, ref_x, ref_z), -1)
ref[..., 2].div_(D - 1.0).mul_(2.0).sub_(1.0)
ref[..., 1].div_(W - 1.0).mul_(2.0).sub_(1.0)
ref[..., 0].div_(H - 1.0).mul_(2.0).sub_(1.0)
ref = ref[None, ...].expand(B * self.n_groups, -1, -1, -1, -1)
return ref
def forward(self, x):
# print('DAttention:', x.shape)
B, C, D, H, W = x.size()
dtype, device = x.dtype, x.device
q = self.proj_q(x)
q_off = einops.rearrange(q, 'b (g c) d h w -> (b g) c d h w', g=self.n_groups, c=self.n_group_channels)
offset = self.conv_offset(q_off).contiguous()
Dk, Hk, Wk = offset.size(2), offset.size(3), offset.size(4)
n_sample = Hk * Wk * Dk
if self.offset_range_factor >= 0 and not self.no_off:
offset_range = torch.tensor([1.0 / (Dk - 1.0), 1.0 / (Hk - 1.0), 1.0 / (Wk - 1.0)], device=device).reshape(
1, 3, 1, 1, 1)
offset = offset.tanh().mul(offset_range).mul(self.offset_range_factor)
offset = einops.rearrange(offset, 'b p d h w -> b d h w p')
reference = self._get_ref_points(Dk, Hk, Wk, B, dtype, device)
if self.no_off:
offset = offset.fill_(0.0)
if self.offset_range_factor >= 0:
pos = offset + reference
else:
pos = (offset + reference).clamp(-1., +1.)
if self.no_off:
x_sampled = F.avg_pool3d(x, kernel_size=self.stride, stride=self.stride)
assert x_sampled.size(2) == Dk and x_sampled.size(3) == Hk and x_sampled.size(
4) == Wk, f"Size is {x_sampled.size()}"
else:
x_sampled = F.grid_sample(
input=x.reshape(B * self.n_groups, self.n_group_channels, D, H, W),
grid=pos[..., (2, 1, 0)],
mode='bilinear', align_corners=True)
x_sampled = x_sampled.reshape(B, C, 1, n_sample)
q = q.reshape(B * self.n_heads, self.n_head_channels, H * W * D)
k = self.proj_k(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)
v = self.proj_v(x_sampled).reshape(B * self.n_heads, self.n_head_channels, n_sample)
attn = torch.einsum('b c m, b c n -> b m n', q, k) # B * h, HW, Ns
attn = attn.mul(self.scale)
if self.use_pe and (not self.no_off):
rpe_table = self.rpe_table
rpe_bias = rpe_table[None, ...].expand(B, -1, -1, -1)
q_grid = self._get_q_grid(D, H, W, B, dtype, device)
displacement = (q_grid.reshape(B * self.n_groups, H * W * D, 3).unsqueeze(2) - pos.reshape(B * self.n_groups, n_sample, 3).unsqueeze(1)).mul(0.5)
displacement = displacement[..., (2, 1, 0)]
attn_bias = F.grid_sample(
input=einops.rearrange(rpe_bias, 'b (g c) h w -> (b g) c h w', c=self.n_group_heads, g=self.n_groups),
grid=displacement[...,1:],
mode='bilinear', align_corners=True)
attn_bias = attn_bias.reshape(B * self.n_heads, H * W * D, n_sample)
attn = attn + attn_bias
attn = torch.softmax(attn, dim=2).clone()
attn = self.attn_drop(attn)
out = torch.einsum('b m n, b c n -> b c m', attn, v)
out = out.reshape(B, C, D, H, W)
y = self.proj_drop(self.proj_out(out))
# pos.reshape(B, self.n_groups, Dk, Hk, Wk, 3), reference.reshape(B, self.n_groups, Dk, Hk, Wk, 3)
return y
class Attention(nn.Module):
"""Multi-head Attention block with relative position embeddings."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
input_size: Optional[Tuple[int, int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
assert (
input_size is not None
), "Input size must be provided if using relative positional encoding."
# initialize relative positional embeddings
self.rel_pos_d = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[2] - 1, head_dim))
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, D, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, D * H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, D * H * W, -1).unbind(0)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.use_rel_pos:
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_d, self.rel_pos_h, self.rel_pos_w, (D, H, W), (D, H, W))
attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, D, H, W, -1).permute(0, 2, 3, 4, 1, 5).reshape(B, D, H, W, -1)
x = self.proj(x)
return x
class LoRA_Attention(nn.Module):
"""Multi-head Attention block with LoRA."""
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
r: int = 4,
input_size: Optional[Tuple[int, int, int]] = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
rel_pos (bool): If True, add relative positional embeddings to the attention map.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.r = r
self.dim = dim
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.lora_linear_a_q = nn.Linear(dim, self.r, bias=False)
self.lora_linear_b_q = nn.Linear(self.r, dim, bias=False)
self.lora_linear_a_v = nn.Linear(dim, self.r, bias=False)
self.lora_linear_b_v = nn.Linear(self.r, dim, bias=False)
self.proj = nn.Linear(dim, dim)
self.use_rel_pos = use_rel_pos
if self.use_rel_pos:
assert (
input_size is not None
), "Input size must be provided if using relative positional encoding."
# initialize relative positional embeddings
self.rel_pos_d = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[2] - 1, head_dim))
self.reset_parameters()
def forward(self, x: torch.Tensor) -> torch.Tensor:
# print('LoRA Attention: ', x.shape)
B, D, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x)
new_q = self.lora_linear_b_q(self.lora_linear_a_q(x))
new_v = self.lora_linear_b_v(self.lora_linear_a_v(x))
qkv[:, :, :, :, : self.dim] += new_q
qkv[:, :, :, :, -self.dim:] += new_v
qkv = qkv.reshape(B, D * H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, D * H * W, -1).unbind(0)
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.use_rel_pos:
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_d, self.rel_pos_h, self.rel_pos_w, (D, H, W), (D, H, W))
attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, D, H, W, -1).permute(0, 2, 3, 4, 1, 5).reshape(B, D, H, W, -1)
x = self.proj(x)
return x
def reset_parameters(self) -> None: # 初始化lora参数
nn.init.kaiming_uniform_(self.lora_linear_a_q.weight, a=math.sqrt(5))
nn.init.kaiming_uniform_(self.lora_linear_a_v.weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_linear_b_q.weight)
nn.init.zeros_(self.lora_linear_b_v.weight)
def window_partition3D(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int, int]]:
"""
Partition into non-overlapping windows with padding if needed.
Args:
x (tensor): input tokens with [B, H, W, C].
window_size (int): window size.
Returns:
windows: windows after partition with [B * num_windows, window_size, window_size, C].
(Hp, Wp): padded height and width before partition
"""
B, D, H, W, C = x.shape
pad_d = (window_size - D % window_size) % window_size
pad_h = (window_size - H % window_size) % window_size
pad_w = (window_size - W % window_size) % window_size
if pad_h > 0 or pad_w > 0 or pad_d > 0:
# F.pad(): dim=1 2 3后方均填充pad_w pad_h pad_d个元素
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h, 0, pad_d))
Hp, Wp, Dp = H + pad_h, W + pad_w, D + pad_d
x = x.view(B, Dp // window_size, window_size, Hp // window_size, window_size, Wp // window_size, window_size, C)
windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size, window_size, window_size, C)
return windows, (Dp, Hp, Wp)
def window_unpartition3D(
windows: torch.Tensor, window_size: int, pad_dhw: Tuple[int, int, int], dhw: Tuple[int, int, int]
) -> torch.Tensor:
"""
Window unpartition into original sequences and removing padding.
Args:
windows (tensor): input tokens with [B * num_windows, window_size, window_size, C].
window_size (int): window size.
pad_hw (Tuple): padded height and width (Hp, Wp).
hw (Tuple): original height and width (H, W) before padding.
Returns:
x: unpartitioned sequences with [B, H, W, C].
"""
Dp, Hp, Wp = pad_dhw
D, H, W = dhw
B = windows.shape[0] // (Dp * Hp * Wp // window_size // window_size // window_size)
x = windows.view(B, Dp // window_size, Hp // window_size, Wp // window_size, window_size, window_size, window_size, -1)
x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, Hp, Wp, Dp, -1)
if Hp > H or Wp > W or Dp > D:
x = x[:, :D, :H, :W, :].contiguous()
return x
def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
"""
Get relative positional embeddings according to the relative positions of
query and key sizes.
Args:
q_size (int): size of query q.
k_size (int): size of key k.
rel_pos (Tensor): relative position embeddings (L, C).
Returns:
Extracted positional embeddings according to relative positions.
"""
max_rel_dist = int(2 * max(q_size, k_size) - 1)
# Interpolate rel pos if needed.
if rel_pos.shape[0] != max_rel_dist:
# Interpolate rel pos.
rel_pos_resized = F.interpolate(
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
size=max_rel_dist,
mode="linear",
)
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
else:
rel_pos_resized = rel_pos
# Scale the coords with short length if shapes for q and k are different.
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)
return rel_pos_resized[relative_coords.long()]
def add_decomposed_rel_pos(
attn: torch.Tensor,
q: torch.Tensor,
rel_pos_d: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int, int],
k_size: Tuple[int, int, int],
) -> torch.Tensor:
"""
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`.
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950
Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
q_d, q_h, q_w = q_size
k_d, k_h, k_w = k_size
Rd = get_rel_pos(q_d, k_d, rel_pos_d)
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
r_q = q.reshape(B, q_d, q_h, q_w, dim)
rel_d = torch.einsum("bdhwc,dkc->bdhwk", r_q, Rd)
rel_h = torch.einsum("bdhwc,hkc->bdhwk", r_q, Rh)
rel_w = torch.einsum("bdhwc,wkc->bdhwk", r_q, Rw)
attn = (
attn.view(B, q_d, q_h, q_w, k_d, k_h, k_w) + rel_d[:, :, :, :, None, None] + rel_h[:, :, :, None, :, None] + rel_w[:, :, :,None,None, :]
).view(B, q_d * q_h * q_w, k_d * k_h * k_w)
return attn
class PatchEmbed3D(nn.Module):
"""
Image to Patch Embedding.
"""
def __init__(
self,
kernel_size: Tuple[int, int] = (16, 16, 16),
stride: Tuple[int, int] = (16, 16, 16),
padding: Tuple[int, int] = (0, 0, 0),
in_chans: int = 1,
embed_dim: int = 768,
) -> None:
"""
Args:
kernel_size (Tuple): kernel size of the projection layer.
stride (Tuple): stride of the projection layer.
padding (Tuple): padding size of the projection layer.
in_chans (int): Number of input image channels.
embed_dim (int): Patch embedding dimension.
"""
super().__init__()
self.proj = nn.Conv3d(
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.proj(x)
# B C X Y Z -> B X Y Z C
x = x.permute(0, 2, 3, 4, 1)
return x
if __name__=='__main__':
b = 2; c = 1; d = 128; h = 128; w = 128
input_images = torch.randn((b, c, d, h, w))
image_encoder = ImageEncoderViT3D()
image_embeddings = image_encoder(input_images)
print(image_embeddings.shape)