Skip to content

Commit

Permalink
Add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Dec 6, 2024
1 parent edbf767 commit 4ca3554
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 0 deletions.
36 changes: 36 additions & 0 deletions tests/processors.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ const MODELS = {
florence2: "Xenova/tiny-random-Florence2ForConditionalGeneration",
qwen2_vl: "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration",
idefics3: "hf-internal-testing/tiny-random-Idefics3ForConditionalGeneration",
paligemma: "hf-internal-testing/tiny-random-PaliGemmaForConditionalGeneration",
};

const BASE_URL = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/";
Expand Down Expand Up @@ -1196,5 +1197,40 @@ describe("Processors", () => {
},
MAX_TEST_TIME,
);

describe(
"PaliGemmaProcessor",
() => {
/** @type {import('../src/transformers.js').PaliGemmaProcessor} */
let processor;
let images = {};

beforeAll(async () => {
processor = await AutoProcessor.from_pretrained(MODELS.paligemma);
images = {
white_image: await load_image(TEST_IMAGES.white_image),
};
});

it("Image-only (default text)", async () => {
const { input_ids, pixel_values } = await processor(images.white_image);
compare(input_ids.dims, [1, 258]);
compare(pixel_values.dims, [1, 3, 224, 224]);
});

it("Single image & text", async () => {
const { input_ids, pixel_values } = await processor(images.white_image, "<image>What is on the flower?");
compare(input_ids.dims, [1, 264]);
compare(pixel_values.dims, [1, 3, 224, 224]);
});

it("Multiple images & text", async () => {
const { input_ids, pixel_values } = await processor([images.white_image, images.white_image], "<image><image>Describe the images.");
compare(input_ids.dims, [1, 518]);
compare(pixel_values.dims, [2, 3, 224, 224]);
});
},
MAX_TEST_TIME,
);
});
});
54 changes: 54 additions & 0 deletions tests/tiny_random.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {
Processor,
Florence2Processor,
Idefics3Processor,
PaliGemmaProcessor,

// Models
LlamaForCausalLM,
Expand Down Expand Up @@ -54,6 +55,7 @@ import {
VisionEncoderDecoderModel,
Florence2ForConditionalGeneration,
Qwen2VLForConditionalGeneration,
PaliGemmaForConditionalGeneration,
MarianMTModel,
PatchTSTModel,
PatchTSTForPrediction,
Expand Down Expand Up @@ -1072,6 +1074,58 @@ describe("Tiny random models", () => {
});
});

describe("paligemma", () => {
const text = "<image>What is on the flower?";

// Empty white image
const dims = [224, 224, 3];
const image = new RawImage(new Uint8ClampedArray(dims[0] * dims[1] * dims[2]).fill(255), ...dims);

describe("PaliGemmaForConditionalGeneration", () => {
const model_id = "hf-internal-testing/tiny-random-PaliGemmaForConditionalGeneration";

/** @type {PaliGemmaForConditionalGeneration} */
let model;
/** @type {PaliGemmaProcessor} */
let processor;
beforeAll(async () => {
model = await PaliGemmaForConditionalGeneration.from_pretrained(model_id, {
// TODO move to config
...DEFAULT_MODEL_OPTIONS,
});
processor = await AutoProcessor.from_pretrained(model_id);
}, MAX_MODEL_LOAD_TIME);

it(
"forward",
async () => {
const inputs = await processor(image, text);

const { logits } = await model(inputs);
expect(logits.dims).toEqual([1, 264, 257216]);
expect(logits.mean().item()).toBeCloseTo(-0.0023024685215204954, 6);
},
MAX_TEST_EXECUTION_TIME,
);

it(
"batch_size=1",
async () => {
const inputs = await processor(image, text);
const generate_ids = await model.generate({ ...inputs, max_new_tokens: 10 });

const new_tokens = generate_ids.slice(null, [inputs.input_ids.dims.at(-1), null]);
expect(new_tokens.tolist()).toEqual([[91711n, 24904n, 144054n, 124983n, 83862n, 124983n, 124983n, 124983n, 141236n, 124983n]]);
},
MAX_TEST_EXECUTION_TIME,
);

afterAll(async () => {
await model?.dispose();
}, MAX_MODEL_DISPOSE_TIME);
});
});

describe("vision-encoder-decoder", () => {
describe("VisionEncoderDecoderModel", () => {
const model_id = "hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2";
Expand Down

0 comments on commit 4ca3554

Please sign in to comment.