-
Notifications
You must be signed in to change notification settings - Fork 0
/
CU_Net3d.py
126 lines (110 loc) · 4.82 KB
/
CU_Net3d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 24 15:13:17 2021
@author: Ding
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
class Prediction(nn.Module):
def __init__(self, num_channels):
super(Prediction, self).__init__()
self.num_layers = 4
self.in_channel = num_channels
self.kernel_size = 9
self.num_filters = 32
self.layer_in = nn.Conv3d(in_channels=self.in_channel, out_channels=self.num_filters,
kernel_size=self.kernel_size, padding=4, stride=1, bias=False)
nn.init.xavier_uniform_(self.layer_in.weight.data)
self.lam_in = nn.Parameter(torch.Tensor([0.01]))
self.lam_i = []
self.layer_down = []
self.layer_up = []
for i in range(self.num_layers):
down_conv = 'down_conv_{}'.format(i)
up_conv = 'up_conv_{}'.format(i)
lam_id = 'lam_{}'.format(i)
layer_2 = nn.Conv3d(in_channels=self.num_filters, out_channels=self.in_channel,
kernel_size=self.kernel_size, padding=4, stride=1, bias=False)
nn.init.xavier_uniform_(layer_2.weight.data)
setattr(self, down_conv, layer_2)
self.layer_down.append(getattr(self, down_conv))
layer_3 = nn.Conv3d(in_channels=self.in_channel, out_channels=self.num_filters,
kernel_size=self.kernel_size, padding=4, stride=1, bias=False)
nn.init.xavier_uniform_(layer_3.weight.data)
setattr(self, up_conv, layer_3)
self.layer_up.append(getattr(self, up_conv))
lam_ = nn.Parameter(torch.Tensor([0.01]))
setattr(self, lam_id, lam_)
self.lam_i.append(getattr(self, lam_id))
def forward(self, mod):
p1 = self.layer_in(mod)
tensor = torch.mul(torch.sign(p1), F.relu(torch.abs(p1) - self.lam_in))
for i in range(self.num_layers):
p3 = self.layer_down[i](tensor)
p4 = self.layer_up[i](p3)
p5 = tensor - p4
p6 = torch.add(p1, p5)
tensor = torch.mul(torch.sign(p6), F.relu(torch.abs(p6) - self.lam_i[i]))
return tensor
class decoder(nn.Module):
def __init__(self):
super(decoder, self).__init__()
self.channel = 1
self.kernel_size = 9
self.filters = 32
self.conv_1 = nn.Conv3d(in_channels=self.filters, out_channels=self.channel, kernel_size=self.kernel_size,
stride=1, padding=4, bias=False)
nn.init.xavier_uniform_(self.conv_1.weight.data)
self.conv_2 = nn.Conv3d(in_channels=self.filters, out_channels=self.channel, kernel_size=self.kernel_size,
stride=1, padding=4, bias=False)
nn.init.xavier_uniform_(self.conv_2.weight.data)
self.conv_3 = nn.Conv3d(in_channels=self.filters, out_channels=self.channel, kernel_size=self.kernel_size,
stride=1, padding=4, bias=False)
nn.init.xavier_uniform_(self.conv_3.weight.data)
def forward(self, u,v,z):
rec_u = self.conv_1(u)
rec_z = self.conv_2(z)
rec_v = self.conv_3(v)
z_rec = rec_u + rec_z + rec_v
return z_rec
class CUNet(nn.Module):
def __init__(self):
super(CUNet, self).__init__()
self.channel = 1
self.num_filters = 32
self.kernel_size = 9
self.net_u = Prediction(num_channels=self.channel)
self.conv_u = nn.Conv3d(in_channels=self.num_filters, out_channels=self.channel, kernel_size=self.kernel_size,
stride=1, padding=4, bias=False)
nn.init.xavier_uniform_(self.conv_u.weight.data)
self.net_v = Prediction(num_channels=self.channel)
self.conv_v = nn.Conv3d(in_channels=self.num_filters, out_channels=self.channel, kernel_size=self.kernel_size,
stride=1, padding=4, bias=False)
nn.init.xavier_uniform_(self.conv_v.weight.data)
self.net_z = Prediction(num_channels=2 * self.channel)
self.decoder = decoder()
# self.reshape_layer=nn.Upsample(size=(27,28,21))
def forward(self, x, y):
u = self.net_u(x)
# y=self.reshape_layer(y)
y=nn.functional.interpolate(y,(27,27,20))
v = self.net_v(y)
p_x = x - self.conv_u(u)
p_y = y - self.conv_v(v)
p_xy = torch.cat((p_x, p_y), dim=1)
z = self.net_z(p_xy)
f_pred = self.decoder(u,v,z)
return f_pred
#if __name__ == '__main__':
# cu = CUNet()
# print(cu)
# a = torch.rand([1, 1, 27, 27,20]).cuda()
# b = torch.rand([1, 1, 27, 28,21]).cuda()
# c=nn.functional.interpolate(b,(27,27,20))
# c.shape
# net = CUNet().cuda()
# a_out = net(a, b)
# print(a.shape)
# print(b.shape)
# print(a_out.shape)