forked from erew123/alltalk_tts
-
Notifications
You must be signed in to change notification settings - Fork 0
/
finetune.py
1888 lines (1682 loc) · 102 KB
/
finetune.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import argparse
import os
import sys
import platform
import site
import tempfile
import signal
import gradio as gr
import torch
import torchaudio
import traceback
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
import random
import gc
import time
import shutil
import psutil
import pandas
import glob
import json
from pathlib import Path
from tqdm import tqdm
from faster_whisper import WhisperModel
# Use a local Tokenizer to resolve Japanese support
# from TTS.tts.layers.xtts.tokenizer import multilingual_cleaners
from system.ft_tokenizer.tokenizer import multilingual_cleaners
import importlib.metadata as metadata
from packaging import version
# STARTUP VARIABLES
this_dir = Path(__file__).parent.resolve()
audio_folder = this_dir / "finetune" / "put-voice-samples-in-here"
out_path = this_dir / "finetune" / "tmp-trn"
progress = 0
theme = gr.themes.Default()
refresh_symbol = '🔄'
os.environ['TRAINER_TELEMETRY'] = '0'
pfc_status = "pass"
# Define the path to the modeldownload config file file
modeldownload_config_file_path = this_dir / "modeldownload.json"
# Check if the JSON file exists
if modeldownload_config_file_path.exists():
with open(modeldownload_config_file_path, "r") as config_file:
settings = json.load(config_file)
# Extract settings from the loaded JSON
base_path = Path(settings.get("base_path", ""))
model_path = Path(settings.get("model_path", ""))
base_model_path = Path(settings.get("model_path", ""))
files_to_download = settings.get("files_to_download", {})
else:
# Default settings if the JSON file doesn't exist or is empty
print("[FINETUNE] \033[91mWarning\033[0m modeldownload.json is missing. Please run this script in the /alltalk_tts/ folder")
sys.exit(1)
##################################################
#### Check to see if a finetuned model exists ####
##################################################
# Set the path to the directory
trained_model_directory = this_dir / "models" / "trainedmodel"
# Check if the directory "trainedmodel" exists
finetuned_model = trained_model_directory.exists()
# If the directory exists, check for the existence of the required files
# If true, this will add a extra option in the Gradio interface for loading Xttsv2 FT
if finetuned_model:
required_files = ["model.pth", "config.json", "vocab.json", "mel_stats.pth", "dvae.pth"]
finetuned_model = all((trained_model_directory / file).exists() for file in required_files)
basemodel_or_finetunedmodel = True
#######################
#### DIAGS for PFC ####
#######################
def check_disk_space():
global pfc_status
# Get the current working directory
current_directory = os.getcwd()
# Get the disk usage statistics for the current directory's disk
disk_usage = shutil.disk_usage(current_directory)
# Convert the free space to GB (1GB = 1 << 30 bytes)
free_space_gb = disk_usage.free / (1 << 30)
# Check if the free space is more than 18GB
is_more_than_18gb = free_space_gb > 18
disk_space_icon = "✅"
if not is_more_than_18gb:
disk_space_icon ="❌"
pfc_status = "fail" # Update global status if disk space check fails
# Generating the markdown text for disk space check
disk_space_markdown = f"""
### 🟩 <u>Disk Space Check</u>
{disk_space_icon} **Disk Space (> 18 GB):** {'' if is_more_than_18gb else 'You have less than 18GB on this disk '} {free_space_gb:.2f} GB
"""
return disk_space_markdown
def test_cuda():
global pfc_status
cuda_home = os.environ.get('CUDA_HOME', 'N/A')
cuda_available = torch.cuda.is_available()
if cuda_available:
try:
# Attempt to create a tensor on GPU
torch.tensor([1.0, 2.0]).cuda()
cuda_status = "CUDA is available and working."
cuda_icon = "✅"
except Exception as e:
cuda_status = f"CUDA is available but not working. Error: {e}"
cuda_icon = "❌"
pfc_status = "fail" # Update global status
else:
cuda_status = "CUDA is not available."
pfc_status = "fail" # Update global status
return cuda_status, cuda_icon, cuda_home
def find_files_in_path_with_wildcard(pattern):
# Get the site-packages directory of the current Python environment
site_packages_path = site.getsitepackages()
found_paths = []
# Adjust the sub-directory based on the operating system
sub_directory = "nvidia/cublas"
if platform.system() == "Linux":
sub_directory = os.path.join(sub_directory, "lib")
else:
sub_directory = os.path.join(sub_directory, "bin")
# Iterate over each site-packages directory (there can be more than one)
for directory in site_packages_path:
# Construct the search directory path
search_directory = os.path.join(directory, sub_directory)
# Use glob to find all files matching the pattern in this directory
for file_path in glob.glob(os.path.join(search_directory, pattern)):
if os.path.isfile(file_path): # Ensure it's a file
found_paths.append(file_path)
return found_paths
def generate_cuda_markdown():
global pfc_status
cuda_status, cuda_icon, cuda_home = test_cuda()
file_name = 'cublas64_11.*' if platform.system() == "Windows" else 'libcublas.so.11*'
found_paths = find_files_in_path_with_wildcard(file_name)
if found_paths:
found_paths_str = ' '.join(found_paths)
found_path_icon = '✅'
else:
found_paths_str = "cublas64_11 is not accessible."
found_path_icon = '❌'
pfc_status = "fail" # Update global status
# Check if 'cu118' or 'cu121' is in the PyTorch version string
pytorch_version = torch.__version__
if 'cu118' in pytorch_version or 'cu121' in pytorch_version:
pytorch_cuda_version_status = ''
pytorch_icon = '✅'
else:
pytorch_cuda_version_status = 'Pytorch CUDA version problem '
pytorch_icon = '❌'
pfc_status = "fail" # Update global status
cuda_markdown = f"""
### 🟨 <u>CUDA Information</u><br>
{found_path_icon} **Cublas64_11 found:** {found_paths_str}
{pytorch_icon} **CUDA_HOME path:** {cuda_home}
"""
pytorch_markdown = f"""
### 🟦 <u>Python & Pytorch Information</u>
{pytorch_icon} **PyTorch Version:** {pytorch_cuda_version_status} {torch.__version__}
{cuda_icon} **CUDA is working:** {cuda_status}
"""
return cuda_markdown, pytorch_markdown
def get_system_ram_markdown():
global pfc_status
virtual_memory = psutil.virtual_memory()
total_ram_gb = virtual_memory.total / (1024 ** 3)
available_ram_gb = virtual_memory.available / (1024 ** 3)
used_ram_percentage = virtual_memory.percent
# Check if the available RAM is less than 8GB
warning_if_low_ram = available_ram_gb < 8
# Decide the message based on the available RAM
ram_status_message = "Warning" if warning_if_low_ram else ""
ram_status_icon = "⚠️" if warning_if_low_ram else "✅"
if torch.cuda.is_available():
gpu_device_id = torch.cuda.current_device()
gpu_device_name = torch.cuda.get_device_name(gpu_device_id)
# Get the total and available memory in bytes, then convert to GB
gpu_total_mem_gb = torch.cuda.get_device_properties(gpu_device_id).total_memory / (1024 ** 3)
# gpu_available_mem_gb = (torch.cuda.get_device_properties(gpu_device_id).total_memory - torch.cuda.memory_allocated(gpu_device_id)) / (1024 ** 3)
# gpu_available_mem_gb = (torch.cuda.get_device_properties(gpu_device_id).total_memory - torch.cuda.memory_reserved(gpu_device_id)) / (1024 ** 3)
gpu_reserved_mem_gb = torch.cuda.memory_reserved(gpu_device_id) / (1024 ** 3)
gpu_available_mem_gb = gpu_total_mem_gb - gpu_reserved_mem_gb
# Check if total or available memory is less than 11 GB and set icons
gpu_total_status_icon = "⚠️" if gpu_total_mem_gb < 12 else "✅"
gpu_available_status_icon = "⚠️" if gpu_available_mem_gb < 12 else "✅"
gpu_status_icon = "✅"
else:
gpu_status_icon = "⚠️"
gpu_device_name = "Cannot detect a CUDA card"
gpu_total_mem_gb = "Cannot detect a CUDA card"
gpu_available_mem_gb = "Cannot detect a CUDA card"
gpu_total_status_icon = gpu_status_icon
gpu_available_status_icon = gpu_status_icon
system_ram_markdown = f"""
### 🟪 <u>System RAM and VRAM Information</u> <br>
{ram_status_icon} **Total RAM:** {total_ram_gb:.2f} GB<br>
{ram_status_icon} **Available RAM:** {ram_status_message + ' - Available RAM is less than 8 GB. You have ' if warning_if_low_ram else ''} {available_ram_gb:.2f} GB available ({used_ram_percentage:.2f}% used)<br><br>
{gpu_status_icon} **GPU Name:** {gpu_device_name}<br>
{gpu_total_status_icon} **GPU Total RAM:** {gpu_total_mem_gb:.2f} GB<br>
{gpu_available_status_icon} **GPU Available RAM:** {gpu_available_mem_gb:.2f} GB<br>
"""
return system_ram_markdown
def check_base_model(base_model_path, files_to_download):
global pfc_status
# Assuming files_to_download is a dict with keys as filenames
base_model_files = list(files_to_download.keys())
missing_files = []
# Check if all base model files exist
for file in base_model_files:
file_path = this_dir / base_path / model_path / file
if not file_path.exists():
missing_files.append(file)
pfc_status = "fail"
return len(missing_files) == 0
# Assuming base_model_path and files_to_download are set from the JSON config as shown above
base_model_detected = check_base_model(base_model_path, files_to_download)
def generate_base_model_markdown(base_model_detected):
global pfc_status
base_model_status = 'Base model detected' if base_model_detected else 'Base model not detected'
base_model_icon = '✅' if base_model_detected else '❌'
base_model_markdown = f"""
### ⬛ <u>XTTS Base Model Detection</u>
{base_model_icon} **Base XTTS Model Status:** {base_model_status}
"""
return base_model_markdown
def check_tts_version(required_version="0.22.0"):
global pfc_status
try:
# Get the installed version of TTS
installed_version = metadata.version("tts")
# Check if the installed version meets the required version
if version.parse(installed_version) >= version.parse(required_version):
tts_status = f"TTS version {installed_version} is installed and meets the requirement."
tts_status_icon = "✅"
else:
tts_status = f"❌ Fail - TTS version {installed_version} is installed but does not meet the required version {required_version}."
tts_status_icon = "❌"
pfc_status = "fail" # Update global status
except metadata.PackageNotFoundError:
# If TTS is not installed
tts_status = "TTS is not installed."
pfc_status = "fail" # Update global status
tts_markdown = f"""
### 🟥 <u>TTS Information</u><br>
{tts_status_icon} **TTS Version:** {tts_status}
"""
return tts_markdown
# Disk space check results to append to the Markdown
disk_space_results = check_disk_space()
cuda_results, pytorch_results = generate_cuda_markdown()
system_ram_results = get_system_ram_markdown()
base_model_results = generate_base_model_markdown(base_model_detected)
tts_version_status = check_tts_version()
def pfc_check_fail():
global pfc_status
if pfc_status == "fail":
print("[FINETUNE]")
print("[FINETUNE] \033[91m****** WARNING PRE-FLIGHT CHECKS FAILED ******* WARNING PRE-FLIGHT CHECKS FAILED *****\033[0m")
print("[FINETUNE] \033[91m* Please refer to the \033[93mPre-flight check tab \033[91mand resolve any issues before continuing. *\033[0m")
print("[FINETUNE] \033[91m*********** Expect errors and failures if you do not resolve these issues. ***********\033[0m")
print("[FINETUNE]")
return
#####################
#### STEP 1 BITS ####
#####################
def create_temp_folder():
temp_folder = os.path.join(os.path.dirname(__file__), 'temp_files')
os.makedirs(temp_folder, exist_ok=True)
return temp_folder
def create_temporary_file(folder, suffix=".wav"):
unique_filename = f"custom_tempfile_{int(time.time())}_{random.randint(1, 1000)}{suffix}"
return os.path.join(folder, unique_filename)
def format_audio_list(target_language, whisper_model, out_path, eval_split_number, speaker_name_input, gradio_progress=progress):
pfc_check_fail()
audio_files = [os.path.join(audio_folder, file) for file in os.listdir(audio_folder) if file.endswith(('.mp3', '.flac', '.wav'))]
buffer=0.2
eval_percentage = eval_split_number / 100.0
speaker_name=speaker_name_input
audio_total_size = 0
os.makedirs(out_path, exist_ok=True)
temp_folder = os.path.join(out_path, "temp") # Update with your folder name
os.makedirs(temp_folder, exist_ok=True)
print("[FINETUNE] \033[94mPart of AllTalk\033[0m https://github.com/erew123/alltalk_tts/")
print("[FINETUNE] \033[94mCoqui Public Model License\033[0m")
print("[FINETUNE] \033[94mhttps://coqui.ai/cpml.txt\033[0m")
print(f"[FINETUNE] \033[94mWhisper model: \033[92m{whisper_model} \033[94mLanguage: \033[92m{target_language} \033[94mEvaluation data percentage: \033[92m{eval_split_number}%\033[0m")
print("[FINETUNE] \033[94mStarting Step 1\033[0m - Preparing Audio/Generating the dataset")
# Write the target language to lang.txt in the output directory
lang_file_path = os.path.join(out_path, "lang.txt")
# Check if lang.txt already exists and contains a different language
current_language = None
if os.path.exists(lang_file_path):
with open(lang_file_path, 'r', encoding='utf-8') as existing_lang_file:
current_language = existing_lang_file.read().strip()
if current_language != target_language:
# Only update lang.txt if target language is different from the current language
with open(lang_file_path, 'w', encoding='utf-8') as lang_file:
lang_file.write(target_language + '\n')
print("[FINETUNE] Updated lang.txt with the target language.")
else:
print("[FINETUNE] The existing language matches the target language")
# Loading Whisper
device = "cuda" if torch.cuda.is_available() else "cpu"
print("[FINETUNE] Loading Whisper Model:", whisper_model)
print("[FINETUNE] Model will be downloaded if its not available, which will take a few minutes.")
asr_model = WhisperModel(whisper_model, device=device, compute_type="float32")
metadata = {"audio_file": [], "text": [], "speaker_name": []}
existing_metadata = {'train': None, 'eval': None}
train_metadata_path = os.path.join(out_path, "metadata_train.csv")
eval_metadata_path = os.path.join(out_path, "metadata_eval.csv")
if os.path.exists(train_metadata_path):
existing_metadata['train'] = pandas.read_csv(train_metadata_path, sep="|")
print("[FINETUNE] Existing training metadata found and loaded.")
if os.path.exists(eval_metadata_path):
existing_metadata['eval'] = pandas.read_csv(eval_metadata_path, sep="|")
print("[FINETUNE] Existing evaluation metadata found and loaded.")
for idx, audio_path in tqdm(enumerate(audio_files)):
if isinstance(audio_path, str):
audio_file_name_without_ext, _ = os.path.splitext(os.path.basename(audio_path))
# If it's a string, it's already the path to the file
audio_path_name = audio_path
elif hasattr(audio_path, 'read'):
# If it has a 'read' attribute, treat it as a file-like object
# and use a temporary file to save its content
audio_file_name_without_ext, _ = os.path.splitext(os.path.basename(audio_path.name))
audio_path_name = create_temporary_file(temp_folder)
with open(audio_path, 'rb') as original_file:
file_content = original_file.read()
with open(audio_path_name, 'wb') as temp_file:
temp_file.write(file_content)
# Create a temporary file path within the new folder
temp_audio_path = create_temporary_file(temp_folder)
try:
if isinstance(audio_path, str):
audio_path_name = audio_path
elif hasattr(audio_path, 'name'):
audio_path_name = audio_path.name
else:
raise ValueError(f"Unsupported audio_path type: {type(audio_path)}")
except Exception as e:
print("[FINETUNE] Error reading original file: {e}")
# Handle the error or raise it if needed
print("[FINETUNE] Current working file:", audio_path_name)
try:
# Copy the audio content
time.sleep(0.5) # Introduce a small delay
shutil.copy2(audio_path_name, temp_audio_path)
except Exception as e:
print("[FINETUNE] Error copying file: {e}")
# Handle the error or raise it if needed
# Load the temporary audio file
wav, sr = torchaudio.load(temp_audio_path, format="wav")
wav = torch.as_tensor(wav).clone().detach().t().to(torch.float32), sr
prefix_check = f"wavs/{audio_file_name_without_ext}_"
# Check both training and evaluation metadata for an entry that starts with the file name.
skip_processing = False
for key in ['train', 'eval']:
if existing_metadata[key] is not None:
mask = existing_metadata[key]['audio_file'].str.startswith(prefix_check)
if mask.any():
print(f"[FINETUNE] Segments from {audio_file_name_without_ext} have been previously processed; skipping...")
skip_processing = True
break
# If we found that we've already processed this file before, continue to the next iteration.
if skip_processing:
continue
wav, sr = torchaudio.load(audio_path)
# stereo to mono if needed
if wav.size(0) != 1:
wav = torch.mean(wav, dim=0, keepdim=True)
wav = wav.squeeze()
audio_total_size += (wav.size(-1) / sr)
segments, _ = asr_model.transcribe(audio_path, vad_filter=True, word_timestamps=True, language=target_language)
segments = list(segments)
i = 0
sentence = ""
sentence_start = None
first_word = True
# added all segments words in a unique list
words_list = []
for _, segment in enumerate(segments):
words = list(segment.words)
words_list.extend(words)
# process each word
for word_idx, word in enumerate(words_list):
if first_word:
sentence_start = word.start
# If it is the first sentence, add buffer or get the beginning of the file
if word_idx == 0:
sentence_start = max(sentence_start - buffer, 0) # Add buffer to the sentence start
else:
# get the previous sentence end
previous_word_end = words_list[word_idx - 1].end
# add buffer or get the silence middle between the previous sentence and the current one
sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start) / 2)
sentence = word.word
first_word = False
else:
sentence += word.word
if word.word[-1] in ["!", ".", "?"]:
sentence = sentence[1:]
# Expand number and abbreviations plus normalization
sentence = multilingual_cleaners(sentence, target_language)
audio_file_name, _ = os.path.splitext(os.path.basename(audio_path))
audio_file = f"wavs/{audio_file_name}_{str(i).zfill(8)}.wav"
# Check for the next word's existence
if word_idx + 1 < len(words_list):
next_word_start = words_list[word_idx + 1].start
else:
# If don't have more words it means that it is the last sentence then use the audio len as next word start
next_word_start = (wav.shape[0] - 1) / sr
# Average the current word end and next word start
word_end = min((word.end + next_word_start) / 2, word.end + buffer)
absolute_path = os.path.join(out_path, audio_file)
os.makedirs(os.path.dirname(absolute_path), exist_ok=True)
i += 1
first_word = True
audio = wav[int(sr * sentence_start):int(sr * word_end)].unsqueeze(0)
# if the audio is too short, ignore it (i.e., < 0.33 seconds)
if audio.size(-1) >= sr / 3:
torchaudio.save(
absolute_path,
audio,
sr
)
else:
continue
metadata["audio_file"].append(audio_file)
metadata["text"].append(sentence)
metadata["speaker_name"].append(speaker_name)
os.remove(temp_audio_path)
if os.path.exists(train_metadata_path) and os.path.exists(eval_metadata_path):
existing_train_df = existing_metadata['train']
existing_eval_df = existing_metadata['eval']
audio_total_size = 121
else:
existing_train_df = pandas.DataFrame(columns=["audio_file", "text", "speaker_name"])
existing_eval_df = pandas.DataFrame(columns=["audio_file", "text", "speaker_name"])
new_data_df = pandas.DataFrame(metadata)
combined_train_df = pandas.concat([existing_train_df, new_data_df], ignore_index=True).drop_duplicates().reset_index(drop=True)
combined_train_df_shuffled = combined_train_df.sample(frac=1)
num_val_samples = int(len(combined_train_df_shuffled) * eval_percentage)
final_eval_set = combined_train_df_shuffled[:num_val_samples]
final_training_set = combined_train_df_shuffled[num_val_samples:]
final_training_set.sort_values('audio_file').to_csv(train_metadata_path, sep='|', index=False)
final_eval_set.sort_values('audio_file').to_csv(eval_metadata_path, sep='|', index=False)
# deallocate VRAM and RAM
del asr_model, final_eval_set, final_training_set, new_data_df, existing_metadata
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
existing_train_df = None
existing_eval_df = None
print("[FINETUNE] Train CSV:", train_metadata_path)
print("[FINETUNE] Eval CSV:", eval_metadata_path)
print("[FINETUNE] Audio Total:", audio_total_size)
return train_metadata_path, eval_metadata_path, audio_total_size
######################
#### STEP 2 BITS #####
######################
from trainer import Trainer, TrainerArgs
from TTS.config.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.layers.xtts.trainer.gpt_trainer import GPTArgs, GPTTrainer, GPTTrainerConfig, XttsAudioConfig
from TTS.utils.manage import ModelManager
def basemodel_or_finetunedmodel_choice(value):
global basemodel_or_finetunedmodel
if value == "Base Model":
basemodel_or_finetunedmodel = True
elif value == "Existing finetuned model":
basemodel_or_finetunedmodel = False
def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv, learning_rate, output_path, max_audio_length=255995):
pfc_check_fail()
# Logging parameters
RUN_NAME = "XTTS_FT"
PROJECT_NAME = "XTTS_trainer"
DASHBOARD_LOGGER = "tensorboard"
LOGGER_URI = None
# Set here the path that the checkpoints will be saved. Default: ./training/
OUT_PATH = os.path.join(output_path, "training")
print("[FINETUNE] \033[94mStarting Step 2\033[0m - Fine-tuning the XTTS Encoder")
print(f"[FINETUNE] \033[94mLanguage: \033[92m{language} \033[94mEpochs: \033[92m{num_epochs} \033[94mBatch size: \033[92m{batch_size}\033[0m \033[94mGrad accumulation steps: \033[92m{grad_acumm}\033[0m")
print(f"[FINETUNE] \033[94mTraining : \033[92m{train_csv}\033[0m")
print(f"[FINETUNE] \033[94mEvaluation : \033[92m{eval_csv}\033[0m")
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Get the current device ID
gpu_device_id = torch.cuda.current_device()
gpu_available_mem_gb = (torch.cuda.get_device_properties(gpu_device_id).total_memory - torch.cuda.memory_allocated(gpu_device_id)) / (1024 ** 3)
print(f"[FINETUNE] \033[94mAvailable VRAM: \033[92m{gpu_available_mem_gb:.2f} GB\033[0m")
if gpu_available_mem_gb < 12:
print(f"[FINETUNE]")
print(f"[FINETUNE] \033[91m****** WARNING PRE-FLIGHT CHECKS FAILED ******* WARNING PRE-FLIGHT CHECKS FAILED *****\033[0m")
print(f"[FINETUNE] \033[94mAvailable VRAM: \033[92m{gpu_available_mem_gb:.2f} GB\033[0m")
print(f"[FINETUNE] \033[94mIf you are running on a Linux system and you have 12GB's or less of VRAM, this step\033[0m")
print(f"[FINETUNE] \033[94mmay fail, due to not enough GPU VRAM. Windows systems will use system RAM as extended\033[0m")
print(f"[FINETUNE] \033[94mVRAM and so should work ok. However, Windows machines will need enough System RAM\033[0m")
print(f"[FINETUNE] \033[94mavailable. Please read the PFC help section available on the first tab of the web\033[0m")
print(f"[FINETUNE] \033[94minterface for more information.\033[0m")
print(f"[FINETUNE] \033[91m****** WARNING PRE-FLIGHT CHECKS FAILED ******* WARNING PRE-FLIGHT CHECKS FAILED *****\033[0m")
print(f"[FINETUNE]")
# Create the directory
os.makedirs(OUT_PATH, exist_ok=True)
# Training Parameters
OPTIMIZER_WD_ONLY_ON_WEIGHTS = True # for multi-gpu training please make it False
START_WITH_EVAL = False # if True it will star with evaluation
BATCH_SIZE = batch_size # set here the batch size
GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps
# Define here the dataset that you want to use for the fine-tuning on.
config_dataset = BaseDatasetConfig(
formatter="coqui",
dataset_name="ft_dataset",
path=os.path.dirname(train_csv),
meta_file_train=train_csv,
meta_file_val=eval_csv,
language=language,
)
# Add here the configs of the datasets
DATASETS_CONFIG_LIST = [config_dataset]
if basemodel_or_finetunedmodel:
# BASE XTTS model checkpoints for fine-tuning.
print("[FINETUNE] Starting finetuning on \033[92mBase Model\033[0m")
TOKENIZER_FILE = str(this_dir / base_path / model_path / "vocab.json")
XTTS_CHECKPOINT = str(this_dir / base_path / model_path / "model.pth")
XTTS_CONFIG_FILE = str(this_dir / base_path / model_path / "config.json")
DVAE_CHECKPOINT = str(this_dir / base_path / model_path / "dvae.pth")
MEL_NORM_FILE = str(this_dir / base_path / model_path / "mel_stats.pth")
else:
# FINETUNED XTTS model checkpoints for fine-tuning.
print("[FINETUNE] Starting finetuning on \033[92mExisting Finetuned Model\033[0m")
TOKENIZER_FILE = str(this_dir / base_path / "trainedmodel" / "vocab.json")
XTTS_CHECKPOINT = str(this_dir / base_path / "trainedmodel" / "model.pth")
XTTS_CONFIG_FILE = str(this_dir / base_path / "trainedmodel" / "config.json")
DVAE_CHECKPOINT = str(this_dir / base_path / "trainedmodel" / "dvae.pth")
MEL_NORM_FILE = str(this_dir / base_path / "trainedmodel" / "mel_stats.pth")
# init args and config
model_args = GPTArgs(
max_conditioning_length=132300, # 6 secs
min_conditioning_length=66150, # 3 secs
debug_loading_failures=False,
max_wav_length=max_audio_length, # ~11.6 seconds
max_text_length=200,
mel_norm_file=MEL_NORM_FILE,
dvae_checkpoint=DVAE_CHECKPOINT,
xtts_checkpoint=XTTS_CHECKPOINT, # checkpoint path of the model that you want to fine-tune
tokenizer_file=TOKENIZER_FILE,
gpt_num_audio_tokens=1026,
gpt_start_audio_token=1024,
gpt_stop_audio_token=1025,
gpt_use_masking_gt_prompt_approach=True,
gpt_use_perceiver_resampler=True,
)
# define audio config
audio_config = XttsAudioConfig(sample_rate=22050, dvae_sample_rate=22050, output_sample_rate=24000)
# Resolve Japanese threading issue
number_of_workers = 8
if language == "ja":
number_of_workers = 0
# training parameters config
config = GPTTrainerConfig(
epochs=num_epochs,
output_path=OUT_PATH,
model_args=model_args,
run_name=RUN_NAME,
project_name=PROJECT_NAME,
run_description="GPT XTTS training",
dashboard_logger=DASHBOARD_LOGGER,
logger_uri=LOGGER_URI,
audio=audio_config,
batch_size=BATCH_SIZE,
batch_group_size=48,
eval_batch_size=BATCH_SIZE,
num_loader_workers=number_of_workers,
eval_split_max_size=256,
print_step=50,
plot_step=100,
log_model_step=100,
save_step=1000,
save_n_checkpoints=1,
save_checkpoints=True,
# target_loss="loss",
print_eval=False,
# Optimizer values like tortoise, pytorch implementation with modifications to not apply WD to non-weight parameters.
optimizer="AdamW",
optimizer_wd_only_on_weights=OPTIMIZER_WD_ONLY_ON_WEIGHTS,
optimizer_params={"betas": [0.9, 0.96], "eps": 1e-8, "weight_decay": 1e-2},
lr=learning_rate, # learning rate
lr_scheduler="MultiStepLR",
# it was adjusted accordly for the new step scheme
lr_scheduler_params={"milestones": [50000 * 18, 150000 * 18, 300000 * 18], "gamma": 0.5, "last_epoch": -1},
test_sentences=[],
)
# init the model from config
model = GPTTrainer.init_from_config(config)
# load training samples
train_samples, eval_samples = load_tts_samples(
DATASETS_CONFIG_LIST,
eval_split=True,
eval_split_max_size=config.eval_split_max_size,
eval_split_size=config.eval_split_size,
)
# init the trainer
trainer = Trainer(
TrainerArgs(
restore_path=None, # xtts checkpoint is restored via xtts_checkpoint key so no need of restore it using Trainer restore_path parameter
skip_train_epoch=False,
start_with_eval=START_WITH_EVAL,
grad_accum_steps=GRAD_ACUMM_STEPS,
),
config,
output_path=OUT_PATH,
model=model,
train_samples=train_samples,
eval_samples=eval_samples,
)
trainer.fit()
# get the longest text audio file to use as speaker reference
samples_len = [len(item["text"].split(" ")) for item in train_samples]
longest_text_idx = samples_len.index(max(samples_len))
speaker_ref = train_samples[longest_text_idx]["audio_file"]
trainer_out_path = trainer.output_path
# deallocate VRAM and RAM
del model, trainer, train_samples, eval_samples, config, model_args, config_dataset
gc.collect()
train_samples = None
eval_samples = None
config_dataset = None
trainer = None
model = None
model_args = None
return XTTS_CONFIG_FILE, XTTS_CHECKPOINT, TOKENIZER_FILE, trainer_out_path, speaker_ref
##########################
#### STEP 3 AND OTHER ####
##########################
def clear_gpu_cache():
# clear the GPU cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
def find_a_speaker_file(folder_path):
search_path = folder_path / "*" / "speakers_xtts.pth"
files = glob.glob(str(search_path), recursive=True)
latest_file = max(files, key=os.path.getctime, default=None)
return latest_file
XTTS_MODEL = None
def load_model(xtts_checkpoint, xtts_config, xtts_vocab):
pfc_check_fail()
global XTTS_MODEL
clear_gpu_cache()
if not xtts_checkpoint or not xtts_config or not xtts_vocab:
return "You need to run the previous steps or manually set the `XTTS checkpoint path`, `XTTS config path`, and `XTTS vocab path` fields !!"
xtts_speakers_pth = find_a_speaker_file(this_dir / "models")
config = XttsConfig()
config.load_json(xtts_config)
XTTS_MODEL = Xtts.init_from_config(config)
print("[FINETUNE] \033[94mStarting Step 3\033[0m Loading XTTS model!")
print(xtts_checkpoint)
print(xtts_vocab)
print(xtts_speakers_pth)
XTTS_MODEL.load_checkpoint(config, checkpoint_path=xtts_checkpoint, vocab_path=xtts_vocab, use_deepspeed=False, speaker_file_path=xtts_speakers_pth)
if torch.cuda.is_available():
XTTS_MODEL.cuda()
print("[FINETUNE] Model Loaded!")
return "Model Loaded!"
def run_tts(lang, tts_text, speaker_audio_file):
if XTTS_MODEL is None or not speaker_audio_file:
return "You need to run the previous step to load the model !!", None, None
speaker_audio_file = str(speaker_audio_file)
gpt_cond_latent, speaker_embedding = XTTS_MODEL.get_conditioning_latents(audio_path=speaker_audio_file, gpt_cond_len=XTTS_MODEL.config.gpt_cond_len, max_ref_length=XTTS_MODEL.config.max_ref_len, sound_norm_refs=XTTS_MODEL.config.sound_norm_refs)
out = XTTS_MODEL.inference(
text=tts_text,
language=lang,
gpt_cond_latent=gpt_cond_latent,
speaker_embedding=speaker_embedding,
temperature=XTTS_MODEL.config.temperature, # Add custom parameters here
length_penalty=XTTS_MODEL.config.length_penalty,
repetition_penalty=XTTS_MODEL.config.repetition_penalty,
top_k=XTTS_MODEL.config.top_k,
top_p=XTTS_MODEL.config.top_p,
)
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as fp:
out["wav"] = torch.tensor(out["wav"]).unsqueeze(0)
out_path = fp.name
torchaudio.save(out_path, out["wav"], 24000)
return "Speech generated !", out_path, speaker_audio_file
def get_available_voices(minimum_size_kb=1200):
voice_files = [
voice for voice in Path(f"{this_dir}/finetune/tmp-trn/wavs").glob("*.wav")
if voice.stat().st_size > minimum_size_kb * 1200 # Convert KB to bytes
]
return sorted([str(file) for file in voice_files]) # Return full path as string
def find_best_models(directory):
"""Find files named 'best_model.pth' in the given directory."""
return [str(file) for file in Path(directory).rglob("best_model.pth")]
def find_models(directory, extension):
"""Find files with a specific extension in the given directory."""
return [str(file) for file in Path(directory).rglob(f"*.{extension}")]
def find_jsons(directory, filename):
"""Find files with a specific filename in the given directory."""
return [str(file) for file in Path(directory).rglob(filename)]
# Your main directory
main_directory = Path(this_dir) / "finetune" / "tmp-trn"
# XTTS checkpoint files (best_model.pth)
xtts_checkpoint_files = find_best_models(main_directory)
# XTTS config files (config.json)
xtts_config_files = find_jsons(main_directory, "config.json")
# XTTS vocab files (vocab.json)
xtts_vocab_files = find_jsons(main_directory, "vocab.json")
##########################
#### STEP 4 AND OTHER ####
##########################
def find_latest_best_model(folder_path):
search_path = folder_path / "XTTS_FT-*" / "best_model.pth"
files = glob.glob(str(search_path), recursive=True)
latest_file = max(files, key=os.path.getctime, default=None)
return latest_file
def compact_model(xtts_checkpoint_copy):
this_dir = Path(__file__).parent.resolve()
print("THIS DIR:", this_dir)
best_model_path_str = str(xtts_checkpoint_copy) # Convert to string
print("best_model_path_str", best_model_path_str)
# Check if the best model file exists
if not best_model_path_str:
print("[FINETUNE] No trained model was found.")
return "No trained model was found."
print(f"[FINETUNE] Best model path: {best_model_path_str}")
# Attempt to load the model
try:
checkpoint = torch.load(best_model_path_str, map_location=torch.device("cpu"))
print(f"[FINETUNE] Checkpoint loaded: {best_model_path_str}")
except Exception as e:
print("[FINETUNE] Error loading checkpoint:", e)
raise
# Define the target directory
target_dir = this_dir / "models" / "trainedmodel"
# Create the target directory if it doesn't exist
target_dir.mkdir(parents=True, exist_ok=True)
del checkpoint["optimizer"]
for key in list(checkpoint["model"].keys()):
if "dvae" in key:
del checkpoint["model"][key]
# Save the modified checkpoint in the target directory
torch.save(checkpoint, str(target_dir / "model.pth")) # Convert to string
# Specify the files you want to copy
files_to_copy = ["vocab.json", "config.json", "speakers_xtts.pth", "mel_stats.pth", "dvae.pth"]
for file_name in files_to_copy:
src_path = this_dir / base_path / base_model_path / file_name
dest_path = target_dir / file_name
shutil.copy(str(src_path), str(dest_path)) # Convert to string
source_wavs_dir = this_dir / "finetune" / "tmp-trn" / "wavs"
target_wavs_dir = target_dir / "wavs"
target_wavs_dir.mkdir(parents=True, exist_ok=True)
# Iterate through files in the source directory
for file_path in source_wavs_dir.iterdir():
# Check if it's a file and larger than 1000 KB
if file_path.is_file() and file_path.stat().st_size > 1000 * 1024:
# Copy the file to the target directory
shutil.copy(str(file_path), str(target_wavs_dir / file_path.name)) # Convert to string
print("[FINETUNE] Model copied to '/models/trainedmodel/'")
return "Model copied to '/models/trainedmodel/'"
def compact_lastfinetuned_model(xtts_checkpoint_copy):
this_dir = Path(__file__).parent.resolve()
best_model_path_str = xtts_checkpoint_copy
print(f"[FINETUNE] Best model path: {best_model_path_str}")
# Check if the best model file exists
if best_model_path_str is None:
print("[FINETUNE] No trained model was found.")
return "No trained model was found."
# Convert model_path_str to Path
best_model_path = Path(best_model_path_str)
# Attempt to load the model
try:
checkpoint = torch.load(best_model_path, map_location=torch.device("cpu"))
except Exception as e:
print("[FINETUNE] Error loading checkpoint:", e)
raise
del checkpoint["optimizer"]
# Define the target directory
target_dir = this_dir / "models" / "lastfinetuned"
# Create the target directory if it doesn't exist
target_dir.mkdir(parents=True, exist_ok=True)
for key in list(checkpoint["model"].keys()):
if "dvae" in key:
del checkpoint["model"][key]
# Save the modified checkpoint in the target directory
torch.save(checkpoint, target_dir / "model.pth")
# Specify the files you want to copy
files_to_copy = ["vocab.json", "config.json", "speakers_xtts.pth", "mel_stats.pth", "dvae.pth",]
for file_name in files_to_copy:
src_path = this_dir / base_path / base_model_path / file_name
dest_path = target_dir / file_name
shutil.copy(str(src_path), str(dest_path))
source_wavs_dir = this_dir / "finetune" / "tmp-trn" / "wavs"
target_wavs_dir = target_dir / "wavs"
target_wavs_dir.mkdir(parents=True, exist_ok=True)
# Iterate through files in the source directory
for file_path in source_wavs_dir.iterdir():
# Check if it's a file and larger than 1000 KB
if file_path.is_file() and file_path.stat().st_size > 1000 * 1024:
# Copy the file to the target directory
shutil.copy(str(file_path), str(target_wavs_dir / file_path.name))
print("[FINETUNE] Model copied to '/models/lastfinetuned/'")
return "Model copied to '/models/lastfinetuned/'"
def compact_custom_model(xtts_checkpoint_copy, folder_path):
this_dir = Path(__file__).parent.resolve()
best_model_path_str = xtts_checkpoint_copy
print(f"[FINETUNE] Best model path: {best_model_path_str}")
# Check if the best model file exists
if best_model_path_str is None:
print("[FINETUNE] No trained model was found.")
return "No trained model was found."
# Convert model_path_str to Path
best_model_path = Path(best_model_path_str)
# Attempt to load the model
try:
checkpoint = torch.load(best_model_path, map_location=torch.device("cpu"))
except Exception as e:
print("[FINETUNE] Error loading checkpoint:", e)
raise
del checkpoint["optimizer"]
# Define the target directory
target_dir = this_dir / "models" / folder_path
# Create the target directory if it doesn't exist
target_dir.mkdir(parents=True, exist_ok=True)
for key in list(checkpoint["model"].keys()):
if "dvae" in key:
del checkpoint["model"][key]
# Save the modified checkpoint in the target directory
torch.save(checkpoint, target_dir / "model.pth")
# Specify the files you want to copy
files_to_copy = ["vocab.json", "config.json", "speakers_xtts.pth", "mel_stats.pth", "dvae.pth",]
for file_name in files_to_copy:
src_path = this_dir / base_path / base_model_path / file_name
dest_path = target_dir / file_name
shutil.copy(str(src_path), str(dest_path))
source_wavs_dir = this_dir / "finetune" / "tmp-trn" / "wavs"
target_wavs_dir = target_dir / "wavs"
target_wavs_dir.mkdir(parents=True, exist_ok=True)
# Iterate through files in the source directory
for file_path in source_wavs_dir.iterdir():
# Check if it's a file and larger than 1000 KB
if file_path.is_file() and file_path.stat().st_size > 1000 * 1024:
# Copy the file to the target directory
shutil.copy(str(file_path), str(target_wavs_dir / file_path.name))
print("[FINETUNE] Model copied to '/models/",folder_path,"/")
return f"Model copied to '/models/{folder_path}/'"
def delete_training_data():
# Define the folder to be deleted
folder_to_delete = Path(this_dir / "finetune" / "tmp-trn")
# Check if the folder exists before deleting
if folder_to_delete.exists():
# Iterate over all files and subdirectories
for item in folder_to_delete.iterdir():
# Exclude trainer_0_log.txt from deletion
if item.name != "trainer_0_log.txt":
try:
if item.is_file():
item.unlink()
elif item.is_dir():