Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functionality to split RawImage into channels; Update slice documentation and tests #978

Merged
merged 14 commits into from
Nov 26, 2024
Merged
42 changes: 29 additions & 13 deletions src/utils/image.js
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@

/**
* @file Helper module for image processing.
*
* These functions and classes are only used internally,
* @file Helper module for image processing.
*
* These functions and classes are only used internally,
* meaning an end-user shouldn't need to access anything here.
*
*
* @module utils/image
*/

Expand Down Expand Up @@ -91,7 +91,7 @@ export class RawImage {
this.channels = channels;
}

/**
/**
* Returns the size of the image (width, height).
* @returns {[number, number]} The size of the image (width, height).
*/
Expand All @@ -101,9 +101,9 @@ export class RawImage {

/**
* Helper method for reading an image from a variety of input types.
* @param {RawImage|string|URL} input
* @param {RawImage|string|URL} input
* @returns The image object.
*
*
* **Example:** Read image from a URL.
* ```javascript
* let image = await RawImage.read('https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/football-match.jpg');
Expand Down Expand Up @@ -181,7 +181,7 @@ export class RawImage {

/**
* Helper method to create a new Image from a tensor
* @param {Tensor} tensor
* @param {Tensor} tensor
*/
static fromTensor(tensor, channel_format = 'CHW') {
if (tensor.dims.length !== 3) {
Expand Down Expand Up @@ -355,7 +355,7 @@ export class RawImage {
case 'nearest':
case 'bilinear':
case 'bicubic':
// Perform resizing using affine transform.
// Perform resizing using affine transform.
// This matches how the python Pillow library does it.
img = img.affine([width / this.width, 0, 0, height / this.height], {
interpolator: resampleMethod
Expand All @@ -368,7 +368,7 @@ export class RawImage {
img = img.resize({
width, height,
fit: 'fill',
kernel: 'lanczos3', // PIL Lanczos uses a kernel size of 3
kernel: 'lanczos3', // PIL Lanczos uses a kernel size of 3
});
break;

Expand Down Expand Up @@ -447,7 +447,7 @@ export class RawImage {
// Create canvas object for this image
const canvas = this.toCanvas();

// Create a new canvas of the desired size. This is needed since if the
// Create a new canvas of the desired size. This is needed since if the
// image is too small, we need to pad it with black pixels.
const ctx = createCanvasFunction(crop_width, crop_height).getContext('2d');

Expand Down Expand Up @@ -495,7 +495,7 @@ export class RawImage {
// Create canvas object for this image
const canvas = this.toCanvas();

// Create a new canvas of the desired size. This is needed since if the
// Create a new canvas of the desired size. This is needed since if the
// image is too small, we need to pad it with black pixels.
const ctx = createCanvasFunction(crop_width, crop_height).getContext('2d');

Expand Down Expand Up @@ -637,6 +637,22 @@ export class RawImage {
return clonedCanvas;
}

/**
* Splits the image data into separate channels.
* The number of elements in the array corresponds to the number of channels in the image.
* @returns {Array<Uint8ClampedArray|Uint8Array>} An array containing separate channel data.
*/
toChannels() {
// Split each channel into a separate entry in a `channels` array.
const channels = [];
for (let c = 0; c < this.channels; c++) {
channels.push(
this.data.filter((_, i) => i % this.channels === c)
);
}
return channels;
}
xenova marked this conversation as resolved.
Show resolved Hide resolved

/**
* Helper method to update the image data.
* @param {Uint8ClampedArray} data The new image data.
Expand Down Expand Up @@ -742,4 +758,4 @@ export class RawImage {
}
});
}
}
}
34 changes: 33 additions & 1 deletion src/utils/tensor.js
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,43 @@ export class Tensor {
return this;
}

/**
* Creates a deep copy of the current Tensor.
* @returns {Tensor} A new Tensor with the same type, data, and dimensions as the original.
*/
clone() {
return new Tensor(this.type, this.data.slice(), this.dims.slice());
}

/**
* Performs a slice operation on the Tensor along specified dimensions.
*
* Consider a Tensor that has a dimension of [4, 7]:
* ```
* [ 1, 2, 3, 4, 5, 6, 7]
* [ 8, 9, 10, 11, 12, 13, 14]
* [15, 16, 17, 18, 19, 20, 21]
* [22, 23, 24, 25, 26, 27, 28]
* ```
* We can slice against the two dims of row and column, for instance in this
* case we can start at the second element, and return to the second last,
* like this:
* ```
* tensor.slice([1, -1], [1, -1]);
* ```
* which would return:
* ```
* [ 9, 10, 11, 12, 13 ]
* [ 16, 17, 18, 19, 20 ]
* ```
*
* @param {...(number|number[]|null)} slices - The slice specifications for each dimension.
* - If a number is given, then a single element is selected.
* - If an array of two numbers is given, then a range of elements [start, end (exclusive)] is selected.
* - If null is given, then the entire dimension is selected.
* @returns {Tensor} A new Tensor containing the selected elements.
* @throws {Error} If the slice input is invalid.
*/
slice(...slices) {
// This allows for slicing with ranges and numbers
const newTensorDims = [];
Expand Down Expand Up @@ -413,7 +446,6 @@ export class Tensor {
data[i] = this_data[originalIndex];
}
return new Tensor(this.type, data, newTensorDims);

}

/**
Expand Down
29 changes: 29 additions & 0 deletions tests/utils/tensor.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,35 @@ describe("Tensor operations", () => {
// TODO add tests for errors
});

describe("slice", () => {
it("should return a given row dim", async () => {
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
const t2 = t1.slice(1);
const target = new Tensor("float32", [3, 4], [2]);

compare(t2, target);
});

it("should return a range of rows", async () => {
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
// The end index is not included.
const t2 = t1.slice([1, 3]);
const target = new Tensor("float32", [3, 4, 5, 6], [2, 2]);

compare(t2, target);
});

it("should return a crop", async () => {
// Create 21 nodes.
const t1 = new Tensor("float32", Array.from({ length: 28 }, (v, i) => v = ++i), [4, 7]);
const t2 = t1.slice([1, -1], [1, -1]);

const target = new Tensor("float32", [9, 10, 11, 12, 13, 16, 17, 18, 19, 20], [2, 5]);

compare(t2, target);
});
});
xenova marked this conversation as resolved.
Show resolved Hide resolved

describe("stack", () => {
const t1 = new Tensor("float32", [0, 1, 2, 3, 4, 5], [1, 3, 2]);

Expand Down
29 changes: 29 additions & 0 deletions tests/utils/utils.test.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { AutoProcessor, hamming, hanning, mel_filter_bank } from "../../src/transformers.js";
import { getFile } from "../../src/utils/hub.js";
import { RawImage } from "../../src/utils/image.js";

import { MAX_TEST_EXECUTION_TIME } from "../init.js";
import { compare } from "../test_utils.js";
Expand Down Expand Up @@ -59,4 +60,32 @@ describe("Utilities", () => {
expect(await data.text()).toBe("Hello, world!");
});
});

describe("Image utilities", () => {
it("Can split image into separate channels", async () => {
const url = './examples/demo-site/public/images/cats.jpg';
const image = await RawImage.fromURL(url);
// Rather than test the entire image, we'll just test the first 3 pixels;
// ensuring that these match.
const image_data = image.toChannels().map(c => c.slice(0, 3));

const target = [
new Uint8Array([140, 144, 145]), // Reds
new Uint8Array([25, 25, 25]), // Greens
new Uint8Array([56, 67, 73]), // Blues
];

compare (image_data, target);
});

it("Can splits channels for grayscale", async () => {
const url = './examples/demo-site/public/images/cats.jpg';
const image = (await RawImage.fromURL(url)).grayscale();

const image_data = image.toChannels().map(c => c.slice(0, 3));
const target = [new Uint8Array([63, 65, 66])];

compare (image_data, target);
});
});
xenova marked this conversation as resolved.
Show resolved Hide resolved
});