Skip to content

Commit

Permalink
Add zero-shot-object-detection w/ OwlViT (#392)
Browse files Browse the repository at this point in the history
* Set `batch_size=1` for owlvit exports

* Add support for owlvit models

* Update default quantization settings

* Add list of supported models

* Revert update of owlvit quantization settings

* Add `OwlViTProcessor`

* Move `get_bounding_box` to utils

* Add `ZeroShotObjectDetectionPipeline`

* Add unit tests

* Add owlvit processor test

* Add listed support for `zero-shot-object-detection`

* Add OWL-ViT to list of supported models

* Update README.md

* Fix typo from merge
  • Loading branch information
xenova authored Nov 20, 2023
1 parent b5ef835 commit 7cf8a2c
Show file tree
Hide file tree
Showing 11 changed files with 353 additions and 24 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
| [Text-to-Image](https://huggingface.co/tasks/text-to-image) | `text-to-image` | Generates images from input text. ||
| [Visual Question Answering](https://huggingface.co/tasks/visual-question-answering) | `visual-question-answering` | Answering open-ended questions based on an image. ||
| [Zero-Shot Image Classification](https://huggingface.co/tasks/zero-shot-image-classification) | `zero-shot-image-classification` | Classifying images into classes that are unseen during training. |[(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotImageClassificationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=zero-shot-image-classification&library=transformers.js) |
| [Zero-Shot Object Detection](https://huggingface.co/tasks/zero-shot-object-detection) | `zero-shot-object-detection` | Identify objects of classes that are unseen during training. |[(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotObjectDetectionPipeline)<br>[(models)](https://huggingface.co/models?other=zero-shot-object-detection&library=transformers.js) |


#### Reinforcement Learning
Expand Down Expand Up @@ -300,6 +301,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team.
1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al.
1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby.
1. **[ResNet](https://huggingface.co/docs/transformers/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei.
Expand Down
1 change: 1 addition & 0 deletions docs/snippets/5_supported-tasks.snippet
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
| [Text-to-Image](https://huggingface.co/tasks/text-to-image) | `text-to-image` | Generates images from input text. | ❌ |
| [Visual Question Answering](https://huggingface.co/tasks/visual-question-answering) | `visual-question-answering` | Answering open-ended questions based on an image. | ❌ |
| [Zero-Shot Image Classification](https://huggingface.co/tasks/zero-shot-image-classification) | `zero-shot-image-classification` | Classifying images into classes that are unseen during training. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotImageClassificationPipeline)<br>[(models)](https://huggingface.co/models?pipeline_tag=zero-shot-image-classification&library=transformers.js) |
| [Zero-Shot Object Detection](https://huggingface.co/tasks/zero-shot-object-detection) | `zero-shot-object-detection` | Identify objects of classes that are unseen during training. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotObjectDetectionPipeline)<br>[(models)](https://huggingface.co/models?other=zero-shot-object-detection&library=transformers.js) |


#### Reinforcement Learning
Expand Down
1 change: 1 addition & 0 deletions docs/snippets/6_supported-models.snippet
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team.
1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al.
1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby.
1. **[ResNet](https://huggingface.co/docs/transformers/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei.
Expand Down
7 changes: 6 additions & 1 deletion scripts/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@
'vision-encoder-decoder': {
'per_channel': False,
'reduce_range': False,
}
},
}

MODELS_WITHOUT_TOKENIZERS = [
Expand Down Expand Up @@ -326,6 +326,11 @@ def main():
with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp:
json.dump(tokenizer_json, fp, indent=4)

elif config.model_type == 'owlvit':
# Override default batch size to 1, needed because non-maximum suppression is performed for exporting.
# For more information, see https://github.com/huggingface/optimum/blob/e3b7efb1257c011db907ef40ab340e795cc5684c/optimum/exporters/onnx/model_configs.py#L1028-L1032
export_kwargs['batch_size'] = 1

else:
pass # TODO

Expand Down
7 changes: 7 additions & 0 deletions scripts/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,13 @@
# (TODO conversational)
'PygmalionAI/pygmalion-350m',
],
'owlvit': [
# Object detection (Zero-shot object detection)
# NOTE: Exported with --batch_size 1
'google/owlvit-base-patch32',
'google/owlvit-base-patch16',
'google/owlvit-large-patch14',
],
'resnet': [
# Image classification
'microsoft/resnet-18',
Expand Down
17 changes: 17 additions & 0 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -3200,6 +3200,12 @@ export class MobileViTForImageClassification extends MobileViTPreTrainedModel {

//////////////////////////////////////////////////

//////////////////////////////////////////////////
export class OwlViTPreTrainedModel extends PreTrainedModel { }
export class OwlViTModel extends OwlViTPreTrainedModel { }
export class OwlViTForObjectDetection extends OwlViTPreTrainedModel { }
//////////////////////////////////////////////////

//////////////////////////////////////////////////
// Beit Models
export class BeitPreTrainedModel extends PreTrainedModel { }
Expand Down Expand Up @@ -4010,6 +4016,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
['detr', ['DetrModel', DetrModel]],
['vit', ['ViTModel', ViTModel]],
['mobilevit', ['MobileViTModel', MobileViTModel]],
['owlvit', ['OwlViTModel', OwlViTModel]],
['beit', ['BeitModel', BeitModel]],
['deit', ['DeiTModel', DeiTModel]],
['resnet', ['ResNetModel', ResNetModel]],
Expand Down Expand Up @@ -4171,6 +4178,10 @@ const MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = new Map([
['yolos', ['YolosForObjectDetection', YolosForObjectDetection]],
]);

const MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = new Map([
['owlvit', ['OwlViTForObjectDetection', OwlViTForObjectDetection]],
]);

const MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = new Map([
['detr', ['DetrForSegmentation', DetrForSegmentation]],
]);
Expand Down Expand Up @@ -4210,6 +4221,7 @@ const MODEL_CLASS_TYPE_MAPPING = [
[MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_CTC_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
Expand Down Expand Up @@ -4380,6 +4392,11 @@ export class AutoModelForObjectDetection extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES];
}

export class AutoModelForZeroShotObjectDetection extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES];
}


/**
* Helper class which is used to instantiate pretrained object detection models with the `from_pretrained` function.
* The chosen model class is determined by the type specified in the model config.
Expand Down
157 changes: 146 additions & 11 deletions src/pipelines.js
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import {
AutoModelForImageClassification,
AutoModelForImageSegmentation,
AutoModelForObjectDetection,
AutoModelForZeroShotObjectDetection,
AutoModelForDocumentQuestionAnswering,
AutoModelForImageToImage,
// AutoModelForTextToWaveform,
Expand All @@ -50,6 +51,7 @@ import {
dispatchCallback,
pop,
product,
get_bounding_box,
} from './utils/core.js';
import {
softmax,
Expand Down Expand Up @@ -1753,28 +1755,148 @@ export class ObjectDetectionPipeline extends Pipeline {
return {
score: batch.scores[i],
label: id2label[batch.classes[i]],
box: this._get_bounding_box(box, !percentage),
box: get_bounding_box(box, !percentage),
}
})
})

return isBatched ? result : result[0];
}
}

/**
* Zero-shot object detection pipeline. This pipeline predicts bounding boxes of
* objects when you provide an image and a set of `candidate_labels`.
*
* **Example:** Zero-shot object detection w/ `Xenova/clip-vit-base-patch32`.
* ```javascript
* let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/astronaut.png';
* let candidate_labels = ['human face', 'rocket', 'helmet', 'american flag'];
* let detector = await pipeline('zero-shot-object-detection', 'Xenova/owlvit-base-patch32');
* let output = await detector(url, candidate_labels);
* // [
* // {
* // score: 0.24392342567443848,
* // label: 'human face',
* // box: { xmin: 180, ymin: 67, xmax: 274, ymax: 175 }
* // },
* // {
* // score: 0.15129457414150238,
* // label: 'american flag',
* // box: { xmin: 0, ymin: 4, xmax: 106, ymax: 513 }
* // },
* // {
* // score: 0.13649864494800568,
* // label: 'helmet',
* // box: { xmin: 277, ymin: 337, xmax: 511, ymax: 511 }
* // },
* // {
* // score: 0.10262022167444229,
* // label: 'rocket',
* // box: { xmin: 352, ymin: -1, xmax: 463, ymax: 287 }
* // }
* // ]
* ```
*
* **Example:** Zero-shot object detection w/ `Xenova/clip-vit-base-patch32` (returning top 4 matches and setting a threshold).
* ```javascript
* let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/beach.png';
* let candidate_labels = ['hat', 'book', 'sunglasses', 'camera'];
* let detector = await pipeline('zero-shot-object-detection', 'Xenova/owlvit-base-patch32');
* let output = await detector(url, candidate_labels, { topk: 4, threshold: 0.05 });
* // [
* // {
* // score: 0.1606510728597641,
* // label: 'sunglasses',
* // box: { xmin: 347, ymin: 229, xmax: 429, ymax: 264 }
* // },
* // {
* // score: 0.08935828506946564,
* // label: 'hat',
* // box: { xmin: 38, ymin: 174, xmax: 258, ymax: 364 }
* // },
* // {
* // score: 0.08530698716640472,
* // label: 'camera',
* // box: { xmin: 187, ymin: 350, xmax: 260, ymax: 411 }
* // },
* // {
* // score: 0.08349756896495819,
* // label: 'book',
* // box: { xmin: 261, ymin: 280, xmax: 494, ymax: 425 }
* // }
* // ]
* ```
*/
export class ZeroShotObjectDetectionPipeline extends Pipeline {

/**
* Helper function to convert list [xmin, xmax, ymin, ymax] into object { "xmin": xmin, ... }
* @param {number[]} box The bounding box as a list.
* @param {boolean} asInteger Whether to cast to integers.
* @returns {Object} The bounding box as an object.
* @private
* Create a new ZeroShotObjectDetectionPipeline.
* @param {Object} options An object containing the following properties:
* @param {string} [options.task] The task of the pipeline. Useful for specifying subtasks.
* @param {PreTrainedModel} [options.model] The model to use.
* @param {PreTrainedTokenizer} [options.tokenizer] The tokenizer to use.
* @param {Processor} [options.processor] The processor to use.
*/
_get_bounding_box(box, asInteger) {
if (asInteger) {
box = box.map(x => x | 0);
constructor(options) {
super(options);
}

/**
* Detect objects (bounding boxes & classes) in the image(s) passed as inputs.
* @param {Array} images The input images.
* @param {string[]} candidate_labels What the model should recognize in the image.
* @param {Object} options The options for the classification.
* @param {number} [options.threshold] The probability necessary to make a prediction.
* @param {number} [options.topk] The number of top predictions that will be returned by the pipeline.
* If the provided number is `null` or higher than the number of predictions available, it will default
* to the number of predictions.
* @param {boolean} [options.percentage=false] Whether to return the boxes coordinates in percentage (true) or in pixels (false).
* @returns {Promise<any>} An array of classifications for each input image or a single classification object if only one input image is provided.
*/
async _call(images, candidate_labels, {
threshold = 0.1,
topk = null,
percentage = false,
} = {}) {
const isBatched = Array.isArray(images);
images = await prepareImages(images);

// Run tokenization
const text_inputs = this.tokenizer(candidate_labels, {
padding: true,
truncation: true
});

// Run processor
const model_inputs = await this.processor(images);

// Since non-maximum suppression is performed for exporting, we need to
// process each image separately. For more information, see:
// https://github.com/huggingface/optimum/blob/e3b7efb1257c011db907ef40ab340e795cc5684c/optimum/exporters/onnx/model_configs.py#L1028-L1032
const toReturn = [];
for (let i = 0; i < images.length; ++i) {
const image = images[i];
const imageSize = [[image.height, image.width]];
const pixel_values = model_inputs.pixel_values[i].unsqueeze_(0);

// Run model with both text and pixel inputs
const output = await this.model({ ...text_inputs, pixel_values });

// @ts-ignore
const processed = this.processor.feature_extractor.post_process_object_detection(output, threshold, imageSize, true)[0];
let result = processed.boxes.map((box, i) => ({
score: processed.scores[i],
label: candidate_labels[processed.classes[i]],
box: get_bounding_box(box, !percentage),
})).sort((a, b) => b.score - a.score);
if (topk !== null) {
result = result.slice(0, topk);
}
toReturn.push(result)
}
const [xmin, ymin, xmax, ymax] = box;

return { xmin, ymin, xmax, ymax };
return isBatched ? toReturn : toReturn[0];
}
}

Expand Down Expand Up @@ -2187,6 +2309,18 @@ const SUPPORTED_TASKS = {
},
"type": "multimodal",
},
"zero-shot-object-detection": {
"tokenizer": AutoTokenizer,
"pipeline": ZeroShotObjectDetectionPipeline,
"model": AutoModelForZeroShotObjectDetection,
"processor": AutoProcessor,
"default": {
// TODO: replace with original
// "model": "google/owlvit-base-patch32",
"model": "Xenova/owlvit-base-patch32",
},
"type": "multimodal",
},
"document-question-answering": {
"tokenizer": AutoTokenizer,
"pipeline": DocumentQuestionAnsweringPipeline,
Expand Down Expand Up @@ -2261,6 +2395,7 @@ const TASK_ALIASES = {
* - `"translation_xx_to_yy"`: will return a `TranslationPipeline`.
* - `"zero-shot-classification"`: will return a `ZeroShotClassificationPipeline`.
* - `"zero-shot-image-classification"`: will return a `ZeroShotImageClassificationPipeline`.
* - `"zero-shot-object-detection"`: will return a `ZeroShotObjectDetectionPipeline`.
* @param {string} [model=null] The name of the pre-trained model to use. If not specified, the default model for the task will be used.
* @param {import('./utils/hub.js').PretrainedOptions} [options] Optional parameters for the pipeline.
* @returns {Promise<Pipeline>} A Pipeline object for the specified task.
Expand Down
Loading

0 comments on commit 7cf8a2c

Please sign in to comment.