Skip to content

Commit

Permalink
Add RawImage.split() function to split images into channels; Improv…
Browse files Browse the repository at this point in the history
…ed documentation and tests (#978)

* Add tests for original slice method.

* Add vslice and tests to retrieve the entire length of a column.

* Add a test for slicing every other column.

* Add method to return each channel as a separate array.

* Add documentation.
Fix TypeScript error for unsure type.

* Remove vslice as it doesn't work as it should.
Update documentation.
Update tests.

* Optimize `RawImage.split()` function

* Use dummy test image

* Update tensor unit tests

* Wrap `.split()` result in `RawImage`

* Update JSDoc

* Update JSDoc

* Update comments

---------

Co-authored-by: Joshua Lochner <[email protected]>
  • Loading branch information
BritishWerewolf and xenova authored Nov 26, 2024
1 parent 1768b8b commit 8896dc7
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 1 deletion.
30 changes: 30 additions & 0 deletions src/utils/image.js
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,36 @@ export class RawImage {
return clonedCanvas;
}

/**
* Split this image into individual bands. This method returns an array of individual image bands from an image.
* For example, splitting an "RGB" image creates three new images each containing a copy of one of the original bands (red, green, blue).
*
* Inspired by PIL's `Image.split()` [function](https://pillow.readthedocs.io/en/latest/reference/Image.html#PIL.Image.Image.split).
* @returns {RawImage[]} An array containing bands.
*/
split() {
const { data, width, height, channels } = this;

/** @type {typeof Uint8Array | typeof Uint8ClampedArray} */
const data_type = /** @type {any} */(data.constructor);
const per_channel_length = data.length / channels;

// Pre-allocate buffers for each channel
const split_data = Array.from(
{ length: channels },
() => new data_type(per_channel_length),
);

// Write pixel data
for (let i = 0; i < per_channel_length; ++i) {
const data_offset = channels * i;
for (let j = 0; j < channels; ++j) {
split_data[j][i] = data[data_offset + j];
}
}
return split_data.map((data) => new RawImage(data, width, height, 1));
}

/**
* Helper method to update the image data.
* @param {Uint8ClampedArray} data The new image data.
Expand Down
34 changes: 33 additions & 1 deletion src/utils/tensor.js
Original file line number Diff line number Diff line change
Expand Up @@ -340,10 +340,43 @@ export class Tensor {
return this;
}

/**
* Creates a deep copy of the current Tensor.
* @returns {Tensor} A new Tensor with the same type, data, and dimensions as the original.
*/
clone() {
return new Tensor(this.type, this.data.slice(), this.dims.slice());
}

/**
* Performs a slice operation on the Tensor along specified dimensions.
*
* Consider a Tensor that has a dimension of [4, 7]:
* ```
* [ 1, 2, 3, 4, 5, 6, 7]
* [ 8, 9, 10, 11, 12, 13, 14]
* [15, 16, 17, 18, 19, 20, 21]
* [22, 23, 24, 25, 26, 27, 28]
* ```
* We can slice against the two dims of row and column, for instance in this
* case we can start at the second element, and return to the second last,
* like this:
* ```
* tensor.slice([1, -1], [1, -1]);
* ```
* which would return:
* ```
* [ 9, 10, 11, 12, 13 ]
* [ 16, 17, 18, 19, 20 ]
* ```
*
* @param {...(number|number[]|null)} slices The slice specifications for each dimension.
* - If a number is given, then a single element is selected.
* - If an array of two numbers is given, then a range of elements [start, end (exclusive)] is selected.
* - If null is given, then the entire dimension is selected.
* @returns {Tensor} A new Tensor containing the selected elements.
* @throws {Error} If the slice input is invalid.
*/
slice(...slices) {
// This allows for slicing with ranges and numbers
const newTensorDims = [];
Expand Down Expand Up @@ -413,7 +446,6 @@ export class Tensor {
data[i] = this_data[originalIndex];
}
return new Tensor(this.type, data, newTensorDims);

}

/**
Expand Down
27 changes: 27 additions & 0 deletions tests/utils/tensor.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,33 @@ describe("Tensor operations", () => {
// TODO add tests for errors
});

describe("slice", () => {
it("should return a given row dim", async () => {
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
const t2 = t1.slice(1);
const target = new Tensor("float32", [3, 4], [2]);

compare(t2, target);
});

it("should return a range of rows", async () => {
const t1 = new Tensor("float32", [1, 2, 3, 4, 5, 6], [3, 2]);
const t2 = t1.slice([1, 3]);
const target = new Tensor("float32", [3, 4, 5, 6], [2, 2]);

compare(t2, target);
});

it("should return a crop", async () => {
const t1 = new Tensor("float32", Array.from({ length: 28 }, (_, i) => i + 1), [4, 7]);
const t2 = t1.slice([1, -1], [1, -1]);

const target = new Tensor("float32", [9, 10, 11, 12, 13, 16, 17, 18, 19, 20], [2, 5]);

compare(t2, target);
});
});

describe("stack", () => {
const t1 = new Tensor("float32", [0, 1, 2, 3, 4, 5], [1, 3, 2]);

Expand Down
23 changes: 23 additions & 0 deletions tests/utils/utils.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,34 @@ describe("Utilities", () => {
});

describe("Image utilities", () => {
const [width, height, channels] = [2, 2, 3];
const data = Uint8Array.from({ length: width * height * channels }, (_, i) => i % 5);
const tiny_image = new RawImage(data, width, height, channels);

let image;
beforeAll(async () => {
image = await RawImage.fromURL("https://picsum.photos/300/200");
});

it("Can split image into separate channels", async () => {
const image_data = tiny_image.split().map(x => x.data);

const target = [
new Uint8Array([0, 3, 1, 4]), // Reds
new Uint8Array([1, 4, 2, 0]), // Greens
new Uint8Array([2, 0, 3, 1]), // Blues
];

compare(image_data, target);
});

it("Can splits channels for grayscale", async () => {
const image_data = tiny_image.grayscale().split().map(x => x.data);
const target = [new Uint8Array([1, 3, 2, 1])];

compare(image_data, target);
});

it("Read image from URL", async () => {
expect(image.width).toBe(300);
expect(image.height).toBe(200);
Expand Down

0 comments on commit 8896dc7

Please sign in to comment.