Skip to content

Commit

Permalink
args factories added to args.py
Browse files Browse the repository at this point in the history
  • Loading branch information
msalhab96 committed Jun 22, 2022
1 parent 42b3eac commit a0c035b
Showing 1 changed file with 43 additions and 43 deletions.
86 changes: 43 additions & 43 deletions args.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ def add_model_args(parser):
group.add_argument(
'--p_dopout', type=float, default=0.1
)
group.add_argument(
'--add_lnorm', type=bool, default=True
)
group.add_argument(
'--left_shift', type=int, default=-1,
help='The number below the center in the sliding attention'
Expand Down Expand Up @@ -62,13 +59,13 @@ def add_model_args(parser):
def add_training_args(parser):
group = parser.add_argument_group('Training')
group.add_argument(
'--train_path', type=str
'--train_path', type=str, required=True
)
group.add_argument(
'--test_path', type=str
'--test_path', type=str, required=True
)
group.add_argument(
'--checkpoint_dir', type=str
'--checkpoint_dir', type=str, required=True
)
group.add_argument(
'--pretrained_path', type=str, required=False, default=None
Expand All @@ -77,10 +74,10 @@ def add_training_args(parser):
'--steps_per_ckpt', type=int
)
group.add_argument(
'--epochs', type=int
'--epochs', type=int, required=True
)
group.add_argument(
'--batch_size', type=int
'--batch_size', type=int, required=True
)
group.add_argument(
'--opt_eps', type=float, default=1e-9
Expand All @@ -94,6 +91,9 @@ def add_training_args(parser):
group.add_argument(
'--opt_warmup_staps', type=int, default=4000
)
group.add_argument(
'--device', type=str, default='cuda'
)


def add_data_args(parser):
Expand Down Expand Up @@ -138,12 +138,12 @@ def get_args() -> dict:
def get_model_args(
args: dict, vocab_size: int, pad_idx: int, n_speakers: int
) -> dict:
d_model = args['d_model']
h = args['h']
p_dopout = args['p_dopout']
d_model = args.d_model
h = args.h
p_dopout = args.p_dopout
model_args = {
'n_layers': args['n_layers'],
'device': args['device']
'n_layers': args.n_layers,
'device': args.device
}
pos_emb_key = 'pos_emb_params'
encoder_key = 'encoder_params'
Expand Down Expand Up @@ -173,38 +173,38 @@ def get_model_args(
'd_model': d_model,
'vocab_size': vocab_size,
'pad_idx': pad_idx,
'device': args['device'],
'add_lnorm': args['add_lnorm']
'device': args.device,
'add_lnorm': args.add_lnorm
}
encoder_params[encoder_key] = {
'd_model': d_model,
'h': h,
'hidden_size': args['hidden_size'],
'hidden_size': args.hidden_size,
'p_dopout': p_dopout
}
decoder_params[decoder_key] = {
'd_model': d_model,
'h': h,
'p_dropout': p_dopout,
'left_shift': args['left_shift'],
'right_shift': args['right_shift'],
'max_steps': args['max_steps'],
'hidden_size': args['hidden_size']
'left_shift': args.left_shift,
'right_shift': args.right_shift,
'max_steps': args.max_steps,
'hidden_size': args.hidden_size
}
speaker_mod_params[speaker_mod_key] = {
'n_speakers': n_speakers,
'emb_size': args['spk_emb_size'],
'emb_size': args.spk_emb_size,
'd_model': d_model
}
prenet_params[prenet_key] = {
'inp_size': args['n_mels'],
'bottleneck_size': args['bottleneck_size'],
'inp_size': args.n_mels,
'bottleneck_size': args.bottleneck_size,
'd_model': d_model,
'p_dropout': p_dopout
}
pred_params[pred_key] = {
'd_model': d_model,
'n_mels': args['n_mels']
'n_mels': args.n_mels
}
return {
**model_args,
Expand All @@ -219,43 +219,43 @@ def get_model_args(

def get_loss_args(args: dict) -> dict:
return {
'h': args['h'],
'dc_strength': args['ldc_lambda'],
'dc_bandwidth': args['att_bandwidth'],
'stop_weight': args['stop_weight']
'h': args.h,
'dc_strength': args.ldc_lambda,
'dc_bandwidth': args.att_bandwidth,
'stop_weight': args.stop_weight
}


def get_optim_args(args: dict) -> dict:
return {
'betas': (args['opt_beta1'], args['opt_beta2']),
'eps': args['opt_eps'],
'warmup_staps': args['warmup_staps'],
'd_model': args['d_model']
'betas': (args.opt_beta1, args.opt_beta2),
'eps': args.opt_eps,
'warmup_staps': args.opt_warmup_staps,
'd_model': args.d_model
}


def get_aud_args(args: dict) -> dict:
return {
'sampling_rate': args['sampling_rate'],
'win_size': args['window_size'],
'hop_size': args['hop_size'],
'n_mels': args['n_mels'],
'n_fft': args['n_fft']
'sampling_rate': args.sampling_rate,
'win_size': args.window_size,
'hop_size': args.hop_size,
'n_mels': args.n_mels,
'n_fft': args.n_fft
}


def get_data_args(args: dict) -> dict:
return {
'sep': args['sep'],
'batch_size': args['batch_size']
'sep': args.sep,
'batch_size': args.batch_size
}


def get_trainer_args(args: dict) -> dict:
return {
'save_dir': args['checkpoint_dir'],
'steps_per_ckpt': args['steps_per_ckpt'],
'epochs': args['epochs'],
'device': args['device']
'save_dir': args.checkpoint_dir,
'steps_per_ckpt': args.steps_per_ckpt,
'epochs': args.epochs,
'device': args.device
}

0 comments on commit a0c035b

Please sign in to comment.