Skip to content

Commit

Permalink
Add image segmentation pipeline unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Dec 13, 2024
1 parent 02ab6cb commit 95ec71c
Show file tree
Hide file tree
Showing 2 changed files with 117 additions and 1 deletion.
3 changes: 2 additions & 1 deletion tests/asset_cache.js
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@ const TEST_IMAGES = Object.freeze({
beetle: BASE_URL + "beetle.png",
book_cover: BASE_URL + "book-cover.png",
corgi: BASE_URL + "corgi.jpg",
man_on_car: BASE_URL + "young-man-standing-and-leaning-on-car.jpg",
});

const TEST_AUDIOS = {
mlk: "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/mlk.npy",
mlk: BASE_URL + "mlk.npy",
};

/** @type {Map<string, RawImage>} */
Expand Down
115 changes: 115 additions & 0 deletions tests/pipelines/test_pipelines_image_segmentation.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import { pipeline, ImageSegmentationPipeline } from "../../src/transformers.js";

import { MAX_MODEL_LOAD_TIME, MAX_TEST_EXECUTION_TIME, MAX_MODEL_DISPOSE_TIME, DEFAULT_MODEL_OPTIONS } from "../init.js";
import { load_cached_image } from "../asset_cache.js";

const PIPELINE_ID = "image-segmentation";

export default () => {
describe("Image Segmentation", () => {
describe("Panoptic Segmentation", () => {
const model_id = "Xenova/detr-resnet-50-panoptic";
/** @type {ImageSegmentationPipeline } */
let pipe;
beforeAll(async () => {
pipe = await pipeline(PIPELINE_ID, model_id, DEFAULT_MODEL_OPTIONS);
}, MAX_MODEL_LOAD_TIME);

it(
"single",
async () => {
const image = await load_cached_image("cats");

const output = await pipe(image);

// First, check mask shapes
for (const item of output) {
expect(item.mask.width).toEqual(image.width);
expect(item.mask.height).toEqual(image.height);
expect(item.mask.channels).toEqual(1);
delete item.mask; // No longer needed
}

// Next, compare scores and labels
const target = [
{
score: 0.9918501377105713,
label: "cat",
},
{
score: 0.9985815286636353,
label: "remote",
},
{
score: 0.999537467956543,
label: "remote",
},
{
score: 0.9919270277023315,
label: "couch",
},
{
score: 0.9993696808815002,
label: "cat",
},
];

expect(output).toBeCloseToNested(target, 2);
},
MAX_TEST_EXECUTION_TIME,
);

afterAll(async () => {
await pipe.dispose();
}, MAX_MODEL_DISPOSE_TIME);
});

describe("Semantic Segmentation", () => {
const model_id = "Xenova/segformer_b0_clothes";
/** @type {ImageSegmentationPipeline } */
let pipe;
beforeAll(async () => {
pipe = await pipeline(PIPELINE_ID, model_id, DEFAULT_MODEL_OPTIONS);
}, MAX_MODEL_LOAD_TIME);

it(
"single",
async () => {
const image = await load_cached_image("man_on_car");

const output = await pipe(image);

// First, check mask shapes
for (const item of output) {
expect(item.mask.width).toEqual(image.width);
expect(item.mask.height).toEqual(image.height);
expect(item.mask.channels).toEqual(1);
delete item.mask; // No longer needed
}

// Next, compare scores and labels
const target = [
{ score: null, label: "Background" },
{ score: null, label: "Hair" },
{ score: null, label: "Upper-clothes" },
{ score: null, label: "Pants" },
{ score: null, label: "Left-shoe" },
{ score: null, label: "Right-shoe" },
{ score: null, label: "Face" },
{ score: null, label: "Right-leg" },
{ score: null, label: "Left-arm" },
{ score: null, label: "Right-arm" },
{ score: null, label: "Bag" },
];

expect(output).toBeCloseToNested(target, 2);
},
MAX_TEST_EXECUTION_TIME,
);

afterAll(async () => {
await pipe.dispose();
}, MAX_MODEL_DISPOSE_TIME);
});
});
};

0 comments on commit 95ec71c

Please sign in to comment.