-
Notifications
You must be signed in to change notification settings - Fork 0
/
hparams.py
105 lines (89 loc) · 3.22 KB
/
hparams.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from text import symbols
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def create_hparams(hparams_string=None, verbose=False):
"""Create model hyperparameters. Parse nondefault from given string."""
hparams = AttrDict({
################################
# Experiment Parameters #
################################
"epochs":500,
"iters_per_checkpoint":500,
"seed":1234,
"dynamic_loss_scaling":True,
"fp16_run":False,
"distributed_run":False,
"dist_backend":"nccl",
"dist_url":"tcp://localhost:14897",
"cudnn_enabled":True,
"cudnn_benchmark":False,
"ignore_layers":['embedding.weight'],
#"freeze_layers":['encoder'], # Freeze tacotron2 layer for finetuning
################################
# Data Parameters #
################################
"load_mel_from_disk":False,
"load_phone_from_disk":True,
"training_files":'filelists/train.txt',
"validation_files":'filelists/test.txt',
"text_cleaners":['transliteration_cleaners'],
################################
# Audio Parameters #
################################
"max_wav_value":32768.0,
"sampling_rate":22050,
"filter_length":1024,
"hop_length":256,
"win_length":1024,
"n_mel_channels":80,
"mel_fmin":0.0,
"mel_fmax":8000.0,
################################
# Model Parameters #
################################
"n_symbols": len(symbols),
"symbols_embedding_dim":512,
"alignloss": "L2",
"attention": "StepwiseMonotonicAttention",
# Encoder parameters
"encoder_kernel_size":5,
"encoder_n_convolutions":3,
"encoder_embedding_dim":512,
# Decoder parameters
"n_frames_per_step":1, # currently only 1 is supported
"decoder_rnn_dim":1024,
"prenet_dim":256,
"max_decoder_steps":1000,
"gate_threshold":0.5,
"p_attention_dropout":0.1,
"p_decoder_dropout":0.1,
# Attention parameters
"attention_rnn_dim":1024,
"attention_dim":128,
# Location Layer parameters
"attention_location_n_filters":32,
"attention_location_kernel_size":31,
# Mel-post processing network parameters
"postnet_embedding_dim":512,
"postnet_kernel_size":5,
"postnet_n_convolutions":5,
################################
# Optimization Hyperparameters #
################################
"use_saved_learning_rate":True,
"learning_rate":1e-3,
"weight_decay":1e-6,
"grad_clip_thresh":1.0,
"batch_size":8, # each gpus
"mask_padding":True # set model's padded outputs to padded values
})
if hparams_string:
hps = hparams_string[1:-2].split("-")
for hp in hps:
k,v = hp.split(":")
if k in hparams:
hparams[k] = v
print("Set hparam: " + k + " to " + v)
return hparams