forked from Saketspradhan/EECS-504-F23
-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
94 lines (75 loc) · 3.96 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import torch
from time import strftime
import os, sys, time
from argparse import ArgumentParser
from preprocess import CropAndExtract
from test_audio2coeff import Audio2Coeff
from facerender.animate import AnimateFromCoeff
from generate_batch import get_data
from generate_facerender_batch import get_facerender_data
def main(args):
#torch.backends.cudnn.enabled = False
pic_path = args.source_image
audio_path = args.driven_audio
save_dir = os.path.join(args.result_dir, "EECS-504-F23-results")
os.makedirs(save_dir, exist_ok=True)
pose_style = 0
device = args.device
batch_size = 8
camera_yaw_list = [0]
camera_pitch_list = [0]
camera_roll_list = [0]
current_code_path = sys.argv[0]
current_root_path = os.path.split(current_code_path)[0]
os.environ['TORCH_HOME']=os.path.join(current_root_path, 'checkpoints')
path_of_lm_croper = os.path.join(current_root_path, 'checkpoints', 'shape_predictor_68_face_landmarks.dat')
path_of_net_recon_model = os.path.join(current_root_path, 'checkpoints', 'epoch_20.pth')
dir_of_BFM_fitting = os.path.join(current_root_path, 'checkpoints', 'BFM_Fitting')
wav2lip_checkpoint = os.path.join(current_root_path, 'checkpoints', 'wav2lip.pth')
audio2pose_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2pose_00140-model.pth')
audio2pose_yaml_path = os.path.join(current_root_path, 'config', 'auido2pose.yaml')
audio2exp_checkpoint = os.path.join(current_root_path, 'checkpoints', 'auido2exp_00300-model.pth')
audio2exp_yaml_path = os.path.join(current_root_path, 'config', 'auido2exp.yaml')
free_view_checkpoint = os.path.join(current_root_path, 'checkpoints', 'facevid2vid_00189-model.pth.tar')
mapping_checkpoint = os.path.join(current_root_path, 'checkpoints', 'mapping_00229-model.pth.tar')
facerender_yaml_path = os.path.join(current_root_path, 'config', 'facerender.yaml')
#init model
print(path_of_net_recon_model)
preprocess_model = CropAndExtract(path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device)
print(audio2pose_checkpoint)
print(audio2exp_checkpoint)
audio_to_coeff = Audio2Coeff(audio2pose_checkpoint, audio2pose_yaml_path,
audio2exp_checkpoint, audio2exp_yaml_path,
wav2lip_checkpoint, device)
print(free_view_checkpoint)
print(mapping_checkpoint)
animate_from_coeff = AnimateFromCoeff(free_view_checkpoint, mapping_checkpoint,
facerender_yaml_path, device)
#crop image and extract 3dmm from image
first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
os.makedirs(first_frame_dir, exist_ok=True)
first_coeff_path, crop_pic_path = preprocess_model.generate(pic_path, first_frame_dir)
if first_coeff_path is None:
print("Can't get the coeffs of the input")
return
#audio2ceoff
batch = get_data(first_coeff_path, audio_path, device)
coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style)
#coeff2video
data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path,
batch_size, camera_yaw_list, camera_pitch_list, camera_roll_list)
animate_from_coeff.generate(data, save_dir)
video_name = data['video_name']
print(f'The generated video is named {video_name} in {save_dir}')
if __name__ == '__main__':
parser = ArgumentParser()
parser.add_argument("--driven_audio", default='./examples/', help="path to driven audio")
parser.add_argument("--source_image", default='./examples/', help="path to source image")
parser.add_argument("--result_dir", default='./examples/results', help="path to output")
parser.add_argument("--cpu", dest="cpu", action="store_true")
args = parser.parse_args()
if torch.cuda.is_available() and not args.cpu:
args.device = "cuda"
else:
args.device = "cpu"
main(args)