diff --git a/README.md b/README.md
index 68bbe3180..4036a0a75 100644
--- a/README.md
+++ b/README.md
@@ -247,6 +247,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
| [Text-to-Image](https://huggingface.co/tasks/text-to-image) | `text-to-image` | Generates images from input text. | ❌ |
| [Visual Question Answering](https://huggingface.co/tasks/visual-question-answering) | `visual-question-answering` | Answering open-ended questions based on an image. | ❌ |
| [Zero-Shot Image Classification](https://huggingface.co/tasks/zero-shot-image-classification) | `zero-shot-image-classification` | Classifying images into classes that are unseen during training. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotImageClassificationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=zero-shot-image-classification&library=transformers.js) |
+| [Zero-Shot Object Detection](https://huggingface.co/tasks/zero-shot-object-detection) | `zero-shot-object-detection` | Identify objects of classes that are unseen during training. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotObjectDetectionPipeline)
[(models)](https://huggingface.co/models?other=zero-shot-object-detection&library=transformers.js) |
#### Reinforcement Learning
@@ -300,6 +301,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team.
1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al.
+1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby.
1. **[ResNet](https://huggingface.co/docs/transformers/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei.
diff --git a/docs/snippets/5_supported-tasks.snippet b/docs/snippets/5_supported-tasks.snippet
index dee075808..5699f8f28 100644
--- a/docs/snippets/5_supported-tasks.snippet
+++ b/docs/snippets/5_supported-tasks.snippet
@@ -59,6 +59,7 @@
| [Text-to-Image](https://huggingface.co/tasks/text-to-image) | `text-to-image` | Generates images from input text. | ❌ |
| [Visual Question Answering](https://huggingface.co/tasks/visual-question-answering) | `visual-question-answering` | Answering open-ended questions based on an image. | ❌ |
| [Zero-Shot Image Classification](https://huggingface.co/tasks/zero-shot-image-classification) | `zero-shot-image-classification` | Classifying images into classes that are unseen during training. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotImageClassificationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=zero-shot-image-classification&library=transformers.js) |
+| [Zero-Shot Object Detection](https://huggingface.co/tasks/zero-shot-object-detection) | `zero-shot-object-detection` | Identify objects of classes that are unseen during training. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotObjectDetectionPipeline)
[(models)](https://huggingface.co/models?other=zero-shot-object-detection&library=transformers.js) |
#### Reinforcement Learning
diff --git a/docs/snippets/6_supported-models.snippet b/docs/snippets/6_supported-models.snippet
index 42c12bd2a..3572a5d95 100644
--- a/docs/snippets/6_supported-models.snippet
+++ b/docs/snippets/6_supported-models.snippet
@@ -42,6 +42,7 @@
1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel.
1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team.
1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al.
+1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby.
1. **[ResNet](https://huggingface.co/docs/transformers/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun.
1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov.
1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei.
diff --git a/scripts/convert.py b/scripts/convert.py
index ffc999bfa..5f3c72e56 100644
--- a/scripts/convert.py
+++ b/scripts/convert.py
@@ -84,7 +84,7 @@
'vision-encoder-decoder': {
'per_channel': False,
'reduce_range': False,
- }
+ },
}
MODELS_WITHOUT_TOKENIZERS = [
@@ -326,6 +326,11 @@ def main():
with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp:
json.dump(tokenizer_json, fp, indent=4)
+ elif config.model_type == 'owlvit':
+ # Override default batch size to 1, needed because non-maximum suppression is performed for exporting.
+ # For more information, see https://github.com/huggingface/optimum/blob/e3b7efb1257c011db907ef40ab340e795cc5684c/optimum/exporters/onnx/model_configs.py#L1028-L1032
+ export_kwargs['batch_size'] = 1
+
else:
pass # TODO
diff --git a/scripts/supported_models.py b/scripts/supported_models.py
index 79d0282c5..2c7edeef9 100644
--- a/scripts/supported_models.py
+++ b/scripts/supported_models.py
@@ -355,6 +355,13 @@
# (TODO conversational)
'PygmalionAI/pygmalion-350m',
],
+ 'owlvit': [
+ # Object detection (Zero-shot object detection)
+ # NOTE: Exported with --batch_size 1
+ 'google/owlvit-base-patch32',
+ 'google/owlvit-base-patch16',
+ 'google/owlvit-large-patch14',
+ ],
'resnet': [
# Image classification
'microsoft/resnet-18',
diff --git a/src/models.js b/src/models.js
index 94c2ba9ae..d56d4b4f6 100644
--- a/src/models.js
+++ b/src/models.js
@@ -3200,6 +3200,12 @@ export class MobileViTForImageClassification extends MobileViTPreTrainedModel {
//////////////////////////////////////////////////
+//////////////////////////////////////////////////
+export class OwlViTPreTrainedModel extends PreTrainedModel { }
+export class OwlViTModel extends OwlViTPreTrainedModel { }
+export class OwlViTForObjectDetection extends OwlViTPreTrainedModel { }
+//////////////////////////////////////////////////
+
//////////////////////////////////////////////////
// Beit Models
export class BeitPreTrainedModel extends PreTrainedModel { }
@@ -4010,6 +4016,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([
['detr', ['DetrModel', DetrModel]],
['vit', ['ViTModel', ViTModel]],
['mobilevit', ['MobileViTModel', MobileViTModel]],
+ ['owlvit', ['OwlViTModel', OwlViTModel]],
['beit', ['BeitModel', BeitModel]],
['deit', ['DeiTModel', DeiTModel]],
['resnet', ['ResNetModel', ResNetModel]],
@@ -4171,6 +4178,10 @@ const MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = new Map([
['yolos', ['YolosForObjectDetection', YolosForObjectDetection]],
]);
+const MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = new Map([
+ ['owlvit', ['OwlViTForObjectDetection', OwlViTForObjectDetection]],
+]);
+
const MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = new Map([
['detr', ['DetrForSegmentation', DetrForSegmentation]],
]);
@@ -4210,6 +4221,7 @@ const MODEL_CLASS_TYPE_MAPPING = [
[MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
+ [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_CTC_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
[MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly],
@@ -4380,6 +4392,11 @@ export class AutoModelForObjectDetection extends PretrainedMixin {
static MODEL_CLASS_MAPPINGS = [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES];
}
+export class AutoModelForZeroShotObjectDetection extends PretrainedMixin {
+ static MODEL_CLASS_MAPPINGS = [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES];
+}
+
+
/**
* Helper class which is used to instantiate pretrained object detection models with the `from_pretrained` function.
* The chosen model class is determined by the type specified in the model config.
diff --git a/src/pipelines.js b/src/pipelines.js
index 0c2e0b3a3..7bb125e5d 100644
--- a/src/pipelines.js
+++ b/src/pipelines.js
@@ -33,6 +33,7 @@ import {
AutoModelForImageClassification,
AutoModelForImageSegmentation,
AutoModelForObjectDetection,
+ AutoModelForZeroShotObjectDetection,
AutoModelForDocumentQuestionAnswering,
AutoModelForImageToImage,
// AutoModelForTextToWaveform,
@@ -50,6 +51,7 @@ import {
dispatchCallback,
pop,
product,
+ get_bounding_box,
} from './utils/core.js';
import {
softmax,
@@ -1753,28 +1755,148 @@ export class ObjectDetectionPipeline extends Pipeline {
return {
score: batch.scores[i],
label: id2label[batch.classes[i]],
- box: this._get_bounding_box(box, !percentage),
+ box: get_bounding_box(box, !percentage),
}
})
})
return isBatched ? result : result[0];
}
+}
+
+/**
+ * Zero-shot object detection pipeline. This pipeline predicts bounding boxes of
+ * objects when you provide an image and a set of `candidate_labels`.
+ *
+ * **Example:** Zero-shot object detection w/ `Xenova/clip-vit-base-patch32`.
+ * ```javascript
+ * let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/astronaut.png';
+ * let candidate_labels = ['human face', 'rocket', 'helmet', 'american flag'];
+ * let detector = await pipeline('zero-shot-object-detection', 'Xenova/owlvit-base-patch32');
+ * let output = await detector(url, candidate_labels);
+ * // [
+ * // {
+ * // score: 0.24392342567443848,
+ * // label: 'human face',
+ * // box: { xmin: 180, ymin: 67, xmax: 274, ymax: 175 }
+ * // },
+ * // {
+ * // score: 0.15129457414150238,
+ * // label: 'american flag',
+ * // box: { xmin: 0, ymin: 4, xmax: 106, ymax: 513 }
+ * // },
+ * // {
+ * // score: 0.13649864494800568,
+ * // label: 'helmet',
+ * // box: { xmin: 277, ymin: 337, xmax: 511, ymax: 511 }
+ * // },
+ * // {
+ * // score: 0.10262022167444229,
+ * // label: 'rocket',
+ * // box: { xmin: 352, ymin: -1, xmax: 463, ymax: 287 }
+ * // }
+ * // ]
+ * ```
+ *
+ * **Example:** Zero-shot object detection w/ `Xenova/clip-vit-base-patch32` (returning top 4 matches and setting a threshold).
+ * ```javascript
+ * let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/beach.png';
+ * let candidate_labels = ['hat', 'book', 'sunglasses', 'camera'];
+ * let detector = await pipeline('zero-shot-object-detection', 'Xenova/owlvit-base-patch32');
+ * let output = await detector(url, candidate_labels, { topk: 4, threshold: 0.05 });
+ * // [
+ * // {
+ * // score: 0.1606510728597641,
+ * // label: 'sunglasses',
+ * // box: { xmin: 347, ymin: 229, xmax: 429, ymax: 264 }
+ * // },
+ * // {
+ * // score: 0.08935828506946564,
+ * // label: 'hat',
+ * // box: { xmin: 38, ymin: 174, xmax: 258, ymax: 364 }
+ * // },
+ * // {
+ * // score: 0.08530698716640472,
+ * // label: 'camera',
+ * // box: { xmin: 187, ymin: 350, xmax: 260, ymax: 411 }
+ * // },
+ * // {
+ * // score: 0.08349756896495819,
+ * // label: 'book',
+ * // box: { xmin: 261, ymin: 280, xmax: 494, ymax: 425 }
+ * // }
+ * // ]
+ * ```
+ */
+export class ZeroShotObjectDetectionPipeline extends Pipeline {
/**
- * Helper function to convert list [xmin, xmax, ymin, ymax] into object { "xmin": xmin, ... }
- * @param {number[]} box The bounding box as a list.
- * @param {boolean} asInteger Whether to cast to integers.
- * @returns {Object} The bounding box as an object.
- * @private
+ * Create a new ZeroShotObjectDetectionPipeline.
+ * @param {Object} options An object containing the following properties:
+ * @param {string} [options.task] The task of the pipeline. Useful for specifying subtasks.
+ * @param {PreTrainedModel} [options.model] The model to use.
+ * @param {PreTrainedTokenizer} [options.tokenizer] The tokenizer to use.
+ * @param {Processor} [options.processor] The processor to use.
*/
- _get_bounding_box(box, asInteger) {
- if (asInteger) {
- box = box.map(x => x | 0);
+ constructor(options) {
+ super(options);
+ }
+
+ /**
+ * Detect objects (bounding boxes & classes) in the image(s) passed as inputs.
+ * @param {Array} images The input images.
+ * @param {string[]} candidate_labels What the model should recognize in the image.
+ * @param {Object} options The options for the classification.
+ * @param {number} [options.threshold] The probability necessary to make a prediction.
+ * @param {number} [options.topk] The number of top predictions that will be returned by the pipeline.
+ * If the provided number is `null` or higher than the number of predictions available, it will default
+ * to the number of predictions.
+ * @param {boolean} [options.percentage=false] Whether to return the boxes coordinates in percentage (true) or in pixels (false).
+ * @returns {Promise} An array of classifications for each input image or a single classification object if only one input image is provided.
+ */
+ async _call(images, candidate_labels, {
+ threshold = 0.1,
+ topk = null,
+ percentage = false,
+ } = {}) {
+ const isBatched = Array.isArray(images);
+ images = await prepareImages(images);
+
+ // Run tokenization
+ const text_inputs = this.tokenizer(candidate_labels, {
+ padding: true,
+ truncation: true
+ });
+
+ // Run processor
+ const model_inputs = await this.processor(images);
+
+ // Since non-maximum suppression is performed for exporting, we need to
+ // process each image separately. For more information, see:
+ // https://github.com/huggingface/optimum/blob/e3b7efb1257c011db907ef40ab340e795cc5684c/optimum/exporters/onnx/model_configs.py#L1028-L1032
+ const toReturn = [];
+ for (let i = 0; i < images.length; ++i) {
+ const image = images[i];
+ const imageSize = [[image.height, image.width]];
+ const pixel_values = model_inputs.pixel_values[i].unsqueeze_(0);
+
+ // Run model with both text and pixel inputs
+ const output = await this.model({ ...text_inputs, pixel_values });
+
+ // @ts-ignore
+ const processed = this.processor.feature_extractor.post_process_object_detection(output, threshold, imageSize, true)[0];
+ let result = processed.boxes.map((box, i) => ({
+ score: processed.scores[i],
+ label: candidate_labels[processed.classes[i]],
+ box: get_bounding_box(box, !percentage),
+ })).sort((a, b) => b.score - a.score);
+ if (topk !== null) {
+ result = result.slice(0, topk);
+ }
+ toReturn.push(result)
}
- const [xmin, ymin, xmax, ymax] = box;
- return { xmin, ymin, xmax, ymax };
+ return isBatched ? toReturn : toReturn[0];
}
}
@@ -2187,6 +2309,18 @@ const SUPPORTED_TASKS = {
},
"type": "multimodal",
},
+ "zero-shot-object-detection": {
+ "tokenizer": AutoTokenizer,
+ "pipeline": ZeroShotObjectDetectionPipeline,
+ "model": AutoModelForZeroShotObjectDetection,
+ "processor": AutoProcessor,
+ "default": {
+ // TODO: replace with original
+ // "model": "google/owlvit-base-patch32",
+ "model": "Xenova/owlvit-base-patch32",
+ },
+ "type": "multimodal",
+ },
"document-question-answering": {
"tokenizer": AutoTokenizer,
"pipeline": DocumentQuestionAnsweringPipeline,
@@ -2261,6 +2395,7 @@ const TASK_ALIASES = {
* - `"translation_xx_to_yy"`: will return a `TranslationPipeline`.
* - `"zero-shot-classification"`: will return a `ZeroShotClassificationPipeline`.
* - `"zero-shot-image-classification"`: will return a `ZeroShotImageClassificationPipeline`.
+ * - `"zero-shot-object-detection"`: will return a `ZeroShotObjectDetectionPipeline`.
* @param {string} [model=null] The name of the pre-trained model to use. If not specified, the default model for the task will be used.
* @param {import('./utils/hub.js').PretrainedOptions} [options] Optional parameters for the pipeline.
* @returns {Promise} A Pipeline object for the specified task.
diff --git a/src/processors.js b/src/processors.js
index dd7130156..07d47e7ac 100644
--- a/src/processors.js
+++ b/src/processors.js
@@ -64,9 +64,13 @@ function center_to_corners_format([centerX, centerY, width, height]) {
* @param {Object} outputs The outputs of the model that must be post-processed
* @param {Tensor} outputs.logits The logits
* @param {Tensor} outputs.pred_boxes The predicted boxes.
+ * @param {number} [threshold=0.5] The threshold to use for the scores.
+ * @param {number[][]} [target_sizes=null] The sizes of the original images.
+ * @param {boolean} [is_zero_shot=false] Whether zero-shot object detection was performed.
* @return {Object[]} An array of objects containing the post-processed outputs.
+ * @private
*/
-function post_process_object_detection(outputs, threshold = 0.5, target_sizes = null) {
+function post_process_object_detection(outputs, threshold = 0.5, target_sizes = null, is_zero_shot = false) {
const out_logits = outputs.logits;
const out_bbox = outputs.pred_boxes;
const [batch_size, num_boxes, num_classes] = out_logits.dims;
@@ -88,19 +92,33 @@ function post_process_object_detection(outputs, threshold = 0.5, target_sizes =
for (let j = 0; j < num_boxes; ++j) {
let logit = logits[j];
- // Get most probable class
- let maxIndex = max(logit.data)[1];
+ let indices = [];
+ let probs;
+ if (is_zero_shot) {
+ // Get indices of classes with high enough probability
+ probs = logit.sigmoid().data;
+ for (let k = 0; k < probs.length; ++k) {
+ if (probs[k] > threshold) {
+ indices.push(k);
+ }
+ }
- if (maxIndex === num_classes - 1) {
- // This is the background class, skip it
- continue;
+ } else {
+ // Get most probable class
+ let maxIndex = max(logit.data)[1];
+
+ if (maxIndex === num_classes - 1) {
+ // This is the background class, skip it
+ continue;
+ }
+ indices.push(maxIndex);
+
+ // Compute softmax over classes
+ probs = softmax(logit.data);
}
- // Compute softmax over classes
- let probs = softmax(logit.data);
+ for (const index of indices) {
- let score = probs[maxIndex];
- if (score > threshold) {
// Some class has a high enough probability
/** @type {number[]} */
let box = bbox[j].data;
@@ -112,8 +130,8 @@ function post_process_object_detection(outputs, threshold = 0.5, target_sizes =
}
info.boxes.push(box);
- info.classes.push(maxIndex);
- info.scores.push(score);
+ info.classes.push(index);
+ info.scores.push(probs[index]);
}
}
toReturn.push(info);
@@ -513,6 +531,12 @@ export class CLIPFeatureExtractor extends ImageFeatureExtractor { }
export class ConvNextFeatureExtractor extends ImageFeatureExtractor { }
export class ViTFeatureExtractor extends ImageFeatureExtractor { }
export class MobileViTFeatureExtractor extends ImageFeatureExtractor { }
+export class OwlViTFeatureExtractor extends ImageFeatureExtractor {
+ /** @type {post_process_object_detection} */
+ post_process_object_detection(...args) {
+ return post_process_object_detection(...args);
+ }
+}
export class DeiTFeatureExtractor extends ImageFeatureExtractor { }
export class BeitFeatureExtractor extends ImageFeatureExtractor { }
export class DonutFeatureExtractor extends ImageFeatureExtractor {
@@ -1502,6 +1526,8 @@ export class SpeechT5Processor extends Processor {
}
}
+export class OwlViTProcessor extends Processor { }
+
//////////////////////////////////////////////////
/**
@@ -1539,6 +1565,7 @@ export class AutoProcessor {
WhisperFeatureExtractor,
ViTFeatureExtractor,
MobileViTFeatureExtractor,
+ OwlViTFeatureExtractor,
CLIPFeatureExtractor,
ConvNextFeatureExtractor,
BeitFeatureExtractor,
@@ -1558,6 +1585,7 @@ export class AutoProcessor {
Wav2Vec2ProcessorWithLM,
SamProcessor,
SpeechT5Processor,
+ OwlViTProcessor,
}
/**
diff --git a/src/utils/core.js b/src/utils/core.js
index 7de13625d..9ab11144c 100644
--- a/src/utils/core.js
+++ b/src/utils/core.js
@@ -184,3 +184,18 @@ export function product(...a) {
export function calculateReflectOffset(i, w) {
return Math.abs((i + w) % (2 * w) - w);
}
+
+/**
+ * Helper function to convert list [xmin, xmax, ymin, ymax] into object { "xmin": xmin, ... }
+ * @param {number[]} box The bounding box as a list.
+ * @param {boolean} asInteger Whether to cast to integers.
+ * @returns {Object} The bounding box as an object.
+ */
+export function get_bounding_box(box, asInteger) {
+ if (asInteger) {
+ box = box.map(x => x | 0);
+ }
+ const [xmin, ymin, xmax, ymax] = box;
+
+ return { xmin, ymin, xmax, ymax };
+}
diff --git a/tests/pipelines.test.js b/tests/pipelines.test.js
index d8f5e56c0..ed449c15b 100644
--- a/tests/pipelines.test.js
+++ b/tests/pipelines.test.js
@@ -1326,6 +1326,105 @@ describe('Pipelines', () => {
}, MAX_TEST_EXECUTION_TIME);
});
+ describe('Zero-shot object detection', () => {
+
+ // List all models which will be tested
+ const models = [
+ 'google/owlvit-base-patch32',
+ ];
+
+ it(models[0], async () => {
+ let detector = await pipeline('zero-shot-object-detection', m(models[0]));
+
+
+ // single (default)
+ {
+ let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/astronaut.png';
+ let candidate_labels = ['human face', 'rocket', 'helmet', 'american flag'];
+
+ let output = await detector(url, candidate_labels);
+
+ // let expected = [
+ // {
+ // score: 0.24392342567443848,
+ // label: 'human face',
+ // box: { xmin: 180, ymin: 67, xmax: 274, ymax: 175 }
+ // },
+ // {
+ // score: 0.15129457414150238,
+ // label: 'american flag',
+ // box: { xmin: 0, ymin: 4, xmax: 106, ymax: 513 }
+ // },
+ // {
+ // score: 0.13649864494800568,
+ // label: 'helmet',
+ // box: { xmin: 277, ymin: 337, xmax: 511, ymax: 511 }
+ // },
+ // {
+ // score: 0.10262022167444229,
+ // label: 'rocket',
+ // box: { xmin: 352, ymin: -1, xmax: 463, ymax: 287 }
+ // }
+ // ]
+
+ expect(output.length).toBeGreaterThan(0);
+ for (let cls of output) {
+ expect(typeof cls.score).toBe('number');
+ expect(typeof cls.label).toBe('string');
+ for (let key of ['xmin', 'ymin', 'xmax', 'ymax']) {
+ expect(typeof cls.box[key]).toBe('number');
+ }
+ }
+ }
+
+ // topk + threshold + percentage
+ {
+ let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/beach.png';
+ let candidate_labels = ['hat', 'book', 'sunglasses', 'camera'];
+
+ let output = await detector(url, candidate_labels, {
+ topk: 4,
+ threshold: 0.05,
+ percentage: true,
+ });
+
+ // let expected = [
+ // {
+ // score: 0.1606510728597641,
+ // label: 'sunglasses',
+ // box: { xmin: 347, ymin: 229, xmax: 429, ymax: 264 }
+ // },
+ // {
+ // score: 0.08935828506946564,
+ // label: 'hat',
+ // box: { xmin: 38, ymin: 174, xmax: 258, ymax: 364 }
+ // },
+ // {
+ // score: 0.08530698716640472,
+ // label: 'camera',
+ // box: { xmin: 187, ymin: 350, xmax: 260, ymax: 411 }
+ // },
+ // {
+ // score: 0.08349756896495819,
+ // label: 'book',
+ // box: { xmin: 261, ymin: 280, xmax: 494, ymax: 425 }
+ // }
+ // ]
+
+ expect(output.length).toBeGreaterThan(0);
+ for (let cls of output) {
+ expect(typeof cls.score).toBe('number');
+ expect(typeof cls.label).toBe('string');
+ for (let key of ['xmin', 'ymin', 'xmax', 'ymax']) {
+ expect(typeof cls.box[key]).toBe('number');
+ }
+ }
+ }
+
+ await detector.dispose();
+ }, MAX_TEST_EXECUTION_TIME);
+ });
+
describe('Image-to-image', () => {
// List all models which will be tested
diff --git a/tests/processors.test.js b/tests/processors.test.js
index 9e5d09f18..4c8680106 100644
--- a/tests/processors.test.js
+++ b/tests/processors.test.js
@@ -38,6 +38,7 @@ describe('Processors', () => {
beit: 'microsoft/beit-base-patch16-224-pt22k-ft22k',
detr: 'facebook/detr-resnet-50',
yolos: 'hustvl/yolos-small-300',
+ owlvit: 'google/owlvit-base-patch32',
clip: 'openai/clip-vit-base-patch16',
}
@@ -46,6 +47,7 @@ describe('Processors', () => {
checkerboard_8x8: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/checkerboard_8x8.png',
receipt: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/receipt.png',
tiger: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/tiger.jpg',
+ cats: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cats.jpg',
// grayscale image
skateboard: 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/ml-web-games/skateboard.png',
@@ -238,6 +240,23 @@ describe('Processors', () => {
}
}, MAX_TEST_EXECUTION_TIME);
+
+ // OwlViTFeatureExtractor
+ it(MODELS.owlvit, async () => {
+ const processor = await AutoProcessor.from_pretrained(m(MODELS.owlvit))
+
+ {
+ const image = await load_image(TEST_IMAGES.cats);
+ const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image);
+
+ compare(pixel_values.dims, [1, 3, 768, 768]);
+ compare(avg(pixel_values.data), 0.250620447910435);
+
+ compare(original_sizes, [[480, 640]]);
+ compare(reshaped_input_sizes, [[768, 768]]);
+ }
+ });
+
// CLIPFeatureExtractor
// - tests center crop (do_center_crop=true, crop_size=224)
it(MODELS.clip, async () => {