From b8719b12dd7b61f99340a7d27e32f105538daf2e Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sun, 19 Nov 2023 08:06:49 +0200 Subject: [PATCH] Ensure WASM fallback does not crash in GH actions (#402) * Ensure WASM fallback does not crash in GH actions * Add unit test for WordPiece `max_input_chars_per_word` * Cleanup * Set max test concurrency to 1 --- package.json | 2 +- tests/generation.test.js | 4 ++-- tests/init.js | 6 ++++++ tests/tensor.test.js | 5 +---- tests/tokenizers.test.js | 13 +++++++++++-- 5 files changed, 21 insertions(+), 9 deletions(-) diff --git a/package.json b/package.json index ce18ca9f3..b6262671c 100644 --- a/package.json +++ b/package.json @@ -10,7 +10,7 @@ "dev": "webpack serve --no-client-overlay", "build": "webpack && npm run typegen", "generate-tests": "python -m tests.generate_tests", - "test": "node --experimental-vm-modules node_modules/jest/bin/jest.js --verbose", + "test": "node --experimental-vm-modules node_modules/jest/bin/jest.js --verbose --maxConcurrency 1", "readme": "python ./docs/scripts/build_readme.py", "docs-api": "node ./docs/scripts/generate.js", "docs-preview": "doc-builder preview transformers.js ./docs/source/ --not_python_module", diff --git a/tests/generation.test.js b/tests/generation.test.js index 2effaf94d..eb6b87f49 100644 --- a/tests/generation.test.js +++ b/tests/generation.test.js @@ -9,8 +9,8 @@ describe('Generation parameters', () => { // List all models which will be tested const models = [ - 'Xenova/LaMini-Flan-T5-77M', // encoder-decoder - 'Xenova/LaMini-GPT-124M', // decoder-only + 'MBZUAI/LaMini-Flan-T5-77M', // encoder-decoder + 'MBZUAI/LaMini-GPT-124M', // decoder-only ]; // encoder-decoder model diff --git a/tests/init.js b/tests/init.js index 6eb3c9a12..b01fe1000 100644 --- a/tests/init.js +++ b/tests/init.js @@ -9,6 +9,12 @@ import { onnxruntimeBackend } from "onnxruntime-node/dist/backend"; import ONNX_COMMON from "onnxruntime-common"; export function init() { + // In rare cases (specifically when running unit tests with GitHub actions), possibly due to + // a large number of concurrent executions, onnxruntime might fallback to use the WASM backend. + // In this case, we set the number of threads to 1 to avoid errors like: + // - `TypeError: The worker script or module filename must be an absolute path or a relative path starting with './' or '../'. Received "blob:nodedata:..."` + ONNX_COMMON.env.wasm.numThreads = 1; + // A workaround to define a new backend for onnxruntime, which // will not throw an error when running tests with jest. // For more information, see: https://github.com/jestjs/jest/issues/11864#issuecomment-1261468011 diff --git a/tests/tensor.test.js b/tests/tensor.test.js index c2dc3374c..0d328a9ef 100644 --- a/tests/tensor.test.js +++ b/tests/tensor.test.js @@ -1,9 +1,6 @@ -import { AutoProcessor, Tensor } from '../src/transformers.js'; - -import { MAX_TEST_EXECUTION_TIME, m } from './init.js'; +import { Tensor } from '../src/transformers.js'; import { compare } from './test_utils.js'; - import { cat, mean, stack } from '../src/utils/tensor.js'; describe('Tensor operations', () => { diff --git a/tests/tokenizers.test.js b/tests/tokenizers.test.js index 0f9e1a57c..27a7e5f9b 100644 --- a/tests/tokenizers.test.js +++ b/tests/tokenizers.test.js @@ -3,6 +3,7 @@ import { AutoTokenizer } from '../src/transformers.js'; import { getFile } from '../src/utils/hub.js'; import { m, MAX_TEST_EXECUTION_TIME } from './init.js'; +import { compare } from './test_utils.js'; // Load test data generated by the python tests // TODO do this dynamically? @@ -41,10 +42,18 @@ describe('Tokenizers', () => { describe('Edge cases', () => { it('should not crash when encoding a very long string', async () => { - let tokenizer = await AutoTokenizer.from_pretrained('t5-small'); + let tokenizer = await AutoTokenizer.from_pretrained('Xenova/t5-small'); let text = String.prototype.repeat.call('Hello world! ', 50000); - let encoded = await tokenizer(text); + let encoded = tokenizer(text); expect(encoded.input_ids.data.length).toBeGreaterThan(100000); }, MAX_TEST_EXECUTION_TIME); + + it('should not take too long', async () => { + let tokenizer = await AutoTokenizer.from_pretrained('Xenova/all-MiniLM-L6-v2'); + + let text = String.prototype.repeat.call('a', 50000); + let token_ids = tokenizer.encode(text); + compare(token_ids, [101, 100, 102]) + }, 5000); // NOTE: 5 seconds });