Skip to content

Commit

Permalink
Refactor per-model unit testing (#1083)
Browse files Browse the repository at this point in the history
* Set up per-model unit tests

* Rename tests

* Do not modify original object when updating model file name

* Distribute unit tests across separate files

* Update comments

* Update tokenization test file names

* Refactor: use asset cache

* Destructuring for code deduplication

* Remove empty file

* Rename deberta-v2 -> deberta_v2

* Rename

* Support casting between number and bigint types

* Use fp32 tiny models

* Move image processing tests to separate folders + auto-detection
  • Loading branch information
xenova authored Dec 11, 2024
1 parent 14bf689 commit effa9a9
Show file tree
Hide file tree
Showing 87 changed files with 3,930 additions and 3,497 deletions.
80 changes: 50 additions & 30 deletions src/models.js
Original file line number Diff line number Diff line change
Expand Up @@ -3666,9 +3666,11 @@ export class CLIPModel extends CLIPPreTrainedModel { }
export class CLIPTextModel extends CLIPPreTrainedModel {
/** @type {typeof PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// Update default model file name if not provided
options.model_file_name ??= 'text_model';
return super.from_pretrained(pretrained_model_name_or_path, options);
return super.from_pretrained(pretrained_model_name_or_path, {
// Update default model file name if not provided
model_file_name: 'text_model',
...options,
});
}
}

Expand Down Expand Up @@ -3701,9 +3703,11 @@ export class CLIPTextModel extends CLIPPreTrainedModel {
export class CLIPTextModelWithProjection extends CLIPPreTrainedModel {
/** @type {typeof PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// Update default model file name if not provided
options.model_file_name ??= 'text_model';
return super.from_pretrained(pretrained_model_name_or_path, options);
return super.from_pretrained(pretrained_model_name_or_path, {
// Update default model file name if not provided
model_file_name: 'text_model',
...options,
});
}
}

Expand All @@ -3713,9 +3717,11 @@ export class CLIPTextModelWithProjection extends CLIPPreTrainedModel {
export class CLIPVisionModel extends CLIPPreTrainedModel {
/** @type {typeof PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// Update default model file name if not provided
options.model_file_name ??= 'vision_model';
return super.from_pretrained(pretrained_model_name_or_path, options);
return super.from_pretrained(pretrained_model_name_or_path, {
// Update default model file name if not provided
model_file_name: 'vision_model',
...options,
});
}
}

Expand Down Expand Up @@ -3748,9 +3754,11 @@ export class CLIPVisionModel extends CLIPPreTrainedModel {
export class CLIPVisionModelWithProjection extends CLIPPreTrainedModel {
/** @type {typeof PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// Update default model file name if not provided
options.model_file_name ??= 'vision_model';
return super.from_pretrained(pretrained_model_name_or_path, options);
return super.from_pretrained(pretrained_model_name_or_path, {
// Update default model file name if not provided
model_file_name: 'vision_model',
...options,
});
}
}
//////////////////////////////////////////////////
Expand Down Expand Up @@ -3834,9 +3842,11 @@ export class SiglipModel extends SiglipPreTrainedModel { }
export class SiglipTextModel extends SiglipPreTrainedModel {
/** @type {typeof PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// Update default model file name if not provided
options.model_file_name ??= 'text_model';
return super.from_pretrained(pretrained_model_name_or_path, options);
return super.from_pretrained(pretrained_model_name_or_path, {
// Update default model file name if not provided
model_file_name: 'text_model',
...options,
});
}
}

Expand Down Expand Up @@ -3869,9 +3879,11 @@ export class SiglipTextModel extends SiglipPreTrainedModel {
export class SiglipVisionModel extends CLIPPreTrainedModel {
/** @type {typeof PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// Update default model file name if not provided
options.model_file_name ??= 'vision_model';
return super.from_pretrained(pretrained_model_name_or_path, options);
return super.from_pretrained(pretrained_model_name_or_path, {
// Update default model file name if not provided
model_file_name: 'vision_model',
...options,
});
}
}
//////////////////////////////////////////////////
Expand Down Expand Up @@ -3926,18 +3938,22 @@ export class JinaCLIPModel extends JinaCLIPPreTrainedModel {
export class JinaCLIPTextModel extends JinaCLIPPreTrainedModel {
/** @type {typeof PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// Update default model file name if not provided
options.model_file_name ??= 'text_model';
return super.from_pretrained(pretrained_model_name_or_path, options);
return super.from_pretrained(pretrained_model_name_or_path, {
// Update default model file name if not provided
model_file_name: 'text_model',
...options,
});
}
}

export class JinaCLIPVisionModel extends JinaCLIPPreTrainedModel {
/** @type {typeof PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// Update default model file name if not provided
options.model_file_name ??= 'vision_model';
return super.from_pretrained(pretrained_model_name_or_path, options);
return super.from_pretrained(pretrained_model_name_or_path, {
// Update default model file name if not provided
model_file_name: 'vision_model',
...options,
});
}
}
//////////////////////////////////////////////////
Expand Down Expand Up @@ -6159,9 +6175,11 @@ export class ClapModel extends ClapPreTrainedModel { }
export class ClapTextModelWithProjection extends ClapPreTrainedModel {
/** @type {typeof PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// Update default model file name if not provided
options.model_file_name ??= 'text_model';
return super.from_pretrained(pretrained_model_name_or_path, options);
return super.from_pretrained(pretrained_model_name_or_path, {
// Update default model file name if not provided
model_file_name: 'text_model',
...options,
});
}
}

Expand Down Expand Up @@ -6194,9 +6212,11 @@ export class ClapTextModelWithProjection extends ClapPreTrainedModel {
export class ClapAudioModelWithProjection extends ClapPreTrainedModel {
/** @type {typeof PreTrainedModel.from_pretrained} */
static async from_pretrained(pretrained_model_name_or_path, options = {}) {
// Update default model file name if not provided
options.model_file_name ??= 'audio_model';
return super.from_pretrained(pretrained_model_name_or_path, options);
return super.from_pretrained(pretrained_model_name_or_path, {
// Update default model file name if not provided
model_file_name: 'audio_model',
...options,
});
}
}
//////////////////////////////////////////////////
Expand Down
15 changes: 14 additions & 1 deletion src/utils/tensor.js
Original file line number Diff line number Diff line change
Expand Up @@ -772,8 +772,21 @@ export class Tensor {
if (!DataTypeMap.hasOwnProperty(type)) {
throw new Error(`Unsupported type: ${type}`);
}

// Handle special cases where a mapping function is needed (e.g., where one type is a bigint and the other is a number)
let map_fn;
const is_source_bigint = ['int64', 'uint64'].includes(this.type);
const is_dest_bigint = ['int64', 'uint64'].includes(type);
if (is_source_bigint && !is_dest_bigint) {
// TypeError: Cannot convert a BigInt value to a number
map_fn = Number;
} else if (!is_source_bigint && is_dest_bigint) {
// TypeError: Cannot convert [x] to a BigInt
map_fn = BigInt;
}

// @ts-ignore
return new Tensor(type, DataTypeMap[type].from(this.data), this.dims);
return new Tensor(type, DataTypeMap[type].from(this.data, map_fn), this.dims);
}
}

Expand Down
43 changes: 43 additions & 0 deletions tests/asset_cache.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import { RawImage } from "../src/transformers.js";

const BASE_URL = "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/";
const TEST_IMAGES = Object.freeze({
white_image: BASE_URL + "white-image.png",
pattern_3x3: BASE_URL + "pattern_3x3.png",
pattern_3x5: BASE_URL + "pattern_3x5.png",
checkerboard_8x8: BASE_URL + "checkerboard_8x8.png",
checkerboard_64x32: BASE_URL + "checkerboard_64x32.png",
gradient_1280x640: BASE_URL + "gradient_1280x640.png",
receipt: BASE_URL + "receipt.png",
tiger: BASE_URL + "tiger.jpg",
paper: BASE_URL + "nougat_paper.png",
cats: BASE_URL + "cats.jpg",

// grayscale image
skateboard: "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/ml-web-games/skateboard.png",

vitmatte_image: BASE_URL + "vitmatte_image.png",
vitmatte_trimap: BASE_URL + "vitmatte_trimap.png",

beetle: BASE_URL + "beetle.png",
book_cover: BASE_URL + "book-cover.png",
});

/** @type {Map<string, RawImage>} */
const IMAGE_CACHE = new Map();
const load_image = async (url) => {
const cached = IMAGE_CACHE.get(url);
if (cached) {
return cached;
}
const image = await RawImage.fromURL(url);
IMAGE_CACHE.set(url, image);
return image;
};

/**
* Load a cached image.
* @param {keyof typeof TEST_IMAGES} name The name of the image to load.
* @returns {Promise<RawImage>} The loaded image.
*/
export const load_cached_image = (name) => load_image(TEST_IMAGES[name]);
58 changes: 58 additions & 0 deletions tests/init.js
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,66 @@ export function init() {
registerBackend("test", onnxruntimeBackend, Number.POSITIVE_INFINITY);
}

export const MAX_PROCESSOR_LOAD_TIME = 10_000; // 10 seconds
export const MAX_MODEL_LOAD_TIME = 15_000; // 15 seconds
export const MAX_TEST_EXECUTION_TIME = 60_000; // 60 seconds
export const MAX_MODEL_DISPOSE_TIME = 1_000; // 1 second

export const MAX_TEST_TIME = MAX_MODEL_LOAD_TIME + MAX_TEST_EXECUTION_TIME + MAX_MODEL_DISPOSE_TIME;

export const DEFAULT_MODEL_OPTIONS = {
dtype: "fp32",
};

expect.extend({
toBeCloseToNested(received, expected, numDigits = 2) {
const compare = (received, expected, path = "") => {
if (typeof received === "number" && typeof expected === "number" && !Number.isInteger(received) && !Number.isInteger(expected)) {
const pass = Math.abs(received - expected) < Math.pow(10, -numDigits);
return {
pass,
message: () => (pass ? `✓ At path '${path}': expected ${received} not to be close to ${expected} with tolerance of ${numDigits} decimal places` : `✗ At path '${path}': expected ${received} to be close to ${expected} with tolerance of ${numDigits} decimal places`),
};
} else if (Array.isArray(received) && Array.isArray(expected)) {
if (received.length !== expected.length) {
return {
pass: false,
message: () => `✗ At path '${path}': array lengths differ. Received length ${received.length}, expected length ${expected.length}`,
};
}
for (let i = 0; i < received.length; i++) {
const result = compare(received[i], expected[i], `${path}[${i}]`);
if (!result.pass) return result;
}
} else if (typeof received === "object" && typeof expected === "object" && received !== null && expected !== null) {
const receivedKeys = Object.keys(received);
const expectedKeys = Object.keys(expected);
if (receivedKeys.length !== expectedKeys.length) {
return {
pass: false,
message: () => `✗ At path '${path}': object keys length differ. Received keys: ${JSON.stringify(receivedKeys)}, expected keys: ${JSON.stringify(expectedKeys)}`,
};
}
for (const key of receivedKeys) {
if (!expected.hasOwnProperty(key)) {
return {
pass: false,
message: () => `✗ At path '${path}': key '${key}' found in received but not in expected`,
};
}
const result = compare(received[key], expected[key], `${path}.${key}`);
if (!result.pass) return result;
}
} else {
const pass = received === expected;
return {
pass,
message: () => (pass ? `✓ At path '${path}': expected ${JSON.stringify(received)} not to equal ${JSON.stringify(expected)}` : `✗ At path '${path}': expected ${JSON.stringify(received)} to equal ${JSON.stringify(expected)}`),
};
}
return { pass: true };
};

return compare(received, expected);
},
});
Loading

0 comments on commit effa9a9

Please sign in to comment.