Skip to content

Commit

Permalink
Add support for processing non-square images w/ `ConvNextFeatureExtra…
Browse files Browse the repository at this point in the history
…ctor` (#503)

* Abstract resize function

* Fix tolerance comparison

* Update `ConvNextFeatureExtractor`

* Update ConvNext unit test
  • Loading branch information
xenova authored Jan 10, 2024
1 parent f6555dc commit 4d1d4d3
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 78 deletions.
201 changes: 128 additions & 73 deletions src/processors.js
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,88 @@ export class ImageFeatureExtractor extends FeatureExtractor {
}
}

/**
* Find the target (width, height) dimension of the output image after
* resizing given the input image and the desired size.
* @param {RawImage} image The image to resize.
* @param {any} size The size to use for resizing the image.
* @returns {[number, number]} The target (width, height) dimension of the output image after resizing.
*/
get_resize_output_image_size(image, size) {
// `size` comes in many forms, so we need to handle them all here:
// 1. `size` is an integer, in which case we resize the image to be a square

const [srcWidth, srcHeight] = image.size;

let shortest_edge;
let longest_edge;

if (this.do_thumbnail) {
// NOTE: custom logic for `Donut` models
const { height, width } = size;
shortest_edge = Math.min(height, width)
}
// Support both formats for backwards compatibility
else if (Number.isInteger(size)) {
shortest_edge = size;
longest_edge = this.config.max_size ?? shortest_edge;

} else if (size !== undefined) {
// Extract known properties from `size`
shortest_edge = size.shortest_edge;
longest_edge = size.longest_edge;
}

// If `longest_edge` and `shortest_edge` are set, maintain aspect ratio and resize to `shortest_edge`
// while keeping the largest dimension <= `longest_edge`
if (shortest_edge !== undefined || longest_edge !== undefined) {
// http://opensourcehacker.com/2011/12/01/calculate-aspect-ratio-conserving-resize-for-images-in-javascript/
// Try resize so that shortest edge is `shortest_edge` (target)
const shortResizeFactor = shortest_edge === undefined
? 1 // If `shortest_edge` is not set, don't upscale
: Math.max(shortest_edge / srcWidth, shortest_edge / srcHeight);

const newWidth = srcWidth * shortResizeFactor;
const newHeight = srcHeight * shortResizeFactor;

// The new width and height might be greater than `longest_edge`, so
// we downscale again to ensure the largest dimension is `longest_edge`
const longResizeFactor = longest_edge === undefined
? 1 // If `longest_edge` is not set, don't downscale
: Math.min(longest_edge / newWidth, longest_edge / newHeight);

// To avoid certain floating point precision issues, we round to 2 decimal places
const finalWidth = Math.floor(Number((newWidth * longResizeFactor).toFixed(2)));
const finalHeight = Math.floor(Number((newHeight * longResizeFactor).toFixed(2)));

return [finalWidth, finalHeight];

} else if (size !== undefined && size.width !== undefined && size.height !== undefined) {
// If `width` and `height` are set, resize to those dimensions
return [size.width, size.height];

} else if (this.size_divisibility !== undefined) {
// Rounds the height and width down to the closest multiple of size_divisibility
const newWidth = Math.floor(srcWidth / this.size_divisibility) * this.size_divisibility;
const newHeight = Math.floor(srcHeight / this.size_divisibility) * this.size_divisibility;
return [newWidth, newHeight];
} else {
throw new Error(`Could not resize image due to unsupported \`this.size\` option in config: ${JSON.stringify(size)}`);
}
}

/**
* Resizes the image.
* @param {RawImage} image The image to resize.
* @returns {Promise<RawImage>} The resized image.
*/
async resize(image) {
const [newWidth, newHeight] = this.get_resize_output_image_size(image, this.size);
return await image.resize(newWidth, newHeight, {
resample: this.resample,
});
}

/**
* @typedef {object} PreprocessedImage
* @property {HeightWidth} original_size The original size of the image.
Expand All @@ -433,8 +515,7 @@ export class ImageFeatureExtractor extends FeatureExtractor {
image = await this.crop_margin(image);
}

const srcWidth = image.width; // original width
const srcHeight = image.height; // original height
const [srcWidth, srcHeight] = image.size; // original image size

// Convert image to RGB if specified in config.
if (do_convert_rgb ?? this.do_convert_rgb) {
Expand All @@ -443,77 +524,12 @@ export class ImageFeatureExtractor extends FeatureExtractor {
image = image.grayscale();
}

// TODO:
// For efficiency reasons, it might be best to merge the resize and center crop operations into one.

// Resize all images
if (this.do_resize) {
// TODO:
// For efficiency reasons, it might be best to merge the resize and center crop operations into one.

// `this.size` comes in many forms, so we need to handle them all here:
// 1. `this.size` is an integer, in which case we resize the image to be a square

let shortest_edge;
let longest_edge;

if (this.do_thumbnail) {
// NOTE: custom logic for `Donut` models
const { height, width } = this.size;
shortest_edge = Math.min(height, width)
}
// Support both formats for backwards compatibility
else if (Number.isInteger(this.size)) {
shortest_edge = this.size;
longest_edge = this.config.max_size ?? shortest_edge;

} else if (this.size !== undefined) {
// Extract known properties from `this.size`
shortest_edge = this.size.shortest_edge;
longest_edge = this.size.longest_edge;
}

// If `longest_edge` and `shortest_edge` are set, maintain aspect ratio and resize to `shortest_edge`
// while keeping the largest dimension <= `longest_edge`
if (shortest_edge !== undefined || longest_edge !== undefined) {
// http://opensourcehacker.com/2011/12/01/calculate-aspect-ratio-conserving-resize-for-images-in-javascript/
// Try resize so that shortest edge is `this.shortest_edge` (target)
const shortResizeFactor = shortest_edge === undefined
? 1 // If `shortest_edge` is not set, don't upscale
: Math.max(shortest_edge / srcWidth, shortest_edge / srcHeight);

const newWidth = srcWidth * shortResizeFactor;
const newHeight = srcHeight * shortResizeFactor;

// The new width and height might be greater than `this.longest_edge`, so
// we downscale again to ensure the largest dimension is `this.longest_edge`
const longResizeFactor = longest_edge === undefined
? 1 // If `longest_edge` is not set, don't downscale
: Math.min(longest_edge / newWidth, longest_edge / newHeight);

// To avoid certain floating point precision issues, we round to 2 decimal places
const finalWidth = Math.floor(Number((newWidth * longResizeFactor).toFixed(2)));
const finalHeight = Math.floor(Number((newHeight * longResizeFactor).toFixed(2)));

// Perform resize
image = await image.resize(finalWidth, finalHeight, {
resample: this.resample,
});

} else if (this.size !== undefined && this.size.width !== undefined && this.size.height !== undefined) {
// If `width` and `height` are set, resize to those dimensions
image = await image.resize(this.size.width, this.size.height, {
resample: this.resample,
});

} else if (this.size_divisibility !== undefined) {
// Rounds the height and width down to the closest multiple of size_divisibility
const newWidth = Math.floor(srcWidth / this.size_divisibility) * this.size_divisibility;
const newHeight = Math.floor(srcHeight / this.size_divisibility) * this.size_divisibility;
image = await image.resize(newWidth, newHeight, {
resample: this.resample,
});

} else {
throw new Error(`Could not resize image due to unsupported \`this.size\` option in config: ${JSON.stringify(this.size)}`);
}
image = await this.resize(image);
}

// Resize the image using thumbnail method.
Expand All @@ -537,7 +553,7 @@ export class ImageFeatureExtractor extends FeatureExtractor {
}

/** @type {HeightWidth} */
let reshaped_input_size = [image.height, image.width];
const reshaped_input_size = [image.height, image.width];

let pixelData = Float32Array.from(image.data);
let imgDims = [image.height, image.width, image.channels];
Expand Down Expand Up @@ -689,7 +705,46 @@ export class GLPNFeatureExtractor extends ImageFeatureExtractor { }
export class CLIPFeatureExtractor extends ImageFeatureExtractor { }
export class ChineseCLIPFeatureExtractor extends ImageFeatureExtractor { }
export class SiglipImageProcessor extends ImageFeatureExtractor { }
export class ConvNextFeatureExtractor extends ImageFeatureExtractor { }
export class ConvNextFeatureExtractor extends ImageFeatureExtractor {
constructor(config) {
super(config);

/**
* Percentage of the image to crop. Only has an effect if this.size < 384.
*/
this.crop_pct = this.config.crop_pct ?? (224 / 256);
}

async resize(image) {
const shortest_edge = this.size?.shortest_edge;
if (shortest_edge === undefined) {
throw new Error(`Size dictionary must contain 'shortest_edge' key.`);
}

if (shortest_edge < 384) {
// maintain same ratio, resizing shortest edge to shortest_edge/crop_pct
const resize_shortest_edge = Math.floor(shortest_edge / this.crop_pct);

const [newWidth, newHeight] = this.get_resize_output_image_size(image, {
shortest_edge: resize_shortest_edge,
});

image = await image.resize(newWidth, newHeight, {
resample: this.resample,
});

// then crop to (shortest_edge, shortest_edge)
image = await image.center_crop(shortest_edge, shortest_edge);
} else {
// warping (no cropping) when evaluated at 384 or larger
image = await image.resize(shortest_edge, shortest_edge, {
resample: this.resample,
});
}

return image;
}
}
export class ConvNextImageProcessor extends ConvNextFeatureExtractor { } // NOTE extends ConvNextFeatureExtractor
export class ViTFeatureExtractor extends ImageFeatureExtractor { }
export class ViTImageProcessor extends ImageFeatureExtractor { }
Expand Down
4 changes: 4 additions & 0 deletions src/utils/image.js
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ export class RawImage {
this.channels = channels;
}

/**
* Returns the size of the image (width, height).
* @returns {[number, number]} The size of the image (width, height).
*/
get size() {
return [this.width, this.height];
}
Expand Down
6 changes: 3 additions & 3 deletions tests/processors.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -123,11 +123,11 @@ describe('Processors', () => {
const image = await load_image(TEST_IMAGES.tiger);
const { pixel_values, original_sizes, reshaped_input_sizes } = await processor(image);

compare(pixel_values.dims, [1, 3, 224, 336]);
compare(avg(pixel_values.data), -0.27736667280600913);
compare(pixel_values.dims, [1, 3, 224, 224]);
compare(avg(pixel_values.data), 0.06262318789958954);

compare(original_sizes, [[408, 612]]);
compare(reshaped_input_sizes, [[224, 336]]);
compare(reshaped_input_sizes, [[224, 224]]);
}
}, MAX_TEST_EXECUTION_TIME);

Expand Down
4 changes: 2 additions & 2 deletions tests/test_utils.js
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ export function compare(val1, val2, tol = 0.1) {
expect(Object.keys(val1)).toHaveLength(Object.keys(val2).length);

for (let key in val1) {
compare(val1[key], val2[key]);
compare(val1[key], val2[key], tol);
}
}

Expand All @@ -63,7 +63,7 @@ export function compare(val1, val2, tol = 0.1) {

if (typeof val1 === 'number' && (!Number.isInteger(val1) || !Number.isInteger(val2))) {
// If both are numbers and at least one of them is not an integer
expect(val1).toBeCloseTo(val2, tol);
expect(val1).toBeCloseTo(val2, -Math.log10(tol));
} else {
// Perform equality test
expect(val1).toEqual(val2);
Expand Down

0 comments on commit 4d1d4d3

Please sign in to comment.