Skip to content

Commit

Permalink
LSP + other typings for ImageFeatureExtractor / SamImageProcessor / D…
Browse files Browse the repository at this point in the history
…etrFeatureExtractor
  • Loading branch information
kungfooman committed Sep 14, 2023
1 parent 4cff2f5 commit c441353
Showing 1 changed file with 41 additions and 12 deletions.
53 changes: 41 additions & 12 deletions src/processors.js
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,15 @@ export class FeatureExtractor extends Callable {
}
}

/**
* @typedef {object} ImageFeatureExtractorResult
* @property {Tensor} pixel_values
* @property {HeightWidth[]} original_sizes Array of two-dimensional tuples
* like [[480, 640]].
* @property {HeightWidth[]} reshaped_input_sizes Array of two-dimensional
* tuples like [[1000, 1330]].
*/

/**
* Feature extractor for image models.
*
Expand Down Expand Up @@ -341,22 +350,23 @@ export class ImageFeatureExtractor extends FeatureExtractor {
* Calls the feature extraction process on an array of image
* URLs, preprocesses each image, and concatenates the resulting
* features into a single Tensor.
* @param {any} images The URL(s) of the image(s) to extract features from.
* @returns {Promise<Object>} An object containing the concatenated pixel values (and other metadata) of the preprocessed images.
* @param {any[]} images The URL(s) of the image(s) to extract features from.
* @param {...any} unused Only used to fix Liskov Substitution Principle errors.
* @returns {Promise<ImageFeatureExtractorResult>} An object containing the concatenated pixel values (and other metadata) of the preprocessed images.
*/
async _call(images) {
async _call(images, ...unused) {
if (!Array.isArray(images)) {
images = [images];
}

let imageData = await Promise.all(images.map(x => this.preprocess(x)));
/** @type {PreprocessedImage[]} */
const imageData = await Promise.all(images.map(x => this.preprocess(x)));

// TODO:

// Concatenate pixel values
// TEMP: Add batch dimension so that concat works
imageData.forEach(x => x.pixel_values.dims = [1, ...x.pixel_values.dims]);
let pixel_values = cat(imageData.map(x => x.pixel_values));
const pixel_values = cat(imageData.map(x => x.pixel_values));

return {
pixel_values: pixel_values,
Expand All @@ -377,6 +387,12 @@ export class MobileViTFeatureExtractor extends ImageFeatureExtractor { }
export class DeiTFeatureExtractor extends ImageFeatureExtractor { }
export class BeitFeatureExtractor extends ImageFeatureExtractor { }

/**
* @typedef {object} DetrFeatureExtractorResultProps
* @property {Tensor} pixel_mask
* @typedef {ImageFeatureExtractorResult & DetrFeatureExtractorResultProps} DetrFeatureExtractorResult
*/

/**
* Detr Feature Extractor.
*
Expand All @@ -386,24 +402,24 @@ export class DetrFeatureExtractor extends ImageFeatureExtractor {
/**
* Calls the feature extraction process on an array of image URLs, preprocesses
* each image, and concatenates the resulting features into a single Tensor.
* @param {any} urls The URL(s) of the image(s) to extract features from.
* @returns {Promise<Object>} An object containing the concatenated pixel values of the preprocessed images.
* @param {any[]} urls The URL(s) of the image(s) to extract features from.
* @returns {Promise<DetrFeatureExtractorResult>} An object containing the concatenated pixel values of the preprocessed images.
*/
async _call(urls) {
let result = await super._call(urls);
const result = await super._call(urls);

// TODO support differently-sized images, for now assume all images are the same size.
// TODO support different mask sizes (not just 64x64)
// Currently, just fill pixel mask with 1s
let maskSize = [result.pixel_values.dims[0], 64, 64];
result.pixel_mask = new Tensor(
const maskSize = [result.pixel_values.dims[0], 64, 64];
const pixel_mask = new Tensor(
'int64',
// TODO: fix error below
new BigInt64Array(maskSize.reduce((a, b) => a * b)).fill(1n),
maskSize
);

return result;
return { ...result, pixel_mask };
}

/**
Expand Down Expand Up @@ -703,7 +719,20 @@ export class YolosFeatureExtractor extends ImageFeatureExtractor {
}
}

/**
* @typedef {object} SamImageProcessorResult
* @property {Tensor} pixel_values
* @property {HeightWidth[]} original_sizes
* @property {HeightWidth[]} reshaped_input_sizes
* @property {Tensor} input_points
*/

export class SamImageProcessor extends ImageFeatureExtractor {
/**
* @param {any[]} images The URL(s) of the image(s) to extract features from.
* @param {*} input_points
* @returns {Promise<SamImageProcessorResult>}
*/
async _call(images, input_points) {
let {
pixel_values,
Expand Down

0 comments on commit c441353

Please sign in to comment.