-
Notifications
You must be signed in to change notification settings - Fork 26
/
model.py
101 lines (82 loc) · 3.34 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.nn as nn
class Filter_Module(nn.Module):
def __init__(self, len_feature):
super(Filter_Module, self).__init__()
self.len_feature = len_feature
self.conv_1 = nn.Sequential(
nn.Conv1d(in_channels=self.len_feature, out_channels=512, kernel_size=1,
stride=1, padding=0),
nn.LeakyReLU()
)
self.conv_2 = nn.Sequential(
nn.Conv1d(in_channels=512, out_channels=1, kernel_size=1,
stride=1, padding=0),
nn.Sigmoid()
)
def forward(self, x):
# x: (B, T, F)
out = x.permute(0, 2, 1)
# out: (B, F, T)
out = self.conv_1(out)
out = self.conv_2(out)
out = out.permute(0, 2, 1)
# out: (B, T, 1)
return out
class CAS_Module(nn.Module):
def __init__(self, len_feature, num_classes):
super(CAS_Module, self).__init__()
self.len_feature = len_feature
self.conv_1 = nn.Sequential(
nn.Conv1d(in_channels=self.len_feature, out_channels=2048, kernel_size=3,
stride=1, padding=1),
nn.LeakyReLU()
)
self.conv_2 = nn.Sequential(
nn.Conv1d(in_channels=2048, out_channels=2048, kernel_size=3,
stride=1, padding=1),
nn.LeakyReLU()
)
self.conv_3 = nn.Sequential(
nn.Conv1d(in_channels=2048, out_channels=num_classes + 1, kernel_size=1,
stride=1, padding=0, bias=False)
)
self.drop_out = nn.Dropout(p=0.7)
def forward(self, x):
# x: (B, T, F)
out = x.permute(0, 2, 1)
# out: (B, F, T)
out = self.conv_1(out)
out = self.conv_2(out)
out = self.drop_out(out)
out = self.conv_3(out)
out = out.permute(0, 2, 1)
# out: (B, T, C + 1)
return out
class BaS_Net(nn.Module):
def __init__(self, len_feature, num_classes, num_segments):
super(BaS_Net, self).__init__()
self.filter_module = Filter_Module(len_feature)
self.len_feature = len_feature
self.num_classes = num_classes
self.cas_module = CAS_Module(len_feature, num_classes)
self.softmax = nn.Softmax(dim=1)
self.num_segments = num_segments
self.k = num_segments // 8
def forward(self, x):
fore_weights = self.filter_module(x)
x_supp = fore_weights * x
cas_base = self.cas_module(x)
cas_supp = self.cas_module(x_supp)
# slicing after sorting is much faster than torch.topk (https://github.com/pytorch/pytorch/issues/22812)
# score_base = torch.mean(torch.topk(cas_base, self.k, dim=1)[0], dim=1)
sorted_scores_base, _= cas_base.sort(descending=True, dim=1)
topk_scores_base = sorted_scores_base[:, :self.k, :]
score_base = torch.mean(topk_scores_base, dim=1)
# score_supp = torch.mean(torch.topk(cas_supp, self.k, dim=1)[0], dim=1)
sorted_scores_supp, _= cas_supp.sort(descending=True, dim=1)
topk_scores_supp = sorted_scores_supp[:, :self.k, :]
score_supp = torch.mean(topk_scores_supp, dim=1)
score_base = self.softmax(score_base)
score_supp = self.softmax(score_supp)
return score_base, cas_base, score_supp, cas_supp, fore_weights