-
Notifications
You must be signed in to change notification settings - Fork 0
/
TF_nn_statistics.py
72 lines (59 loc) · 2.04 KB
/
TF_nn_statistics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import pickle
from pathlib import Path
class ConvLayer():
def __init__(self, aperture, zin=0, zout=0):
self._aperture = aperture
self._zin = zin
self._zout = zout
def __str__(self):
return f'Aperture = {self._aperture} | input depth = {self._zin:4} | filters = {self._zout:4}'
def get_layer(layers : list, name : str) -> dict:
for l in layers:
if l['name'] == name:
return l
raise IndexError(f'Layer with name {str} does not exist!')
def get_zin(layers : list, inbound_nodes : dict) -> int:
in_layer_name = inbound_nodes[0][0][0]
if in_layer_name == 'input_1':
return 3
l = get_layer(layers, in_layer_name)
if l['class_name'] == 'Conv2D':
return l['config']['filters']
else:
return get_zin(layers, l['inbound_nodes'])
def calc_stat(file_name):
# fname = "configs/nn_struct_yolo4.pkl"
fname = file_name
with open(fname, "rb") as f:
struct = pickle.load(f)
layers = struct['layers']
layer_types = []
conv_layers = []
for l in layers:
if l['class_name'] not in layer_types:
layer_types.append(l['class_name'])
if l['class_name'] == 'Conv2D':
aperture = l['config']['kernel_size']
zin = get_zin(layers, l['inbound_nodes'])
zout = l['config']['filters']
conv_layers.append(ConvLayer(aperture, zin, zout))
layers_num = len(conv_layers)
tmp = {}
for l in conv_layers:
if str(l) not in tmp.keys():
tmp[str(l)] = 1
else:
tmp[str(l)] += 1
conv_layers = tmp
print(layer_types)
i = 0
for l in conv_layers.keys():
percents = conv_layers[l]/(layers_num/100)
print(f'{i:3}: {l} => {conv_layers[l]:3} = {percents:5.3}%')
i += 1
if __name__ == "__main__":
path = Path.cwd() / "configs/"
for net in path.glob('*.pkl'):
print(f'Statistics from file {net.name}')
calc_stat(net)
print("===================================================")