diff --git a/README.md b/README.md index 68bbe3180..4036a0a75 100644 --- a/README.md +++ b/README.md @@ -247,6 +247,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te | [Text-to-Image](https://huggingface.co/tasks/text-to-image) | `text-to-image` | Generates images from input text. | ❌ | | [Visual Question Answering](https://huggingface.co/tasks/visual-question-answering) | `visual-question-answering` | Answering open-ended questions based on an image. | ❌ | | [Zero-Shot Image Classification](https://huggingface.co/tasks/zero-shot-image-classification) | `zero-shot-image-classification` | Classifying images into classes that are unseen during training. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotImageClassificationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=zero-shot-image-classification&library=transformers.js) | +| [Zero-Shot Object Detection](https://huggingface.co/tasks/zero-shot-object-detection) | `zero-shot-object-detection` | Identify objects of classes that are unseen during training. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotObjectDetectionPipeline)
[(models)](https://huggingface.co/models?other=zero-shot-object-detection&library=transformers.js) | #### Reinforcement Learning @@ -300,6 +301,7 @@ You can refine your search by selecting the task you're interested in (e.g., [te 1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. 1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team. 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. +1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. 1. **[ResNet](https://huggingface.co/docs/transformers/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. 1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei. diff --git a/docs/snippets/5_supported-tasks.snippet b/docs/snippets/5_supported-tasks.snippet index dee075808..5699f8f28 100644 --- a/docs/snippets/5_supported-tasks.snippet +++ b/docs/snippets/5_supported-tasks.snippet @@ -59,6 +59,7 @@ | [Text-to-Image](https://huggingface.co/tasks/text-to-image) | `text-to-image` | Generates images from input text. | ❌ | | [Visual Question Answering](https://huggingface.co/tasks/visual-question-answering) | `visual-question-answering` | Answering open-ended questions based on an image. | ❌ | | [Zero-Shot Image Classification](https://huggingface.co/tasks/zero-shot-image-classification) | `zero-shot-image-classification` | Classifying images into classes that are unseen during training. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotImageClassificationPipeline)
[(models)](https://huggingface.co/models?pipeline_tag=zero-shot-image-classification&library=transformers.js) | +| [Zero-Shot Object Detection](https://huggingface.co/tasks/zero-shot-object-detection) | `zero-shot-object-detection` | Identify objects of classes that are unseen during training. | ✅ [(docs)](https://huggingface.co/docs/transformers.js/api/pipelines#module_pipelines.ZeroShotObjectDetectionPipeline)
[(models)](https://huggingface.co/models?other=zero-shot-object-detection&library=transformers.js) | #### Reinforcement Learning diff --git a/docs/snippets/6_supported-models.snippet b/docs/snippets/6_supported-models.snippet index 42c12bd2a..3572a5d95 100644 --- a/docs/snippets/6_supported-models.snippet +++ b/docs/snippets/6_supported-models.snippet @@ -42,6 +42,7 @@ 1. **[MT5](https://huggingface.co/docs/transformers/model_doc/mt5)** (from Google AI) released with the paper [mT5: A massively multilingual pre-trained text-to-text transformer](https://arxiv.org/abs/2010.11934) by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya Siddhant, Aditya Barua, Colin Raffel. 1. **[NLLB](https://huggingface.co/docs/transformers/model_doc/nllb)** (from Meta) released with the paper [No Language Left Behind: Scaling Human-Centered Machine Translation](https://arxiv.org/abs/2207.04672) by the NLLB team. 1. **[OPT](https://huggingface.co/docs/transformers/master/model_doc/opt)** (from Meta AI) released with the paper [OPT: Open Pre-trained Transformer Language Models](https://arxiv.org/abs/2205.01068) by Susan Zhang, Stephen Roller, Naman Goyal, Mikel Artetxe, Moya Chen, Shuohui Chen et al. +1. **[OWL-ViT](https://huggingface.co/docs/transformers/model_doc/owlvit)** (from Google AI) released with the paper [Simple Open-Vocabulary Object Detection with Vision Transformers](https://arxiv.org/abs/2205.06230) by Matthias Minderer, Alexey Gritsenko, Austin Stone, Maxim Neumann, Dirk Weissenborn, Alexey Dosovitskiy, Aravindh Mahendran, Anurag Arnab, Mostafa Dehghani, Zhuoran Shen, Xiao Wang, Xiaohua Zhai, Thomas Kipf, and Neil Houlsby. 1. **[ResNet](https://huggingface.co/docs/transformers/model_doc/resnet)** (from Microsoft Research) released with the paper [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385) by Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun. 1. **[RoBERTa](https://huggingface.co/docs/transformers/model_doc/roberta)** (from Facebook), released together with the paper [RoBERTa: A Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott, Naman Goyal, Jingfei Du, Mandar Joshi, Danqi Chen, Omer Levy, Mike Lewis, Luke Zettlemoyer, Veselin Stoyanov. 1. **[SpeechT5](https://huggingface.co/docs/transformers/model_doc/speecht5)** (from Microsoft Research) released with the paper [SpeechT5: Unified-Modal Encoder-Decoder Pre-Training for Spoken Language Processing](https://arxiv.org/abs/2110.07205) by Junyi Ao, Rui Wang, Long Zhou, Chengyi Wang, Shuo Ren, Yu Wu, Shujie Liu, Tom Ko, Qing Li, Yu Zhang, Zhihua Wei, Yao Qian, Jinyu Li, Furu Wei. diff --git a/scripts/convert.py b/scripts/convert.py index ffc999bfa..5f3c72e56 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -84,7 +84,7 @@ 'vision-encoder-decoder': { 'per_channel': False, 'reduce_range': False, - } + }, } MODELS_WITHOUT_TOKENIZERS = [ @@ -326,6 +326,11 @@ def main(): with open(os.path.join(output_model_folder, 'tokenizer.json'), 'w', encoding='utf-8') as fp: json.dump(tokenizer_json, fp, indent=4) + elif config.model_type == 'owlvit': + # Override default batch size to 1, needed because non-maximum suppression is performed for exporting. + # For more information, see https://github.com/huggingface/optimum/blob/e3b7efb1257c011db907ef40ab340e795cc5684c/optimum/exporters/onnx/model_configs.py#L1028-L1032 + export_kwargs['batch_size'] = 1 + else: pass # TODO diff --git a/scripts/supported_models.py b/scripts/supported_models.py index 79d0282c5..2c7edeef9 100644 --- a/scripts/supported_models.py +++ b/scripts/supported_models.py @@ -355,6 +355,13 @@ # (TODO conversational) 'PygmalionAI/pygmalion-350m', ], + 'owlvit': [ + # Object detection (Zero-shot object detection) + # NOTE: Exported with --batch_size 1 + 'google/owlvit-base-patch32', + 'google/owlvit-base-patch16', + 'google/owlvit-large-patch14', + ], 'resnet': [ # Image classification 'microsoft/resnet-18', diff --git a/src/models.js b/src/models.js index 94c2ba9ae..d56d4b4f6 100644 --- a/src/models.js +++ b/src/models.js @@ -3200,6 +3200,12 @@ export class MobileViTForImageClassification extends MobileViTPreTrainedModel { ////////////////////////////////////////////////// +////////////////////////////////////////////////// +export class OwlViTPreTrainedModel extends PreTrainedModel { } +export class OwlViTModel extends OwlViTPreTrainedModel { } +export class OwlViTForObjectDetection extends OwlViTPreTrainedModel { } +////////////////////////////////////////////////// + ////////////////////////////////////////////////// // Beit Models export class BeitPreTrainedModel extends PreTrainedModel { } @@ -4010,6 +4016,7 @@ const MODEL_MAPPING_NAMES_ENCODER_ONLY = new Map([ ['detr', ['DetrModel', DetrModel]], ['vit', ['ViTModel', ViTModel]], ['mobilevit', ['MobileViTModel', MobileViTModel]], + ['owlvit', ['OwlViTModel', OwlViTModel]], ['beit', ['BeitModel', BeitModel]], ['deit', ['DeiTModel', DeiTModel]], ['resnet', ['ResNetModel', ResNetModel]], @@ -4171,6 +4178,10 @@ const MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES = new Map([ ['yolos', ['YolosForObjectDetection', YolosForObjectDetection]], ]); +const MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES = new Map([ + ['owlvit', ['OwlViTForObjectDetection', OwlViTForObjectDetection]], +]); + const MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES = new Map([ ['detr', ['DetrForSegmentation', DetrForSegmentation]], ]); @@ -4210,6 +4221,7 @@ const MODEL_CLASS_TYPE_MAPPING = [ [MODEL_FOR_IMAGE_SEGMENTATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], + [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_MASK_GENERATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_CTC_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], [MODEL_FOR_AUDIO_CLASSIFICATION_MAPPING_NAMES, MODEL_TYPES.EncoderOnly], @@ -4380,6 +4392,11 @@ export class AutoModelForObjectDetection extends PretrainedMixin { static MODEL_CLASS_MAPPINGS = [MODEL_FOR_OBJECT_DETECTION_MAPPING_NAMES]; } +export class AutoModelForZeroShotObjectDetection extends PretrainedMixin { + static MODEL_CLASS_MAPPINGS = [MODEL_FOR_ZERO_SHOT_OBJECT_DETECTION_MAPPING_NAMES]; +} + + /** * Helper class which is used to instantiate pretrained object detection models with the `from_pretrained` function. * The chosen model class is determined by the type specified in the model config. diff --git a/src/pipelines.js b/src/pipelines.js index 0c2e0b3a3..7bb125e5d 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -33,6 +33,7 @@ import { AutoModelForImageClassification, AutoModelForImageSegmentation, AutoModelForObjectDetection, + AutoModelForZeroShotObjectDetection, AutoModelForDocumentQuestionAnswering, AutoModelForImageToImage, // AutoModelForTextToWaveform, @@ -50,6 +51,7 @@ import { dispatchCallback, pop, product, + get_bounding_box, } from './utils/core.js'; import { softmax, @@ -1753,28 +1755,148 @@ export class ObjectDetectionPipeline extends Pipeline { return { score: batch.scores[i], label: id2label[batch.classes[i]], - box: this._get_bounding_box(box, !percentage), + box: get_bounding_box(box, !percentage), } }) }) return isBatched ? result : result[0]; } +} + +/** + * Zero-shot object detection pipeline. This pipeline predicts bounding boxes of + * objects when you provide an image and a set of `candidate_labels`. + * + * **Example:** Zero-shot object detection w/ `Xenova/clip-vit-base-patch32`. + * ```javascript + * let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/astronaut.png'; + * let candidate_labels = ['human face', 'rocket', 'helmet', 'american flag']; + * let detector = await pipeline('zero-shot-object-detection', 'Xenova/owlvit-base-patch32'); + * let output = await detector(url, candidate_labels); + * // [ + * // { + * // score: 0.24392342567443848, + * // label: 'human face', + * // box: { xmin: 180, ymin: 67, xmax: 274, ymax: 175 } + * // }, + * // { + * // score: 0.15129457414150238, + * // label: 'american flag', + * // box: { xmin: 0, ymin: 4, xmax: 106, ymax: 513 } + * // }, + * // { + * // score: 0.13649864494800568, + * // label: 'helmet', + * // box: { xmin: 277, ymin: 337, xmax: 511, ymax: 511 } + * // }, + * // { + * // score: 0.10262022167444229, + * // label: 'rocket', + * // box: { xmin: 352, ymin: -1, xmax: 463, ymax: 287 } + * // } + * // ] + * ``` + * + * **Example:** Zero-shot object detection w/ `Xenova/clip-vit-base-patch32` (returning top 4 matches and setting a threshold). + * ```javascript + * let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/beach.png'; + * let candidate_labels = ['hat', 'book', 'sunglasses', 'camera']; + * let detector = await pipeline('zero-shot-object-detection', 'Xenova/owlvit-base-patch32'); + * let output = await detector(url, candidate_labels, { topk: 4, threshold: 0.05 }); + * // [ + * // { + * // score: 0.1606510728597641, + * // label: 'sunglasses', + * // box: { xmin: 347, ymin: 229, xmax: 429, ymax: 264 } + * // }, + * // { + * // score: 0.08935828506946564, + * // label: 'hat', + * // box: { xmin: 38, ymin: 174, xmax: 258, ymax: 364 } + * // }, + * // { + * // score: 0.08530698716640472, + * // label: 'camera', + * // box: { xmin: 187, ymin: 350, xmax: 260, ymax: 411 } + * // }, + * // { + * // score: 0.08349756896495819, + * // label: 'book', + * // box: { xmin: 261, ymin: 280, xmax: 494, ymax: 425 } + * // } + * // ] + * ``` + */ +export class ZeroShotObjectDetectionPipeline extends Pipeline { /** - * Helper function to convert list [xmin, xmax, ymin, ymax] into object { "xmin": xmin, ... } - * @param {number[]} box The bounding box as a list. - * @param {boolean} asInteger Whether to cast to integers. - * @returns {Object} The bounding box as an object. - * @private + * Create a new ZeroShotObjectDetectionPipeline. + * @param {Object} options An object containing the following properties: + * @param {string} [options.task] The task of the pipeline. Useful for specifying subtasks. + * @param {PreTrainedModel} [options.model] The model to use. + * @param {PreTrainedTokenizer} [options.tokenizer] The tokenizer to use. + * @param {Processor} [options.processor] The processor to use. */ - _get_bounding_box(box, asInteger) { - if (asInteger) { - box = box.map(x => x | 0); + constructor(options) { + super(options); + } + + /** + * Detect objects (bounding boxes & classes) in the image(s) passed as inputs. + * @param {Array} images The input images. + * @param {string[]} candidate_labels What the model should recognize in the image. + * @param {Object} options The options for the classification. + * @param {number} [options.threshold] The probability necessary to make a prediction. + * @param {number} [options.topk] The number of top predictions that will be returned by the pipeline. + * If the provided number is `null` or higher than the number of predictions available, it will default + * to the number of predictions. + * @param {boolean} [options.percentage=false] Whether to return the boxes coordinates in percentage (true) or in pixels (false). + * @returns {Promise} An array of classifications for each input image or a single classification object if only one input image is provided. + */ + async _call(images, candidate_labels, { + threshold = 0.1, + topk = null, + percentage = false, + } = {}) { + const isBatched = Array.isArray(images); + images = await prepareImages(images); + + // Run tokenization + const text_inputs = this.tokenizer(candidate_labels, { + padding: true, + truncation: true + }); + + // Run processor + const model_inputs = await this.processor(images); + + // Since non-maximum suppression is performed for exporting, we need to + // process each image separately. For more information, see: + // https://github.com/huggingface/optimum/blob/e3b7efb1257c011db907ef40ab340e795cc5684c/optimum/exporters/onnx/model_configs.py#L1028-L1032 + const toReturn = []; + for (let i = 0; i < images.length; ++i) { + const image = images[i]; + const imageSize = [[image.height, image.width]]; + const pixel_values = model_inputs.pixel_values[i].unsqueeze_(0); + + // Run model with both text and pixel inputs + const output = await this.model({ ...text_inputs, pixel_values }); + + // @ts-ignore + const processed = this.processor.feature_extractor.post_process_object_detection(output, threshold, imageSize, true)[0]; + let result = processed.boxes.map((box, i) => ({ + score: processed.scores[i], + label: candidate_labels[processed.classes[i]], + box: get_bounding_box(box, !percentage), + })).sort((a, b) => b.score - a.score); + if (topk !== null) { + result = result.slice(0, topk); + } + toReturn.push(result) } - const [xmin, ymin, xmax, ymax] = box; - return { xmin, ymin, xmax, ymax }; + return isBatched ? toReturn : toReturn[0]; } } @@ -2187,6 +2309,18 @@ const SUPPORTED_TASKS = { }, "type": "multimodal", }, + "zero-shot-object-detection": { + "tokenizer": AutoTokenizer, + "pipeline": ZeroShotObjectDetectionPipeline, + "model": AutoModelForZeroShotObjectDetection, + "processor": AutoProcessor, + "default": { + // TODO: replace with original + // "model": "google/owlvit-base-patch32", + "model": "Xenova/owlvit-base-patch32", + }, + "type": "multimodal", + }, "document-question-answering": { "tokenizer": AutoTokenizer, "pipeline": DocumentQuestionAnsweringPipeline, @@ -2261,6 +2395,7 @@ const TASK_ALIASES = { * - `"translation_xx_to_yy"`: will return a `TranslationPipeline`. * - `"zero-shot-classification"`: will return a `ZeroShotClassificationPipeline`. * - `"zero-shot-image-classification"`: will return a `ZeroShotImageClassificationPipeline`. + * - `"zero-shot-object-detection"`: will return a `ZeroShotObjectDetectionPipeline`. * @param {string} [model=null] The name of the pre-trained model to use. If not specified, the default model for the task will be used. * @param {import('./utils/hub.js').PretrainedOptions} [options] Optional parameters for the pipeline. * @returns {Promise} A Pipeline object for the specified task. diff --git a/src/processors.js b/src/processors.js index dd7130156..07d47e7ac 100644 --- a/src/processors.js +++ b/src/processors.js @@ -64,9 +64,13 @@ function center_to_corners_format([centerX, centerY, width, height]) { * @param {Object} outputs The outputs of the model that must be post-processed * @param {Tensor} outputs.logits The logits * @param {Tensor} outputs.pred_boxes The predicted boxes. + * @param {number} [threshold=0.5] The threshold to use for the scores. + * @param {number[][]} [target_sizes=null] The sizes of the original images. + * @param {boolean} [is_zero_shot=false] Whether zero-shot object detection was performed. * @return {Object[]} An array of objects containing the post-processed outputs. + * @private */ -function post_process_object_detection(outputs, threshold = 0.5, target_sizes = null) { +function post_process_object_detection(outputs, threshold = 0.5, target_sizes = null, is_zero_shot = false) { const out_logits = outputs.logits; const out_bbox = outputs.pred_boxes; const [batch_size, num_boxes, num_classes] = out_logits.dims; @@ -88,19 +92,33 @@ function post_process_object_detection(outputs, threshold = 0.5, target_sizes = for (let j = 0; j < num_boxes; ++j) { let logit = logits[j]; - // Get most probable class - let maxIndex = max(logit.data)[1]; + let indices = []; + let probs; + if (is_zero_shot) { + // Get indices of classes with high enough probability + probs = logit.sigmoid().data; + for (let k = 0; k < probs.length; ++k) { + if (probs[k] > threshold) { + indices.push(k); + } + } - if (maxIndex === num_classes - 1) { - // This is the background class, skip it - continue; + } else { + // Get most probable class + let maxIndex = max(logit.data)[1]; + + if (maxIndex === num_classes - 1) { + // This is the background class, skip it + continue; + } + indices.push(maxIndex); + + // Compute softmax over classes + probs = softmax(logit.data); } - // Compute softmax over classes - let probs = softmax(logit.data); + for (const index of indices) { - let score = probs[maxIndex]; - if (score > threshold) { // Some class has a high enough probability /** @type {number[]} */ let box = bbox[j].data; @@ -112,8 +130,8 @@ function post_process_object_detection(outputs, threshold = 0.5, target_sizes = } info.boxes.push(box); - info.classes.push(maxIndex); - info.scores.push(score); + info.classes.push(index); + info.scores.push(probs[index]); } } toReturn.push(info); @@ -513,6 +531,12 @@ export class CLIPFeatureExtractor extends ImageFeatureExtractor { } export class ConvNextFeatureExtractor extends ImageFeatureExtractor { } export class ViTFeatureExtractor extends ImageFeatureExtractor { } export class MobileViTFeatureExtractor extends ImageFeatureExtractor { } +export class OwlViTFeatureExtractor extends ImageFeatureExtractor { + /** @type {post_process_object_detection} */ + post_process_object_detection(...args) { + return post_process_object_detection(...args); + } +} export class DeiTFeatureExtractor extends ImageFeatureExtractor { } export class BeitFeatureExtractor extends ImageFeatureExtractor { } export class DonutFeatureExtractor extends ImageFeatureExtractor { @@ -1502,6 +1526,8 @@ export class SpeechT5Processor extends Processor { } } +export class OwlViTProcessor extends Processor { } + ////////////////////////////////////////////////// /** @@ -1539,6 +1565,7 @@ export class AutoProcessor { WhisperFeatureExtractor, ViTFeatureExtractor, MobileViTFeatureExtractor, + OwlViTFeatureExtractor, CLIPFeatureExtractor, ConvNextFeatureExtractor, BeitFeatureExtractor, @@ -1558,6 +1585,7 @@ export class AutoProcessor { Wav2Vec2ProcessorWithLM, SamProcessor, SpeechT5Processor, + OwlViTProcessor, } /** diff --git a/src/utils/core.js b/src/utils/core.js index 7de13625d..9ab11144c 100644 --- a/src/utils/core.js +++ b/src/utils/core.js @@ -184,3 +184,18 @@ export function product(...a) { export function calculateReflectOffset(i, w) { return Math.abs((i + w) % (2 * w) - w); } + +/** + * Helper function to convert list [xmin, xmax, ymin, ymax] into object { "xmin": xmin, ... } + * @param {number[]} box The bounding box as a list. + * @param {boolean} asInteger Whether to cast to integers. + * @returns {Object} The bounding box as an object. + */ +export function get_bounding_box(box, asInteger) { + if (asInteger) { + box = box.map(x => x | 0); + } + const [xmin, ymin, xmax, ymax] = box; + + return { xmin, ymin, xmax, ymax }; +} diff --git a/tests/pipelines.test.js b/tests/pipelines.test.js index d8f5e56c0..ed449c15b 100644 --- a/tests/pipelines.test.js +++ b/tests/pipelines.test.js @@ -1326,6 +1326,105 @@ describe('Pipelines', () => { }, MAX_TEST_EXECUTION_TIME); }); + describe('Zero-shot object detection', () => { + + // List all models which will be tested + const models = [ + 'google/owlvit-base-patch32', + ]; + + it(models[0], async () => { + let detector = await pipeline('zero-shot-object-detection', m(models[0])); + + + // single (default) + { + let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/astronaut.png'; + let candidate_labels = ['human face', 'rocket', 'helmet', 'american flag']; + + let output = await detector(url, candidate_labels); + + // let expected = [ + // { + // score: 0.24392342567443848, + // label: 'human face', + // box: { xmin: 180, ymin: 67, xmax: 274, ymax: 175 } + // }, + // { + // score: 0.15129457414150238, + // label: 'american flag', + // box: { xmin: 0, ymin: 4, xmax: 106, ymax: 513 } + // }, + // { + // score: 0.13649864494800568, + // label: 'helmet', + // box: { xmin: 277, ymin: 337, xmax: 511, ymax: 511 } + // }, + // { + // score: 0.10262022167444229, + // label: 'rocket', + // box: { xmin: 352, ymin: -1, xmax: 463, ymax: 287 } + // } + // ] + + expect(output.length).toBeGreaterThan(0); + for (let cls of output) { + expect(typeof cls.score).toBe('number'); + expect(typeof cls.label).toBe('string'); + for (let key of ['xmin', 'ymin', 'xmax', 'ymax']) { + expect(typeof cls.box[key]).toBe('number'); + } + } + } + + // topk + threshold + percentage + { + let url = 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/beach.png'; + let candidate_labels = ['hat', 'book', 'sunglasses', 'camera']; + + let output = await detector(url, candidate_labels, { + topk: 4, + threshold: 0.05, + percentage: true, + }); + + // let expected = [ + // { + // score: 0.1606510728597641, + // label: 'sunglasses', + // box: { xmin: 347, ymin: 229, xmax: 429, ymax: 264 } + // }, + // { + // score: 0.08935828506946564, + // label: 'hat', + // box: { xmin: 38, ymin: 174, xmax: 258, ymax: 364 } + // }, + // { + // score: 0.08530698716640472, + // label: 'camera', + // box: { xmin: 187, ymin: 350, xmax: 260, ymax: 411 } + // }, + // { + // score: 0.08349756896495819, + // label: 'book', + // box: { xmin: 261, ymin: 280, xmax: 494, ymax: 425 } + // } + // ] + + expect(output.length).toBeGreaterThan(0); + for (let cls of output) { + expect(typeof cls.score).toBe('number'); + expect(typeof cls.label).toBe('string'); + for (let key of ['xmin', 'ymin', 'xmax', 'ymax']) { + expect(typeof cls.box[key]).toBe('number'); + } + } + } + + await detector.dispose(); + }, MAX_TEST_EXECUTION_TIME); + }); + describe('Image-to-image', () => { // List all models which will be tested diff --git a/tests/processors.test.js b/tests/processors.test.js index 9e5d09f18..4c8680106 100644 --- a/tests/processors.test.js +++ b/tests/processors.test.js @@ -38,6 +38,7 @@ describe('Processors', () => { beit: 'microsoft/beit-base-patch16-224-pt22k-ft22k', detr: 'facebook/detr-resnet-50', yolos: 'hustvl/yolos-small-300', + owlvit: 'google/owlvit-base-patch32', clip: 'openai/clip-vit-base-patch16', } @@ -46,6 +47,7 @@ describe('Processors', () => { checkerboard_8x8: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/checkerboard_8x8.png', receipt: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/receipt.png', tiger: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/tiger.jpg', + cats: 'https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cats.jpg', // grayscale image skateboard: 'https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/ml-web-games/skateboard.png', @@ -238,6 +240,23 @@ describe('Processors', () => { } }, MAX_TEST_EXECUTION_TIME); + + // OwlViTFeatureExtractor + it(MODELS.owlvit, async () => { + const processor = await AutoProcessor.from_pretrained(m(MODELS.owlvit)) + + { + const image = await load_image(TEST_IMAGES.cats); + const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image); + + compare(pixel_values.dims, [1, 3, 768, 768]); + compare(avg(pixel_values.data), 0.250620447910435); + + compare(original_sizes, [[480, 640]]); + compare(reshaped_input_sizes, [[768, 768]]); + } + }); + // CLIPFeatureExtractor // - tests center crop (do_center_crop=true, crop_size=224) it(MODELS.clip, async () => {