From 4ca3554b8f73aa899352e0224d2192f89b69de9f Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Fri, 6 Dec 2024 18:21:33 +0000 Subject: [PATCH] Add unit tests --- tests/processors.test.js | 36 ++++++++++++++++++++++++++ tests/tiny_random.test.js | 54 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) diff --git a/tests/processors.test.js b/tests/processors.test.js index a17cd4fc3..8e3133563 100644 --- a/tests/processors.test.js +++ b/tests/processors.test.js @@ -48,6 +48,7 @@ const MODELS = { florence2: "Xenova/tiny-random-Florence2ForConditionalGeneration", qwen2_vl: "hf-internal-testing/tiny-random-Qwen2VLForConditionalGeneration", idefics3: "hf-internal-testing/tiny-random-Idefics3ForConditionalGeneration", + paligemma: "hf-internal-testing/tiny-random-PaliGemmaForConditionalGeneration", }; const BASE_URL = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/"; @@ -1196,5 +1197,40 @@ describe("Processors", () => { }, MAX_TEST_TIME, ); + + describe( + "PaliGemmaProcessor", + () => { + /** @type {import('../src/transformers.js').PaliGemmaProcessor} */ + let processor; + let images = {}; + + beforeAll(async () => { + processor = await AutoProcessor.from_pretrained(MODELS.paligemma); + images = { + white_image: await load_image(TEST_IMAGES.white_image), + }; + }); + + it("Image-only (default text)", async () => { + const { input_ids, pixel_values } = await processor(images.white_image); + compare(input_ids.dims, [1, 258]); + compare(pixel_values.dims, [1, 3, 224, 224]); + }); + + it("Single image & text", async () => { + const { input_ids, pixel_values } = await processor(images.white_image, "What is on the flower?"); + compare(input_ids.dims, [1, 264]); + compare(pixel_values.dims, [1, 3, 224, 224]); + }); + + it("Multiple images & text", async () => { + const { input_ids, pixel_values } = await processor([images.white_image, images.white_image], "Describe the images."); + compare(input_ids.dims, [1, 518]); + compare(pixel_values.dims, [2, 3, 224, 224]); + }); + }, + MAX_TEST_TIME, + ); }); }); diff --git a/tests/tiny_random.test.js b/tests/tiny_random.test.js index 72f0e5c53..d3deb5a09 100644 --- a/tests/tiny_random.test.js +++ b/tests/tiny_random.test.js @@ -20,6 +20,7 @@ import { Processor, Florence2Processor, Idefics3Processor, + PaliGemmaProcessor, // Models LlamaForCausalLM, @@ -54,6 +55,7 @@ import { VisionEncoderDecoderModel, Florence2ForConditionalGeneration, Qwen2VLForConditionalGeneration, + PaliGemmaForConditionalGeneration, MarianMTModel, PatchTSTModel, PatchTSTForPrediction, @@ -1072,6 +1074,58 @@ describe("Tiny random models", () => { }); }); + describe("paligemma", () => { + const text = "What is on the flower?"; + + // Empty white image + const dims = [224, 224, 3]; + const image = new RawImage(new Uint8ClampedArray(dims[0] * dims[1] * dims[2]).fill(255), ...dims); + + describe("PaliGemmaForConditionalGeneration", () => { + const model_id = "hf-internal-testing/tiny-random-PaliGemmaForConditionalGeneration"; + + /** @type {PaliGemmaForConditionalGeneration} */ + let model; + /** @type {PaliGemmaProcessor} */ + let processor; + beforeAll(async () => { + model = await PaliGemmaForConditionalGeneration.from_pretrained(model_id, { + // TODO move to config + ...DEFAULT_MODEL_OPTIONS, + }); + processor = await AutoProcessor.from_pretrained(model_id); + }, MAX_MODEL_LOAD_TIME); + + it( + "forward", + async () => { + const inputs = await processor(image, text); + + const { logits } = await model(inputs); + expect(logits.dims).toEqual([1, 264, 257216]); + expect(logits.mean().item()).toBeCloseTo(-0.0023024685215204954, 6); + }, + MAX_TEST_EXECUTION_TIME, + ); + + it( + "batch_size=1", + async () => { + const inputs = await processor(image, text); + const generate_ids = await model.generate({ ...inputs, max_new_tokens: 10 }); + + const new_tokens = generate_ids.slice(null, [inputs.input_ids.dims.at(-1), null]); + expect(new_tokens.tolist()).toEqual([[91711n, 24904n, 144054n, 124983n, 83862n, 124983n, 124983n, 124983n, 141236n, 124983n]]); + }, + MAX_TEST_EXECUTION_TIME, + ); + + afterAll(async () => { + await model?.dispose(); + }, MAX_MODEL_DISPOSE_TIME); + }); + }); + describe("vision-encoder-decoder", () => { describe("VisionEncoderDecoderModel", () => { const model_id = "hf-internal-testing/tiny-random-VisionEncoderDecoderModel-vit-gpt2";