Skip to content

Commit

Permalink
Run make style
Browse files Browse the repository at this point in the history
  • Loading branch information
akx committed Dec 4, 2023
1 parent 4aee106 commit f770922
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 67 deletions.
39 changes: 23 additions & 16 deletions TTS/demos/xtts_ft_demo/utils/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ def list_audios(basePath, contains=None):
# return the set of files that are valid
return list_files(basePath, validExts=audio_types, contains=contains)


def list_files(basePath, validExts=None, contains=None):
# loop over the directory structure
for (rootDir, dirNames, filenames) in os.walk(basePath):
for rootDir, dirNames, filenames in os.walk(basePath):
# loop over the filenames in the current directory
for filename in filenames:
# if the contains string is not none and the filename does not contain
Expand All @@ -30,15 +31,24 @@ def list_files(basePath, validExts=None, contains=None):
continue

# determine the file extension of the current file
ext = filename[filename.rfind("."):].lower()
ext = filename[filename.rfind(".") :].lower()

# check to see if the file is an audio and should be processed
if validExts is None or ext.endswith(validExts):
# construct the path to the audio and yield it
audioPath = os.path.join(rootDir, filename)
yield audioPath

def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0.2, eval_percentage=0.15, speaker_name="coqui", gradio_progress=None):

def format_audio_list(
audio_files,
target_language="en",
out_path=None,
buffer=0.2,
eval_percentage=0.15,
speaker_name="coqui",
gradio_progress=None,
):
audio_total_size = 0
# make sure that ooutput file exists
os.makedirs(out_path, exist_ok=True)
Expand All @@ -63,7 +73,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
wav = torch.mean(wav, dim=0, keepdim=True)

wav = wav.squeeze()
audio_total_size += (wav.size(-1) / sr)
audio_total_size += wav.size(-1) / sr

segments, _ = asr_model.transcribe(audio_path, word_timestamps=True, language=target_language)
segments = list(segments)
Expand All @@ -88,7 +98,7 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0
# get previous sentence end
previous_word_end = words_list[word_idx - 1].end
# add buffer or get the silence midle between the previous sentence and the current one
sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start)/2)
sentence_start = max(sentence_start - buffer, (previous_word_end + sentence_start) / 2)

sentence = word.word
first_word = False
Expand All @@ -112,19 +122,16 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0

# Average the current word end and next word start
word_end = min((word.end + next_word_start) / 2, word.end + buffer)

absoulte_path = os.path.join(out_path, audio_file)
os.makedirs(os.path.dirname(absoulte_path), exist_ok=True)
i += 1
first_word = True

audio = wav[int(sr*sentence_start):int(sr*word_end)].unsqueeze(0)
audio = wav[int(sr * sentence_start) : int(sr * word_end)].unsqueeze(0)
# if the audio is too short ignore it (i.e < 0.33 seconds)
if audio.size(-1) >= sr/3:
torchaudio.save(absoulte_path,
audio,
sr
)
if audio.size(-1) >= sr / 3:
torchaudio.save(absoulte_path, audio, sr)
else:
continue

Expand All @@ -134,21 +141,21 @@ def format_audio_list(audio_files, target_language="en", out_path=None, buffer=0

df = pandas.DataFrame(metadata)
df = df.sample(frac=1)
num_val_samples = int(len(df)*eval_percentage)
num_val_samples = int(len(df) * eval_percentage)

df_eval = df[:num_val_samples]
df_train = df[num_val_samples:]

df_train = df_train.sort_values('audio_file')
df_train = df_train.sort_values("audio_file")
train_metadata_path = os.path.join(out_path, "metadata_train.csv")
df_train.to_csv(train_metadata_path, sep="|", index=False)

eval_metadata_path = os.path.join(out_path, "metadata_eval.csv")
df_eval = df_eval.sort_values('audio_file')
df_eval = df_eval.sort_values("audio_file")
df_eval.to_csv(eval_metadata_path, sep="|", index=False)

# deallocate VRAM and RAM
del asr_model, df_train, df_eval, df, metadata
gc.collect()

return train_metadata_path, eval_metadata_path, audio_total_size
return train_metadata_path, eval_metadata_path, audio_total_size
9 changes: 4 additions & 5 deletions TTS/demos/xtts_ft_demo/utils/gpt_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
BATCH_SIZE = batch_size # set here the batch size
GRAD_ACUMM_STEPS = grad_acumm # set here the grad accumulation steps


# Define here the dataset that you want to use for the fine-tuning on.
config_dataset = BaseDatasetConfig(
formatter="coqui",
Expand All @@ -43,7 +42,6 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
CHECKPOINTS_OUT_PATH = os.path.join(OUT_PATH, "XTTS_v2.0_original_model_files/")
os.makedirs(CHECKPOINTS_OUT_PATH, exist_ok=True)


# DVAE files
DVAE_CHECKPOINT_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/dvae.pth"
MEL_NORM_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/mel_stats.pth"
Expand All @@ -55,8 +53,9 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,
# download DVAE files if needed
if not os.path.isfile(DVAE_CHECKPOINT) or not os.path.isfile(MEL_NORM_FILE):
print(" > Downloading DVAE files!")
ModelManager._download_model_files([MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True)

ModelManager._download_model_files(
[MEL_NORM_LINK, DVAE_CHECKPOINT_LINK], CHECKPOINTS_OUT_PATH, progress_bar=True
)

# Download XTTS v2.0 checkpoint if needed
TOKENIZER_FILE_LINK = "https://coqui.gateway.scarf.sh/hf-coqui/XTTS-v2/main/vocab.json"
Expand Down Expand Up @@ -160,7 +159,7 @@ def train_gpt(language, num_epochs, batch_size, grad_acumm, train_csv, eval_csv,

# get the longest text audio file to use as speaker reference
samples_len = [len(item["text"].split(" ")) for item in train_samples]
longest_text_idx = samples_len.index(max(samples_len))
longest_text_idx = samples_len.index(max(samples_len))
speaker_ref = train_samples[longest_text_idx]["audio_file"]

trainer_out_path = trainer.output_path
Expand Down
Loading

0 comments on commit f770922

Please sign in to comment.