-
Notifications
You must be signed in to change notification settings - Fork 43
/
capsule.py
59 lines (47 loc) · 2.31 KB
/
capsule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable
import config
class CapsuleLayer(nn.Module):
def __init__(self, num_capsules, num_route_nodes, in_channels, out_channels, kernel_size=None, stride=None,
num_iterations=config.NUM_ROUTING_ITERATIONS):
super(CapsuleLayer, self).__init__()
self.num_route_nodes = num_route_nodes
self.num_iterations = num_iterations
self.num_capsules = num_capsules
if num_route_nodes != -1:
self.route_weights = nn.Parameter(torch.randn(num_capsules, num_route_nodes, in_channels, out_channels))
else:
self.capsules = nn.ModuleList(
[nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=0) for _ in
range(num_capsules)])
@staticmethod
def squash(tensor, dim=-1):
squared_norm = (tensor ** 2).sum(dim=dim, keepdim=True)
scale = squared_norm / (1 + squared_norm)
return scale * tensor / torch.sqrt(squared_norm)
def forward(self, x):
if self.num_route_nodes != -1:
priors = x[None, :, :, None, :] @ self.route_weights[:, None, :, :, :]
logits = Variable(torch.zeros(*priors.size()))
if torch.cuda.is_available():
logits = logits.cuda()
for i in range(self.num_iterations):
probs = F.softmax(logits, dim=2)
outputs = self.squash((probs * priors).sum(dim=2, keepdim=True))
if i != self.num_iterations - 1:
delta_logits = (priors * outputs).sum(dim=-1, keepdim=True)
logits = logits + delta_logits
else:
outputs = [capsule(x).view(x.size(0), -1, 1) for capsule in self.capsules]
outputs = torch.cat(outputs, dim=-1)
outputs = self.squash(outputs)
return outputs
if __name__ == "__main__":
primary_capsules = CapsuleLayer(num_capsules=8, num_route_nodes=-1, in_channels=256, out_channels=32,
kernel_size=9, stride=2)
print(primary_capsules)
digit_capsules = CapsuleLayer(num_capsules=config.NUM_CLASSES, num_route_nodes=32 * 6 * 6, in_channels=8,
out_channels=16)
print(digit_capsules)