-
Notifications
You must be signed in to change notification settings - Fork 641
/
convert_weights.py
140 lines (101 loc) · 4.03 KB
/
convert_weights.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import argparse
import numpy as np
import os
import tensorflow as tf
from AnimeGANv2.net import generator as tf_generator
import torch
from model import Generator
def load_tf_weights(tf_path):
test_real = tf.placeholder(tf.float32, [1, None, None, 3], name='test')
with tf.variable_scope("generator", reuse=False):
test_generated = tf_generator.G_net(test_real).fake
saver = tf.train.Saver()
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True, device_count = {'GPU': 0})) as sess:
ckpt = tf.train.get_checkpoint_state(tf_path)
assert ckpt is not None and ckpt.model_checkpoint_path is not None, f"Failed to load checkpoint {tf_path}"
saver.restore(sess, ckpt.model_checkpoint_path)
print(f"Tensorflow model checkpoint {ckpt.model_checkpoint_path} loaded")
tf_weights = {}
for v in tf.trainable_variables():
tf_weights[v.name] = v.eval()
return tf_weights
def convert_keys(k):
# 1. divide tf weight name in three parts [block_idx, layer_idx, weight/bias]
# 2. handle each part & merge into a pytorch model keys
k = k.replace("Conv/", "Conv_0/").replace("LayerNorm/", "LayerNorm_0/")
keys = k.split("/")[2:]
is_dconv = False
# handle C block..
if keys[0] == "C":
if keys[1] in ["Conv_1", "LayerNorm_1"]:
keys[1] = keys[1].replace("1", "5")
if len(keys) == 4:
assert "r" in keys[1]
if keys[1] == keys[2]:
is_dconv = True
keys[2] = "1.1"
block_c_maps = {
"1": "1.2",
"Conv_1": "2",
"2": "3",
}
if keys[2] in block_c_maps:
keys[2] = block_c_maps[keys[2]]
keys[1] = keys[1].replace("r", "") + ".layers." + keys[2]
keys[2] = keys[3]
keys.pop(-1)
assert len(keys) == 3
# handle output block
if "out" in keys[0]:
keys[1] = "0"
# first part
if keys[0] in ["A", "B", "C", "D", "E"]:
keys[0] = "block_" + keys[0].lower()
# second part
if "LayerNorm_" in keys[1]:
keys[1] = keys[1].replace("LayerNorm_", "") + ".2"
if "Conv_" in keys[1]:
keys[1] = keys[1].replace("Conv_", "") + ".1"
# third part
keys[2] = {
"weights:0": "weight",
"w:0": "weight",
"bias:0": "bias",
"gamma:0": "weight",
"beta:0": "bias",
}[keys[2]]
return ".".join(keys), is_dconv
def convert_and_save(tf_checkpoint_path, save_name):
tf_weights = load_tf_weights(tf_checkpoint_path)
torch_net = Generator()
torch_weights = torch_net.state_dict()
torch_converted_weights = {}
for k, v in tf_weights.items():
torch_k, is_dconv = convert_keys(k)
assert torch_k in torch_weights, f"weight name mismatch: {k}"
converted_weight = torch.from_numpy(v)
if len(converted_weight.shape) == 4:
if is_dconv:
converted_weight = converted_weight.permute(2, 3, 0, 1)
else:
converted_weight = converted_weight.permute(3, 2, 0, 1)
assert torch_weights[torch_k].shape == converted_weight.shape, f"shape mismatch: {k}"
torch_converted_weights[torch_k] = converted_weight
assert sorted(list(torch_converted_weights)) == sorted(list(torch_weights)), f"some weights are missing"
torch_net.load_state_dict(torch_converted_weights)
torch.save(torch_net.state_dict(), save_name)
print(f"PyTorch model saved at {save_name}")
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--tf_checkpoint_path',
type=str,
default='AnimeGANv2/checkpoint/generator_Paprika_weight',
)
parser.add_argument(
'--save_name',
type=str,
default='pytorch_generator_Paprika.pt',
)
args = parser.parse_args()
convert_and_save(args.tf_checkpoint_path, args.save_name)