-
Notifications
You must be signed in to change notification settings - Fork 0
/
MobilenetBase.py
127 lines (103 loc) · 4.5 KB
/
MobilenetBase.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
#MobileNetV3Base
from keras.layers import Conv2D, DepthwiseConv2D, Dense, GlobalAveragePooling2D
from keras.layers import Activation, BatchNormalization, Add, Multiply, Reshape
from keras import backend as K
class MobilenetBase:
def __init__(self, shape, n_class, alpha=1.0):
"""Init
# Arguments
input_shape: An integer or tuple/list of 3 integers, shape
of input tensor.
n_class: Integer, number of classes.
alpha: Integer, width multiplier.
"""
self.shape = shape
self.n_class = n_class
self.alpha = alpha
def _relu6(self, x): #ReLU6
return K.relu(x, max_value=6.0)
def _hard_swish(self, x): #Hard-swish
return x * K.relu(x + 3.0, max_value=6.0) / 6.0
def _return_activation(self, x, nl):
"""Convolution Block
This function defines a activation choice.
# Arguments
x: Tensor, input tensor of conv layer.
nl: String, nonlinearity activation type. #可参考MobileNetV3论文
# Returns
Output tensor.
"""
if nl == 'HS':
x = Activation(self._hard_swish)(x)
if nl == 'RE':
x = Activation(self._relu6)(x)
return x
def _conv_block(self, inputs, filters, kernel, strides, nl):
"""Convolution Block
This function defines a 2D convolution operation with BN and activation.
# Arguments
inputs: Tensor, input tensor of conv layer.
filters: Integer, the dimensionality of the output space.
kernel: An integer or tuple/list of 2 integers, specifying the
width and height of the 2D convolution window.
strides: An integer or tuple/list of 2 integers,
specifying the strides of the convolution along the width and height.
Can be a single integer to specify the same value for
all spatial dimensions.
nl: String, nonlinearity activation type.
# Returns
Output tensor.
"""
channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
x = Conv2D(filters, kernel, padding='same', strides=strides)(inputs)
x = BatchNormalization(axis=channel_axis)(x)
return self._return_activation(x, nl)
def _squeeze(self, inputs): #SE-Net
"""Squeeze and Excitation.
This function defines a squeeze structure.
# Arguments
inputs: Tensor, input tensor of conv layer.
"""
input_channels = int(inputs.shape[-1])
x = GlobalAveragePooling2D()(inputs)
x = Dense(input_channels, activation='relu')(x)
x = Dense(input_channels, activation='hard_sigmoid')(x)
x = Reshape((1, 1, input_channels))(x)
x = Multiply()([inputs, x])
return x
def _bottleneck(self, inputs, filters, kernel, e, s, squeeze, nl):
"""Bottleneck
This function defines a basic bottleneck structure.
# Arguments
inputs: Tensor, input tensor of conv layer.
filters: Integer, the dimensionality of the output space.
kernel: An integer or tuple/list of 2 integers, specifying the
width and height of the 2D convolution window.
e: Integer, expansion factor.
t is always applied to the input size.
s: An integer or tuple/list of 2 integers,specifying the strides
of the convolution along the width and height.Can be a single
integer to specify the same value for all spatial dimensions.
squeeze: Boolean, Whether to use the squeeze.
nl: String, nonlinearity activation type.
# Returns
Output tensor.
"""
channel_axis = 1 if K.image_data_format() == 'channels_first' else -1
input_shape = K.int_shape(inputs)
tchannel = int(e)
cchannel = int(self.alpha * filters)
r = s == 1 and input_shape[3] == filters
x = self._conv_block(inputs, tchannel, (1, 1), (1, 1), nl)
x = DepthwiseConv2D(kernel, strides=(s, s), depth_multiplier=1, padding='same')(x)
x = BatchNormalization(axis=channel_axis)(x)
x = self._return_activation(x, nl)
if squeeze:
x = self._squeeze(x)
x = Conv2D(cchannel, (1, 1), strides=(1, 1), padding='same')(x)
x = BatchNormalization(axis=channel_axis)(x)
if r:
x = Add()([x, inputs])
return x
def build(self):
pass