From 1394f73107ca3b0ac5affcf721ab7f4b52195412 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Tue, 26 Dec 2023 16:34:52 +0200 Subject: [PATCH] Add support for VITS (multilingual TTS) (#466) * Add custom VITS tokenizer converter * Do not decode if expected input_ids is empty * Update vits tokenizer tests * Implement `VitsTokenizer` * Add support for VITS model * Support VITS through pipeline API * Update JSDoc * Add TTS unit test * Add speecht5 unit test * Fix typo * Fix typo * Update speecht5 model id * Add note about using quantized speecht5 in unit tests * Monkey-patch `BigInt64Array` and `BigUint64Array` --- README.md | 1 + docs/snippets/6_supported-models.snippet | 1 + scripts/convert.py | 8 ++ scripts/extra/vits.py | 100 +++++++++++++++++++++++ scripts/supported_models.py | 21 +++++ src/models.js | 79 +++++++++++++++++- src/pipelines.js | 50 +++++++++++- src/tokenizers.js | 22 +++++ tests/generate_tests.py | 13 ++- tests/init.js | 2 + tests/pipelines.test.js | 42 ++++++++++ tests/tokenizers.test.js | 3 + 12 files changed, 336 insertions(+), 6 deletions(-) create mode 100644 scripts/extra/vits.py diff --git a/README.md b/README.md index 0264a1dca..307f92e53 100644 --- a/README.md +++ b/README.md @@ -336,6 +336,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te 1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei. 1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. 1. **[ViTMatte](https://huggingface.co/docs/transformers/model_doc/vitmatte)** (from HUST-VL) released with the paper [ViTMatte: Boosting Image Matting with Pretrained Plain Vision Transformers](https://arxiv.org/abs/2305.15272) by Jingfeng Yao, Xinggang Wang, Shusheng Yang, Baoyuan Wang. +1. **[VITS](https://huggingface.co/docs/transformers/model_doc/vits)** (from Kakao Enterprise) released with the paper [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103) by Jaehyeon Kim, Jungil Kong, Juhee Son. 1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. 1. **[WavLM](https://huggingface.co/docs/transformers/model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei. 1. **[Whisper](https://huggingface.co/docs/transformers/model_doc/whisper)** (from OpenAI) released with the paper [Robust Speech Recognition via Large-Scale Weak Supervision](https://cdn.openai.com/papers/whisper.pdf) by Alec Radford, Jong Wook Kim, Tao Xu, Greg Brockman, Christine McLeavey, Ilya Sutskever. diff --git a/docs/snippets/6_supported-models.snippet b/docs/snippets/6_supported-models.snippet index 4733a2e66..7fdbe4c18 100644 --- a/docs/snippets/6_supported-models.snippet +++ b/docs/snippets/6_supported-models.snippet @@ -71,6 +71,7 @@ 1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei. 1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby. 1. **[ViTMatte](https://huggingface.co/docs/transformers/model_doc/vitmatte)** (from HUST-VL) released with the paper [ViTMatte: Boosting Image Matting with Pretrained Plain Vision Transformers](https://arxiv.org/abs/2305.15272) by Jingfeng Yao, Xinggang Wang, Shusheng Yang, Baoyuan Wang. +1. **[VITS](https://huggingface.co/docs/transformers/model_doc/vits)** (from Kakao Enterprise) released with the paper [Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech](https://arxiv.org/abs/2106.06103) by Jaehyeon Kim, Jungil Kong, Juhee Son. 1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli. 1. **[WavLM](https://huggingface.co/docs/transformers/model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei. 1. **[Whisper](https://huggingface.co/docs/transformers/model_doc/whisper)** (from OpenAI) released with the paper [Robust Speech Recognition via Large-Scale Weak Supervision](https://cdn.openai.com/papers/whisper.pdf) by Alec Radford, Jong Wook Kim, Tao Xu, Greg Brockman, Christine McLeavey, Ilya Sutskever. diff --git a/scripts/convert.py b/scripts/convert.py index f8aab906a..9d6f6ae7a 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -334,7 +334,15 @@ def main(): with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp: json.dump(tokenizer_json, fp, indent=4) + + elif config.model_type == 'vits': + if tokenizer is not None: + from .extra.vits import generate_tokenizer_json + tokenizer_json = generate_tokenizer_json(tokenizer) + with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp: + json.dump(tokenizer_json, fp, indent=4) + elif config.model_type == 'speecht5': # TODO allow user to specify vocoder path export_kwargs["model_kwargs"] = {"vocoder": "microsoft/speecht5_hifigan"} diff --git a/scripts/extra/vits.py b/scripts/extra/vits.py new file mode 100644 index 000000000..d4b75f22c --- /dev/null +++ b/scripts/extra/vits.py @@ -0,0 +1,100 @@ + + +def generate_tokenizer_json(tokenizer): + vocab = tokenizer.get_vocab() + + normalizers = [] + + if tokenizer.normalize: + # Lowercase the input string + normalizers.append({ + "type": "Lowercase", + }) + + if tokenizer.language == 'ron': + # Replace diacritics + normalizers.append({ + "type": "Replace", + "pattern": { + "String": "ț", + }, + "content": "ţ", + }) + + if tokenizer.phonemize: + raise NotImplementedError("Phonemization is not implemented yet") + + elif tokenizer.normalize: + # strip any chars outside of the vocab (punctuation) + chars = ''.join(x for x in vocab if len(x) == 1) + escaped = chars.replace('-', r'\-').replace(']', r'\]') + normalizers.append({ + "type": "Replace", + "pattern": { + "Regex": f"[^{escaped}]", + }, + "content": "", + }) + normalizers.append({ + "type": "Strip", + "strip_left": True, + "strip_right": True, + }) + + if tokenizer.add_blank: + # add pad token between each char + normalizers.append({ + "type": "Replace", + "pattern": { + # Add a blank token between each char, except when blank (then do nothing) + "Regex": "(?=.)|(?} The outputs for the VITS model. + */ + async _call(model_inputs) { + return new VitsModelOutput(await super._call(model_inputs)); + } +} +////////////////////////////////////////////////// + + ////////////////////////////////////////////////// // AutoModels, used to simplify construction of PreTrainedModels // (uses config to instantiate correct class) @@ -4789,6 +4830,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([ ['hubert', ['HubertModel', HubertModel]], ['wavlm', ['WavLMModel', WavLMModel]], ['audio-spectrogram-transformer', ['ASTModel', ASTModel]], + ['vits', ['VitsModel', VitsModel]], ['detr', ['DetrModel', DetrModel]], ['table-transformer', ['TableTransformerModel', TableTransformerModel]], @@ -4846,11 +4888,15 @@ const MODEL_MAPPING_NAMES_DECODER_ONLY = new Map([ const MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = new Map([ ['speecht5', ['SpeechT5ForSpeechToText', SpeechT5ForSpeechToText]], ['whisper', ['WhisperForConditionalGeneration', WhisperForConditionalGeneration]], -]) +]); const MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES = new Map([ ['speecht5', ['SpeechT5ForTextToSpeech', SpeechT5ForTextToSpeech]], -]) +]); + +const MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES = new Map([ + ['vits', ['VitsModel', VitsModel]], +]); const MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = new Map([ ['bert', ['BertForSequenceClassification', BertForSequenceClassification]], @@ -5044,6 +5090,7 @@ const MODEL_CLASS_TYPE_MAPPING = [ [MODEL_FOR_CTC_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES, MODEL_TYPES.Seq2Seq], + [MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], ]; for (const [mappings, type] of MODEL_CLASS_TYPE_MAPPING) { @@ -5136,6 +5183,17 @@ export class AutoModelForTextToSpectrogram extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_TEXT_TO_SPECTROGRAM_MAPPING_NAMES]; } +/** + * Helper class which is used to instantiate pretrained text-to-waveform models with the `from_pretrained` function. + * The chosen model class is determined by the type specified in the model config. + * + * @example + * let model = await AutoModelForTextToSpectrogram.from_pretrained('facebook/mms-tts-eng'); + */ +export class AutoModelForTextToWaveform extends PretrainedMixin { + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_TEXT_TO_WAVEFORM_MAPPING_NAMES]; +} + /** * Helper class which is used to instantiate pretrained causal language models with the `from_pretrained` function. * The chosen model class is determined by the type specified in the model config. @@ -5375,3 +5433,20 @@ export class ImageMattingOutput extends ModelOutput { this.alphas = alphas; } } + +/** + * Describes the outputs for the VITS model. + */ +export class VitsModelOutput extends ModelOutput { + /** + * @param {Object} output The output of the model. + * @param {Tensor} output.waveform The final audio waveform predicted by the model, of shape `(batch_size, sequence_length)`. + * @param {Tensor} output.spectrogram The log-mel spectrogram predicted at the output of the flow model. + * This spectrogram is passed to the Hi-Fi GAN decoder model to obtain the final audio waveform. + */ + constructor({ waveform, spectrogram }) { + super(); + this.waveform = waveform; + this.spectrogram = spectrogram; + } +} diff --git a/src/pipelines.js b/src/pipelines.js index a34dee9be..d65b9e189 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -26,6 +26,7 @@ import { AutoModelForMaskedLM, AutoModelForSeq2SeqLM, AutoModelForSpeechSeq2Seq, + AutoModelForTextToWaveform, AutoModelForTextToSpectrogram, AutoModelForCTC, AutoModelForCausalLM, @@ -37,7 +38,6 @@ import { AutoModelForDocumentQuestionAnswering, AutoModelForImageToImage, AutoModelForDepthEstimation, - // AutoModelForTextToWaveform, PreTrainedModel, } from './models.js'; import { @@ -2112,6 +2112,16 @@ export class DocumentQuestionAnsweringPipeline extends Pipeline { * wav.fromScratch(1, out.sampling_rate, '32f', out.audio); * fs.writeFileSync('out.wav', wav.toBuffer()); * ``` + * + * **Example:** Multilingual speech generation with `Xenova/mms-tts-fra`. See [here](https://huggingface.co/models?pipeline_tag=text-to-speech&other=vits&sort=trending) for the full list of available languages (1107). + * ```js + * let synthesizer = await pipeline('text-to-speech', 'Xenova/mms-tts-fra'); + * let out = await synthesizer('Bonjour'); + * // { + * // audio: Float32Array(23808) [-0.00037693005288019776, 0.0003325853613205254, ...], + * // sampling_rate: 16000 + * // } + * ``` */ export class TextToAudioPipeline extends Pipeline { DEFAULT_VOCODER_ID = "Xenova/speecht5_hifigan" @@ -2143,6 +2153,34 @@ export class TextToAudioPipeline extends Pipeline { async _call(text_inputs, { speaker_embeddings = null, } = {}) { + // If this.processor is not set, we are using a `AutoModelForTextToWaveform` model + if (this.processor) { + return this._call_text_to_spectrogram(text_inputs, { speaker_embeddings }); + } else { + return this._call_text_to_waveform(text_inputs); + } + } + + async _call_text_to_waveform(text_inputs) { + + // Run tokenization + const inputs = this.tokenizer(text_inputs, { + padding: true, + truncation: true + }); + + // Generate waveform + const { waveform } = await this.model(inputs); + + const sampling_rate = this.model.config.sampling_rate; + return { + audio: waveform.data, + sampling_rate, + } + } + + async _call_text_to_spectrogram(text_inputs, { speaker_embeddings }) { + // Load vocoder, if not provided if (!this.vocoder) { console.log('No vocoder specified, using default HifiGan vocoder.'); @@ -2412,8 +2450,8 @@ const SUPPORTED_TASKS = { "text-to-audio": { "tokenizer": AutoTokenizer, "pipeline": TextToAudioPipeline, - "model": [ /* TODO: AutoModelForTextToWaveform, */ AutoModelForTextToSpectrogram], - "processor": AutoProcessor, + "model": [AutoModelForTextToWaveform, AutoModelForTextToSpectrogram], + "processor": [AutoProcessor, /* Some don't use a processor */ null], "default": { // TODO: replace with original // "model": "microsoft/speecht5_tts", @@ -2673,6 +2711,12 @@ async function loadItems(mapping, model, pretrainedOptions) { promise = new Promise(async (resolve, reject) => { let e; for (let c of cls) { + if (c === null) { + // If null, we resolve it immediately, meaning the relevant + // class was not found, but it is optional. + resolve(null); + return; + } try { resolve(await c.from_pretrained(model, pretrainedOptions)); return; diff --git a/src/tokenizers.js b/src/tokenizers.js index 5cb4997e9..0dc949894 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -2080,6 +2080,18 @@ class BPEDecoder extends Decoder { } } +// Custom decoder for VITS +class VitsDecoder extends Decoder { + /** @type {Decoder['decode_chain']} */ + decode_chain(tokens) { + let decoded = ''; + for (let i = 1; i < tokens.length; i += 2) { + decoded += tokens[i]; + } + return [decoded]; + } +} + /** * This PreTokenizer replaces spaces with the given replacement character, adds a prefix space if requested, @@ -4169,6 +4181,15 @@ export class SpeechT5Tokenizer extends PreTrainedTokenizer { } export class NougatTokenizer extends PreTrainedTokenizer { } +export class VitsTokenizer extends PreTrainedTokenizer { + + constructor(tokenizerJSON, tokenizerConfig) { + super(tokenizerJSON, tokenizerConfig); + + // Custom decoder function + this.decoder = new VitsDecoder({}); + } +} /** * Helper class which is used to instantiate pretrained tokenizers with the `from_pretrained` function. * The chosen tokenizer class is determined by the type specified in the tokenizer config. @@ -4216,6 +4237,7 @@ export class AutoTokenizer { BlenderbotSmallTokenizer, SpeechT5Tokenizer, NougatTokenizer, + VitsTokenizer, // Base case: PreTrainedTokenizer, diff --git a/tests/generate_tests.py b/tests/generate_tests.py index 9d9261bdb..2cd03f40d 100644 --- a/tests/generate_tests.py +++ b/tests/generate_tests.py @@ -42,6 +42,11 @@ # TODO: remove when https://github.com/huggingface/transformers/issues/28164 is fixed 'roformer', + + # TODO: remove when https://github.com/huggingface/transformers/issues/28173 is fixed. Issues include: + # - decoding with `skip_special_tokens=True`. + # - interspersing the pad token is broken. + 'vits', ] TOKENIZERS_TO_IGNORE = [ @@ -118,7 +123,13 @@ "The Heavenly Llama is said to drink water from the ocean and urinates as it rains.[6] According to " \ "Aymara eschatology, llamas will return to the water springs and lagoons where they come from at the " \ "end of time.[6]", - ] + ], + + "vits": [ + "abcdefghijklmnopqrstuvwxyz01234567890", + # Special treatment of characters in certain language + "ț ţ", + ], }, "custom": { "facebook/blenderbot_small-90M": [ diff --git a/tests/init.js b/tests/init.js index b01fe1000..cda487ec9 100644 --- a/tests/init.js +++ b/tests/init.js @@ -26,10 +26,12 @@ export function init() { "Int8Array", "Int16Array", "Int32Array", + "BigInt64Array", "Uint8Array", "Uint8ClampedArray", "Uint16Array", "Uint32Array", + "BigUint64Array", "Float32Array", "Float64Array", ]; diff --git a/tests/pipelines.test.js b/tests/pipelines.test.js index 5ccbb6ed8..4d9619b65 100644 --- a/tests/pipelines.test.js +++ b/tests/pipelines.test.js @@ -909,6 +909,48 @@ describe('Pipelines', () => { }, MAX_TEST_EXECUTION_TIME); }); + describe('Text-to-speech generation', () => { + + // List all models which will be tested + const models = [ + 'microsoft/speecht5_tts', + 'facebook/mms-tts-fra', + ]; + + it(models[0], async () => { + let synthesizer = await pipeline('text-to-speech', m(models[0]), { + // NOTE: Although the quantized version produces incoherent results, + // it it is okay to use for testing. + // quantized: false, + }); + + let speaker_embeddings = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/speaker_embeddings.bin'; + + { // Generate English speech + let output = await synthesizer('Hello, my dog is cute', { speaker_embeddings }); + expect(output.audio.length).toBeGreaterThan(0); + expect(output.sampling_rate).toEqual(16000); + } + + await synthesizer.dispose(); + + }, MAX_TEST_EXECUTION_TIME); + + it(models[1], async () => { + let synthesizer = await pipeline('text-to-speech', m(models[1])); + + { // Generate French speech + let output = await synthesizer('Bonjour'); + expect(output.audio.length).toBeGreaterThan(0); + expect(output.sampling_rate).toEqual(16000); + } + + await synthesizer.dispose(); + + }, MAX_TEST_EXECUTION_TIME); + + }); + describe('Audio classification', () => { // List all models which will be tested diff --git a/tests/tokenizers.test.js b/tests/tokenizers.test.js index bc2bf1208..6bafa2365 100644 --- a/tests/tokenizers.test.js +++ b/tests/tokenizers.test.js @@ -29,6 +29,9 @@ describe('Tokenizers (dynamic)', () => { expect(encoded).toEqual(test.encoded); + // Skip decoding tests if encoding produces zero tokens + if (test.encoded.input_ids.length === 0) continue; + // Test decoding let decoded_with_special = tokenizer.decode(encoded.input_ids, { skip_special_tokens: false }); expect(decoded_with_special).toEqual(test.decoded_with_special);