Skip to content

Commit

Permalink
Add support for VITS (multilingual TTS) (#466)
Browse files Browse the repository at this point in the history
* Add custom VITS tokenizer converter

* Do not decode if expected input_ids is empty

* Update vits tokenizer tests

* Implement `VitsTokenizer`

* Add support for VITS model

* Support VITS through pipeline API

* Update JSDoc

* Add TTS unit test

* Add speecht5 unit test

* Fix typo

* Fix typo

* Update speecht5 model id

* Add note about using quantized speecht5 in unit tests

* Monkey-patch `BigInt64Array` and `BigUint64Array`
  • Loading branch information
xenova authored Dec 26, 2023
1 parent f5bc758 commit 1394f73
Show file tree
Hide file tree
Showing 12 changed files with 336 additions and 6 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei.
1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
1. **[ViTMatte](https://huggingface.co/docs/transformers/model_doc/vitmatte)** (from HUST-VL) released with the paper [ViTMatte: Boosting Image Matting with Pretrained Plain Vision Transformers](https://arxiv.org/abs/2305.15272) by Jingfeng Yao, Xinggang Wang, Shusheng Yang, Baoyuan Wang.
1. **[VITS](https://huggingface.co/docs/transformers/model_doc/vits)** (from Kakao Enterprise) released with the paper [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103) by Jaehyeon Kim, Jungil Kong, Juhee Son.
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
1. **[WavLM](https://huggingface.co/docs/transformers/model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei.
1. **[Whisper](https://huggingface.co/docs/transformers/model_doc/whisper)** (from OpenAI) released with the paper [Robust Speech Recognition via Large-Scale Weak Supervision](https://cdn.openai.com/papers/whisper.pdf) by Alec Radford, Jong Wook Kim, Tao Xu, Greg Brockman, Christine McLeavey, Ilya Sutskever.
Expand Down
1 change: 1 addition & 0 deletions docs/snippets/6_supported-models.snippet
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei.
1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
1. **[ViTMatte](https://huggingface.co/docs/transformers/model_doc/vitmatte)** (from HUST-VL) released with the paper [ViTMatte: Boosting Image Matting with Pretrained Plain Vision Transformers](https://arxiv.org/abs/2305.15272) by Jingfeng Yao, Xinggang Wang, Shusheng Yang, Baoyuan Wang.
1. **[VITS](https://huggingface.co/docs/transformers/model_doc/vits)** (from Kakao Enterprise) released with the paper [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103) by Jaehyeon Kim, Jungil Kong, Juhee Son.
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
1. **[WavLM](https://huggingface.co/docs/transformers/model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei.
1. **[Whisper](https://huggingface.co/docs/transformers/model_doc/whisper)** (from OpenAI) released with the paper [Robust Speech Recognition via Large-Scale Weak Supervision](https://cdn.openai.com/papers/whisper.pdf) by Alec Radford, Jong Wook Kim, Tao Xu, Greg Brockman, Christine McLeavey, Ilya Sutskever.
Expand Down
8 changes: 8 additions & 0 deletions scripts/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,15 @@ def main():

with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp:
json.dump(tokenizer_json, fp, indent=4)

elif config.model_type == 'vits':
if tokenizer is not None:
from .extra.vits import generate_tokenizer_json
tokenizer_json = generate_tokenizer_json(tokenizer)

with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp:
json.dump(tokenizer_json, fp, indent=4)

elif config.model_type == 'speecht5':
# TODO allow user to specify vocoder path
export_kwargs["model_kwargs"] = {"vocoder": "microsoft/speecht5_hifigan"}
Expand Down
100 changes: 100 additions & 0 deletions scripts/extra/vits.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@


def generate_tokenizer_json(tokenizer):
vocab = tokenizer.get_vocab()

normalizers = []

if tokenizer.normalize:
# Lowercase the input string
normalizers.append({
"type": "Lowercase",
})

if tokenizer.language == 'ron':
# Replace diacritics
normalizers.append({
"type": "Replace",
"pattern": {
"String": "ț",
},
"content": "ţ",
})

if tokenizer.phonemize:
raise NotImplementedError("Phonemization is not implemented yet")

elif tokenizer.normalize:
# strip any chars outside of the vocab (punctuation)
chars = ''.join(x for x in vocab if len(x) == 1)
escaped = chars.replace('-', r'\-').replace(']', r'\]')
normalizers.append({
"type": "Replace",
"pattern": {
"Regex": f"[^{escaped}]",
},
"content": "",
})
normalizers.append({
"type": "Strip",
"strip_left": True,
"strip_right": True,
})

if tokenizer.add_blank:
# add pad token between each char
normalizers.append({
"type": "Replace",
"pattern": {
# Add a blank token between each char, except when blank (then do nothing)
"Regex": "(?=.)|(?<!^)$",
},
"content": tokenizer.pad_token,
})

if len(normalizers) == 0:
normalizer = None
elif len(normalizers) == 1:
normalizer = normalizers[0]
else:
normalizer = {
"type": "Sequence",
"normalizers": normalizers,
}

tokenizer_json = {
"version": "1.0",
"truncation": None,
"padding": None,
"added_tokens": [
{
"id": vocab[token],
"content": token,
"single_word": False,
"lstrip": False,
"rstrip": False,
"normalized": False,
"special": True
}
for token in vocab

# `tokenizer.pad_token` should not be considered an added token
if token in (tokenizer.unk_token, )
],
"normalizer": normalizer,
"pre_tokenizer": {
"type": "Split",
"pattern": {
"Regex": ""
},
"behavior": "Isolated",
"invert": False
},
"post_processor": None,
"decoder": None, # Custom decoder implemented in JS
"model": {
"vocab": vocab
},
}

return tokenizer_json
21 changes: 21 additions & 0 deletions scripts/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,27 @@
'hustvl/vitmatte-base-composition-1k',
],
},
'vits': {
# Text-to-audio/Text-to-speech/Text-to-waveform
'text-to-waveform': {
# NOTE: requires --task text-to-waveform --skip_validation
'echarlaix/tiny-random-vits',
'facebook/mms-tts-eng',
'facebook/mms-tts-rus',
'facebook/mms-tts-hin',
'facebook/mms-tts-yor',
'facebook/mms-tts-spa',
'facebook/mms-tts-fra',
'facebook/mms-tts-ara',
'facebook/mms-tts-ron',
'facebook/mms-tts-vie',
'facebook/mms-tts-deu',
'facebook/mms-tts-kor',
'facebook/mms-tts-por',
# TODO add more checkpoints from
# https://huggingface.co/models?other=vits&sort=trending&search=facebook-tts
}
},
'wav2vec2': {
# Feature extraction # NOTE: requires --task feature-extraction
'feature-extraction': [
Expand Down
79 changes: 77 additions & 2 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -4696,6 +4696,47 @@ export class ClapAudioModelWithProjection extends ClapPreTrainedModel {
//////////////////////////////////////////////////


//////////////////////////////////////////////////
// VITS models
export class VitsPreTrainedModel extends PreTrainedModel { }

/**
* The complete VITS model, for text-to-speech synthesis.
*
* **Example:** Generate speech from text with `VitsModel`.
* ```javascript
* import { AutoTokenizer, VitsModel } from '@xenova/transformers';
*
* // Load the tokenizer and model
* const tokenizer = await AutoTokenizer.from_pretrained('Xenova/mms-tts-eng');
* const model = await VitsModel.from_pretrained('Xenova/mms-tts-eng');
*
* // Run tokenization
* const inputs = tokenizer('I love transformers');
*
* // Generate waveform
* const { waveform } = await model(inputs);
* // Tensor {
* // dims: [ 1, 35328 ],
* // type: 'float32',
* // data: Float32Array(35328) [ ... ],
* // size: 35328,
* // }
* ```
*/
export class VitsModel extends VitsPreTrainedModel {
/**
* Calls the model on new inputs.
* @param {Object} model_inputs The inputs to the model.
* @returns {Promise<VitsModelOutput>} The outputs for the VITS model.
*/
async _call(model_inputs) {
return new VitsModelOutput(await super._call(model_inputs));
}
}
//////////////////////////////////////////////////


//////////////////////////////////////////////////
// AutoModels, used to simplify construction of PreTrainedModels
// (uses config to instantiate correct class)
Expand Down Expand Up @@ -4789,6 +4830,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
['hubert', ['HubertModel', HubertModel]],
['wavlm', ['WavLMModel', WavLMModel]],
['audio-spectrogram-transformer', ['ASTModel', ASTModel]],
['vits', ['VitsModel', VitsModel]],

['detr', ['DetrModel', DetrModel]],
['table-transformer', ['TableTransformerModel', TableTransformerModel]],
Expand Down Expand Up @@ -4846,11 +4888,15 @@ const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([
const MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = new Map([
['speecht5', ['SpeechT5ForSpeechToText', SpeechT5ForSpeechToText]],
['whisper', ['WhisperForConditionalGeneration', WhisperForConditionalGeneration]],
])
]);

const MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = new Map([
['speecht5', ['SpeechT5ForTextToSpeech', SpeechT5ForTextToSpeech]],
])
]);

const MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = new Map([
['vits', ['VitsModel', VitsModel]],
]);

const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([
['bert', ['BertForSequenceClassification', BertForSequenceClassification]],
Expand Down Expand Up @@ -5044,6 +5090,7 @@ const MODEL_CLASS_TYPE_MAPPING = [
[MODEL_FOR_CTC_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq],
[MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
];

for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) {
Expand Down Expand Up @@ -5136,6 +5183,17 @@ export class AutoModelForTextToSpectrogram extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES];
}

/**
* Helper class which is used to instantiate pretrained text-to-waveform models with the `from_pretrained` function.
* The chosen model class is determined by the type specified in the model config.
*
* @example
* let model = await AutoModelForTextToSpectrogram.from_pretrained('facebook/mms-tts-eng');
*/
export class AutoModelForTextToWaveform extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES];
}

/**
* Helper class which is used to instantiate pretrained causal language models with the `from_pretrained` function.
* The chosen model class is determined by the type specified in the model config.
Expand Down Expand Up @@ -5375,3 +5433,20 @@ export class ImageMattingOutput extends ModelOutput {
this.alphas = alphas;
}
}

/**
* Describes the outputs for the VITS model.
*/
export class VitsModelOutput extends ModelOutput {
/**
* @param {Object} output The output of the model.
* @param {Tensor} output.waveform The final audio waveform predicted by the model, of shape `(batch_size, sequence_length)`.
* @param {Tensor} output.spectrogram The log-mel spectrogram predicted at the output of the flow model.
* This spectrogram is passed to the Hi-Fi GAN decoder model to obtain the final audio waveform.
*/
constructor({ waveform, spectrogram }) {
super();
this.waveform = waveform;
this.spectrogram = spectrogram;
}
}
50 changes: 47 additions & 3 deletions src/pipelines.js
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import {
AutoModelForMaskedLM,
AutoModelForSeq2SeqLM,
AutoModelForSpeechSeq2Seq,
AutoModelForTextToWaveform,
AutoModelForTextToSpectrogram,
AutoModelForCTC,
AutoModelForCausalLM,
Expand All @@ -37,7 +38,6 @@ import {
AutoModelForDocumentQuestionAnswering,
AutoModelForImageToImage,
AutoModelForDepthEstimation,
// AutoModelForTextToWaveform,
PreTrainedModel,
} from './models.js';
import {
Expand Down Expand Up @@ -2112,6 +2112,16 @@ export class DocumentQuestionAnsweringPipeline extends Pipeline {
* wav.fromScratch(1, out.sampling_rate, '32f', out.audio);
* fs.writeFileSync('out.wav', wav.toBuffer());
* ```
*
* **Example:** Multilingual speech generation with `Xenova/mms-tts-fra`. See [here](https://huggingface.co/models?pipeline_tag=text-to-speech&other=vits&sort=trending) for the full list of available languages (1107).
* ```js
* let synthesizer = await pipeline('text-to-speech', 'Xenova/mms-tts-fra');
* let out = await synthesizer('Bonjour');
* // {
* // audio: Float32Array(23808) [-0.00037693005288019776, 0.0003325853613205254, ...],
* // sampling_rate: 16000
* // }
* ```
*/
export class TextToAudioPipeline extends Pipeline {
DEFAULT_VOCODER_ID = "Xenova/speecht5_hifigan"
Expand Down Expand Up @@ -2143,6 +2153,34 @@ export class TextToAudioPipeline extends Pipeline {
async _call(text_inputs, {
speaker_embeddings = null,
} = {}) {
// If this.processor is not set, we are using a `AutoModelForTextToWaveform` model
if (this.processor) {
return this._call_text_to_spectrogram(text_inputs, { speaker_embeddings });
} else {
return this._call_text_to_waveform(text_inputs);
}
}

async _call_text_to_waveform(text_inputs) {

// Run tokenization
const inputs = this.tokenizer(text_inputs, {
padding: true,
truncation: true
});

// Generate waveform
const { waveform } = await this.model(inputs);

const sampling_rate = this.model.config.sampling_rate;
return {
audio: waveform.data,
sampling_rate,
}
}

async _call_text_to_spectrogram(text_inputs, { speaker_embeddings }) {

// Load vocoder, if not provided
if (!this.vocoder) {
console.log('No vocoder specified, using default HifiGan vocoder.');
Expand Down Expand Up @@ -2412,8 +2450,8 @@ const SUPPORTED_TASKS = {
"text-to-audio": {
"tokenizer": AutoTokenizer,
"pipeline": TextToAudioPipeline,
"model": [ /* TODO: AutoModelForTextToWaveform, */ AutoModelForTextToSpectrogram],
"processor": AutoProcessor,
"model": [AutoModelForTextToWaveform, AutoModelForTextToSpectrogram],
"processor": [AutoProcessor, /* Some don't use a processor */ null],
"default": {
// TODO: replace with original
// "model": "microsoft/speecht5_tts",
Expand Down Expand Up @@ -2673,6 +2711,12 @@ async function loadItems(mapping, model, pretrainedOptions) {
promise = new Promise(async (resolve, reject) => {
let e;
for (let c of cls) {
if (c === null) {
// If null, we resolve it immediately, meaning the relevant
// class was not found, but it is optional.
resolve(null);
return;
}
try {
resolve(await c.from_pretrained(model, pretrainedOptions));
return;
Expand Down
Loading

0 comments on commit 1394f73

Please sign in to comment.