Skip to content

Commit

Permalink
Add support for ViTMatte models (#448)
Browse files Browse the repository at this point in the history
* Add support for `VitMatte` models

* Add `VitMatteImageProcessor`

* Add `VitMatteImageProcessor` unit test

* Fix typo

* Add example code for `VitMatteForImageMatting`

* Fix JSDoc

* Fix typo
  • Loading branch information
xenova authored Dec 13, 2023
1 parent 80d22da commit b978ff8
Show file tree
Hide file tree
Showing 6 changed files with 202 additions and 29 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei.
1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
1. **[ViTMatte](https://huggingface.co/docs/transformers/model_doc/vitmatte)** (from HUST-VL) released with the paper [ViTMatte: Boosting Image Matting with Pretrained Plain Vision Transformers](https://arxiv.org/abs/2305.15272) by Jingfeng Yao, Xinggang Wang, Shusheng Yang, Baoyuan Wang.
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
1. **[WavLM](https://huggingface.co/docs/transformers/model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei.
1. **[Whisper](https://huggingface.co/docs/transformers/model_doc/whisper)** (from OpenAI) released with the paper [Robust Speech Recognition via Large-Scale Weak Supervision](https://cdn.openai.com/papers/whisper.pdf) by Alec Radford, Jong Wook Kim, Tao Xu, Greg Brockman, Christine McLeavey, Ilya Sutskever.
Expand Down
1 change: 1 addition & 0 deletions docs/snippets/6_supported-models.snippet
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
1. **[T5v1.1](https://huggingface.co/docs/transformers/model_doc/t5v1.1)** (from Google AI) released in the repository [google-research/text-to-text-transfer-transformer](https://github.com/google-research/text-to-text-transfer-transformer/blob/main/released_checkpoints.md#t511) by Colin Raffel and Noam Shazeer and Adam Roberts and Katherine Lee and Sharan Narang and Michael Matena and Yanqi Zhou and Wei Li and Peter J. Liu.
1. **[TrOCR](https://huggingface.co/docs/transformers/model_doc/trocr)** (from Microsoft), released together with the paper [TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models](https://arxiv.org/abs/2109.10282) by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei.
1. **[Vision Transformer (ViT)](https://huggingface.co/docs/transformers/model_doc/vit)** (from Google AI) released with the paper [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929) by Alexey Dosovitskiy, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, Matthias Minderer, Georg Heigold, Sylvain Gelly, Jakob Uszkoreit, Neil Houlsby.
1. **[ViTMatte](https://huggingface.co/docs/transformers/model_doc/vitmatte)** (from HUST-VL) released with the paper [ViTMatte: Boosting Image Matting with Pretrained Plain Vision Transformers](https://arxiv.org/abs/2305.15272) by Jingfeng Yao, Xinggang Wang, Shusheng Yang, Baoyuan Wang.
1. **[Wav2Vec2](https://huggingface.co/docs/transformers/model_doc/wav2vec2)** (from Facebook AI) released with the paper [wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations](https://arxiv.org/abs/2006.11477) by Alexei Baevski, Henry Zhou, Abdelrahman Mohamed, Michael Auli.
1. **[WavLM](https://huggingface.co/docs/transformers/model_doc/wavlm)** (from Microsoft Research) released with the paper [WavLM: Large-Scale Self-Supervised Pre-Training for Full Stack Speech Processing](https://arxiv.org/abs/2110.13900) by Sanyuan Chen, Chengyi Wang, Zhengyang Chen, Yu Wu, Shujie Liu, Zhuo Chen, Jinyu Li, Naoyuki Kanda, Takuya Yoshioka, Xiong Xiao, Jian Wu, Long Zhou, Shuo Ren, Yanmin Qian, Yao Qian, Jian Wu, Michael Zeng, Furu Wei.
1. **[Whisper](https://huggingface.co/docs/transformers/model_doc/whisper)** (from OpenAI) released with the paper [Robust Speech Recognition via Large-Scale Weak Supervision](https://cdn.openai.com/papers/whisper.pdf) by Alec Radford, Jong Wook Kim, Tao Xu, Greg Brockman, Christine McLeavey, Ilya Sutskever.
Expand Down
9 changes: 9 additions & 0 deletions scripts/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,6 +775,15 @@
'google/vit-base-patch16-224',
],
},
'vitmatte': {
# Image matting
'image-matting': [
'hustvl/vitmatte-small-distinctions-646',
'hustvl/vitmatte-base-distinctions-646',
'hustvl/vitmatte-small-composition-1k',
'hustvl/vitmatte-base-composition-1k',
],
},
'wav2vec2': {
# Feature extraction # NOTE: requires --task feature-extraction
'feature-extraction': [
Expand Down
88 changes: 87 additions & 1 deletion src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -3441,6 +3441,74 @@ export class ViTForImageClassification extends ViTPreTrainedModel {
}
//////////////////////////////////////////////////

//////////////////////////////////////////////////
export class VitMattePreTrainedModel extends PreTrainedModel { }

/**
* ViTMatte framework leveraging any vision backbone e.g. for ADE20k, CityScapes.
*
* **Example:** Perform image matting with a `VitMatteForImageMatting` model.
* ```javascript
* import { AutoProcessor, VitMatteForImageMatting, RawImage } from '@xenova/transformers';
*
* // Load processor and model
* const processor = await AutoProcessor.from_pretrained('Xenova/vitmatte-small-distinctions-646');
* const model = await VitMatteForImageMatting.from_pretrained('Xenova/vitmatte-small-distinctions-646');
*
* // Load image and trimap
* const image = await RawImage.fromURL('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/vitmatte_image.png');
* const trimap = await RawImage.fromURL('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/vitmatte_trimap.png');
*
* // Prepare image + trimap for the model
* const inputs = await processor(image, trimap);
*
* // Predict alpha matte
* const { alphas } = await model(inputs);
* // Tensor {
* // dims: [ 1, 1, 640, 960 ],
* // type: 'float32',
* // size: 614400,
* // data: Float32Array(614400) [ 0.9894027709960938, 0.9970508813858032, ... ]
* // }
* ```
*
* You can visualize the alpha matte as follows:
* ```javascript
* import { Tensor, cat } from '@xenova/transformers';
*
* // Visualize predicted alpha matte
* const imageTensor = new Tensor(
* 'uint8',
* new Uint8Array(image.data),
* [image.height, image.width, image.channels]
* ).transpose(2, 0, 1);
*
* // Convert float (0-1) alpha matte to uint8 (0-255)
* const alphaChannel = alphas
* .squeeze(0)
* .mul_(255)
* .clamp_(0, 255)
* .round_()
* .to('uint8');
*
* // Concatenate original image with predicted alpha
* const imageData = cat([imageTensor, alphaChannel], 0);
*
* // Save output image
* const outputImage = RawImage.fromTensor(imageData);
* outputImage.save('output.png');
* ```
*/
export class VitMatteForImageMatting extends VitMattePreTrainedModel {
/**
* @param {any} model_inputs
*/
async _call(model_inputs) {
return new ImageMattingOutput(await super._call(model_inputs));
}
}
//////////////////////////////////////////////////

//////////////////////////////////////////////////
export class MobileViTPreTrainedModel extends PreTrainedModel { }
export class MobileViTModel extends MobileViTPreTrainedModel { }
Expand Down Expand Up @@ -4827,7 +4895,9 @@ const MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES = new Map([
['audio-spectrogram-transformer', ['ASTForAudioClassification', ASTForAudioClassification]],
]);


const MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES = new Map([
['vitmatte', ['VitMatteForImageMatting', VitMatteForImageMatting]],
]);

const MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = new Map([
['swin2sr', ['Swin2SRForImageSuperResolution', Swin2SRForImageSuperResolution]],
Expand All @@ -4853,6 +4923,7 @@ const MODEL_CLASS_TYPE_MAPPING = [
[MODEL_FOR_VISION_2_SEQ_MAPPING_NAMES, MODEL_TYPES.Vision2Seq],
[MODEL_FOR_IMAGE_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_DEPTH_ESTIMATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
Expand Down Expand Up @@ -5058,6 +5129,10 @@ export class AutoModelForDocumentQuestionAnswering extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_DOCUMENT_QUESTION_ANSWERING_MAPPING_NAMES];
}

export class AutoModelForImageMatting extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_MATTING_MAPPING_NAMES];
}

export class AutoModelForImageToImage extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES];
}
Expand Down Expand Up @@ -5177,3 +5252,14 @@ export class CausalLMOutputWithPast extends ModelOutput {
this.past_key_values = past_key_values;
}
}

export class ImageMattingOutput extends ModelOutput {
/**
* @param {Object} output The output of the model.
* @param {Tensor} output.alphas Estimated alpha values, of shape `(batch_size, num_channels, height, width)`.
*/
constructor({ alphas }) {
super();
this.alphas = alphas;
}
}
Loading

0 comments on commit b978ff8

Please sign in to comment.