-
Notifications
You must be signed in to change notification settings - Fork 4
/
layers.py
80 lines (64 loc) · 3.15 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import numpy as np
from collections import OrderedDict
import sys
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import pdb
class DynamicFilterLayer(nn.Module): #MergeLayer
def __init__(self, filter_size, stride=(1,1), pad=(0,0), flip_filters=False, grouping=False):
super(DynamicFilterLayer, self).__init__()
self.filter_size = filter_size #tuple 3
self.stride = stride #tuple 2
self.pad = pad #tuple 2
self.flip_filters = flip_filters
self.grouping = grouping
def get_output_shape_for(self, input_shapes):
if self.grouping:
shape = (input_shapes[0][0], input_shapes[0][1], input_shapes[0][2], input_shapes[0][3])
else:
shape = (input_shape[0][0], 1, input_shapes[0][2], input_shapes[0][3])
return shape
def forward(self, _input, **kwargs):
#def get_output_for(self, _input, **kwargs):
image = _input[0]
filters = _input[1]
conv_mode = 'conv' if self.flip_filters else 'cross'
border_mode = self.pad
if border_mode == 'same':
border_mode = tuple(s // 2 for s in self.filter_size)
filter_size = self.filter_size
if self.grouping:
filter_localexpand_np = np.reshape(np.eye(np.prod(filter_size), np.prod(filter_size)), (np.prod(filter_size), filter_size[0], filter_size[1]))
filter_localexpand = filter_localexpand_np.float()
outputs = []
for i in range(3):
input_localexpand = F.Conv2d(image[:, [i], :, :], kerns= filter_localexpand,
subsample=self.stride, border_mode=border_mode, conv_mod= conv_mode)
output = torch.sum(input_localexpand*filters[i], dim=1, keepdim=True)
outputs.append(output)
output = torch.cat(outputs, dim=1)
else:
filter_localexpand_np = np.reshape(np.eye(np.prod(filter_size)), (np.prod(filter_size), filter_size[2], filter_size[0], filter_size[1]))
filter_localexpand = torch.from_numpy(filter_localexpand_np.astype('float32')).cuda()
input_localexpand = F.conv2d(image, filter_localexpand, padding = self.pad)
output = torch.sum(input_localexpand*filters, dim=1, keepdim=True)
return output
class DynamicFilterLayer1D(nn.Module): #MergeLayer
def __init__(self, filter_size, stride=1, pad=0):
super(DynamicFilterLayer1D, self).__init__()
self.filter_size = filter_size #tuple 3
self.stride = stride #tuple 2
self.pad = pad #tuple 2
def forward(self, _input, **kwargs):
image = _input[0]
filters = _input[1]
image = image.unsqueeze(0)
output = []
for i in range(image.shape[1]): # 60 times
result = F.conv1d(image[:,i], filters[i], padding = self.pad, stride = self.stride)
output.append(result)
output = torch.cat(output, 0)
#pdb.set_trace()
return output