forked from rwth-i6/returnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
NetworkCopyUtils.py
40 lines (31 loc) · 1.91 KB
/
NetworkCopyUtils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from Log import log
class LayerDoNotMatchForCopy(Exception): pass
def intelli_copy_layer(old_layer, new_layer):
"""
:type old_layer: NetworkBaseLayer.Layer
:type new_layer: NetworkBaseLayer.Layer
Copies from old_layer to new_layer.
We support slightly different param names. That can happen because the param names
could encode the source/target layer number, e.g. named "hidden_N".
Thus we need to translate the parameter names for the new network.
For the translation, we expect that a sorted list of the old output source layer names
matches the related list of new output source layer names.
"""
old_output_param_names = sorted(old_layer.params.keys())
new_output_param_names = sorted(new_layer.params.keys())
if len(old_output_param_names) != len(new_output_param_names):
raise LayerDoNotMatchForCopy("num parameters do not match. old layer: %s, new layer: %s" %
(old_output_param_names, new_output_param_names))
new_output_param_name_map = {old_param_name: new_param_name
for old_param_name, new_param_name in zip(old_output_param_names,
new_output_param_names)}
print >> log.v5, "Copy map: %s" % sorted(new_output_param_name_map.items())
old_output_params = old_layer.get_params_dict()
new_output_params = {new_output_param_name_map[old_param_name]: param
for old_param_name, param in old_output_params.items()}
for p, v in new_output_params.items():
self_param_shape = new_layer.params[p].get_value(borrow=True, return_internal_type=True).shape
if self_param_shape != v.shape:
raise LayerDoNotMatchForCopy("In %s, param %s shape does not match. Expected (new layer) %s, got (old layer) %s." %
(new_layer, p, self_param_shape, v.shape))
new_layer.params[p].set_value(v, borrow=True)