From f770922c4ec550e67c08b45ff51981703198a21b Mon Sep 17 00:00:00 2001 From: Aarni Koskela Date: Mon, 4 Dec 2023 10:38:07 +0200 Subject: [PATCH] Run `make style` --- TTS/demos/xtts_ft_demo/utils/formatter.py | 39 ++++---- TTS/demos/xtts_ft_demo/utils/gpt_train.py | 9 +- TTS/demos/xtts_ft_demo/xtts_demo.py | 111 +++++++++++++--------- 3 files changed, 92 insertions(+), 67 deletions(-) diff --git a/TTS/demos/xtts_ft_demo/utils/formatter.py b/TTS/demos/xtts_ft_demo/utils/formatter.py index 6d7b770ff5..40e8b8ed32 100644 --- a/TTS/demos/xtts_ft_demo/utils/formatter.py +++ b/TTS/demos/xtts_ft_demo/utils/formatter.py @@ -19,9 +19,10 @@ def list_audios(basePath, contains=None): # return the set of files that are valid return list_files(basePath, validExts=audio_types, contains=contains) + def list_files(basePath, validExts=None, contains=None): # loop over the directory structure - for (rootDir, dirNames, filenames) in os.walk(basePath): + for rootDir, dirNames, filenames in os.walk(basePath): # loop over the filenames in the current directory for filename in filenames: # if the contains string is not none and the filename does not contain @@ -30,7 +31,7 @@ def list_files(basePath, validExts=None, contains=None): continue # determine the file extension of the current file - ext = filename[filename.rfind("."):].lower() + ext = filename[filename.rfind(".") :].lower() # check to see if the file is an audio and should be processed if validExts is None or ext.endswith(validExts): @@ -38,7 +39,16 @@ def list_files(basePath, validExts=None, contains=None): audioPath = os.path.join(rootDir, filename) yield audioPath -def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None): + +def format_audio_list( + audio_files, + target_language="en", + out_path=None, + buffer=0.2, + eval_percentage=0.15, + speaker_name="coqui", + gradio_progress=None, +): audio_total_size = 0 # make sure that ooutput file exists os.makedirs(out_path, exist_ok=True) @@ -63,7 +73,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0 wav = torch.mean(wav, dim=0, keepdim=True) wav = wav.squeeze() - audio_total_size += (wav.size(-1) / sr) + audio_total_size += wav.size(-1) / sr segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language) segments = list(segments) @@ -88,7 +98,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0 # get previous sentence end previous_word_end = words_list[word_idx - 1].end # add buffer or get the silence midle between the previous sentence and the current one - sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start)/2) + sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start) / 2) sentence = word.word first_word = False @@ -112,19 +122,16 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0 # Average the current word end and next word start word_end = min((word.end + next_word_start) / 2, word.end + buffer) - + absoulte_path = os.path.join(out_path, audio_file) os.makedirs(os.path.dirname(absoulte_path), exist_ok=True) i += 1 first_word = True - audio = wav[int(sr*sentence_start):int(sr*word_end)].unsqueeze(0) + audio = wav[int(sr * sentence_start) : int(sr * word_end)].unsqueeze(0) # if the audio is too short ignore it (i.e < 0.33 seconds) - if audio.size(-1) >= sr/3: - torchaudio.save(absoulte_path, - audio, - sr - ) + if audio.size(-1) >= sr / 3: + torchaudio.save(absoulte_path, audio, sr) else: continue @@ -134,21 +141,21 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0 df = pandas.DataFrame(metadata) df = df.sample(frac=1) - num_val_samples = int(len(df)*eval_percentage) + num_val_samples = int(len(df) * eval_percentage) df_eval = df[:num_val_samples] df_train = df[num_val_samples:] - df_train = df_train.sort_values('audio_file') + df_train = df_train.sort_values("audio_file") train_metadata_path = os.path.join(out_path, "metadata_train.csv") df_train.to_csv(train_metadata_path, sep="|", index=False) eval_metadata_path = os.path.join(out_path, "metadata_eval.csv") - df_eval = df_eval.sort_values('audio_file') + df_eval = df_eval.sort_values("audio_file") df_eval.to_csv(eval_metadata_path, sep="|", index=False) # deallocate VRAM and RAM del asr_model, df_train, df_eval, df, metadata gc.collect() - return train_metadata_path, eval_metadata_path, audio_total_size \ No newline at end of file + return train_metadata_path, eval_metadata_path, audio_total_size diff --git a/TTS/demos/xtts_ft_demo/utils/gpt_train.py b/TTS/demos/xtts_ft_demo/utils/gpt_train.py index 80be4fab40..7b41966b8f 100644 --- a/TTS/demos/xtts_ft_demo/utils/gpt_train.py +++ b/TTS/demos/xtts_ft_demo/utils/gpt_train.py @@ -25,7 +25,6 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, BATCH_SIZE = batch_size # set here the batch size GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps - # Define here the dataset that you want to use for the fine-tuning on. config_dataset = BaseDatasetConfig( formatter="coqui", @@ -43,7 +42,6 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/") os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True) - # DVAE files DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth" MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth" @@ -55,8 +53,9 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, # download DVAE files if needed if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE): print(" > Downloading DVAE files!") - ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True) - + ModelManager._download_model_files( + [MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True + ) # Download XTTS v2.0 checkpoint if needed TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json" @@ -160,7 +159,7 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, # get the longest text audio file to use as speaker reference samples_len = [len(item["text"].split(" ")) for item in train_samples] - longest_text_idx = samples_len.index(max(samples_len)) + longest_text_idx = samples_len.index(max(samples_len)) speaker_ref = train_samples[longest_text_idx]["audio_file"] trainer_out_path = trainer.output_path diff --git a/TTS/demos/xtts_ft_demo/xtts_demo.py b/TTS/demos/xtts_ft_demo/xtts_demo.py index b8ffb231dd..85168c641d 100644 --- a/TTS/demos/xtts_ft_demo/xtts_demo.py +++ b/TTS/demos/xtts_ft_demo/xtts_demo.py @@ -20,7 +20,10 @@ def clear_gpu_cache(): if torch.cuda.is_available(): torch.cuda.empty_cache() + XTTS_MODEL = None + + def load_model(xtts_checkpoint, xtts_config, xtts_vocab): global XTTS_MODEL clear_gpu_cache() @@ -37,17 +40,23 @@ def load_model(xtts_checkpoint, xtts_config, xtts_vocab): print("Model Loaded!") return "Model Loaded!" + def run_tts(lang, tts_text, speaker_audio_file): if XTTS_MODEL is None or not speaker_audio_file: return "You need to run the previous step to load the model !!", None, None - gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs) + gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents( + audio_path=speaker_audio_file, + gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, + max_ref_length=XTTS_MODEL.config.max_ref_len, + sound_norm_refs=XTTS_MODEL.config.sound_norm_refs, + ) out = XTTS_MODEL.inference( text=tts_text, language=lang, gpt_cond_latent=gpt_cond_latent, speaker_embedding=speaker_embedding, - temperature=XTTS_MODEL.config.temperature, # Add custom parameters here + temperature=XTTS_MODEL.config.temperature, # Add custom parameters here length_penalty=XTTS_MODEL.config.length_penalty, repetition_penalty=XTTS_MODEL.config.repetition_penalty, top_k=XTTS_MODEL.config.top_k, @@ -62,8 +71,6 @@ def run_tts(lang, tts_text, speaker_audio_file): return "Speech generated !", out_path, speaker_audio_file - - # define a logger to redirect class Logger: def __init__(self, filename="log.out"): @@ -82,6 +89,7 @@ def flush(self): def isatty(self): return False + # redirect stdout and stderr to a file sys.stdout = Logger() sys.stderr = sys.stdout @@ -90,13 +98,10 @@ def isatty(self): # logging.basicConfig(stream=sys.stdout, level=logging.INFO) logging.basicConfig( - level=logging.INFO, - format="%(asctime)s [%(levelname)s] %(message)s", - handlers=[ - logging.StreamHandler(sys.stdout) - ] + level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", handlers=[logging.StreamHandler(sys.stdout)] ) + def read_logs(): sys.stdout.flush() with open(sys.stdout.log_file, "r") as f: @@ -104,7 +109,6 @@ def read_logs(): if __name__ == "__main__": - parser = argparse.ArgumentParser( description="""XTTS fine-tuning demo\n\n""" """ @@ -187,12 +191,10 @@ def read_logs(): "zh", "hu", "ko", - "ja" + "ja", ], ) - progress_data = gr.Label( - label="Progress:" - ) + progress_data = gr.Label(label="Progress:") logs = gr.Textbox( label="Logs:", interactive=False, @@ -200,20 +202,30 @@ def read_logs(): demo.load(read_logs, None, logs, every=1) prompt_compute_btn = gr.Button(value="Step 1 - Create dataset") - + def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(track_tqdm=True)): clear_gpu_cache() out_path = os.path.join(out_path, "dataset") os.makedirs(out_path, exist_ok=True) if audio_path is None: - return "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", "", "" + return ( + "You should provide one or multiple audio files! If you provided it, probably the upload of the files is not finished yet!", + "", + "", + ) else: try: - train_meta, eval_meta, audio_total_size = format_audio_list(audio_path, target_language=language, out_path=out_path, gradio_progress=progress) + train_meta, eval_meta, audio_total_size = format_audio_list( + audio_path, target_language=language, out_path=out_path, gradio_progress=progress + ) except: traceback.print_exc() error = traceback.format_exc() - return f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", "", "" + return ( + f"The data processing was interrupted due an error !! Please check the console to verify the full error message! \n Error summary: {error}", + "", + "", + ) clear_gpu_cache() @@ -233,7 +245,7 @@ def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(trac eval_csv = gr.Textbox( label="Eval CSV:", ) - num_epochs = gr.Slider( + num_epochs = gr.Slider( label="Number of epochs:", minimum=1, maximum=100, @@ -261,9 +273,7 @@ def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(trac step=1, value=args.max_audio_length, ) - progress_train = gr.Label( - label="Progress:" - ) + progress_train = gr.Label(label="Progress:") logs_tts_train = gr.Textbox( label="Logs:", interactive=False, @@ -271,18 +281,41 @@ def preprocess_dataset(audio_path, language, out_path, progress=gr.Progress(trac demo.load(read_logs, None, logs_tts_train, every=1) train_btn = gr.Button(value="Step 2 - Run the training") - def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length): + def train_model( + language, train_csv, eval_csv, num_epochs, batch_size, grad_acumm, output_path, max_audio_length + ): clear_gpu_cache() if not train_csv or not eval_csv: - return "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", "", "", "", "" + return ( + "You need to run the data processing step or manually set `Train CSV` and `Eval CSV` fields !", + "", + "", + "", + "", + ) try: # convert seconds to waveform frames max_audio_length = int(max_audio_length * 22050) - config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, output_path=output_path, max_audio_length=max_audio_length) + config_path, original_xtts_checkpoint, vocab_file, exp_path, speaker_wav = train_gpt( + language, + num_epochs, + batch_size, + grad_acumm, + train_csv, + eval_csv, + output_path=output_path, + max_audio_length=max_audio_length, + ) except: traceback.print_exc() error = traceback.format_exc() - return f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", "", "", "", "" + return ( + f"The training was interrupted due an error !! Please check the console to check the full error message! \n Error summary: {error}", + "", + "", + "", + "", + ) # copy original files to avoid parameters changes issues os.system(f"cp {config_path} {exp_path}") @@ -309,9 +342,7 @@ def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acum label="XTTS vocab path:", value="", ) - progress_load = gr.Label( - label="Progress:" - ) + progress_load = gr.Label(label="Progress:") load_btn = gr.Button(value="Step 3 - Load Fine-tuned XTTS model") with gr.Column() as col2: @@ -339,7 +370,7 @@ def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acum "hu", "ko", "ja", - ] + ], ) tts_text = gr.Textbox( label="Input Text.", @@ -348,9 +379,7 @@ def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acum tts_btn = gr.Button(value="Step 4 - Inference") with gr.Column() as col3: - progress_gen = gr.Label( - label="Progress:" - ) + progress_gen = gr.Label(label="Progress:") tts_output_audio = gr.Audio(label="Generated Audio.") reference_audio = gr.Audio(label="Reference audio used.") @@ -368,7 +397,6 @@ def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acum ], ) - train_btn.click( fn=train_model, inputs=[ @@ -383,14 +411,10 @@ def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acum ], outputs=[progress_train, xtts_config, xtts_vocab, xtts_checkpoint, speaker_reference_audio], ) - + load_btn.click( fn=load_model, - inputs=[ - xtts_checkpoint, - xtts_config, - xtts_vocab - ], + inputs=[xtts_checkpoint, xtts_config, xtts_vocab], outputs=[progress_load], ) @@ -404,9 +428,4 @@ def train_model(language, train_csv, eval_csv, num_epochs, batch_size, grad_acum outputs=[progress_gen, tts_output_audio, reference_audio], ) - demo.launch( - share=True, - debug=False, - server_port=args.port, - server_name="0.0.0.0" - ) + demo.launch(share=True, debug=False, server_port=args.port, server_name="0.0.0.0")