-
Notifications
You must be signed in to change notification settings - Fork 111
/
gen_wav.py
113 lines (97 loc) · 3.98 KB
/
gen_wav.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import numpy as np
from vocoder.bigvgan.models import VocoderBigVGAN
from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config
from omegaconf import OmegaConf
import argparse
import soundfile
device = 'cuda' # change to 'cpu‘ if you do not have gpu. generating with cpu is very slow.
SAMPLE_RATE = 16000
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--prompt",
type=str,
default="a bird chirps",
help="the prompt to generate audio"
)
parser.add_argument(
"--ddim_steps",
type=int,
default=100,
help="number of ddim sampling steps",
)
parser.add_argument(
"--duration",
type=int,
default=10,
help="audio duration, seconds",
)
parser.add_argument(
"--n_samples",
type=int,
default=1,
help="how many samples to produce for the given prompt",
)
parser.add_argument(
"--scale",
type=float,
default=3.0, # if it's 1, only condition is taken into consideration
help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
)
parser.add_argument(
"--save_name",
type=str,
default='test',
help="audio path name for saving",
)
return parser.parse_args()
def initialize_model(config, ckpt,device=device):
config = OmegaConf.load(config)
model = instantiate_from_config(config.model)
model.load_state_dict(torch.load(ckpt,map_location='cpu')["state_dict"], strict=False)
model = model.to(device)
model.cond_stage_model.to(model.device)
model.cond_stage_model.device = model.device
print(model.device,device,model.cond_stage_model.device)
sampler = DDIMSampler(model)
return sampler
def dur_to_size(duration):
latent_width = int(duration * 7.8)
if latent_width % 4 != 0:
latent_width = (latent_width // 4 + 1) * 4
return latent_width
def gen_wav(sampler,vocoder,prompt,ddim_steps,scale,duration,n_samples):
latent_width = dur_to_size(duration)
start_code = torch.randn(n_samples, sampler.model.first_stage_model.embed_dim, 10, latent_width).to(device=device, dtype=torch.float32)
uc = None
if scale != 1.0:
uc = sampler.model.get_learned_conditioning(n_samples * [""])
c = sampler.model.get_learned_conditioning(n_samples * [prompt])
shape = [sampler.model.first_stage_model.embed_dim, 10, latent_width] # 10 is latent height
samples_ddim, _ = sampler.sample(S=ddim_steps,
conditioning=c,
batch_size=n_samples,
shape=shape,
verbose=False,
unconditional_guidance_scale=scale,
unconditional_conditioning=uc,
x_T=start_code)
x_samples_ddim = sampler.model.decode_first_stage(samples_ddim)
wav_list = []
for idx,spec in enumerate(x_samples_ddim):
wav = vocoder.vocode(spec)
if len(wav) < SAMPLE_RATE * duration:
wav = np.pad(wav,SAMPLE_RATE*duration-len(wav),mode='constant',constant_values=0)
wav_list.append(wav)
return wav_list
if __name__ == '__main__':
args = parse_args()
sampler = initialize_model('configs/text_to_audio/txt2audio_args.yaml', 'useful_ckpts/maa1_full.ckpt')
vocoder = VocoderBigVGAN('useful_ckpts/bigvgan',device=device)
print("Generating audios, it may takes a long time depending on your gpu performance")
wav_list = gen_wav(sampler,vocoder,prompt=args.prompt,ddim_steps=args.ddim_steps,scale=args.scale,duration=args.duration,n_samples=args.n_samples)
for idx,wav in enumerate(wav_list):
soundfile.write(f'{args.save_name}_{idx}.wav',wav,samplerate=SAMPLE_RATE)
print(f"audios are saved in {args.save_name}_i.wav")